BhavyaSamhithaMallineni commited on
Commit
d65c2ac
·
verified ·
1 Parent(s): 9e6784b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +75 -0
train.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSeq2SeqLM,
5
+ Seq2SeqTrainer,
6
+ Seq2SeqTrainingArguments,
7
+ DataCollatorForSeq2Seq,
8
+ )
9
+ import torch
10
+
11
+ # 1. Load dataset
12
+ dataset = load_dataset("rohitsaxena/MovieSum")
13
+
14
+ # Rename columns if needed
15
+ dataset = dataset.rename_columns({"script": "input_text", "summary": "target_text"})
16
+
17
+ # 2. Load model and tokenizer
18
+ model_checkpoint = "facebook/bart-large-cnn"
19
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
21
+
22
+ # 3. Preprocessing
23
+ def preprocess_function(examples):
24
+ inputs = tokenizer(
25
+ examples["input_text"],
26
+ max_length=1024,
27
+ padding="max_length",
28
+ truncation=True,
29
+ )
30
+ with tokenizer.as_target_tokenizer():
31
+ labels = tokenizer(
32
+ examples["target_text"],
33
+ max_length=128,
34
+ padding="max_length",
35
+ truncation=True,
36
+ )
37
+ inputs["labels"] = labels["input_ids"]
38
+ return inputs
39
+
40
+ tokenized_dataset = dataset.map(preprocess_function, batched=True)
41
+
42
+ # 4. Training arguments
43
+ training_args = Seq2SeqTrainingArguments(
44
+ output_dir="./film-script-summarizer",
45
+ evaluation_strategy="epoch",
46
+ learning_rate=2e-5,
47
+ per_device_train_batch_size=2,
48
+ per_device_eval_batch_size=2,
49
+ num_train_epochs=3,
50
+ weight_decay=0.01,
51
+ save_total_limit=2,
52
+ push_to_hub=True,
53
+ hub_model_id="BhavyaSamhithaMallineni/FilmScriptSummarizer",
54
+ hub_strategy="every_save",
55
+ logging_dir="./logs",
56
+ logging_steps=50,
57
+ fp16=torch.cuda.is_available(),
58
+ )
59
+
60
+ # 5. Trainer setup
61
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
62
+
63
+ trainer = Seq2SeqTrainer(
64
+ model=model,
65
+ args=training_args,
66
+ train_dataset=tokenized_dataset["train"],
67
+ eval_dataset=tokenized_dataset["test"],
68
+ tokenizer=tokenizer,
69
+ data_collator=data_collator,
70
+ )
71
+
72
+ # 6. Train and push to hub
73
+ trainer.train()
74
+ trainer.push_to_hub()
75
+ tokenizer.push_to_hub("BhavyaSamhithaMallineni/FilmScriptSummarizer")