Cardiosense-AG commited on
Commit
4bb50d6
·
verified ·
1 Parent(s): ced1708

Update src/model_loader.py

Browse files
Files changed (1) hide show
  1. src/model_loader.py +93 -134
src/model_loader.py CHANGED
@@ -1,186 +1,145 @@
1
  # src/model_loader.py
 
 
 
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
 
4
- import json
5
  import os
6
  import time
7
  from functools import lru_cache
8
- from typing import List, Dict, Tuple
9
 
10
  import torch
11
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
12
-
13
- try:
14
- from transformers import BitsAndBytesConfig
15
- _HAS_BNB = True
16
- except Exception:
17
- _HAS_BNB = False
18
-
19
- from .paths import hf_cache_dir
20
 
21
- _LOG_PREFIX = "[model_loader]"
22
 
 
 
23
 
24
- def _env_flag(name: str) -> bool:
25
- return os.getenv(name, "").strip().lower() in {"1", "true", "yes", "on"}
26
 
27
-
28
- def _select_model() -> Tuple[str, bool, bool]:
29
- """
30
- Returns (model_id, use_cuda, use_4bit)
31
- - Primary when CUDA available
32
- - Fallback to CPU model otherwise or when FORCE_CPU_LLM=1
33
- """
34
- primary = os.getenv("MODEL_ID", "google/medgemma-27b-text-it")
35
- fallback = os.getenv("MODEL_FALLBACK_ID", "google/medgemma-4b-it")
36
- force_cpu = _env_flag("FORCE_CPU_LLM")
37
- quant4 = (os.getenv("QUANT_MODE", "4bit").lower() == "4bit")
38
-
39
- if not force_cpu and torch.cuda.is_available():
40
- print(f"{_LOG_PREFIX} CUDA available. Selecting primary model: {primary} (4-bit={quant4 and _HAS_BNB})")
41
- return primary, True, (quant4 and _HAS_BNB)
42
  else:
43
- print(f"{_LOG_PREFIX} Using CPU fallback model: {fallback}")
44
- return fallback, False, False
45
 
46
 
47
  @lru_cache(maxsize=1)
48
  def _load_tokenizer(model_id: str):
49
- cache = str(hf_cache_dir())
50
- tok = AutoTokenizer.from_pretrained(
51
- model_id,
52
- cache_dir=cache,
53
- use_fast=True,
54
- trust_remote_code=True,
55
- )
56
- if tok.pad_token_id is None and tok.eos_token_id is not None:
57
- tok.pad_token = tok.eos_token
58
- print(f"{_LOG_PREFIX} Tokenizer loaded: {model_id} | cache={cache}")
59
  return tok
60
 
61
 
62
- @lru_cache(maxsize=1)
63
- def _load_model(model_id: str, use_cuda: bool, use_4bit: bool):
64
- cache = str(hf_cache_dir())
65
- t0 = time.perf_counter()
 
 
66
 
67
- if use_cuda and use_4bit and _HAS_BNB:
68
- bnb_config = BitsAndBytesConfig(
69
- load_in_4bit=True,
70
- bnb_4bit_quant_type="nf4",
71
- bnb_4bit_compute_dtype=torch.bfloat16,
72
- bnb_4bit_use_double_quant=True,
73
- )
74
- model = AutoModelForCausalLM.from_pretrained(
75
- model_id,
76
- cache_dir=cache,
77
- device_map="auto",
78
- torch_dtype=torch.bfloat16,
79
- quantization_config=bnb_config,
80
- trust_remote_code=True,
81
- )
82
- quant_txt = "4-bit (bnb, nf4)"
83
  else:
84
- # CPU or non-quantized path
85
- device_map = "auto" if use_cuda else {"": "cpu"}
86
- model = AutoModelForCausalLM.from_pretrained(
87
- model_id,
88
- cache_dir=cache,
89
- device_map=device_map,
90
- torch_dtype=torch.bfloat16 if use_cuda else torch.float32,
91
- low_cpu_mem_usage=True,
92
- trust_remote_code=True,
93
- )
94
- quant_txt = "none"
95
-
96
- dt = time.perf_counter() - t0
97
- print(f"{_LOG_PREFIX} Model loaded: {model_id} | quant={quant_txt} | time={dt:.2f}s")
98
  return model
99
 
100
 
101
- def _format_messages(tokenizer, messages: List[Dict[str, str]]):
102
- """
103
- Bypass chat template and build a simple instruction prompt.
104
- This avoids models that output placeholder JSON schemas.
105
- """
106
- sys_text = "\n".join([m["content"] for m in messages if m["role"] == "system"])
107
- usr_text = "\n".join([m["content"] for m in messages if m["role"] == "user"])
108
- prompt = (
109
- f"{sys_text.strip()}\n\n"
110
- f"---\n"
111
- f"{usr_text.strip()}\n\n"
112
- f"Respond only with valid JSON for the SOAP draft as described above."
113
- )
114
- tokens = tokenizer(prompt, return_tensors="pt")
115
- return tokens["input_ids"]
116
-
117
 
118
 
119
- def _stub_json_response() -> str:
120
  """
121
- Deterministic JSON for end-to-end UI tests.
122
  """
123
- obj = {
124
- "subjective": "Patient reports intermittent exertional chest tightness for 2 months, no rest pain.",
125
- "objective": "BP 132/78, HR 72, BMI 29. No murmurs. LDL 155 mg/dL, A1C 7.8%. eGFR 52.",
126
- "assessment": [
127
- "Stable angina symptoms with ASCVD risk factors (DM2, hyperlipidemia).",
128
- "No red flags on history/exam today."
129
- ],
130
- "plan": [
131
- "Start/continue high‑intensity statin; consider ezetimibe if LDL >70 on maximally tolerated statin.",
132
- "Low‑dose aspirin for secondary prevention if established ASCVD; otherwise not routine for primary prevention.",
133
- "Cardiology referral if symptoms persist or worsen; consider stress testing."
134
- ]
135
- }
136
- return json.dumps(obj, ensure_ascii=False)
137
 
138
 
139
  def generate_chat(
140
  messages: List[Dict[str, str]],
141
  *,
142
- max_new_tokens: int = 512,
143
  temperature: float = 0.2,
144
  top_p: float = 0.95,
145
  ) -> str:
146
- """
147
- Main text generation entry point.
148
- - Honors E2E_STUB=1 for deterministic JSON (no model load).
149
- - Otherwise loads tokenizer/model (GPU-first, CPU fallback) and generates.
150
- """
151
- if _env_flag("E2E_STUB"):
152
- print(f"{_LOG_PREFIX} E2E_STUB=1 — returning deterministic JSON without model load.")
153
- return _stub_json_response()
154
-
155
- model_id, use_cuda, use_4bit = _select_model()
156
  tok = _load_tokenizer(model_id)
157
- model = _load_model(model_id, use_cuda, use_4bit)
158
 
159
- inputs = _format_messages(tok, messages)
160
- input_ids = inputs.to(model.device)
161
 
162
- gen_cfg = dict(
163
  max_new_tokens=max_new_tokens,
164
  do_sample=True,
165
  temperature=temperature,
166
  top_p=top_p,
 
167
  eos_token_id=tok.eos_token_id,
168
- pad_token_id=tok.pad_token_id,
169
  )
170
 
171
- t0 = time.perf_counter()
172
  with torch.no_grad():
173
- output_ids = model.generate(
174
- input_ids=input_ids,
175
- **gen_cfg,
176
- )
177
- dt = time.perf_counter() - t0
 
 
 
 
178
 
179
- # Return only the newly generated tokens
180
- generated = output_ids[0, input_ids.shape[-1]:]
181
- text = tok.decode(generated, skip_special_tokens=True)
182
- print(f"{_LOG_PREFIX} Generated {generated.shape[-1]} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p})")
183
- return text
184
 
185
 
186
 
 
1
  # src/model_loader.py
2
+ # -----------------------------------------------------------------------------
3
+ # Why this change
4
+ # -----------------------------------------------------------------------------
5
+ # - Fix fallback model id → 'google/medgemma-4b-text-it' (previous typo caused
6
+ # CPU-only runs to fail).
7
+ # - Keep primary on GPU in 4-bit (bnb, nf4) when available; otherwise fallback.
8
+ # - Provide a single generate_chat(messages, **gen_kwargs) entry point with
9
+ # consistent logging and without relying on chat templates (manual prompt).
10
+ # - Lightweight logs show model choice, cache path, and generation time.
11
+ # -----------------------------------------------------------------------------
12
+
13
  from __future__ import annotations
14
 
 
15
  import os
16
  import time
17
  from functools import lru_cache
18
+ from typing import Dict, List
19
 
20
  import torch
21
+ from transformers import (
22
+ AutoModelForCausalLM,
23
+ AutoTokenizer,
24
+ BitsAndBytesConfig,
25
+ )
 
 
 
 
26
 
27
+ HF_CACHE = os.environ.get("HF_HOME") or os.environ.get("TRANSFORMERS_CACHE") or "/data/econsult/hf_cache"
28
 
29
+ MODEL_PRIMARY_ID = os.environ.get("MODEL_PRIMARY_ID", "google/medgemma-27b-text-it")
30
+ MODEL_FALLBACK_ID = os.environ.get("MODEL_FALLBACK_ID", "google/medgemma-4b-text-it") # <-- fixed
31
 
 
 
32
 
33
+ def _pick_device_and_quant() -> Dict[str, object]:
34
+ cuda = torch.cuda.is_available()
35
+ if cuda:
36
+ # Prefer 4-bit NF4 on GPU for the primary model
37
+ quant = BitsAndBytesConfig(
38
+ load_in_4bit=True,
39
+ bnb_4bit_use_double_quant=True,
40
+ bnb_4bit_quant_type="nf4",
41
+ bnb_4bit_compute_dtype=torch.bfloat16,
42
+ )
43
+ return {"device_map": "auto", "quantization_config": quant, "torch_dtype": torch.bfloat16}
 
 
 
 
44
  else:
45
+ # CPU path: no bnb quantization. Load smaller fallback model in fp32.
46
+ return {"device_map": "auto", "torch_dtype": torch.float32}
47
 
48
 
49
  @lru_cache(maxsize=1)
50
  def _load_tokenizer(model_id: str):
51
+ tok = AutoTokenizer.from_pretrained(model_id, cache_dir=HF_CACHE, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
52
  return tok
53
 
54
 
55
+ @lru_cache(maxsize=2)
56
+ def _load_model(model_id: str, use_quant: bool):
57
+ opts = _pick_device_and_quant()
58
+ if not torch.cuda.is_available():
59
+ # CPU: avoid quantization args that require CUDA
60
+ opts.pop("quantization_config", None)
61
 
62
+ t0 = time.time()
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ model_id,
65
+ cache_dir=HF_CACHE,
66
+ trust_remote_code=True,
67
+ **opts,
68
+ )
69
+ dt = time.time() - t0
70
+ if torch.cuda.is_available() and "quantization_config" in opts:
71
+ print(f"[model_loader] Model loaded: {model_id} | quant=4-bit (bnb, nf4) | time={dt:.2f}s")
 
 
 
 
 
 
72
  else:
73
+ dtype = "fp32" if opts.get("torch_dtype") == torch.float32 else str(opts.get("torch_dtype"))
74
+ print(f"[model_loader] Model loaded: {model_id} | dtype={dtype} | time={dt:.2f}s")
 
 
 
 
 
 
 
 
 
 
 
 
75
  return model
76
 
77
 
78
+ def _select_ids() -> str:
79
+ # Prefer primary if CUDA; otherwise fallback
80
+ if torch.cuda.is_available():
81
+ print(f"[model_loader] CUDA available. Selecting primary model: {MODEL_PRIMARY_ID} (4-bit=True)")
82
+ return MODEL_PRIMARY_ID
83
+ else:
84
+ print(f"[model_loader] CUDA not available. Selecting fallback model: {MODEL_FALLBACK_ID} (CPU)")
85
+ return MODEL_FALLBACK_ID
 
 
 
 
 
 
 
 
86
 
87
 
88
+ def _build_prompt(messages: List[Dict[str, str]]) -> str:
89
  """
90
+ Manual prompt (avoid chat templates). We keep it simple and instructive.
91
  """
92
+ sys = ""
93
+ turns = []
94
+ for m in messages:
95
+ role = m.get("role", "user")
96
+ content = m.get("content", "")
97
+ if role == "system":
98
+ sys = content.strip()
99
+ elif role == "user":
100
+ turns.append(f"User: {content.strip()}")
101
+ elif role == "assistant":
102
+ turns.append(f"Assistant: {content.strip()}")
103
+ prompt = (sys + "\n\n" if sys else "") + "\n".join(turns) + "\nAssistant:"
104
+ return prompt
 
105
 
106
 
107
  def generate_chat(
108
  messages: List[Dict[str, str]],
109
  *,
110
+ max_new_tokens: int = 700,
111
  temperature: float = 0.2,
112
  top_p: float = 0.95,
113
  ) -> str:
114
+ model_id = _select_ids()
 
 
 
 
 
 
 
 
 
115
  tok = _load_tokenizer(model_id)
116
+ model = _load_model(model_id, use_quant=torch.cuda.is_available())
117
 
118
+ prompt = _build_prompt(messages)
119
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
120
 
121
+ gen_kwargs = dict(
122
  max_new_tokens=max_new_tokens,
123
  do_sample=True,
124
  temperature=temperature,
125
  top_p=top_p,
126
+ pad_token_id=tok.eos_token_id,
127
  eos_token_id=tok.eos_token_id,
 
128
  )
129
 
130
+ t0 = time.time()
131
  with torch.no_grad():
132
+ out = model.generate(**inputs, **gen_kwargs)
133
+ dt = time.time() - t0
134
+
135
+ text = tok.decode(out[0], skip_special_tokens=True)
136
+ # Strip the prompt
137
+ generated = text[len(prompt) :].strip()
138
+
139
+ print(f"[model_loader] Generated {max_new_tokens} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p})")
140
+ print(f"[model_loader] Tokenizer loaded: {model_id} | cache={HF_CACHE}")
141
 
142
+ return generated
 
 
 
 
143
 
144
 
145