CASS2.0 / app.py
DSDUDEd's picture
Update app.py
1959595 verified
import os
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
import evaluate
import numpy as np
import gradio as gr
# -----------------------------
# Load Datasets
# -----------------------------
pii_ds = load_dataset("ai4privacy/pii-masking-300k")
cnn_ds = load_dataset("abisee/cnn_dailymail", "1.0.0")
try:
docqa_ds = load_dataset("vidore/syntheticDocQA_energy_train")
except Exception as e:
print("⚠️ Skipping docQA dataset (requires login):", e)
docqa_ds = None
# -----------------------------
# Build Pairs from Datasets (Safe Version)
# -----------------------------
pairs = []
def safe_map(dataset, input_keys, output_keys, name, limit=1000):
"""
dataset: Hugging Face dataset split
input_keys: list of possible input column names
output_keys: list of possible output column names
name: dataset name (for logs)
limit: number of samples to select
"""
available = dataset.column_names
chosen_in = next((k for k in input_keys if k in available), None)
chosen_out = next((k for k in output_keys if k in available), None)
if not chosen_in or not chosen_out:
print(f"⚠️ Skipping {name} (no matching columns). Available: {available}")
return []
print(f"✅ Using {name}: input='{chosen_in}', output='{chosen_out}'")
def make_pairs(example):
return {"input": example[chosen_in], "output": example[chosen_out]}
return dataset.map(make_pairs).select(range(min(limit, len(dataset))))
pii_pairs = safe_map(pii_ds["train"], ["original", "text"], ["masked", "masked_text"], "PII")
cnn_pairs = safe_map(cnn_ds["train"], ["article"], ["highlights", "summary"], "CNN/DailyMail")
if docqa_ds is not None:
docqa_pairs = safe_map(docqa_ds["train"], ["question"], ["answer"], "DocQA")
else:
docqa_pairs = []
pairs.extend(pii_pairs)
pairs.extend(cnn_pairs)
pairs.extend(docqa_pairs)
dataset = Dataset.from_list(pairs)
# -----------------------------
# Model + Tokenizer
# -----------------------------
model_name = "google/flan-t5-small" # small, fast model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def tokenize_function(example):
model_inputs = tokenizer(example["input"], max_length=512, truncation=True)
labels = tokenizer(example["output"], max_length=128, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# -----------------------------
# Training
# -----------------------------
metric = evaluate.load("rouge")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
return {k: round(v * 100, 4) for k, v in result.items()}
training_args = TrainingArguments(
output_dir="./results",
eval_strategy="no",
learning_rate=2e-5,
per_device_train_batch_size=8,
num_train_epochs=1,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
save_strategy="no"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
eval_dataset=None,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
# -----------------------------
# Gradio App
# -----------------------------
def generate_response(input_text):
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(**inputs, max_new_tokens=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
demo = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Cass 2.0 Model")
if __name__ == "__main__":
demo.launch()