Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| import torch | |
| # 1. Load dataset | |
| dataset = load_dataset("rohitsaxena/MovieSum") | |
| # Rename columns if needed | |
| dataset = dataset.rename_columns({"script": "input_text", "summary": "target_text"}) | |
| # 2. Load model and tokenizer | |
| model_checkpoint = "facebook/bart-large-cnn" | |
| tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) | |
| # 3. Preprocessing | |
| def preprocess_function(examples): | |
| inputs = tokenizer( | |
| examples["input_text"], | |
| max_length=1024, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| with tokenizer.as_target_tokenizer(): | |
| labels = tokenizer( | |
| examples["target_text"], | |
| max_length=128, | |
| padding="max_length", | |
| truncation=True, | |
| ) | |
| inputs["labels"] = labels["input_ids"] | |
| return inputs | |
| tokenized_dataset = dataset.map(preprocess_function, batched=True) | |
| # 4. Training arguments | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir="./film-script-summarizer", | |
| evaluation_strategy="epoch", | |
| learning_rate=2e-5, | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=2, | |
| num_train_epochs=3, | |
| weight_decay=0.01, | |
| save_total_limit=2, | |
| push_to_hub=True, | |
| hub_model_id="BhavyaSamhithaMallineni/FilmScriptSummarizer", | |
| hub_strategy="every_save", | |
| logging_dir="./logs", | |
| logging_steps=50, | |
| fp16=torch.cuda.is_available(), | |
| ) | |
| # 5. Trainer setup | |
| data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset["train"], | |
| eval_dataset=tokenized_dataset["test"], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| ) | |
| # 6. Train and push to hub | |
| trainer.train() | |
| trainer.push_to_hub() | |
| tokenizer.push_to_hub("BhavyaSamhithaMallineni/FilmScriptSummarizer") | |