This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.
Understanding LoRA: The Game-Changing Technique
Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That's exactly what LoRA (Low-Rank Adaptation) does for language models.
The Traditional Fine-Tuning Problem
Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:
- ❌ Massive memory requirements
- ❌ Slow training
- ❌ Risk of "catastrophic forgetting" (losing general knowledge)
- ❌ Large model files
The LoRA Solution
LoRA adds small "adapter" layers that learn new behaviors while keeping the original model frozen:
- ✅ Minimal memory usage
- ✅ Fast training
- ✅ Preserves general knowledge
- ✅ Tiny adapter file size
- ✅ Can be combined or switched out easily
How LoRA Works Under the Hood
Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.
MLX: Apple's Secret Weapon for AI
MLX is Apple's machine learning framework designed specifically for Apple Silicon. It's what makes our local fine-tuning possible and incredibly fast.
Why MLX is good for Local AI
- Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
- Optimized Computation: Hand-tuned for Apple Silicon's specific capabilities
- Memory Efficiency: Intelligent memory management for maximum model sizes
- Python Integration: Easy to use while being incredibly fast
Setting Up Our Fine-Tuning Pipeline
Let us build our fine-tuning system step by step, understanding each component.
Step 1: Configuration and Setup
First, let's create a comprehensive configuration system:
touch fine_tuning_config.py
# Create fine_tuning_config.py import os from pathlib import Path import mlx.core as mx class FineTuningConfig: """Centralized configuration for fine-tuning""" def __init__(self): # Model configuration self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct" self.adapter_path = "./adapters/email_sentiment" # Data paths self.train_data_path = "./data/mlx_format/train.jsonl" self.valid_data_path = "./data/mlx_format/valid.jsonl" # LoRA parameters self.lora_layers = 16 # Number of transformer layers to add LoRA to self.lora_rank = 16 # The 'r' in LoRA - higher = more capacity but slower self.lora_alpha = 32 # Scaling factor for LoRA adapters # Training parameters self.batch_size = 2 # Batch size (reduce if out of memory) self.learning_rate = 5e-5 # Learning rate self.max_iters = 1000 # Maximum training iterations self.steps_per_report = 10 # How often to print progress self.steps_per_eval = 200 # How often to run validation self.save_every = 400 # How often to save checkpoints # Hardware optimization self.use_gpu = mx.metal.is_available() self.max_sequence_length = 2048 # Create directories Path(self.adapter_path).mkdir(parents=True, exist_ok=True) def print_config(self): """Print current configuration""" print("🔧 Fine-tuning Configuration:") print(f" Base model: {self.base_model}") print(f" GPU available: {self.use_gpu}") print(f" LoRA rank: {self.lora_rank}") print(f" LoRA layers: {self.lora_layers}") print(f" Batch size: {self.batch_size}") print(f" Learning rate: {self.learning_rate}") print(f" Max iterations: {self.max_iters}") print(f" Adapter path: {self.adapter_path}") # Create and test config if __name__ == "__main__": config = FineTuningConfig() config.print_config() Step 2: Memory and Performance Monitoring
Before we start fine-tuning, let's create tools to monitor our system:
touch monitoring.py
# Create monitoring.py import time import mlx.core as mx from typing import Dict, List import psutil class PerformanceMonitor: """Monitor memory usage and training performance""" def __init__(self): self.start_time = time.time() self.metrics = [] def log_memory_usage(self, step: int, loss: float = None): """Log current memory and performance metrics""" # GPU memory (if available) gpu_memory = {} if mx.metal.is_available(): gpu_memory = { 'active_mb': mx.metal.get_active_memory() / 1e6, 'peak_mb': mx.metal.get_peak_memory() / 1e6 } # System memory system_memory = psutil.virtual_memory() # Training metrics elapsed = time.time() - self.start_time metrics = { 'step': step, 'elapsed_seconds': elapsed, 'loss': loss, 'gpu_active_mb': gpu_memory.get('active_mb', 0), 'gpu_peak_mb': gpu_memory.get('peak_mb', 0), 'system_memory_percent': system_memory.percent, 'system_memory_available_gb': system_memory.available / 1e9 } self.metrics.append(metrics) if step % 50 == 0: # Print every 50 steps self.print_status(metrics) return metrics def print_status(self, metrics: Dict): """Print current training status""" print(f"Step {metrics['step']:4d} | " f"Loss: {metrics['loss']:.4f} | " f"GPU: {metrics['gpu_active_mb']:.0f}MB | " f"Time: {metrics['elapsed_seconds']:.1f}s") def get_training_summary(self): """Get summary of training run""" if not self.metrics: return {} peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics) total_time = self.metrics[-1]['elapsed_seconds'] final_loss = self.metrics[-1]['loss'] return { 'total_training_time': total_time, 'peak_gpu_memory_mb': peak_gpu, 'final_loss': final_loss, 'steps_completed': len(self.metrics) } Step 3: The Fine-Tuning Engine
Now let's create our main fine-tuning script using MLX-LM:
touch fine_tune_model.py
# Create fine_tune_model.py import subprocess import time import json import os from pathlib import Path from fine_tuning_config import FineTuningConfig from monitoring import PerformanceMonitor class MLXFineTuner: """Fine-tune models using MLX with LoRA""" def __init__(self, config: FineTuningConfig): self.config = config self.monitor = PerformanceMonitor() def validate_data(self): """Validate that training data exists and is properly formatted""" print("📊 Validating training data...") if not os.path.exists(self.config.train_data_path): raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}") # Count training examples train_count = 0 with open(self.config.train_data_path, 'r') as f: for line in f: if line.strip(): train_count += 1 print(f"✅ Found {train_count} training examples") # Validate format with open(self.config.train_data_path, 'r') as f: first_line = f.readline() try: example = json.loads(first_line) if 'text' not in example: raise ValueError("Training data must have 'text' field") print("✅ Data format validated") except json.JSONDecodeError: raise ValueError("Training data must be valid JSONL format") return train_count def build_training_command(self): """Build the MLX-LM training command""" cmd = [ "python3", "-m", "mlx_lm", "lora", "--model", self.config.base_model, "--train", "--data", "./data/mlx_format", # Directory containing train.jsonl "--batch-size", str(self.config.batch_size), "--iters", str(self.config.max_iters), "--learning-rate", str(self.config.learning_rate), "--steps-per-report", str(self.config.steps_per_report), "--steps-per-eval", str(self.config.steps_per_eval), "--adapter-path", self.config.adapter_path, "--save-every", str(self.config.save_every) ] return cmd def run_fine_tuning(self): """Execute the fine-tuning process""" print("🚀 Starting LoRA fine-tuning with MLX...") print("=" * 60) # Validate everything is ready train_count = self.validate_data() self.config.print_config() # Build command cmd = self.build_training_command() print(f"\n📝 Command: {' '.join(cmd)}") # Start training start_time = time.time() print(f"\n🏃 Training started at {time.strftime('%H:%M:%S')}") print(f"📚 Training on {train_count} examples") print("💡 This typically takes 3-10 minutes on Apple Silicon M3") print("⏰ Progress will be reported every 10 steps\n") try: # Run the training command result = subprocess.run(cmd, capture_output=True, text=True, check=True) training_time = time.time() - start_time print("\n" + "="*60) print("🎉 Fine-tuning completed successfully!") print(f"⏱️ Total training time: {training_time:.1f} seconds") print(f"💾 Adapters saved to: {self.config.adapter_path}") # Save training metadata metadata = { 'model_name': self.config.base_model, 'training_time_seconds': training_time, 'training_examples': train_count, 'lora_rank': self.config.lora_rank, 'lora_layers': self.config.lora_layers, 'batch_size': self.config.batch_size, 'learning_rate': self.config.learning_rate, 'max_iters': self.config.max_iters, 'timestamp': time.time(), 'command_used': ' '.join(cmd) } metadata_path = f"{self.config.adapter_path}/training_metadata.json" with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"📊 Training metadata saved to: {metadata_path}") # Parse and display training output self.parse_training_output(result.stdout) return True, metadata except subprocess.CalledProcessError as e: print("\n❌ Fine-tuning failed!") print(f"Error code: {e.returncode}") print(f"Error output: {e.stderr}") print(f"Standard output: {e.stdout}") return False, None def parse_training_output(self, output: str): """Parse and display key information from training output""" print("\n📈 Training Progress Summary:") print("-" * 40) lines = output.split('\n') # Look for key training metrics for line in lines: if 'Loss:' in line or 'Validation' in line: print(f" {line.strip()}") # Look for final metrics for line in reversed(lines): if 'Loss:' in line: print(f"\n🎯 Final training loss: {line.split('Loss:')[-1].strip()}") break def verify_training_output(self): """Verify that training produced the expected files""" print("\n🔍 Verifying training output...") adapter_path = Path(self.config.adapter_path) # Check for adapter files adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz")) if adapter_files: print(f"✅ Found adapter files: {[f.name for f in adapter_files]}") else: print("❌ No adapter files found") return False # Check for configuration config_file = adapter_path / "adapter_config.json" if config_file.exists(): print(f"✅ Found adapter config: {config_file}") # Display config contents with open(config_file, 'r') as f: config_data = json.load(f) print(f" LoRA rank: {config_data.get('r', 'unknown')}") print(f" LoRA alpha: {config_data.get('lora_alpha', 'unknown')}") else: print("⚠️ No adapter config found") # Calculate total size total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file()) print(f"📁 Total adapter size: {total_size / 1e6:.1f} MB") return True def main(): """Main fine-tuning execution""" print("🤖 MLX LoRA Fine-Tuning Pipeline") print("=" * 50) # Create configuration config = FineTuningConfig() # Create fine-tuner fine_tuner = MLXFineTuner(config) # Run fine-tuning success, metadata = fine_tuner.run_fine_tuning() if success: # Verify output fine_tuner.verify_training_output() print("\n✨ Fine-tuning pipeline completed successfully!") print("\n🎯 Next steps:") print(" 1. Test your fine-tuned model") print(" 2. Run evaluation to measure performance") print(" 3. Build your application interface") return metadata else: print("\n💥 Fine-tuning failed. Please check the error messages above.") return None if __name__ == "__main__": metadata = main()
Top comments (0)