LaunchLLM / fine_tuning /peft_trainer.py
Bmccloud22's picture
Deploy LaunchLLM - Production AI Training Platform
ec8f374 verified
"""
PEFT Trainer Module
General Parameter-Efficient Fine-Tuning trainer supporting multiple PEFT methods.
"""
from typing import Optional, List, Dict, Any
from pathlib import Path
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer
)
from peft import (
get_peft_model,
PeftConfig,
PeftModel,
prepare_model_for_kbit_training
)
class PEFTTrainer:
"""
General PEFT Trainer supporting multiple parameter-efficient fine-tuning methods.
Supports:
- LoRA (Low-Rank Adaptation)
- Prefix Tuning
- P-Tuning
- Prompt Tuning
- IA3 (Infused Adapter by Inhibiting and Amplifying Inner Activations)
"""
def __init__(
self,
model_name: str,
peft_config: PeftConfig,
output_dir: str = "./models/peft_output"
):
"""
Initialize PEFT Trainer.
Args:
model_name: HuggingFace model path or name
peft_config: PEFT configuration (LoraConfig, PrefixTuningConfig, etc.)
output_dir: Directory for saving checkpoints and final model
"""
self.model_name = model_name
self.peft_config = peft_config
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.model = None
self.tokenizer = None
self.trainer = None
def load_model(
self,
use_4bit: bool = False,
use_8bit: bool = False,
device_map: str = "auto"
) -> None:
"""
Load model with PEFT configuration.
Args:
use_4bit: Use 4-bit quantization
use_8bit: Use 8-bit quantization
device_map: Device mapping strategy
"""
print(f"Loading model: {self.model_name}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Quantization config
quantization_config = None
if use_4bit or use_8bit:
from transformers import BitsAndBytesConfig
if use_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
else:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# Load base model
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=True
)
# Prepare for k-bit training if quantized
if use_4bit or use_8bit:
self.model = prepare_model_for_kbit_training(self.model)
# Apply PEFT
self.model = get_peft_model(self.model, self.peft_config)
# Print trainable parameters
self.model.print_trainable_parameters()
print("βœ… Model loaded with PEFT")
def save_model(self, save_path: Optional[str] = None) -> None:
"""
Save PEFT adapter weights.
Args:
save_path: Path to save adapters
"""
if save_path is None:
save_path = str(self.output_dir / "final_model")
Path(save_path).mkdir(parents=True, exist_ok=True)
self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
print(f"βœ… PEFT model saved to: {save_path}")
def load_adapter(self, adapter_path: str) -> None:
"""
Load pre-trained PEFT adapter.
Args:
adapter_path: Path to adapter weights
"""
print(f"Loading PEFT adapter from: {adapter_path}")
self.model = PeftModel.from_pretrained(
self.model,
adapter_path
)
print("βœ… Adapter loaded")