from .configuration_wilai import WilaiConfig from .modeling_wilai import WilaiModel, WilaiForCausalLM from .tokenization_wilai import WilaiTokenizer, WilaiFastTokenizer from .auto import from_pretrained_tokenizer # Register our config and model with transformers' auto-mapping so # AutoModelForCausalLM.from_pretrained(... ) can instantiate WilaiForCausalLM # when the config's "model_type" is set to "wilai". This registration is # performed at import time if the transformers internals are available. try: # Use the proper registration functions from transformers import register_config, register_model_for_auto_class, AutoModelForCausalLM # Register config try: register_config("wilai", WilaiConfig) except Exception as e: print(f"Failed to register config: {e}") # Register model try: register_model_for_auto_class(AutoModelForCausalLM, "wilai", WilaiForCausalLM) except Exception as e: print(f"Failed to register model: {e}") except ImportError as e: print(f"Could not import registration functions: {e}") # Fallback to manual mapping try: from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING try: CONFIG_MAPPING.register("wilai", WilaiConfig) except Exception as e: print(f"Failed to register CONFIG_MAPPING using register: {e}") # Fallback to direct assignment try: CONFIG_MAPPING["wilai"] = WilaiConfig except Exception as e: print(f"Failed to register CONFIG_MAPPING manually: {e}") try: MODEL_FOR_CAUSAL_LM_MAPPING[WilaiConfig] = WilaiForCausalLM except Exception as e: print(f"Failed to register MODEL_FOR_CAUSAL_LM_MAPPING manually: {e}") try: TOKENIZER_MAPPING["wilai"] = WilaiTokenizer except Exception as e: print(f"Failed to register TOKENIZER_MAPPING with model_type: {e}") try: TOKENIZER_MAPPING[WilaiConfig] = WilaiTokenizer except Exception as e: print(f"Failed to register TOKENIZER_MAPPING manually: {e}") except Exception as e: print(f"Manual registration also failed: {e}") pass __all__ = [ "WilaiConfig", "WilaiModel", "WilaiForCausalLM", "WilaiTokenizer", "WilaiFastTokenizer", "from_pretrained_tokenizer", ]