Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from main import tokenizer, model, device | |
| import torch | |
| import pandas as pd | |
| # Загружаем данные из CSV файла | |
| df = pd.read_csv("QazSynt_train.csv") | |
| def get_random_row(): | |
| random_row = df.sample(n=1) | |
| return random_row.iloc[0] | |
| def qa_pipeline(text, question): | |
| # Подготовка входных данных для модели | |
| inputs = tokenizer(question, text, return_tensors="pt") | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| batch = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| } | |
| # Выполнение предсказания | |
| start_logits, end_logits, loss = model(batch) | |
| start_index = torch.argmax(start_logits, dim=-1).item() | |
| end_index = torch.argmax(end_logits, dim=-1).item() | |
| # Нахождение индексов начала и конца ответа | |
| start_index = torch.argmax(start_logits, dim=-1).item() | |
| end_index = torch.argmax(end_logits, dim=-1).item() | |
| # Извлечение и декодирование предсказанных токенов ответа | |
| predict_answer_tokens = input_ids[0, start_index : end_index + 1] | |
| return tokenizer.decode(predict_answer_tokens) | |
| def answer_question(context, question): | |
| result = qa_pipeline(context, question) | |
| return result | |
| def get_random_example(): | |
| random_row = get_random_row() | |
| context = random_row['context'] | |
| question = random_row['question'] | |
| real_answer = random_row['answer'] | |
| predicted_answer = answer_question(context, question) | |
| return context, question, real_answer, predicted_answer | |
| # Интерфейс Gradio | |
| with gr.Blocks() as iface: | |
| with gr.Row(): | |
| with gr.Column(): | |
| context = gr.Textbox(lines=10, label="Context") | |
| question = gr.Textbox(lines=2, label="Question") | |
| real_answer = gr.Textbox(lines=2, label="Real Answer") | |
| with gr.Column(): | |
| predicted_answer = gr.Textbox(lines=2, label="Predicted Answer") | |
| generate_button = gr.Button("Get Random Example") | |
| def update_example(): | |
| context_val, question_val, real_answer_val, predicted_answer_val = get_random_example() | |
| return context_val, question_val, real_answer_val, predicted_answer_val | |
| generate_button.click(update_example, outputs=[context, question, real_answer, predicted_answer]) | |
| iface.launch() | |