DSDUDEd commited on
Commit
31e0539
Β·
verified Β·
1 Parent(s): 1833ee7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ Trainer,
6
+ TrainingArguments,
7
+ DataCollatorForSeq2Seq,
8
+ )
9
+ from datasets import load_dataset, Dataset
10
+ import random
11
+
12
+ # -----------------------------
13
+ # Load Base Model
14
+ # -----------------------------
15
+ model_name = "PerceptronAI/Isaac-0.1"
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
18
+
19
+ # -----------------------------
20
+ # Load Datasets
21
+ # -----------------------------
22
+ print("πŸ“₯ Loading datasets...")
23
+
24
+ pii_ds = load_dataset("ai4privacy/pii-masking-300k")
25
+ cnn_ds = load_dataset("abisee/cnn_dailymail", "1.0.0")
26
+
27
+ try:
28
+ docqa_ds = load_dataset("vidore/syntheticDocQA_energy_train")
29
+ except Exception as e:
30
+ print("⚠️ Skipping docQA dataset (requires login):", e)
31
+ docqa_ds = None
32
+
33
+ # -----------------------------
34
+ # Build Training Samples
35
+ # -----------------------------
36
+ def make_pairs_pii(example):
37
+ return {"input": example["text"], "output": example["masked_text"]}
38
+
39
+ def make_pairs_cnn(example):
40
+ return {"input": example["article"], "output": example["highlights"]}
41
+
42
+ pii_pairs = pii_ds["train"].map(make_pairs_pii).select(range(1000)) # small subset
43
+ cnn_pairs = cnn_ds["train"].map(make_pairs_cnn).select(range(1000))
44
+
45
+ pairs = []
46
+ pairs.extend(pii_pairs)
47
+ pairs.extend(cnn_pairs)
48
+
49
+ if docqa_ds is not None:
50
+ def make_pairs_docqa(example):
51
+ return {"input": example["question"], "output": example["answer"]}
52
+ docqa_pairs = docqa_ds["train"].map(make_pairs_docqa).select(range(1000))
53
+ pairs.extend(docqa_pairs)
54
+
55
+ dataset = Dataset.from_list(pairs)
56
+
57
+ # -----------------------------
58
+ # Tokenization
59
+ # -----------------------------
60
+ def tokenize(batch):
61
+ inputs = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=256)
62
+ outputs = tokenizer(batch["output"], truncation=True, padding="max_length", max_length=256)
63
+ inputs["labels"] = outputs["input_ids"]
64
+ return inputs
65
+
66
+ tokenized_dataset = dataset.map(tokenize, batched=True)
67
+
68
+ # -----------------------------
69
+ # Training
70
+ # -----------------------------
71
+ training_args = TrainingArguments(
72
+ output_dir="./cass2.0",
73
+ overwrite_output_dir=True,
74
+ num_train_epochs=1,
75
+ per_device_train_batch_size=2,
76
+ save_steps=100,
77
+ save_total_limit=2,
78
+ logging_steps=20,
79
+ learning_rate=5e-5,
80
+ fp16=True,
81
+ )
82
+
83
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
84
+
85
+ trainer = Trainer(
86
+ model=model,
87
+ args=training_args,
88
+ train_dataset=tokenized_dataset,
89
+ tokenizer=tokenizer,
90
+ data_collator=data_collator,
91
+ )
92
+
93
+ print("πŸš€ Training Cass2.0...")
94
+ trainer.train()
95
+ print("βœ… Training complete!")
96
+
97
+ # -----------------------------
98
+ # Simple Chat UI
99
+ # -----------------------------
100
+ from transformers import pipeline
101
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
102
+
103
+ def chat(message, history):
104
+ prompt = "".join([f"User: {m[0]}\nCass2.0: {m[1]}\n" for m in history])
105
+ prompt += f"User: {message}\nCass2.0:"
106
+ output = pipe(prompt, max_length=256, do_sample=True, temperature=0.7)[0]["generated_text"]
107
+ reply = output.split("Cass2.0:")[-1].strip()
108
+ history.append((message, reply))
109
+ return history, history
110
+
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# πŸ€– Cass2.0 β€” Trained AI Assistant")
113
+ chatbot = gr.Chatbot()
114
+ msg = gr.Textbox(label="Type your message")
115
+ clear = gr.Button("Clear")
116
+
117
+ msg.submit(chat, [msg, chatbot], [chatbot, chatbot])
118
+ clear.click(lambda: None, None, chatbot)
119
+
120
+ demo.launch()