FilmScriptSummary / train.py
BhavyaSamhithaMallineni's picture
Create train.py
d65c2ac verified
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
DataCollatorForSeq2Seq,
)
import torch
# 1. Load dataset
dataset = load_dataset("rohitsaxena/MovieSum")
# Rename columns if needed
dataset = dataset.rename_columns({"script": "input_text", "summary": "target_text"})
# 2. Load model and tokenizer
model_checkpoint = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# 3. Preprocessing
def preprocess_function(examples):
inputs = tokenizer(
examples["input_text"],
max_length=1024,
padding="max_length",
truncation=True,
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples["target_text"],
max_length=128,
padding="max_length",
truncation=True,
)
inputs["labels"] = labels["input_ids"]
return inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# 4. Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="./film-script-summarizer",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_train_epochs=3,
weight_decay=0.01,
save_total_limit=2,
push_to_hub=True,
hub_model_id="BhavyaSamhithaMallineni/FilmScriptSummarizer",
hub_strategy="every_save",
logging_dir="./logs",
logging_steps=50,
fp16=torch.cuda.is_available(),
)
# 5. Trainer setup
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
tokenizer=tokenizer,
data_collator=data_collator,
)
# 6. Train and push to hub
trainer.train()
trainer.push_to_hub()
tokenizer.push_to_hub("BhavyaSamhithaMallineni/FilmScriptSummarizer")