Hayme commited on
Commit
4932f40
Β·
1 Parent(s): e88dc29

Improve model loading with better error handling and SafeTensors support

Browse files
Files changed (1) hide show
  1. app.py +162 -69
app.py CHANGED
@@ -12,81 +12,87 @@ import uvicorn
12
  # Initialize FastAPI
13
  app = FastAPI()
14
 
15
- # Download model files
16
  model_name = "Hayme/agrisago-bert"
17
- print("Downloading model files...")
18
 
19
  try:
20
- # Download model files
21
- model_dir = "./model"
22
- os.makedirs(model_dir, exist_ok=True)
23
 
24
- # Download necessary files
25
- files_to_download = [
26
- "config.json",
27
- "pytorch_model.bin",
28
- "tokenizer_config.json",
29
- "tokenizer.json",
30
- "vocab.txt"
31
- ]
32
-
33
- for file in files_to_download:
34
- try:
35
- local_path = hf_hub_download(
36
- repo_id=model_name,
37
- filename=file,
38
- local_dir=model_dir
39
- )
40
- print(f"Downloaded {file}")
41
- except Exception as e:
42
- print(f"Warning: Could not download {file}: {e}")
43
-
44
- # Load tokenizer and model
45
- print("Loading tokenizer and model...")
46
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
47
- model = AutoModel.from_pretrained(model_dir)
48
  model.eval()
49
-
50
- print("Model loaded successfully!")
51
  model_loaded = True
52
 
 
 
 
 
53
  except Exception as e:
54
- print(f"Error loading model: {e}")
55
- print("Using fallback dummy model for testing")
56
  tokenizer = None
57
  model = None
58
  model_loaded = False
59
 
60
  def get_bert_embedding(text):
61
  """Get BERT embedding for text"""
62
- if not model_loaded:
 
63
  # Return a dummy embedding for testing
64
- return [0.1] * 768
 
65
 
66
  try:
 
 
67
  # Tokenize and encode
68
- inputs = tokenizer(text,
69
- return_tensors="pt",
70
- truncation=True,
71
- padding=True,
72
- max_length=512)
 
 
 
 
 
 
73
 
74
  # Get embeddings
75
  with torch.no_grad():
76
  outputs = model(**inputs)
77
  # Use [CLS] token embedding (first token)
78
- embedding = outputs.last_hidden_state[:, 0, :].squeeze().numpy()
79
 
 
 
 
 
 
 
80
  return embedding.tolist()
81
 
82
  except Exception as e:
83
- print(f"Error in get_bert_embedding: {e}")
84
  # Return dummy embedding on error
85
- return [0.1] * 768
 
86
 
87
  def calculate_similarity(text1, text2):
88
  """Calculate cosine similarity between two texts"""
89
  try:
 
 
90
  # Get embeddings
91
  emb1 = np.array(get_bert_embedding(text1))
92
  emb2 = np.array(get_bert_embedding(text2))
@@ -100,10 +106,12 @@ def calculate_similarity(text1, text2):
100
  return 0.0
101
 
102
  similarity = dot_product / (norm1 * norm2)
103
- return float(similarity)
 
 
104
 
105
  except Exception as e:
106
- print(f"Error calculating similarity: {e}")
107
  return 0.0
108
 
109
  # FastAPI endpoints
@@ -119,8 +127,14 @@ async def get_embedding_endpoint(request: TextRequest):
119
  """Get embedding for text"""
120
  try:
121
  embedding = get_bert_embedding(request.text)
122
- return {"embedding": embedding, "success": True}
 
 
 
 
 
123
  except Exception as e:
 
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
126
  @app.post("/similarity")
@@ -128,15 +142,36 @@ async def get_similarity_endpoint(request: SimilarityRequest):
128
  """Get similarity between two texts"""
129
  try:
130
  similarity = calculate_similarity(request.text1, request.text2)
131
- return {"similarity": similarity, "success": True}
 
 
 
 
132
  except Exception as e:
 
133
  raise HTTPException(status_code=500, detail=str(e))
134
 
135
  @app.get("/")
136
  async def root():
137
- return {"message": "AgriSagot BERT Model API", "status": "running", "model_loaded": model_loaded}
 
 
 
 
 
 
 
138
 
139
- # Gradio interface
 
 
 
 
 
 
 
 
 
140
  def gradio_embedding(text):
141
  """Gradio interface for embeddings"""
142
  if not text.strip():
@@ -144,7 +179,8 @@ def gradio_embedding(text):
144
 
145
  try:
146
  embedding = get_bert_embedding(text)
147
- return f"Generated embedding vector of length {len(embedding)}\nFirst 10 values: {embedding[:10]}"
 
148
  except Exception as e:
149
  return f"Error: {str(e)}"
150
 
@@ -155,40 +191,97 @@ def gradio_similarity(text1, text2):
155
 
156
  try:
157
  similarity = calculate_similarity(text1, text2)
158
- return f"Similarity score: {similarity:.4f}"
 
159
  except Exception as e:
160
  return f"Error: {str(e)}"
161
 
162
  # Create Gradio interface
163
- with gr.Blocks(title="AgriSagot BERT Model") as demo:
164
- gr.Markdown("# AgriSagot BERT Model")
165
- gr.Markdown("Agricultural text processing with BERT embeddings")
166
 
167
- if not model_loaded:
168
- gr.Markdown("⚠️ **Warning**: Model not loaded properly. Using dummy responses for testing.")
 
169
  else:
170
- gr.Markdown("βœ… Model loaded successfully!")
 
 
 
171
 
172
- with gr.Tab("Text Embedding"):
173
- text_input = gr.Textbox(label="Enter text", placeholder="e.g., 'Cabbage fungal treatment'")
174
- embedding_output = gr.Textbox(label="Embedding Info", lines=3)
175
- embedding_btn = gr.Button("Get Embedding")
 
 
 
 
 
176
  embedding_btn.click(gradio_embedding, inputs=text_input, outputs=embedding_output)
177
 
178
- with gr.Tab("Text Similarity"):
179
- text1_input = gr.Textbox(label="Text 1", placeholder="e.g., 'Cabbage disease treatment'")
180
- text2_input = gr.Textbox(label="Text 2", placeholder="e.g., 'Fungicide for cabbage'")
181
- similarity_output = gr.Textbox(label="Similarity Score")
182
- similarity_btn = gr.Button("Calculate Similarity")
 
 
 
 
 
 
 
 
 
183
  similarity_btn.click(gradio_similarity, inputs=[text1_input, text2_input], outputs=similarity_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Mount Gradio app to FastAPI
186
  app = gr.mount_gradio_app(app, demo, path="/")
187
 
188
  if __name__ == "__main__":
189
- print("Starting server...")
190
- print("FastAPI endpoints available at:")
 
 
 
 
 
191
  print("- POST /embedding")
192
- print("- POST /similarity")
193
  print("- Gradio interface at /")
 
194
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
12
  # Initialize FastAPI
13
  app = FastAPI()
14
 
15
+ # Model configuration
16
  model_name = "Hayme/agrisago-bert"
17
+ print("Loading AgriSagot BERT model...")
18
 
19
  try:
20
+ print("Attempting to load tokenizer...")
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
+ print("βœ… Tokenizer loaded successfully")
23
 
24
+ print("Attempting to load model...")
25
+ # Try to load model - transformers will automatically handle safetensors vs pytorch format
26
+ model = AutoModel.from_pretrained(
27
+ model_name,
28
+ trust_remote_code=True,
29
+ torch_dtype=torch.float32, # Ensure compatibility
30
+ device_map="auto" if torch.cuda.is_available() else None
31
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  model.eval()
33
+ print("βœ… Model loaded successfully")
 
34
  model_loaded = True
35
 
36
+ # Get model info
37
+ print(f"Model type: {type(model)}")
38
+ print(f"Model device: {next(model.parameters()).device}")
39
+
40
  except Exception as e:
41
+ print(f"❌ Error loading model: {e}")
42
+ print("Will use dummy responses for testing")
43
  tokenizer = None
44
  model = None
45
  model_loaded = False
46
 
47
  def get_bert_embedding(text):
48
  """Get BERT embedding for text"""
49
+ if not model_loaded or model is None or tokenizer is None:
50
+ print("⚠️ Model not loaded, returning dummy embedding")
51
  # Return a dummy embedding for testing
52
+ np.random.seed(hash(text) % 2**32) # Consistent dummy based on text
53
+ return np.random.normal(0, 0.1, 768).tolist()
54
 
55
  try:
56
+ print(f"Getting embedding for: {text[:50]}...")
57
+
58
  # Tokenize and encode
59
+ inputs = tokenizer(
60
+ text,
61
+ return_tensors="pt",
62
+ truncation=True,
63
+ padding=True,
64
+ max_length=512
65
+ )
66
+
67
+ # Move to same device as model
68
+ device = next(model.parameters()).device
69
+ inputs = {k: v.to(device) for k, v in inputs.items()}
70
 
71
  # Get embeddings
72
  with torch.no_grad():
73
  outputs = model(**inputs)
74
  # Use [CLS] token embedding (first token)
75
+ embedding = outputs.last_hidden_state[:, 0, :].squeeze()
76
 
77
+ # Move back to CPU and convert to numpy
78
+ if embedding.device != torch.device('cpu'):
79
+ embedding = embedding.cpu()
80
+ embedding = embedding.numpy()
81
+
82
+ print(f"βœ… Generated embedding of shape: {embedding.shape}")
83
  return embedding.tolist()
84
 
85
  except Exception as e:
86
+ print(f"❌ Error in get_bert_embedding: {e}")
87
  # Return dummy embedding on error
88
+ np.random.seed(hash(text) % 2**32)
89
+ return np.random.normal(0, 0.1, 768).tolist()
90
 
91
  def calculate_similarity(text1, text2):
92
  """Calculate cosine similarity between two texts"""
93
  try:
94
+ print(f"Calculating similarity between texts...")
95
+
96
  # Get embeddings
97
  emb1 = np.array(get_bert_embedding(text1))
98
  emb2 = np.array(get_bert_embedding(text2))
 
106
  return 0.0
107
 
108
  similarity = dot_product / (norm1 * norm2)
109
+ result = float(similarity)
110
+ print(f"βœ… Similarity calculated: {result:.4f}")
111
+ return result
112
 
113
  except Exception as e:
114
+ print(f"❌ Error calculating similarity: {e}")
115
  return 0.0
116
 
117
  # FastAPI endpoints
 
127
  """Get embedding for text"""
128
  try:
129
  embedding = get_bert_embedding(request.text)
130
+ return {
131
+ "embedding": embedding,
132
+ "success": True,
133
+ "model_loaded": model_loaded,
134
+ "embedding_length": len(embedding)
135
+ }
136
  except Exception as e:
137
+ print(f"API Error: {e}")
138
  raise HTTPException(status_code=500, detail=str(e))
139
 
140
  @app.post("/similarity")
 
142
  """Get similarity between two texts"""
143
  try:
144
  similarity = calculate_similarity(request.text1, request.text2)
145
+ return {
146
+ "similarity": similarity,
147
+ "success": True,
148
+ "model_loaded": model_loaded
149
+ }
150
  except Exception as e:
151
+ print(f"API Error: {e}")
152
  raise HTTPException(status_code=500, detail=str(e))
153
 
154
  @app.get("/")
155
  async def root():
156
+ return {
157
+ "message": "AgriSagot BERT Model API",
158
+ "status": "running",
159
+ "model_loaded": model_loaded,
160
+ "model_name": model_name,
161
+ "torch_version": torch.__version__,
162
+ "device": str(next(model.parameters()).device) if model_loaded else "N/A"
163
+ }
164
 
165
+ @app.get("/health")
166
+ async def health_check():
167
+ """Health check endpoint"""
168
+ return {
169
+ "status": "healthy",
170
+ "model_loaded": model_loaded,
171
+ "endpoints": ["/", "/embedding", "/similarity", "/health"]
172
+ }
173
+
174
+ # Gradio interface functions
175
  def gradio_embedding(text):
176
  """Gradio interface for embeddings"""
177
  if not text.strip():
 
179
 
180
  try:
181
  embedding = get_bert_embedding(text)
182
+ status = "βœ… Real BERT embedding" if model_loaded else "⚠️ Dummy embedding (model not loaded)"
183
+ return f"{status}\nEmbedding length: {len(embedding)}\nFirst 10 values: {embedding[:10]}"
184
  except Exception as e:
185
  return f"Error: {str(e)}"
186
 
 
191
 
192
  try:
193
  similarity = calculate_similarity(text1, text2)
194
+ status = "βœ… Real BERT similarity" if model_loaded else "⚠️ Dummy similarity (model not loaded)"
195
+ return f"{status}\nSimilarity score: {similarity:.4f}"
196
  except Exception as e:
197
  return f"Error: {str(e)}"
198
 
199
  # Create Gradio interface
200
+ with gr.Blocks(title="AgriSagot BERT Model", theme=gr.themes.Soft()) as demo:
201
+ gr.Markdown("# 🌾 AgriSagot BERT Model")
202
+ gr.Markdown("Agricultural text processing with BERT embeddings for crop disease recommendations")
203
 
204
+ # Status display
205
+ if model_loaded:
206
+ gr.Markdown("βœ… **Status**: Model loaded successfully! Using real BERT embeddings.")
207
  else:
208
+ gr.Markdown("⚠️ **Status**: Model not loaded. Using dummy responses for API testing.")
209
+
210
+ gr.Markdown(f"**Model**: {model_name}")
211
+ gr.Markdown(f"**PyTorch Version**: {torch.__version__}")
212
 
213
+ with gr.Tab("πŸ” Text Embedding"):
214
+ gr.Markdown("Generate BERT embeddings for agricultural text")
215
+ text_input = gr.Textbox(
216
+ label="Enter agricultural text",
217
+ placeholder="e.g., 'Cabbage fungal treatment with copper-based fungicide'",
218
+ lines=2
219
+ )
220
+ embedding_output = gr.Textbox(label="Embedding Info", lines=4)
221
+ embedding_btn = gr.Button("Get Embedding", variant="primary")
222
  embedding_btn.click(gradio_embedding, inputs=text_input, outputs=embedding_output)
223
 
224
+ with gr.Tab("πŸ”„ Text Similarity"):
225
+ gr.Markdown("Compare similarity between two agricultural texts")
226
+ text1_input = gr.Textbox(
227
+ label="Text 1",
228
+ placeholder="e.g., 'Cabbage disease treatment'",
229
+ lines=2
230
+ )
231
+ text2_input = gr.Textbox(
232
+ label="Text 2",
233
+ placeholder="e.g., 'Fungicide for cabbage crops'",
234
+ lines=2
235
+ )
236
+ similarity_output = gr.Textbox(label="Similarity Result", lines=3)
237
+ similarity_btn = gr.Button("Calculate Similarity", variant="primary")
238
  similarity_btn.click(gradio_similarity, inputs=[text1_input, text2_input], outputs=similarity_output)
239
+
240
+ with gr.Tab("πŸ“š API Documentation"):
241
+ gr.Markdown("""
242
+ ## API Endpoints
243
+
244
+ ### POST /embedding
245
+ Get BERT embedding for text
246
+ ```json
247
+ {
248
+ "text": "your agricultural text here"
249
+ }
250
+ ```
251
+
252
+ ### POST /similarity
253
+ Get similarity between two texts
254
+ ```json
255
+ {
256
+ "text1": "first text",
257
+ "text2": "second text"
258
+ }
259
+ ```
260
+
261
+ ### GET /health
262
+ Check API health status
263
+
264
+ ## Example Usage
265
+ ```bash
266
+ curl -X POST "https://hayme-agrisagot-bert.hf.space/embedding" \\
267
+ -H "Content-Type: application/json" \\
268
+ -d '{"text":"cabbage fungal disease treatment"}'
269
+ ```
270
+ """)
271
 
272
  # Mount Gradio app to FastAPI
273
  app = gr.mount_gradio_app(app, demo, path="/")
274
 
275
  if __name__ == "__main__":
276
+ print("\n" + "="*50)
277
+ print("πŸš€ Starting AgriSagot BERT API Server")
278
+ print("="*50)
279
+ print(f"Model loaded: {model_loaded}")
280
+ print("FastAPI endpoints:")
281
+ print("- GET /")
282
+ print("- GET /health")
283
  print("- POST /embedding")
284
+ print("- POST /similarity")
285
  print("- Gradio interface at /")
286
+ print("="*50)
287
  uvicorn.run(app, host="0.0.0.0", port=7860)