|
|
import os |
|
|
import io |
|
|
import base64 |
|
|
import traceback |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from flask import Flask, request, render_template, flash, redirect, url_for, jsonify |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
AutoModel, |
|
|
AutoImageProcessor, |
|
|
T5ForConditionalGeneration, |
|
|
T5Tokenizer, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
) |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
MODEL_PATH = '/cluster/home/ammaa/Downloads/Projects/CheXpert-Report-Generation/swin-t5-model.pth' |
|
|
SWIN_MODEL_NAME = "microsoft/swin-base-patch4-window7-224" |
|
|
T5_MODEL_NAME = "t5-base" |
|
|
LLAMA_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct" |
|
|
HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") |
|
|
|
|
|
if not HF_TOKEN: |
|
|
print("Warning: HUGGING_FACE_HUB_TOKEN environment variable not set. Llama model download might fail.") |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
UPLOAD_FOLDER = 'uploads' |
|
|
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} |
|
|
|
|
|
|
|
|
class ImageCaptioningModel(nn.Module): |
|
|
def __init__(self, |
|
|
swin_model_name=SWIN_MODEL_NAME, |
|
|
t5_model_name=T5_MODEL_NAME): |
|
|
super().__init__() |
|
|
|
|
|
self.swin = AutoModel.from_pretrained(swin_model_name) |
|
|
self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name) |
|
|
|
|
|
self.img_proj = nn.Linear(self.swin.config.hidden_size, self.t5.config.d_model) |
|
|
|
|
|
def forward(self, images, labels=None): |
|
|
|
|
|
swin_outputs = self.swin(images, return_dict=True) |
|
|
img_feats = swin_outputs.last_hidden_state |
|
|
img_feats_proj = self.img_proj(img_feats) |
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) |
|
|
if labels is not None: |
|
|
outputs = self.t5(encoder_outputs=encoder_outputs, labels=labels) |
|
|
else: |
|
|
outputs = self.t5(encoder_outputs=encoder_outputs) |
|
|
return outputs |
|
|
|
|
|
|
|
|
swin_t5_model = None |
|
|
swin_t5_tokenizer = None |
|
|
transform = None |
|
|
llama_model = None |
|
|
llama_tokenizer = None |
|
|
|
|
|
def load_swin_t5_model_components(): |
|
|
"""Loads the Swin-T5 model, tokenizer, and transformation pipeline.""" |
|
|
global swin_t5_model, swin_t5_tokenizer, transform |
|
|
try: |
|
|
print(f"Loading Swin-T5 model components on device: {DEVICE}") |
|
|
|
|
|
swin_t5_model = ImageCaptioningModel(swin_model_name=SWIN_MODEL_NAME, t5_model_name=T5_MODEL_NAME) |
|
|
|
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
|
raise FileNotFoundError(f"Swin-T5 Model file not found at {MODEL_PATH}.") |
|
|
|
|
|
|
|
|
state = torch.load(MODEL_PATH, map_location=DEVICE) |
|
|
|
|
|
if isinstance(state, dict) and "model_state_dict" in state and len(state) > 1: |
|
|
|
|
|
swin_t5_model.load_state_dict(state["model_state_dict"]) |
|
|
else: |
|
|
swin_t5_model.load_state_dict(state) |
|
|
|
|
|
swin_t5_model.to(DEVICE) |
|
|
swin_t5_model.eval() |
|
|
print("Swin-T5 Model loaded successfully.") |
|
|
|
|
|
|
|
|
swin_t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME) |
|
|
print("Swin-T5 Tokenizer loaded successfully.") |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
print("Transforms defined.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading Swin-T5 model components: {e}") |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
raise |
|
|
|
|
|
def load_llama_model_components(): |
|
|
"""Loads the Llama model and tokenizer.""" |
|
|
global llama_model, llama_tokenizer |
|
|
if not HF_TOKEN: |
|
|
print("Skipping Llama model load: Hugging Face token not found.") |
|
|
return |
|
|
|
|
|
try: |
|
|
print(f"Loading Llama model ({LLAMA_MODEL_NAME}) components...") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
try: |
|
|
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
except Exception: |
|
|
torch_dtype = torch.float16 |
|
|
else: |
|
|
torch_dtype = torch.float32 |
|
|
|
|
|
|
|
|
llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_NAME, use_auth_token=HF_TOKEN) |
|
|
llama_model = AutoModelForCausalLM.from_pretrained( |
|
|
LLAMA_MODEL_NAME, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map="auto", |
|
|
use_auth_token=HF_TOKEN |
|
|
) |
|
|
llama_model.eval() |
|
|
print("Llama Model and Tokenizer loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading Llama model components: {e}") |
|
|
print(traceback.format_exc()) |
|
|
llama_model = None |
|
|
llama_tokenizer = None |
|
|
print("WARNING: Chatbot functionality will be disabled due to loading error.") |
|
|
|
|
|
|
|
|
def generate_report(image_bytes, selected_vlm, max_length=100): |
|
|
"""Generates a report/caption for the given image bytes using Swin-T5.""" |
|
|
global swin_t5_model, swin_t5_tokenizer, transform |
|
|
|
|
|
if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: |
|
|
load_swin_t5_model_components() |
|
|
if swin_t5_model is None or swin_t5_tokenizer is None or transform is None: |
|
|
raise RuntimeError("Swin-T5 model components failed to load.") |
|
|
|
|
|
if selected_vlm != "swin_t5_chexpert": |
|
|
return "Error: Selected VLM is not supported." |
|
|
|
|
|
try: |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
input_image = transform(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
swin_outputs = swin_t5_model.swin(input_image, return_dict=True) |
|
|
img_feats = swin_outputs.last_hidden_state |
|
|
img_feats_proj = swin_t5_model.img_proj(img_feats) |
|
|
encoder_outputs = BaseModelOutput(last_hidden_state=img_feats_proj) |
|
|
|
|
|
generated_ids = swin_t5_model.t5.generate( |
|
|
encoder_outputs=encoder_outputs, |
|
|
max_length=max_length, |
|
|
num_beams=4, |
|
|
early_stopping=True |
|
|
) |
|
|
report = swin_t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
return report |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during Swin-T5 report generation: {e}") |
|
|
print(traceback.format_exc()) |
|
|
return f"Error generating report: {e}" |
|
|
|
|
|
|
|
|
def generate_chat_response(question, report_context, max_new_tokens=250): |
|
|
"""Generates a chat response using Llama based on the report context.""" |
|
|
global llama_model, llama_tokenizer |
|
|
if llama_model is None or llama_tokenizer is None: |
|
|
return "Chatbot is currently unavailable." |
|
|
|
|
|
system_prompt = "You are a helpful medical assistant. I'm a medical student, your task is to help me understand the following report." |
|
|
prompt = (f"{system_prompt}\n\nBased on the following report:\n\n---\n{report_context}\n---\n\n" |
|
|
f"Please answer this question: {question}\n") |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True) |
|
|
input_ids = inputs["input_ids"].to(next(llama_model.parameters()).device) |
|
|
attention_mask = inputs.get("attention_mask", None) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(input_ids.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = llama_model.generate( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
max_new_tokens=max_new_tokens, |
|
|
eos_token_id=llama_tokenizer.eos_token_id, |
|
|
do_sample=True, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
pad_token_id=llama_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated = outputs[0] |
|
|
|
|
|
response_ids = generated[input_ids.shape[-1]:] |
|
|
response_text = llama_tokenizer.decode(response_ids, skip_special_tokens=True).strip() |
|
|
return response_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during Llama chat generation: {e}") |
|
|
print(traceback.format_exc()) |
|
|
return f"Error generating chat response: {e}" |
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
app.secret_key = os.urandom(24) |
|
|
|
|
|
|
|
|
print("Loading models on application startup...") |
|
|
try: |
|
|
load_swin_t5_model_components() |
|
|
load_llama_model_components() |
|
|
print("Model loading complete.") |
|
|
except Exception as e: |
|
|
print(f"FATAL ERROR during model loading: {e}") |
|
|
print(traceback.format_exc()) |
|
|
|
|
|
|
|
|
|
|
|
def allowed_file(filename): |
|
|
return '.' in filename and \ |
|
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
|
|
def parse_patient_info(filename): |
|
|
""" |
|
|
Parses a filename like '00069-34-Frontal-AP-63.0-Male-White.png' |
|
|
Returns a dictionary with 'view', 'age', 'gender', 'ethnicity'. |
|
|
Returns None if parsing fails. |
|
|
""" |
|
|
try: |
|
|
base_name = os.path.splitext(filename)[0] |
|
|
parts = base_name.split('-') |
|
|
if len(parts) < 5: |
|
|
print(f"Warning: Filename '{filename}' has fewer parts than expected.") |
|
|
return None |
|
|
|
|
|
ethnicity = parts[-1] |
|
|
gender = parts[-2] |
|
|
age_str = parts[-3] |
|
|
try: |
|
|
age = int(float(age_str)) |
|
|
except ValueError: |
|
|
print(f"Warning: Could not parse age '{age_str}' from filename '{filename}'.") |
|
|
return None |
|
|
|
|
|
view_parts = parts[2:-3] |
|
|
view = '-'.join(view_parts) if view_parts else "Unknown" |
|
|
|
|
|
if gender.lower() not in ['male', 'female', 'other', 'unknown']: |
|
|
print(f"Warning: Unusual gender '{gender}' found in filename '{filename}'.") |
|
|
|
|
|
return { |
|
|
'view': view, |
|
|
'age': age, |
|
|
'gender': gender.capitalize(), |
|
|
'ethnicity': ethnicity.capitalize() |
|
|
} |
|
|
except IndexError: |
|
|
print(f"Error parsing filename '{filename}': Index out of bounds.") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error parsing filename '{filename}': {e}") |
|
|
print(traceback.format_exc()) |
|
|
return None |
|
|
|
|
|
|
|
|
@app.route('/', methods=['GET']) |
|
|
def index(): |
|
|
chatbot_available = bool(llama_model and llama_tokenizer) |
|
|
return render_template('index.html', chatbot_available=chatbot_available) |
|
|
|
|
|
@app.route('/predict', methods=['POST']) |
|
|
def predict(): |
|
|
chatbot_available = bool(llama_model and llama_tokenizer) |
|
|
patient_info = None |
|
|
|
|
|
if 'image' not in request.files: |
|
|
flash('No image file part in the request.', 'danger') |
|
|
return redirect(url_for('index')) |
|
|
|
|
|
file = request.files['image'] |
|
|
vlm_choice = request.form.get('vlm_choice', 'swin_t5_chexpert') |
|
|
try: |
|
|
max_length = int(request.form.get('max_length', 100)) |
|
|
if not (10 <= max_length <= 512): |
|
|
raise ValueError("Max length must be between 10 and 512.") |
|
|
except ValueError as e: |
|
|
flash(f'Invalid Max Length value: {e}', 'danger') |
|
|
return redirect(url_for('index')) |
|
|
|
|
|
if file.filename == '': |
|
|
flash('No image selected for uploading.', 'warning') |
|
|
return redirect(url_for('index')) |
|
|
|
|
|
if file and allowed_file(file.filename): |
|
|
try: |
|
|
image_bytes = file.read() |
|
|
|
|
|
original_filename = file.filename |
|
|
patient_info = parse_patient_info(original_filename) |
|
|
if patient_info: |
|
|
print(f"Parsed Patient Info: {patient_info}") |
|
|
else: |
|
|
print(f"Could not parse patient info from filename: {original_filename}") |
|
|
|
|
|
report = generate_report(image_bytes, vlm_choice, max_length) |
|
|
|
|
|
if isinstance(report, str) and report.startswith("Error"): |
|
|
flash(f'Report Generation Failed: {report}', 'danger') |
|
|
image_data = base64.b64encode(image_bytes).decode('utf-8') |
|
|
return render_template('index.html', |
|
|
report=None, |
|
|
image_data=image_data, |
|
|
patient_info=patient_info, |
|
|
chatbot_available=chatbot_available) |
|
|
|
|
|
image_data = base64.b64encode(image_bytes).decode('utf-8') |
|
|
|
|
|
return render_template('index.html', |
|
|
report=report, |
|
|
image_data=image_data, |
|
|
patient_info=patient_info, |
|
|
chatbot_available=chatbot_available) |
|
|
|
|
|
except FileNotFoundError as fnf_error: |
|
|
flash(f'Model file not found: {fnf_error}. Please check server configuration.', 'danger') |
|
|
print(f"Model file error: {fnf_error}\n{traceback.format_exc()}") |
|
|
return redirect(url_for('index')) |
|
|
except RuntimeError as rt_error: |
|
|
flash(f'Model loading error: {rt_error}. Please check server logs.', 'danger') |
|
|
print(f"Runtime error during prediction: {rt_error}\n{traceback.format_exc()}") |
|
|
return redirect(url_for('index')) |
|
|
except Exception as e: |
|
|
flash(f'An unexpected error occurred during prediction: {e}', 'danger') |
|
|
print(f"Error during prediction: {e}\n{traceback.format_exc()}") |
|
|
return redirect(url_for('index')) |
|
|
else: |
|
|
flash('Invalid image file type. Allowed types: png, jpg, jpeg.', 'danger') |
|
|
return redirect(url_for('index')) |
|
|
|
|
|
@app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
if not llama_model or not llama_tokenizer: |
|
|
return jsonify({"answer": "Chatbot is not available."}), 503 |
|
|
|
|
|
data = request.get_json() |
|
|
if not data or 'question' not in data or 'report_context' not in data: |
|
|
return jsonify({"error": "Missing question or report context"}), 400 |
|
|
|
|
|
question = data['question'] |
|
|
report_context = data['report_context'] |
|
|
|
|
|
try: |
|
|
answer = generate_chat_response(question, report_context) |
|
|
return jsonify({"answer": answer}) |
|
|
except Exception as e: |
|
|
print(f"Error in /chat endpoint: {e}") |
|
|
print(traceback.format_exc()) |
|
|
return jsonify({"error": "Failed to generate chat response"}), 500 |
|
|
|
|
|
if __name__ == '__main__': |
|
|
app.run(host='0.0.0.0', port=5000, debug=False) |
|
|
|