Cardiosense-AG commited on
Commit
20f1bde
·
verified ·
1 Parent(s): eae09d2

Update src/model_loader.py

Browse files
Files changed (1) hide show
  1. src/model_loader.py +80 -67
src/model_loader.py CHANGED
@@ -4,7 +4,7 @@ from __future__ import annotations
4
  import os
5
  import time
6
  from functools import lru_cache
7
- from typing import Dict, List
8
 
9
  import torch
10
  from transformers import (
@@ -12,23 +12,42 @@ from transformers import (
12
  AutoTokenizer,
13
  BitsAndBytesConfig,
14
  )
15
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
16
 
17
  # --- Diagnostic print to confirm runtime versions ---
18
  import transformers
19
  print("[init]", "torch", torch.__version__, "transformers", transformers.__version__)
20
 
21
-
22
  HF_CACHE = os.environ.get("HF_HOME") or os.environ.get("TRANSFORMERS_CACHE") or "/data/econsult/hf_cache"
23
 
24
- # Accept MODEL_ID (preferred) or fallback to MODEL_PRIMARY_ID to avoid env-name drift.
25
- MODEL_PRIMARY_ID = os.environ.get("MODEL_ID") or os.environ.get("MODEL_PRIMARY_ID", "google/medgemma-27b-text-it")
26
- MODEL_FALLBACK_ID = os.environ.get("MODEL_FALLBACK_ID", "google/medgemma-4b-text-it")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
28
 
29
  def _pick_device_and_quant() -> Dict[str, object]:
30
- cuda = torch.cuda.is_available()
31
- if cuda:
32
  quant = BitsAndBytesConfig(
33
  load_in_4bit=True,
34
  bnb_4bit_use_double_quant=True,
@@ -36,20 +55,19 @@ def _pick_device_and_quant() -> Dict[str, object]:
36
  bnb_4bit_compute_dtype=torch.bfloat16,
37
  )
38
  return {"device_map": "auto", "quantization_config": quant, "torch_dtype": torch.bfloat16}
39
- else:
40
- return {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
41
-
42
-
43
- @lru_cache(maxsize=1)
44
- def _select_ids() -> str:
45
- # Prefer explicit env override; else keep default.
46
- model_id = (os.environ.get("MODEL_ID") or MODEL_PRIMARY_ID).strip()
47
- fb = MODEL_FALLBACK_ID.strip()
48
- # Simple sanity hints
49
- if not model_id:
50
- model_id = fb
51
- return model_id
52
-
53
 
54
  @lru_cache(maxsize=1)
55
  def _load_tokenizer(model_id: str):
@@ -59,39 +77,47 @@ def _load_tokenizer(model_id: str):
59
  tok.pad_token = tok.eos_token
60
  return tok
61
 
62
-
63
  @lru_cache(maxsize=1)
64
- def _load_model(model_id: str, use_quant: bool = True):
65
  device_kwargs = _pick_device_and_quant() if use_quant else {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
66
- print(f"[model_loader] Loading model: {model_id} | quant={use_quant} | device_kwargs={list(device_kwargs.keys())}")
67
- try:
68
- model = AutoModelForCausalLM.from_pretrained(
69
- model_id,
70
- low_cpu_mem_usage=True,
71
- trust_remote_code=True,
72
- cache_dir=HF_CACHE,
73
- **device_kwargs,
74
- )
75
- model.eval()
76
- return model
77
- except Exception as e:
78
- # Fallback to smaller model on CPU
79
- fb = MODEL_FALLBACK_ID
80
- print(f"[model_loader] Primary load failed: {e}\nFalling back to: {fb}")
81
- model = AutoModelForCausalLM.from_pretrained(
82
- fb,
83
- low_cpu_mem_usage=True,
84
- trust_remote_code=True,
85
- cache_dir=HF_CACHE,
86
- device_map={"": "cpu"},
87
- torch_dtype=torch.float32,
88
- )
89
- model.eval()
90
- return model
91
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- def _build_prompt(messages: List[Dict[str, str]]) -> str:
94
- """Very simple chat prompt for IT models."""
95
  sys_msgs = [m["content"] for m in messages if m.get("role") == "system"]
96
  turns = []
97
  for m in messages:
@@ -99,21 +125,8 @@ def _build_prompt(messages: List[Dict[str, str]]) -> str:
99
  turns.append(f"User: {m['content']}")
100
  elif m.get("role") == "assistant":
101
  turns.append(f"Assistant: {m['content']}")
102
- return (sys_msgs[0] + "\n\n" if sys_msgs else "") + "\n".join(turns) + "\nAssistant:"
103
-
104
-
105
- def generate_chat(
106
- messages: List[Dict[str, str]],
107
- *,
108
- max_new_tokens: int = 700,
109
- temperature: float = 0.2,
110
- top_p: float = 0.95,
111
- ) -> str:
112
- model_id = _select_ids()
113
- tok = _load_tokenizer(model_id)
114
- model = _load_model(model_id, use_quant=torch.cuda.is_available())
115
 
116
- prompt = _build_prompt(messages)
117
  inputs = tok(prompt, return_tensors="pt").to(model.device)
118
 
119
  gen_kwargs = dict(
@@ -133,8 +146,7 @@ def generate_chat(
133
  text = tok.decode(out[0], skip_special_tokens=True)
134
  generated = text[len(prompt):].strip()
135
 
136
- print(f"[model_loader] Generated {max_new_tokens} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p})")
137
- print(f"[model_loader] Tokenizer loaded: {model_id} | cache={HF_CACHE}")
138
  return generated
139
 
140
 
@@ -145,3 +157,4 @@ def generate_chat(
145
 
146
 
147
 
 
 
4
  import os
5
  import time
6
  from functools import lru_cache
7
+ from typing import Dict, List, Tuple
8
 
9
  import torch
10
  from transformers import (
 
12
  AutoTokenizer,
13
  BitsAndBytesConfig,
14
  )
 
15
 
16
  # --- Diagnostic print to confirm runtime versions ---
17
  import transformers
18
  print("[init]", "torch", torch.__version__, "transformers", transformers.__version__)
19
 
 
20
  HF_CACHE = os.environ.get("HF_HOME") or os.environ.get("TRANSFORMERS_CACHE") or "/data/econsult/hf_cache"
21
 
22
+ # -------------------- Env normalization --------------------
23
+
24
+ def _resolve_model_ids() -> Tuple[str, str]:
25
+ """
26
+ Resolve primary/fallback with precedence:
27
+ - Primary: Model_ID > MODEL_ID > MODEL_PRIMARY_ID > default
28
+ - Fallback: Model_Fallback_ID > MODEL_FALLBACK_ID > default
29
+ """
30
+ env = os.environ
31
+ primary = (
32
+ env.get("Model_ID") or
33
+ env.get("MODEL_ID") or
34
+ env.get("MODEL_PRIMARY_ID") or
35
+ "google/medgemma-27b-text-it"
36
+ )
37
+ fallback = (
38
+ env.get("Model_Fallback_ID") or
39
+ env.get("MODEL_FALLBACK_ID") or
40
+ "google/medgemma-4b-it"
41
+ )
42
+ return primary.strip(), fallback.strip()
43
+
44
+ def _force_cpu() -> bool:
45
+ return str(os.environ.get("FORCE_CPU_LLM", "")).strip().lower() in {"1", "true", "yes"}
46
 
47
+ # -------------------- Device & model selection --------------------
48
 
49
  def _pick_device_and_quant() -> Dict[str, object]:
50
+ if torch.cuda.is_available() and not _force_cpu():
 
51
  quant = BitsAndBytesConfig(
52
  load_in_4bit=True,
53
  bnb_4bit_use_double_quant=True,
 
55
  bnb_4bit_compute_dtype=torch.bfloat16,
56
  )
57
  return {"device_map": "auto", "quantization_config": quant, "torch_dtype": torch.bfloat16}
58
+ # CPU path
59
+ return {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
60
+
61
+ def _select_runtime_model_id() -> Tuple[str, bool, str]:
62
+ """
63
+ Returns (selected_model_id, is_fallback, device_label)
64
+ device_label in {"GPU","CPU"}
65
+ """
66
+ primary, fallback = _resolve_model_ids()
67
+ on_gpu = torch.cuda.is_available() and not _force_cpu()
68
+ if on_gpu:
69
+ return primary, False, "GPU"
70
+ return fallback, True, "CPU"
 
71
 
72
  @lru_cache(maxsize=1)
73
  def _load_tokenizer(model_id: str):
 
77
  tok.pad_token = tok.eos_token
78
  return tok
79
 
 
80
  @lru_cache(maxsize=1)
81
+ def _load_model(model_id: str, use_quant: bool):
82
  device_kwargs = _pick_device_and_quant() if use_quant else {"device_map": {"": "cpu"}, "torch_dtype": torch.float32}
83
+ print(f"[model_loader] Loading model: {model_id} | device_kwargs={list(device_kwargs.keys())}")
84
+ model = AutoModelForCausalLM.from_pretrained(
85
+ model_id,
86
+ low_cpu_mem_usage=True,
87
+ trust_remote_code=True,
88
+ cache_dir=HF_CACHE,
89
+ **device_kwargs,
90
+ )
91
+ model.eval()
92
+ return model
93
+
94
+ # -------------------- Public helpers --------------------
95
+
96
+ def active_model_status() -> Dict[str, str | bool]:
97
+ primary, fallback = _resolve_model_ids()
98
+ selected, is_fallback, device = _select_runtime_model_id()
99
+ forced = _force_cpu()
100
+ return {
101
+ "primary_id": primary,
102
+ "fallback_id": fallback,
103
+ "selected_id": selected,
104
+ "device": device,
105
+ "is_fallback": bool(is_fallback or (device == "CPU")),
106
+ "forced_cpu": forced,
107
+ }
108
 
109
+ def generate_chat(
110
+ messages: List[Dict[str, str]],
111
+ *,
112
+ max_new_tokens: int = 700,
113
+ temperature: float = 0.2,
114
+ top_p: float = 0.95,
115
+ ) -> str:
116
+ selected_id, is_fallback, device = _select_runtime_model_id()
117
+ tok = _load_tokenizer(selected_id)
118
+ model = _load_model(selected_id, use_quant=(device == "GPU"))
119
 
120
+ # Very simple chat prompt for IT models.
 
121
  sys_msgs = [m["content"] for m in messages if m.get("role") == "system"]
122
  turns = []
123
  for m in messages:
 
125
  turns.append(f"User: {m['content']}")
126
  elif m.get("role") == "assistant":
127
  turns.append(f"Assistant: {m['content']}")
128
+ prompt = (sys_msgs[0] + "\n\n" if sys_msgs else "") + "\n".join(turns) + "\nAssistant:"
 
 
 
 
 
 
 
 
 
 
 
 
129
 
 
130
  inputs = tok(prompt, return_tensors="pt").to(model.device)
131
 
132
  gen_kwargs = dict(
 
146
  text = tok.decode(out[0], skip_special_tokens=True)
147
  generated = text[len(prompt):].strip()
148
 
149
+ print(f"[model_loader] Generated <= {max_new_tokens} tokens in {dt:.2f}s (temp={temperature}, top_p={top_p}) | {selected_id} on {device} | fallback={is_fallback}")
 
150
  return generated
151
 
152
 
 
157
 
158
 
159
 
160
+