Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| model_path = "LukeJacob2023/religion-classifier" | |
| # 分类名称 | |
| labels = ["基督教", "佛教", "无信仰"] | |
| # 1. 加载tokenizer和模型 | |
| tokenizer = BertTokenizer.from_pretrained(model_path) | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| # 确保模型在评估模式 | |
| model.eval() | |
| def predict(text): | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probabilities = torch.nn.functional.softmax(outputs.logits / 5.0, dim=-1)[0] | |
| return {label: float(prob) for label, prob in zip(labels, probabilities)} | |
| # 创建Gradio接口 | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=2, label="Input Text"), | |
| outputs=gr.Label(num_top_classes=3, label="Predictions"), | |
| title="Religion Classification", | |
| description="请输入内容(繁体中文)" | |
| ) | |
| # 启动Gradio应用 | |
| iface.launch() | |