import os import io import base64 import traceback import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from flask import Flask, request, render_template, flash, redirect, url_for, jsonify from dotenv import load_dotenv # Use the auto classes to avoid version-specific direct imports from transformers import ( AutoModel, # used for vision (Swin) AutoImageProcessor, # optional: if you want processor instead of torchvision T5ForConditionalGeneration, T5Tokenizer, AutoModelForCausalLM, AutoTokenizer, ) from transformers.modeling_outputs import BaseModelOutput load_dotenv() # Load environment variables from .env file # --- Configuration --- MODEL_PATH = '/cluster/home/ammaa/Downloads/Projects/CheXpert-Report-Generation/swin-t5-model.pth' SWIN_MODEL_NAME = "microsoft/swin-base-patch4-window7-224" T5_MODEL_NAME = "t5-base" LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") # Hugging Face token (optional) if not HF_TOKEN: print("Warning: HUGGING_FACE_HUB_TOKEN environment variable not set. Llama model download might fail.") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") UPLOAD_FOLDER = 'uploads' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # --- Swin-T5 Model Definition --- class ImageCaptioningModel(nn.Module): def __init__(self, swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME): super().__init__() # Use AutoModel for the vision backbone (works across transformer versions) self.swin = AutoModel.from_pretrained(swin_model_name) self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name) # Project swin hidden states to T5 d_model self.img_proj = nn.Linear(self.swin.config.hidden_size, self.t5.config.d_model) def forward(self, images, labels=None): # images: expected shape (batch, channels, height, width) swin_outputs = self.swin(images, return_dict=True) img_feats = swin_outputs.last_hidden_state # (batch, seq_len, hidden) img_feats_proj = self.img_proj(img_feats) # project to T5 d_model encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) if labels is not None: outputs = self.t5(encoder_outputs=encoder_outputs, labels=labels) else: outputs = self.t5(encoder_outputs=encoder_outputs) return outputs # --- Global Variables for Model Components --- swin_t5_model = None swin_t5_tokenizer = None transform = None llama_model = None llama_tokenizer = None def load_swin_t5_model_components(): """Loads the Swin-T5 model, tokenizer, and transformation pipeline.""" global swin_t5_model, swin_t5_tokenizer, transform try: print(f"Loading Swin-T5 model components on device: {DEVICE}") # Initialize model structure swin_t5_model = ImageCaptioningModel(swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME) # Load state dictionary if provided if not os.path.exists(MODEL_PATH): raise FileNotFoundError(f"Swin-T5 Model file not found at {MODEL_PATH}.") # Load state dict into model (map_location ensures correct device) state = torch.load(MODEL_PATH, map_location=DEVICE) # If the saved state is a dict containing model key (common), attempt to pull it if isinstance(state, dict) and "model_state_dict" in state and len(state) > 1: # typical saved checkpoint structure { 'epoch':..., 'model_state_dict':..., ... } swin_t5_model.load_state_dict(state["model_state_dict"]) else: swin_t5_model.load_state_dict(state) swin_t5_model.to(DEVICE) swin_t5_model.eval() print("Swin-T5 Model loaded successfully.") # Load tokenizer for T5 swin_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME) print("Swin-T5 Tokenizer loaded successfully.") # Define (simple) image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print("Transforms defined.") except Exception as e: print(f"Error loading Swin-T5 model components: {e}") print(traceback.format_exc()) # Re-raise so startup knows loading failed (your code caught it) raise def load_llama_model_components(): """Loads the Llama model and tokenizer.""" global llama_model, llama_tokenizer if not HF_TOKEN: print("Skipping Llama model load: Hugging Face token not found.") return try: print(f"Loading Llama model ({LLAMA_MODEL_NAME}) components...") # Choose an appropriate dtype for loading if torch.cuda.is_available(): # prefer bf16 if supported to save memory on modern GPUs try: torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 except Exception: torch_dtype = torch.float16 else: torch_dtype = torch.float32 # Use use_auth_token parameter for private models / gated access llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_auth_token=HF_TOKEN) llama_model = AutoModelForCausalLM.from_pretrained( LLAMA_MODEL_NAME, torch_dtype=torch_dtype, device_map="auto", use_auth_token=HF_TOKEN ) llama_model.eval() print("Llama Model and Tokenizer loaded successfully.") except Exception as e: print(f"Error loading Llama model components: {e}") print(traceback.format_exc()) llama_model = None llama_tokenizer = None print("WARNING: Chatbot functionality will be disabled due to loading error.") # --- Inference Function (Swin-T5) --- def generate_report(image_bytes, selected_vlm, max_length=100): """Generates a report/caption for the given image bytes using Swin-T5.""" global swin_t5_model, swin_t5_tokenizer, transform # Ensure components are loaded (attempt to load if missing) if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: load_swin_t5_model_components() if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: raise RuntimeError("Swin-T5 model components failed to load.") if selected_vlm != "swin_t5_chexpert": return "Error: Selected VLM is not supported." try: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") input_image = transform(image).unsqueeze(0).to(DEVICE) # Perform inference with torch.no_grad(): swin_outputs = swin_t5_model.swin(input_image, return_dict=True) img_feats = swin_outputs.last_hidden_state img_feats_proj = swin_t5_model.img_proj(img_feats) encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) generated_ids = swin_t5_model.t5.generate( encoder_outputs=encoder_outputs, max_length=max_length, num_beams=4, early_stopping=True ) report = swin_t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True) return report except Exception as e: print(f"Error during Swin-T5 report generation: {e}") print(traceback.format_exc()) return f"Error generating report: {e}" # --- Chat Function (Llama) --- def generate_chat_response(question, report_context, max_new_tokens=250): """Generates a chat response using Llama based on the report context.""" global llama_model, llama_tokenizer if llama_model is None or llama_tokenizer is None: return "Chatbot is currently unavailable." system_prompt = "You are a helpful medical assistant. I'm a medical student, your task is to help me understand the following report." prompt = (f"{system_prompt}\n\nBased on the following report:\n\n---\n{report_context}\n---\n\n" f"Please answer this question: {question}\n") try: # Tokenize and move to model device inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True) input_ids = inputs["input_ids"].to(next(llama_model.parameters()).device) attention_mask = inputs.get("attention_mask", None) if attention_mask is not None: attention_mask = attention_mask.to(input_ids.device) with torch.no_grad(): outputs = llama_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, eos_token_id=llama_tokenizer.eos_token_id, do_sample=True, temperature=0.6, top_p=0.9, pad_token_id=llama_tokenizer.eos_token_id ) # Returned outputs: (batch, seq_len). We want the newly generated part after the prompt. generated = outputs[0] # Remove input prompt tokens to keep only the response response_ids = generated[input_ids.shape[-1]:] response_text = llama_tokenizer.decode(response_ids, skip_special_tokens=True).strip() return response_text except Exception as e: print(f"Error during Llama chat generation: {e}") print(traceback.format_exc()) return f"Error generating chat response: {e}" # --- Flask Application Setup --- app = Flask(__name__) app.secret_key = os.urandom(24) # Load models when the application starts print("Loading models on application startup...") try: load_swin_t5_model_components() load_llama_model_components() print("Model loading complete.") except Exception as e: print(f"FATAL ERROR during model loading: {e}") print(traceback.format_exc()) # Continue with limited functionality (report generation may fail if swin-t5 didn't load) # Optionally: exit(1) def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def parse_patient_info(filename): """ Parses a filename like '00069-34-Frontal-AP-63.0-Male-White.png' Returns a dictionary with 'view', 'age', 'gender', 'ethnicity'. Returns None if parsing fails. """ try: base_name = os.path.splitext(filename)[0] parts = base_name.split('-') if len(parts) < 5: print(f"Warning: Filename '{filename}' has fewer parts than expected.") return None ethnicity = parts[-1] gender = parts[-2] age_str = parts[-3] try: age = int(float(age_str)) except ValueError: print(f"Warning: Could not parse age '{age_str}' from filename '{filename}'.") return None view_parts = parts[2:-3] view = '-'.join(view_parts) if view_parts else "Unknown" if gender.lower() not in ['male', 'female', 'other', 'unknown']: print(f"Warning: Unusual gender '{gender}' found in filename '{filename}'.") return { 'view': view, 'age': age, 'gender': gender.capitalize(), 'ethnicity': ethnicity.capitalize() } except IndexError: print(f"Error parsing filename '{filename}': Index out of bounds.") return None except Exception as e: print(f"Error parsing filename '{filename}': {e}") print(traceback.format_exc()) return None # --- Routes --- @app.route('/', methods=['GET']) def index(): chatbot_available = bool(llama_model and llama_tokenizer) return render_template('index.html', chatbot_available=chatbot_available) @app.route('/predict', methods=['POST']) def predict(): chatbot_available = bool(llama_model and llama_tokenizer) patient_info = None if 'image' not in request.files: flash('No image file part in the request.', 'danger') return redirect(url_for('index')) file = request.files['image'] vlm_choice = request.form.get('vlm_choice', 'swin_t5_chexpert') try: max_length = int(request.form.get('max_length', 100)) if not (10 <= max_length <= 512): raise ValueError("Max length must be between 10 and 512.") except ValueError as e: flash(f'Invalid Max Length value: {e}', 'danger') return redirect(url_for('index')) if file.filename == '': flash('No image selected for uploading.', 'warning') return redirect(url_for('index')) if file and allowed_file(file.filename): try: image_bytes = file.read() original_filename = file.filename patient_info = parse_patient_info(original_filename) if patient_info: print(f"Parsed Patient Info: {patient_info}") else: print(f"Could not parse patient info from filename: {original_filename}") report = generate_report(image_bytes, vlm_choice, max_length) if isinstance(report, str) and report.startswith("Error"): flash(f'Report Generation Failed: {report}', 'danger') image_data = base64.b64encode(image_bytes).decode('utf-8') return render_template('index.html', report=None, image_data=image_data, patient_info=patient_info, chatbot_available=chatbot_available) image_data = base64.b64encode(image_bytes).decode('utf-8') return render_template('index.html', report=report, image_data=image_data, patient_info=patient_info, chatbot_available=chatbot_available) except FileNotFoundError as fnf_error: flash(f'Model file not found: {fnf_error}. Please check server configuration.', 'danger') print(f"Model file error: {fnf_error}\n{traceback.format_exc()}") return redirect(url_for('index')) except RuntimeError as rt_error: flash(f'Model loading error: {rt_error}. Please check server logs.', 'danger') print(f"Runtime error during prediction: {rt_error}\n{traceback.format_exc()}") return redirect(url_for('index')) except Exception as e: flash(f'An unexpected error occurred during prediction: {e}', 'danger') print(f"Error during prediction: {e}\n{traceback.format_exc()}") return redirect(url_for('index')) else: flash('Invalid image file type. Allowed types: png, jpg, jpeg.', 'danger') return redirect(url_for('index')) @app.route('/chat', methods=['POST']) def chat(): if not llama_model or not llama_tokenizer: return jsonify({"answer": "Chatbot is not available."}), 503 data = request.get_json() if not data or 'question' not in data or 'report_context' not in data: return jsonify({"error": "Missing question or report context"}), 400 question = data['question'] report_context = data['report_context'] try: answer = generate_chat_response(question, report_context) return jsonify({"answer": answer}) except Exception as e: print(f"Error in /chat endpoint: {e}") print(traceback.format_exc()) return jsonify({"error": "Failed to generate chat response"}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=False)