|
|
import os |
|
|
from datasets import load_dataset, Dataset |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer |
|
|
import evaluate |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pii_ds = load_dataset("ai4privacy/pii-masking-300k") |
|
|
cnn_ds = load_dataset("abisee/cnn_dailymail", "1.0.0") |
|
|
try: |
|
|
docqa_ds = load_dataset("vidore/syntheticDocQA_energy_train") |
|
|
except Exception as e: |
|
|
print("⚠️ Skipping docQA dataset (requires login):", e) |
|
|
docqa_ds = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pairs = [] |
|
|
|
|
|
def safe_map(dataset, input_keys, output_keys, name, limit=1000): |
|
|
""" |
|
|
dataset: Hugging Face dataset split |
|
|
input_keys: list of possible input column names |
|
|
output_keys: list of possible output column names |
|
|
name: dataset name (for logs) |
|
|
limit: number of samples to select |
|
|
""" |
|
|
available = dataset.column_names |
|
|
chosen_in = next((k for k in input_keys if k in available), None) |
|
|
chosen_out = next((k for k in output_keys if k in available), None) |
|
|
|
|
|
if not chosen_in or not chosen_out: |
|
|
print(f"⚠️ Skipping {name} (no matching columns). Available: {available}") |
|
|
return [] |
|
|
|
|
|
print(f"✅ Using {name}: input='{chosen_in}', output='{chosen_out}'") |
|
|
|
|
|
def make_pairs(example): |
|
|
return {"input": example[chosen_in], "output": example[chosen_out]} |
|
|
|
|
|
return dataset.map(make_pairs).select(range(min(limit, len(dataset)))) |
|
|
|
|
|
pii_pairs = safe_map(pii_ds["train"], ["original", "text"], ["masked", "masked_text"], "PII") |
|
|
cnn_pairs = safe_map(cnn_ds["train"], ["article"], ["highlights", "summary"], "CNN/DailyMail") |
|
|
|
|
|
if docqa_ds is not None: |
|
|
docqa_pairs = safe_map(docqa_ds["train"], ["question"], ["answer"], "DocQA") |
|
|
else: |
|
|
docqa_pairs = [] |
|
|
|
|
|
pairs.extend(pii_pairs) |
|
|
pairs.extend(cnn_pairs) |
|
|
pairs.extend(docqa_pairs) |
|
|
|
|
|
dataset = Dataset.from_list(pairs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "google/flan-t5-small" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
def tokenize_function(example): |
|
|
model_inputs = tokenizer(example["input"], max_length=512, truncation=True) |
|
|
labels = tokenizer(example["output"], max_length=128, truncation=True) |
|
|
model_inputs["labels"] = labels["input_ids"] |
|
|
return model_inputs |
|
|
|
|
|
tokenized_datasets = dataset.map(tokenize_function, batched=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric = evaluate.load("rouge") |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
predictions, labels = eval_pred |
|
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) |
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
|
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) |
|
|
return {k: round(v * 100, 4) for k, v in result.items()} |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
eval_strategy="no", |
|
|
learning_rate=2e-5, |
|
|
per_device_train_batch_size=8, |
|
|
num_train_epochs=1, |
|
|
weight_decay=0.01, |
|
|
logging_dir="./logs", |
|
|
logging_steps=10, |
|
|
save_strategy="no" |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_datasets, |
|
|
eval_dataset=None, |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response(input_text): |
|
|
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512) |
|
|
outputs = model.generate(**inputs, max_new_tokens=128) |
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
demo = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Cass 2.0 Model") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|