|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import time |
|
|
from functools import lru_cache |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
) |
|
|
|
|
|
|
|
|
import transformers |
|
|
print("[init]", "torch", torch.__version__, "transformers", transformers.__version__) |
|
|
|
|
|
HF_CACHE = os.environ.get("HF_HOME") or os.environ.get("TRANSFORMERS_CACHE") or "/data/econsult/hf_cache" |
|
|
|
|
|
|
|
|
|
|
|
def _resolve_model_ids() -> Tuple[str, str]: |
|
|
""" |
|
|
Resolve primary/fallback with precedence: |
|
|
- Primary: Model_ID > MODEL_ID > MODEL_PRIMARY_ID > default |
|
|
- Fallback: Model_Fallback_ID > MODEL_FALLBACK_ID > default |
|
|
""" |
|
|
env = os.environ |
|
|
primary = ( |
|
|
env.get("Model_ID") or |
|
|
env.get("MODEL_ID") or |
|
|
env.get("MODEL_PRIMARY_ID") or |
|
|
"google/medgemma-27b-text-it" |
|
|
) |
|
|
fallback = ( |
|
|
env.get("Model_Fallback_ID") or |
|
|
env.get("MODEL_FALLBACK_ID") or |
|
|
"google/medgemma-4b-it" |
|
|
) |
|
|
return primary.strip(), fallback.strip() |
|
|
|
|
|
def _force_cpu() -> bool: |
|
|
return str(os.environ.get("FORCE_CPU_LLM", "")).strip().lower() in {"1", "true", "yes"} |
|
|
|
|
|
|
|
|
|
|
|
def _pick_device_and_quant() -> Dict[str, object]: |
|
|
if torch.cuda.is_available() and not _force_cpu(): |
|
|
quant = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
) |
|
|
return {"device_map": "auto", "quantization_config": quant, "torch_dtype": torch.bfloat16} |
|
|
|
|
|
return {"device_map": {"": "cpu"}, "torch_dtype": torch.float32} |
|
|
|
|
|
def _select_runtime_model_id() -> Tuple[str, bool, str]: |
|
|
""" |
|
|
Returns (selected_model_id, is_fallback, device_label) |
|
|
device_label in {"GPU","CPU"} |
|
|
""" |
|
|
primary, fallback = _resolve_model_ids() |
|
|
on_gpu = torch.cuda.is_available() and not _force_cpu() |
|
|
if on_gpu: |
|
|
return primary, False, "GPU" |
|
|
return fallback, True, "CPU" |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def _load_tokenizer(model_id: str): |
|
|
print(f"[model_loader] Loading tokenizer: {model_id} (cache={HF_CACHE})") |
|
|
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True, cache_dir=HF_CACHE) |
|
|
if tok.pad_token is None and tok.eos_token is not None: |
|
|
tok.pad_token = tok.eos_token |
|
|
return tok |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
|
def _load_model(model_id: str, use_quant: bool): |
|
|
device_kwargs = _pick_device_and_quant() if use_quant else {"device_map": {"": "cpu"}, "torch_dtype": torch.float32} |
|
|
print(f"[model_loader] Loading model: {model_id} | device_kwargs={list(device_kwargs.keys())}") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=True, |
|
|
cache_dir=HF_CACHE, |
|
|
**device_kwargs, |
|
|
) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def active_model_status() -> Dict[str, str | bool]: |
|
|
primary, fallback = _resolve_model_ids() |
|
|
selected, is_fallback, device = _select_runtime_model_id() |
|
|
forced = _force_cpu() |
|
|
return { |
|
|
"primary_id": primary, |
|
|
"fallback_id": fallback, |
|
|
"selected_id": selected, |
|
|
"device": device, |
|
|
"is_fallback": bool(is_fallback or (device == "CPU")), |
|
|
"forced_cpu": forced, |
|
|
} |
|
|
|
|
|
def generate_chat( |
|
|
messages: List[Dict[str, str]], |
|
|
*, |
|
|
max_new_tokens: int = 700, |
|
|
temperature: float = 0.2, |
|
|
top_p: float = 0.95, |
|
|
) -> str: |
|
|
selected_id, is_fallback, device = _select_runtime_model_id() |
|
|
tok = _load_tokenizer(selected_id) |
|
|
model = _load_model(selected_id, use_quant=(device == "GPU")) |
|
|
|
|
|
|
|
|
sys_msgs = [m["content"] for m in messages if m.get("role") == "system"] |
|
|
turns = [] |
|
|
for m in messages: |
|
|
if m.get("role") == "user": |
|
|
turns.append(f"User: {m['content']}") |
|
|
elif m.get("role") == "assistant": |
|
|
turns.append(f"Assistant: {m['content']}") |
|
|
prompt = (sys_msgs[0] + "\n\n" if sys_msgs else "") + "\n".join(turns) + "\nAssistant:" |
|
|
|
|
|
inputs = tok(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
gen_kwargs = dict( |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
eos_token_id=tok.eos_token_id, |
|
|
) |
|
|
|
|
|
t0 = time.time() |
|
|
with torch.no_grad(): |
|
|
out = model.generate(**inputs, **gen_kwargs) |
|
|
dt = time.time() - t0 |
|
|
|
|
|
text = tok.decode(out[0], skip_special_tokens=True) |
|
|
generated = text[len(prompt):].strip() |
|
|
|
|
|
print(f"[model_loader] Generated <= {max_new_tokens} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p}) | {selected_id} on {device} | fallback={is_fallback}") |
|
|
return generated |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|