ai_econsult_demo / src /model_loader.py
Cardiosense-AG's picture
Update src/model_loader.py
20f1bde verified
raw
history blame
5.06 kB
# src/model_loader.py
from __future__ import annotations
import os
import time
from functools import lru_cache
from typing import Dict, List, Tuple
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
# --- Diagnostic print to confirm runtime versions ---
import transformers
print("[init]", "torch", torch.__version__, "transformers", transformers.__version__)
HF_CACHE = os.environ.get("HF_HOME") or os.environ.get("TRANSFORMERS_CACHE") or "/data/econsult/hf_cache"
# -------------------- Env normalization --------------------
def _resolve_model_ids() -> Tuple[str, str]:
"""
Resolve primary/fallback with precedence:
- Primary: Model_ID > MODEL_ID > MODEL_PRIMARY_ID > default
- Fallback: Model_Fallback_ID > MODEL_FALLBACK_ID > default
"""
env = os.environ
primary = (
env.get("Model_ID") or
env.get("MODEL_ID") or
env.get("MODEL_PRIMARY_ID") or
"google/medgemma-27b-text-it"
)
fallback = (
env.get("Model_Fallback_ID") or
env.get("MODEL_FALLBACK_ID") or
"google/medgemma-4b-it"
)
return primary.strip(), fallback.strip()
def _force_cpu() -> bool:
return str(os.environ.get("FORCE_CPU_LLM", "")).strip().lower() in {"1", "true", "yes"}
# -------------------- Device & model selection --------------------
def _pick_device_and_quant() -> Dict[str, object]:
if torch.cuda.is_available() and not _force_cpu():
quant = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return {"device_map": "auto", "quantization_config": quant, "torch_dtype": torch.bfloat16}
# CPU path
return {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
def _select_runtime_model_id() -> Tuple[str, bool, str]:
"""
Returns (selected_model_id, is_fallback, device_label)
device_label in {"GPU","CPU"}
"""
primary, fallback = _resolve_model_ids()
on_gpu = torch.cuda.is_available() and not _force_cpu()
if on_gpu:
return primary, False, "GPU"
return fallback, True, "CPU"
@lru_cache(maxsize=1)
def _load_tokenizer(model_id: str):
print(f"[model_loader] Loading tokenizer: {model_id} (cache={HF_CACHE})")
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, cache_dir=HF_CACHE)
if tok.pad_token is None and tok.eos_token is not None:
tok.pad_token = tok.eos_token
return tok
@lru_cache(maxsize=1)
def _load_model(model_id: str, use_quant: bool):
device_kwargs = _pick_device_and_quant() if use_quant else {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
print(f"[model_loader] Loading model: {model_id} | device_kwargs={list(device_kwargs.keys())}")
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
trust_remote_code=True,
cache_dir=HF_CACHE,
**device_kwargs,
)
model.eval()
return model
# -------------------- Public helpers --------------------
def active_model_status() -> Dict[str, str | bool]:
primary, fallback = _resolve_model_ids()
selected, is_fallback, device = _select_runtime_model_id()
forced = _force_cpu()
return {
"primary_id": primary,
"fallback_id": fallback,
"selected_id": selected,
"device": device,
"is_fallback": bool(is_fallback or (device == "CPU")),
"forced_cpu": forced,
}
def generate_chat(
messages: List[Dict[str, str]],
*,
max_new_tokens: int = 700,
temperature: float = 0.2,
top_p: float = 0.95,
) -> str:
selected_id, is_fallback, device = _select_runtime_model_id()
tok = _load_tokenizer(selected_id)
model = _load_model(selected_id, use_quant=(device == "GPU"))
# Very simple chat prompt for IT models.
sys_msgs = [m["content"] for m in messages if m.get("role") == "system"]
turns = []
for m in messages:
if m.get("role") == "user":
turns.append(f"User: {m['content']}")
elif m.get("role") == "assistant":
turns.append(f"Assistant: {m['content']}")
prompt = (sys_msgs[0] + "\n\n" if sys_msgs else "") + "\n".join(turns) + "\nAssistant:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
t0 = time.time()
with torch.no_grad():
out = model.generate(**inputs, **gen_kwargs)
dt = time.time() - t0
text = tok.decode(out[0], skip_special_tokens=True)
generated = text[len(prompt):].strip()
print(f"[model_loader] Generated <= {max_new_tokens} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p}) | {selected_id} on {device} | fallback={is_fallback}")
return generated