Patch loading SparseEncoder from Hub

#2
by tomaarsen HF Staff - opened
adapter_config.json → lora/adapter_config.json RENAMED
File without changes
adapter_model.safetensors → lora/adapter_model.safetensors RENAMED
File without changes
splade.py CHANGED
@@ -6,16 +6,24 @@ This file supports two loading paths:
6
  1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
- The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
10
- loads the base model and applies the adapter.
11
  """
12
 
 
 
13
  import torch
14
  from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
15
  from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
16
  from transformers.utils import is_flash_attn_2_available
17
  from .utils import prepare_tokenizer, splade_max, similarity, encode
18
 
 
 
 
 
 
 
19
 
20
  class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
21
  def tie_weights(self, *args, **kwargs):
@@ -40,27 +48,33 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
40
 
41
  @classmethod
42
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
 
43
  from peft import PeftConfig, PeftModel
44
 
45
- try:
46
- peft_config = PeftConfig.from_pretrained(
47
- pretrained_model_name_or_path, token=kwargs.get("token")
 
 
 
 
 
 
 
 
 
48
  )
49
- except Exception:
50
- peft_config = None
51
 
52
- if peft_config is None:
53
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
54
 
 
 
55
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
56
  config = kwargs.pop("config", None)
57
  if config is None or not isinstance(config, PretrainedConfig):
58
- config = AutoConfig.from_pretrained(
59
- pretrained_model_name_or_path, token=kwargs.get("token")
60
- )
61
-
62
- # We apply the adapter manually below, so drop any auto-PEFT hints to avoid double loading
63
- kwargs.pop("adapter_kwargs", None)
64
 
65
  base_model = super().from_pretrained(
66
  peft_config.base_model_name_or_path,
@@ -69,9 +83,7 @@ class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
69
  **kwargs,
70
  )
71
 
72
- return PeftModel.from_pretrained(
73
- base_model, pretrained_model_name_or_path, token=kwargs.get("token")
74
- )
75
 
76
 
77
  class SpladeConfig(PretrainedConfig):
@@ -128,7 +140,7 @@ class Splade(PreTrainedModel):
128
  )
129
 
130
  def save_pretrained(self, save_directory, *args, **kwargs):
131
- self.model.save_pretrained(save_directory)
132
  self.config.save_pretrained(save_directory)
133
 
134
  @classmethod
 
6
  1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
+ The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B in the `lora/` subfolder;
10
+ `Qwen3ForCausalLM.from_pretrained` loads the base model and applies the adapter.
11
  """
12
 
13
+ import os
14
+
15
  import torch
16
  from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
17
  from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
18
  from transformers.utils import is_flash_attn_2_available
19
  from .utils import prepare_tokenizer, splade_max, similarity, encode
20
 
21
+ # The adapter lives in this subfolder rather than at the repo root so that
22
+ # `find_adapter_config_file` doesn't trigger transformers' auto-PEFT path,
23
+ # which would otherwise redirect hub loads to `Qwen/Qwen3-8B` and lose the
24
+ # `auto_map` routing to the classes in this file.
25
+ ADAPTER_SUBFOLDER = "lora"
26
+
27
 
28
  class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
29
  def tie_weights(self, *args, **kwargs):
 
48
 
49
  @classmethod
50
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
51
+ from huggingface_hub import snapshot_download
52
  from peft import PeftConfig, PeftModel
53
 
54
+ token = kwargs.get("token")
55
+
56
+ # Resolve the adapter to a local path before handing it to PEFT.
57
+ # PEFT's `subfolder=` kwarg uses `os.path.join` on Windows, producing
58
+ # backslashed hub paths that break the safetensors-vs-bin fallback.
59
+ if os.path.isdir(pretrained_model_name_or_path):
60
+ adapter_path = os.path.join(pretrained_model_name_or_path, ADAPTER_SUBFOLDER)
61
+ else:
62
+ local_repo = snapshot_download(
63
+ pretrained_model_name_or_path,
64
+ allow_patterns=[f"{ADAPTER_SUBFOLDER}/*"],
65
+ token=token,
66
  )
67
+ adapter_path = os.path.join(local_repo, ADAPTER_SUBFOLDER)
 
68
 
69
+ if not os.path.isfile(os.path.join(adapter_path, "adapter_config.json")):
70
  return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
71
 
72
+ peft_config = PeftConfig.from_pretrained(adapter_path, token=token)
73
+
74
  # Use provided splade config (has is_causal=False) or load it from the adapter repo
75
  config = kwargs.pop("config", None)
76
  if config is None or not isinstance(config, PretrainedConfig):
77
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, token=token)
 
 
 
 
 
78
 
79
  base_model = super().from_pretrained(
80
  peft_config.base_model_name_or_path,
 
83
  **kwargs,
84
  )
85
 
86
+ return PeftModel.from_pretrained(base_model, adapter_path, token=token)
 
 
87
 
88
 
89
  class SpladeConfig(PretrainedConfig):
 
140
  )
141
 
142
  def save_pretrained(self, save_directory, *args, **kwargs):
143
+ self.model.save_pretrained(os.path.join(save_directory, ADAPTER_SUBFOLDER))
144
  self.config.save_pretrained(save_directory)
145
 
146
  @classmethod