Sayed223 commited on
Commit
0550c3e
·
verified ·
1 Parent(s): 0136649

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -128
app.py CHANGED
@@ -1,58 +1,58 @@
1
  import os
2
  import io
3
  import base64
 
4
  import torch
5
  import torch.nn as nn
6
  import torchvision.transforms as transforms
7
  from PIL import Image
8
  from flask import Flask, request, render_template, flash, redirect, url_for, jsonify
9
- from dotenv import load_dotenv # Import dotenv
10
 
11
-
12
- # Import necessary classes from your original script / transformers
13
  from transformers import (
14
- SwinModel,
 
15
  T5ForConditionalGeneration,
16
  T5Tokenizer,
17
- AutoModelForCausalLM, # Added for Llama
18
- AutoTokenizer, # Added for Llama
19
  )
20
  from transformers.modeling_outputs import BaseModelOutput
21
 
22
- load_dotenv() # Load environment variables from .env file
23
 
24
  # --- Configuration ---
25
- MODEL_PATH = '/cluster/home/ammaa/Downloads/Projects/CheXpert-Report-Generation/swin-t5-model.pth' # Path to your trained model weights
26
  SWIN_MODEL_NAME = "microsoft/swin-base-patch4-window7-224"
27
  T5_MODEL_NAME = "t5-base"
28
- LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" # Llama model
29
- HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") # Get token from env
30
 
31
  if not HF_TOKEN:
32
  print("Warning: HUGGING_FACE_HUB_TOKEN environment variable not set. Llama model download might fail.")
33
 
34
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- UPLOAD_FOLDER = 'uploads' # Optional: If you want to save uploads temporarily
36
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
37
 
38
- # Ensure the upload folder exists if you use it
39
- # if not os.path.exists(UPLOAD_FOLDER):
40
- # os.makedirs(UPLOAD_FOLDER)
41
-
42
  # --- Swin-T5 Model Definition ---
43
  class ImageCaptioningModel(nn.Module):
44
  def __init__(self,
45
  swin_model_name=SWIN_MODEL_NAME,
46
  t5_model_name=T5_MODEL_NAME):
47
  super().__init__()
48
- self.swin = SwinModel.from_pretrained(swin_model_name)
 
49
  self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name)
 
50
  self.img_proj = nn.Linear(self.swin.config.hidden_size, self.t5.config.d_model)
51
 
52
  def forward(self, images, labels=None):
53
- swin_outputs = self.swin(images)
54
- img_feats = swin_outputs.last_hidden_state
55
- img_feats_proj = self.img_proj(img_feats)
 
56
  encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj)
57
  if labels is not None:
58
  outputs = self.t5(encoder_outputs=encoder_outputs, labels=labels)
@@ -75,20 +75,28 @@ def load_swin_t5_model_components():
75
  # Initialize model structure
76
  swin_t5_model = ImageCaptioningModel(swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME)
77
 
78
- # Load state dictionary
79
  if not os.path.exists(MODEL_PATH):
80
- raise FileNotFoundError(f"Swin-T5 Model file not found at {MODEL_PATH}.")
81
- # Load Swin-T5 model to the primary DEVICE (can be CPU or GPU)
82
- swin_t5_model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
 
 
 
 
 
 
 
 
83
  swin_t5_model.to(DEVICE)
84
- swin_t5_model.eval() # Set to evaluation mode
85
  print("Swin-T5 Model loaded successfully.")
86
 
87
- # Load tokenizer
88
  swin_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
89
  print("Swin-T5 Tokenizer loaded successfully.")
90
 
91
- # Define image transformations
92
  transform = transforms.Compose([
93
  transforms.Resize((224, 224)),
94
  transforms.ToTensor(),
@@ -99,6 +107,8 @@ def load_swin_t5_model_components():
99
 
100
  except Exception as e:
101
  print(f"Error loading Swin-T5 model components: {e}")
 
 
102
  raise
103
 
104
  def load_llama_model_components():
@@ -106,57 +116,58 @@ def load_llama_model_components():
106
  global llama_model, llama_tokenizer
107
  if not HF_TOKEN:
108
  print("Skipping Llama model load: Hugging Face token not found.")
109
- return # Don't attempt to load if no token
110
 
111
  try:
112
  print(f"Loading Llama model ({LLAMA_MODEL_NAME}) components...")
113
- # Use bfloat16 for memory efficiency if available, otherwise float16/32
114
- torch_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
115
 
116
- llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
 
 
117
  llama_model = AutoModelForCausalLM.from_pretrained(
118
  LLAMA_MODEL_NAME,
119
  torch_dtype=torch_dtype,
120
- device_map="auto", # Automatically distribute across GPUs/CPU RAM if needed
121
- token=HF_TOKEN
122
- # Add quantization config here if needed (e.g., load_in_4bit=True with bitsandbytes)
123
- # quantization_config=BitsAndBytesConfig(...)
124
  )
125
- llama_model.eval() # Set to evaluation mode
126
  print("Llama Model and Tokenizer loaded successfully.")
127
-
128
  except Exception as e:
129
  print(f"Error loading Llama model components: {e}")
130
- # Decide if the app should run without the chat feature or crash
131
  llama_model = None
132
  llama_tokenizer = None
133
  print("WARNING: Chatbot functionality will be disabled due to loading error.")
134
- # raise # Uncomment this if the chat feature is critical
135
 
136
  # --- Inference Function (Swin-T5) ---
137
  def generate_report(image_bytes, selected_vlm, max_length=100):
138
  """Generates a report/caption for the given image bytes using Swin-T5."""
139
  global swin_t5_model, swin_t5_tokenizer, transform
140
- if not all([swin_t5_model, swin_t5_tokenizer, transform]):
141
- # Check if loading failed or wasn't called
 
142
  if swin_t5_model is None or swin_t5_tokenizer is None or transform is None:
143
- load_swin_t5_model_components() # Attempt to load again if missing
144
- if not all([swin_t5_model, swin_t5_tokenizer, transform]):
145
- raise RuntimeError("Swin-T5 model components failed to load.")
146
- else:
147
- raise RuntimeError("Swin-T5 model components not loaded properly.")
148
-
149
 
150
  if selected_vlm != "swin_t5_chexpert":
151
  return "Error: Selected VLM is not supported."
152
 
153
  try:
154
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
155
- input_image = transform(image).unsqueeze(0).to(DEVICE) # Add batch dimension and send to device
156
 
157
  # Perform inference
158
  with torch.no_grad():
159
- swin_outputs = swin_t5_model.swin(input_image)
160
  img_feats = swin_outputs.last_hidden_state
161
  img_feats_proj = swin_t5_model.img_proj(img_feats)
162
  encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj)
@@ -172,60 +183,52 @@ def generate_report(image_bytes, selected_vlm, max_length=100):
172
 
173
  except Exception as e:
174
  print(f"Error during Swin-T5 report generation: {e}")
 
175
  return f"Error generating report: {e}"
176
 
177
- # --- Chat Function (Llama 3.1) ---
178
  def generate_chat_response(question, report_context, max_new_tokens=250):
179
  """Generates a chat response using Llama based on the report context."""
180
  global llama_model, llama_tokenizer
181
- if not llama_model or not llama_tokenizer:
182
  return "Chatbot is currently unavailable."
183
 
184
- # System prompt to guide the LLM
185
  system_prompt = "You are a helpful medical assistant. I'm a medical student, your task is to help me understand the following report."
186
- # Construct the prompt using the chat template
187
- messages = [
188
- {"role": "system", "content": system_prompt},
189
- {"role": "user", "content": f"Based on the following report:\n\n---\n{report_context}\n---\n\nPlease answer this question: {question}"}
190
- ]
191
 
192
- # Prepare input for the model
193
  try:
194
- # Use the tokenizer's chat template
195
- input_ids = llama_tokenizer.apply_chat_template(
196
- messages,
197
- add_generation_prompt=True,
198
- return_tensors="pt"
199
- ).to(llama_model.device) # Move input IDs to the same device as the model
200
-
201
- # Set terminators for generation
202
- # Common terminators for Llama 3 Instruct
203
- terminators = [
204
- llama_tokenizer.eos_token_id,
205
- llama_tokenizer.convert_tokens_to_ids("<|eot_id|>")
206
- ]
207
 
208
  with torch.no_grad():
209
  outputs = llama_model.generate(
210
- input_ids,
 
211
  max_new_tokens=max_new_tokens,
212
- eos_token_id=terminators,
213
- do_sample=True, # Use sampling for more natural responses
214
  temperature=0.6,
215
  top_p=0.9,
216
- pad_token_id=llama_tokenizer.eos_token_id # Avoid warning, set pad_token_id
217
  )
218
 
219
- # Decode the response, skipping the input prompt part
220
- response_ids = outputs[0][input_ids.shape[-1]:]
221
- response_text = llama_tokenizer.decode(response_ids, skip_special_tokens=True)
222
- return response_text.strip()
 
 
223
 
224
  except Exception as e:
225
  print(f"Error during Llama chat generation: {e}")
 
226
  return f"Error generating chat response: {e}"
227
 
228
-
229
  # --- Flask Application Setup ---
230
  app = Flask(__name__)
231
  app.secret_key = os.urandom(24)
@@ -234,18 +237,18 @@ app.secret_key = os.urandom(24)
234
  print("Loading models on application startup...")
235
  try:
236
  load_swin_t5_model_components()
237
- load_llama_model_components() # Load Llama
238
  print("Model loading complete.")
239
  except Exception as e:
240
  print(f"FATAL ERROR during model loading: {e}")
241
- # Depending on requirements, you might want to exit or continue with limited functionality
242
- # exit(1) # Example: Exit if models are critical
 
243
 
244
  def allowed_file(filename):
245
  return '.' in filename and \
246
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
247
 
248
- # ---- NEW: Function to Parse Filename ----
249
  def parse_patient_info(filename):
250
  """
251
  Parses a filename like '00069-34-Frontal-AP-63.0-Male-White.png'
@@ -255,56 +258,49 @@ def parse_patient_info(filename):
255
  try:
256
  base_name = os.path.splitext(filename)[0]
257
  parts = base_name.split('-')
258
- # Expected structure based on example: ... - ViewPart1 - ViewPartN - Age - Gender - Ethnicity
259
- if len(parts) < 5: # Need at least initial parts, age, gender, ethnicity
260
  print(f"Warning: Filename '{filename}' has fewer parts than expected.")
261
  return None
262
 
263
  ethnicity = parts[-1]
264
  gender = parts[-2]
265
  age_str = parts[-3]
266
- # Handle potential '.0' in age and convert to int
267
  try:
268
  age = int(float(age_str))
269
  except ValueError:
270
  print(f"Warning: Could not parse age '{age_str}' from filename '{filename}'.")
271
- return None # Or set age to None/default
272
 
273
- # Assume view is everything between the second part (index 1) and the age part (index -3)
274
  view_parts = parts[2:-3]
275
- view = '-'.join(view_parts) if view_parts else "Unknown" # Handle cases with missing view
276
 
277
- # Basic validation
278
- if gender.lower() not in ['male', 'female', 'other', 'unknown']: # Be flexible
279
- print(f"Warning: Unusual gender '{gender}' found in filename '{filename}'.")
280
- # Decide whether to return None or keep it
281
 
282
  return {
283
  'view': view,
284
  'age': age,
285
- 'gender': gender.capitalize(), # Capitalize for display
286
- 'ethnicity': ethnicity.capitalize() # Capitalize for display
287
  }
288
  except IndexError:
289
  print(f"Error parsing filename '{filename}': Index out of bounds.")
290
  return None
291
  except Exception as e:
292
  print(f"Error parsing filename '{filename}': {e}")
 
293
  return None
294
 
295
  # --- Routes ---
296
-
297
  @app.route('/', methods=['GET'])
298
  def index():
299
- """Renders the main page."""
300
  chatbot_available = bool(llama_model and llama_tokenizer)
301
  return render_template('index.html', chatbot_available=chatbot_available)
302
 
303
  @app.route('/predict', methods=['POST'])
304
  def predict():
305
- """Handles image upload and prediction."""
306
- chatbot_available = bool(llama_model and llama_tokenizer) # Check again
307
- patient_info = None # Initialize patient_info
308
 
309
  if 'image' not in request.files:
310
  flash('No image file part in the request.', 'danger')
@@ -317,8 +313,8 @@ def predict():
317
  if not (10 <= max_length <= 512):
318
  raise ValueError("Max length must be between 10 and 512.")
319
  except ValueError as e:
320
- flash(f'Invalid Max Length value: {e}', 'danger')
321
- return redirect(url_for('index'))
322
 
323
  if file.filename == '':
324
  flash('No image selected for uploading.', 'warning')
@@ -328,46 +324,39 @@ def predict():
328
  try:
329
  image_bytes = file.read()
330
 
331
- # ---- ADDED: Parse filename ----
332
  original_filename = file.filename
333
  patient_info = parse_patient_info(original_filename)
334
  if patient_info:
335
  print(f"Parsed Patient Info: {patient_info}")
336
  else:
337
  print(f"Could not parse patient info from filename: {original_filename}")
338
- # ---- END ADDED ----
339
 
340
- # Generate report using Swin-T5
341
  report = generate_report(image_bytes, vlm_choice, max_length)
342
 
343
- # Check for errors in report generation
344
- if report.startswith("Error"):
345
- flash(f'Report Generation Failed: {report}', 'danger')
346
- # Still render with image if possible, but show error
347
- image_data = base64.b64encode(image_bytes).decode('utf-8')
348
- return render_template('index.html',
349
- report=None, # Or pass the error message
350
- image_data=image_data,
351
- patient_info=patient_info, # Pass parsed info even if report failed
352
- chatbot_available=chatbot_available)
353
-
354
 
355
  image_data = base64.b64encode(image_bytes).decode('utf-8')
356
 
357
- # Render the page with results AND the report text for JS/Chat
358
  return render_template('index.html',
359
  report=report,
360
  image_data=image_data,
361
- patient_info=patient_info, # Pass the parsed info
362
- chatbot_available=chatbot_available) # Pass availability again
363
 
364
  except FileNotFoundError as fnf_error:
365
- flash(f'Model file not found: {fnf_error}. Please check server configuration.', 'danger')
366
- print(f"Model file error: {fnf_error}\n{traceback.format_exc()}")
367
- return redirect(url_for('index'))
368
  except RuntimeError as rt_error:
369
  flash(f'Model loading error: {rt_error}. Please check server logs.', 'danger')
370
- print(f"Runtime error during prediction (model loading?): {rt_error}\n{traceback.format_exc()}")
371
  return redirect(url_for('index'))
372
  except Exception as e:
373
  flash(f'An unexpected error occurred during prediction: {e}', 'danger')
@@ -377,12 +366,10 @@ def predict():
377
  flash('Invalid image file type. Allowed types: png, jpg, jpeg.', 'danger')
378
  return redirect(url_for('index'))
379
 
380
- # --- New Chat Endpoint ---
381
  @app.route('/chat', methods=['POST'])
382
  def chat():
383
- """Handles chat requests based on the generated report."""
384
  if not llama_model or not llama_tokenizer:
385
- return jsonify({"answer": "Chatbot is not available."}), 503 # Service unavailable
386
 
387
  data = request.get_json()
388
  if not data or 'question' not in data or 'report_context' not in data:
@@ -396,8 +383,8 @@ def chat():
396
  return jsonify({"answer": answer})
397
  except Exception as e:
398
  print(f"Error in /chat endpoint: {e}")
 
399
  return jsonify({"error": "Failed to generate chat response"}), 500
400
 
401
  if __name__ == '__main__':
402
- # Make sure to set debug=False for production/sharing
403
- app.run(host='0.0.0.0', port=5000, debug=False)
 
1
  import os
2
  import io
3
  import base64
4
+ import traceback
5
  import torch
6
  import torch.nn as nn
7
  import torchvision.transforms as transforms
8
  from PIL import Image
9
  from flask import Flask, request, render_template, flash, redirect, url_for, jsonify
10
+ from dotenv import load_dotenv
11
 
12
+ # Use the auto classes to avoid version-specific direct imports
 
13
  from transformers import (
14
+ AutoModel, # used for vision (Swin)
15
+ AutoImageProcessor, # optional: if you want processor instead of torchvision
16
  T5ForConditionalGeneration,
17
  T5Tokenizer,
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
  )
21
  from transformers.modeling_outputs import BaseModelOutput
22
 
23
+ load_dotenv() # Load environment variables from .env file
24
 
25
  # --- Configuration ---
26
+ MODEL_PATH = '/cluster/home/ammaa/Downloads/Projects/CheXpert-Report-Generation/swin-t5-model.pth'
27
  SWIN_MODEL_NAME = "microsoft/swin-base-patch4-window7-224"
28
  T5_MODEL_NAME = "t5-base"
29
+ LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
30
+ HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") # Hugging Face token (optional)
31
 
32
  if not HF_TOKEN:
33
  print("Warning: HUGGING_FACE_HUB_TOKEN environment variable not set. Llama model download might fail.")
34
 
35
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ UPLOAD_FOLDER = 'uploads'
37
  ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
38
 
 
 
 
 
39
  # --- Swin-T5 Model Definition ---
40
  class ImageCaptioningModel(nn.Module):
41
  def __init__(self,
42
  swin_model_name=SWIN_MODEL_NAME,
43
  t5_model_name=T5_MODEL_NAME):
44
  super().__init__()
45
+ # Use AutoModel for the vision backbone (works across transformer versions)
46
+ self.swin = AutoModel.from_pretrained(swin_model_name)
47
  self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name)
48
+ # Project swin hidden states to T5 d_model
49
  self.img_proj = nn.Linear(self.swin.config.hidden_size, self.t5.config.d_model)
50
 
51
  def forward(self, images, labels=None):
52
+ # images: expected shape (batch, channels, height, width)
53
+ swin_outputs = self.swin(images, return_dict=True)
54
+ img_feats = swin_outputs.last_hidden_state # (batch, seq_len, hidden)
55
+ img_feats_proj = self.img_proj(img_feats) # project to T5 d_model
56
  encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj)
57
  if labels is not None:
58
  outputs = self.t5(encoder_outputs=encoder_outputs, labels=labels)
 
75
  # Initialize model structure
76
  swin_t5_model = ImageCaptioningModel(swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME)
77
 
78
+ # Load state dictionary if provided
79
  if not os.path.exists(MODEL_PATH):
80
+ raise FileNotFoundError(f"Swin-T5 Model file not found at {MODEL_PATH}.")
81
+
82
+ # Load state dict into model (map_location ensures correct device)
83
+ state = torch.load(MODEL_PATH, map_location=DEVICE)
84
+ # If the saved state is a dict containing model key (common), attempt to pull it
85
+ if isinstance(state, dict) and "model_state_dict" in state and len(state) > 1:
86
+ # typical saved checkpoint structure { 'epoch':..., 'model_state_dict':..., ... }
87
+ swin_t5_model.load_state_dict(state["model_state_dict"])
88
+ else:
89
+ swin_t5_model.load_state_dict(state)
90
+
91
  swin_t5_model.to(DEVICE)
92
+ swin_t5_model.eval()
93
  print("Swin-T5 Model loaded successfully.")
94
 
95
+ # Load tokenizer for T5
96
  swin_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
97
  print("Swin-T5 Tokenizer loaded successfully.")
98
 
99
+ # Define (simple) image transformations
100
  transform = transforms.Compose([
101
  transforms.Resize((224, 224)),
102
  transforms.ToTensor(),
 
107
 
108
  except Exception as e:
109
  print(f"Error loading Swin-T5 model components: {e}")
110
+ print(traceback.format_exc())
111
+ # Re-raise so startup knows loading failed (your code caught it)
112
  raise
113
 
114
  def load_llama_model_components():
 
116
  global llama_model, llama_tokenizer
117
  if not HF_TOKEN:
118
  print("Skipping Llama model load: Hugging Face token not found.")
119
+ return
120
 
121
  try:
122
  print(f"Loading Llama model ({LLAMA_MODEL_NAME}) components...")
 
 
123
 
124
+ # Choose an appropriate dtype for loading
125
+ if torch.cuda.is_available():
126
+ # prefer bf16 if supported to save memory on modern GPUs
127
+ try:
128
+ torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
129
+ except Exception:
130
+ torch_dtype = torch.float16
131
+ else:
132
+ torch_dtype = torch.float32
133
+
134
+ # Use use_auth_token parameter for private models / gated access
135
+ llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_auth_token=HF_TOKEN)
136
  llama_model = AutoModelForCausalLM.from_pretrained(
137
  LLAMA_MODEL_NAME,
138
  torch_dtype=torch_dtype,
139
+ device_map="auto",
140
+ use_auth_token=HF_TOKEN
 
 
141
  )
142
+ llama_model.eval()
143
  print("Llama Model and Tokenizer loaded successfully.")
 
144
  except Exception as e:
145
  print(f"Error loading Llama model components: {e}")
146
+ print(traceback.format_exc())
147
  llama_model = None
148
  llama_tokenizer = None
149
  print("WARNING: Chatbot functionality will be disabled due to loading error.")
 
150
 
151
  # --- Inference Function (Swin-T5) ---
152
  def generate_report(image_bytes, selected_vlm, max_length=100):
153
  """Generates a report/caption for the given image bytes using Swin-T5."""
154
  global swin_t5_model, swin_t5_tokenizer, transform
155
+ # Ensure components are loaded (attempt to load if missing)
156
+ if swin_t5_model is None or swin_t5_tokenizer is None or transform is None:
157
+ load_swin_t5_model_components()
158
  if swin_t5_model is None or swin_t5_tokenizer is None or transform is None:
159
+ raise RuntimeError("Swin-T5 model components failed to load.")
 
 
 
 
 
160
 
161
  if selected_vlm != "swin_t5_chexpert":
162
  return "Error: Selected VLM is not supported."
163
 
164
  try:
165
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
166
+ input_image = transform(image).unsqueeze(0).to(DEVICE)
167
 
168
  # Perform inference
169
  with torch.no_grad():
170
+ swin_outputs = swin_t5_model.swin(input_image, return_dict=True)
171
  img_feats = swin_outputs.last_hidden_state
172
  img_feats_proj = swin_t5_model.img_proj(img_feats)
173
  encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj)
 
183
 
184
  except Exception as e:
185
  print(f"Error during Swin-T5 report generation: {e}")
186
+ print(traceback.format_exc())
187
  return f"Error generating report: {e}"
188
 
189
+ # --- Chat Function (Llama) ---
190
  def generate_chat_response(question, report_context, max_new_tokens=250):
191
  """Generates a chat response using Llama based on the report context."""
192
  global llama_model, llama_tokenizer
193
+ if llama_model is None or llama_tokenizer is None:
194
  return "Chatbot is currently unavailable."
195
 
 
196
  system_prompt = "You are a helpful medical assistant. I'm a medical student, your task is to help me understand the following report."
197
+ prompt = (f"{system_prompt}\n\nBased on the following report:\n\n---\n{report_context}\n---\n\n"
198
+ f"Please answer this question: {question}\n")
 
 
 
199
 
 
200
  try:
201
+ # Tokenize and move to model device
202
+ inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True)
203
+ input_ids = inputs["input_ids"].to(next(llama_model.parameters()).device)
204
+ attention_mask = inputs.get("attention_mask", None)
205
+ if attention_mask is not None:
206
+ attention_mask = attention_mask.to(input_ids.device)
 
 
 
 
 
 
 
207
 
208
  with torch.no_grad():
209
  outputs = llama_model.generate(
210
+ input_ids=input_ids,
211
+ attention_mask=attention_mask,
212
  max_new_tokens=max_new_tokens,
213
+ eos_token_id=llama_tokenizer.eos_token_id,
214
+ do_sample=True,
215
  temperature=0.6,
216
  top_p=0.9,
217
+ pad_token_id=llama_tokenizer.eos_token_id
218
  )
219
 
220
+ # Returned outputs: (batch, seq_len). We want the newly generated part after the prompt.
221
+ generated = outputs[0]
222
+ # Remove input prompt tokens to keep only the response
223
+ response_ids = generated[input_ids.shape[-1]:]
224
+ response_text = llama_tokenizer.decode(response_ids, skip_special_tokens=True).strip()
225
+ return response_text
226
 
227
  except Exception as e:
228
  print(f"Error during Llama chat generation: {e}")
229
+ print(traceback.format_exc())
230
  return f"Error generating chat response: {e}"
231
 
 
232
  # --- Flask Application Setup ---
233
  app = Flask(__name__)
234
  app.secret_key = os.urandom(24)
 
237
  print("Loading models on application startup...")
238
  try:
239
  load_swin_t5_model_components()
240
+ load_llama_model_components()
241
  print("Model loading complete.")
242
  except Exception as e:
243
  print(f"FATAL ERROR during model loading: {e}")
244
+ print(traceback.format_exc())
245
+ # Continue with limited functionality (report generation may fail if swin-t5 didn't load)
246
+ # Optionally: exit(1)
247
 
248
  def allowed_file(filename):
249
  return '.' in filename and \
250
  filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
251
 
 
252
  def parse_patient_info(filename):
253
  """
254
  Parses a filename like '00069-34-Frontal-AP-63.0-Male-White.png'
 
258
  try:
259
  base_name = os.path.splitext(filename)[0]
260
  parts = base_name.split('-')
261
+ if len(parts) < 5:
 
262
  print(f"Warning: Filename '{filename}' has fewer parts than expected.")
263
  return None
264
 
265
  ethnicity = parts[-1]
266
  gender = parts[-2]
267
  age_str = parts[-3]
 
268
  try:
269
  age = int(float(age_str))
270
  except ValueError:
271
  print(f"Warning: Could not parse age '{age_str}' from filename '{filename}'.")
272
+ return None
273
 
 
274
  view_parts = parts[2:-3]
275
+ view = '-'.join(view_parts) if view_parts else "Unknown"
276
 
277
+ if gender.lower() not in ['male', 'female', 'other', 'unknown']:
278
+ print(f"Warning: Unusual gender '{gender}' found in filename '{filename}'.")
 
 
279
 
280
  return {
281
  'view': view,
282
  'age': age,
283
+ 'gender': gender.capitalize(),
284
+ 'ethnicity': ethnicity.capitalize()
285
  }
286
  except IndexError:
287
  print(f"Error parsing filename '{filename}': Index out of bounds.")
288
  return None
289
  except Exception as e:
290
  print(f"Error parsing filename '{filename}': {e}")
291
+ print(traceback.format_exc())
292
  return None
293
 
294
  # --- Routes ---
 
295
  @app.route('/', methods=['GET'])
296
  def index():
 
297
  chatbot_available = bool(llama_model and llama_tokenizer)
298
  return render_template('index.html', chatbot_available=chatbot_available)
299
 
300
  @app.route('/predict', methods=['POST'])
301
  def predict():
302
+ chatbot_available = bool(llama_model and llama_tokenizer)
303
+ patient_info = None
 
304
 
305
  if 'image' not in request.files:
306
  flash('No image file part in the request.', 'danger')
 
313
  if not (10 <= max_length <= 512):
314
  raise ValueError("Max length must be between 10 and 512.")
315
  except ValueError as e:
316
+ flash(f'Invalid Max Length value: {e}', 'danger')
317
+ return redirect(url_for('index'))
318
 
319
  if file.filename == '':
320
  flash('No image selected for uploading.', 'warning')
 
324
  try:
325
  image_bytes = file.read()
326
 
 
327
  original_filename = file.filename
328
  patient_info = parse_patient_info(original_filename)
329
  if patient_info:
330
  print(f"Parsed Patient Info: {patient_info}")
331
  else:
332
  print(f"Could not parse patient info from filename: {original_filename}")
 
333
 
 
334
  report = generate_report(image_bytes, vlm_choice, max_length)
335
 
336
+ if isinstance(report, str) and report.startswith("Error"):
337
+ flash(f'Report Generation Failed: {report}', 'danger')
338
+ image_data = base64.b64encode(image_bytes).decode('utf-8')
339
+ return render_template('index.html',
340
+ report=None,
341
+ image_data=image_data,
342
+ patient_info=patient_info,
343
+ chatbot_available=chatbot_available)
 
 
 
344
 
345
  image_data = base64.b64encode(image_bytes).decode('utf-8')
346
 
 
347
  return render_template('index.html',
348
  report=report,
349
  image_data=image_data,
350
+ patient_info=patient_info,
351
+ chatbot_available=chatbot_available)
352
 
353
  except FileNotFoundError as fnf_error:
354
+ flash(f'Model file not found: {fnf_error}. Please check server configuration.', 'danger')
355
+ print(f"Model file error: {fnf_error}\n{traceback.format_exc()}")
356
+ return redirect(url_for('index'))
357
  except RuntimeError as rt_error:
358
  flash(f'Model loading error: {rt_error}. Please check server logs.', 'danger')
359
+ print(f"Runtime error during prediction: {rt_error}\n{traceback.format_exc()}")
360
  return redirect(url_for('index'))
361
  except Exception as e:
362
  flash(f'An unexpected error occurred during prediction: {e}', 'danger')
 
366
  flash('Invalid image file type. Allowed types: png, jpg, jpeg.', 'danger')
367
  return redirect(url_for('index'))
368
 
 
369
  @app.route('/chat', methods=['POST'])
370
  def chat():
 
371
  if not llama_model or not llama_tokenizer:
372
+ return jsonify({"answer": "Chatbot is not available."}), 503
373
 
374
  data = request.get_json()
375
  if not data or 'question' not in data or 'report_context' not in data:
 
383
  return jsonify({"answer": answer})
384
  except Exception as e:
385
  print(f"Error in /chat endpoint: {e}")
386
+ print(traceback.format_exc())
387
  return jsonify({"error": "Failed to generate chat response"}), 500
388
 
389
  if __name__ == '__main__':
390
+ app.run(host='0.0.0.0', port=5000, debug=False)