DEV Community

Prashant Nigam
Prashant Nigam

Posted on

The Magic of LoRA Fine-Tuning with MLX (Part 4)

This is where the magic happens! In this part, we will deep dive into LoRA (Low-Rank Adaptation) fine-tuning and use MLX to train our model with incredible efficiency on Apple Silicon.

Understanding LoRA: The Game-Changing Technique

Imagine you are a master chef who wants to learn a new cuisine. Instead of forgetting everything you know and starting from scratch, you add new techniques and flavor profiles to your existing knowledge. That's exactly what LoRA (Low-Rank Adaptation) does for language models.

The Traditional Fine-Tuning Problem

Traditional fine-tuning updates all 1.7 billion parameters of our model. This means:

  • ❌ Massive memory requirements
  • ❌ Slow training
  • ❌ Risk of "catastrophic forgetting" (losing general knowledge)
  • ❌ Large model files

The LoRA Solution

LoRA adds small "adapter" layers that learn new behaviors while keeping the original model frozen:

  • ✅ Minimal memory usage
  • ✅ Fast training
  • ✅ Preserves general knowledge
  • ✅ Tiny adapter file size
  • ✅ Can be combined or switched out easily

How LoRA Works Under the Hood

Think of the original model as a Swiss Army knife with all its tools welded in place. LoRA adds new attachments that can be snapped on or off.

MLX: Apple's Secret Weapon for AI

MLX is Apple's machine learning framework designed specifically for Apple Silicon. It's what makes our local fine-tuning possible and incredibly fast.

Why MLX is good for Local AI

  1. Unified Memory Architecture: M-series chips share memory between CPU and GPU, eliminating data transfer bottlenecks
  2. Optimized Computation: Hand-tuned for Apple Silicon's specific capabilities
  3. Memory Efficiency: Intelligent memory management for maximum model sizes
  4. Python Integration: Easy to use while being incredibly fast

Setting Up Our Fine-Tuning Pipeline

Let us build our fine-tuning system step by step, understanding each component.

Step 1: Configuration and Setup

First, let's create a comprehensive configuration system:

touch fine_tuning_config.py

# Create fine_tuning_config.py import os from pathlib import Path import mlx.core as mx class FineTuningConfig: """Centralized configuration for fine-tuning""" def __init__(self): # Model configuration  self.base_model = "HuggingFaceTB/SmolLM2-1.7B-Instruct" self.adapter_path = "./adapters/email_sentiment" # Data paths  self.train_data_path = "./data/mlx_format/train.jsonl" self.valid_data_path = "./data/mlx_format/valid.jsonl" # LoRA parameters  self.lora_layers = 16 # Number of transformer layers to add LoRA to  self.lora_rank = 16 # The 'r' in LoRA - higher = more capacity but slower  self.lora_alpha = 32 # Scaling factor for LoRA adapters  # Training parameters  self.batch_size = 2 # Batch size (reduce if out of memory)  self.learning_rate = 5e-5 # Learning rate  self.max_iters = 1000 # Maximum training iterations  self.steps_per_report = 10 # How often to print progress  self.steps_per_eval = 200 # How often to run validation  self.save_every = 400 # How often to save checkpoints  # Hardware optimization  self.use_gpu = mx.metal.is_available() self.max_sequence_length = 2048 # Create directories  Path(self.adapter_path).mkdir(parents=True, exist_ok=True) def print_config(self): """Print current configuration""" print("🔧 Fine-tuning Configuration:") print(f" Base model: {self.base_model}") print(f" GPU available: {self.use_gpu}") print(f" LoRA rank: {self.lora_rank}") print(f" LoRA layers: {self.lora_layers}") print(f" Batch size: {self.batch_size}") print(f" Learning rate: {self.learning_rate}") print(f" Max iterations: {self.max_iters}") print(f" Adapter path: {self.adapter_path}") # Create and test config if __name__ == "__main__": config = FineTuningConfig() config.print_config() 
Enter fullscreen mode Exit fullscreen mode

Step 2: Memory and Performance Monitoring

Before we start fine-tuning, let's create tools to monitor our system:

touch monitoring.py

# Create monitoring.py import time import mlx.core as mx from typing import Dict, List import psutil class PerformanceMonitor: """Monitor memory usage and training performance""" def __init__(self): self.start_time = time.time() self.metrics = [] def log_memory_usage(self, step: int, loss: float = None): """Log current memory and performance metrics""" # GPU memory (if available)  gpu_memory = {} if mx.metal.is_available(): gpu_memory = { 'active_mb': mx.metal.get_active_memory() / 1e6, 'peak_mb': mx.metal.get_peak_memory() / 1e6 } # System memory  system_memory = psutil.virtual_memory() # Training metrics  elapsed = time.time() - self.start_time metrics = { 'step': step, 'elapsed_seconds': elapsed, 'loss': loss, 'gpu_active_mb': gpu_memory.get('active_mb', 0), 'gpu_peak_mb': gpu_memory.get('peak_mb', 0), 'system_memory_percent': system_memory.percent, 'system_memory_available_gb': system_memory.available / 1e9 } self.metrics.append(metrics) if step % 50 == 0: # Print every 50 steps  self.print_status(metrics) return metrics def print_status(self, metrics: Dict): """Print current training status""" print(f"Step {metrics['step']:4d} | " f"Loss: {metrics['loss']:.4f} | " f"GPU: {metrics['gpu_active_mb']:.0f}MB | " f"Time: {metrics['elapsed_seconds']:.1f}s") def get_training_summary(self): """Get summary of training run""" if not self.metrics: return {} peak_gpu = max(m['gpu_peak_mb'] for m in self.metrics) total_time = self.metrics[-1]['elapsed_seconds'] final_loss = self.metrics[-1]['loss'] return { 'total_training_time': total_time, 'peak_gpu_memory_mb': peak_gpu, 'final_loss': final_loss, 'steps_completed': len(self.metrics) } 
Enter fullscreen mode Exit fullscreen mode

Step 3: The Fine-Tuning Engine

Now let's create our main fine-tuning script using MLX-LM:

touch fine_tune_model.py

# Create fine_tune_model.py import subprocess import time import json import os from pathlib import Path from fine_tuning_config import FineTuningConfig from monitoring import PerformanceMonitor class MLXFineTuner: """Fine-tune models using MLX with LoRA""" def __init__(self, config: FineTuningConfig): self.config = config self.monitor = PerformanceMonitor() def validate_data(self): """Validate that training data exists and is properly formatted""" print("📊 Validating training data...") if not os.path.exists(self.config.train_data_path): raise FileNotFoundError(f"Training data not found: {self.config.train_data_path}") # Count training examples  train_count = 0 with open(self.config.train_data_path, 'r') as f: for line in f: if line.strip(): train_count += 1 print(f"✅ Found {train_count} training examples") # Validate format  with open(self.config.train_data_path, 'r') as f: first_line = f.readline() try: example = json.loads(first_line) if 'text' not in example: raise ValueError("Training data must have 'text' field") print("✅ Data format validated") except json.JSONDecodeError: raise ValueError("Training data must be valid JSONL format") return train_count def build_training_command(self): """Build the MLX-LM training command""" cmd = [ "python3", "-m", "mlx_lm", "lora", "--model", self.config.base_model, "--train", "--data", "./data/mlx_format", # Directory containing train.jsonl  "--batch-size", str(self.config.batch_size), "--iters", str(self.config.max_iters), "--learning-rate", str(self.config.learning_rate), "--steps-per-report", str(self.config.steps_per_report), "--steps-per-eval", str(self.config.steps_per_eval), "--adapter-path", self.config.adapter_path, "--save-every", str(self.config.save_every) ] return cmd def run_fine_tuning(self): """Execute the fine-tuning process""" print("🚀 Starting LoRA fine-tuning with MLX...") print("=" * 60) # Validate everything is ready  train_count = self.validate_data() self.config.print_config() # Build command  cmd = self.build_training_command() print(f"\n📝 Command: {' '.join(cmd)}") # Start training  start_time = time.time() print(f"\n🏃 Training started at {time.strftime('%H:%M:%S')}") print(f"📚 Training on {train_count} examples") print("💡 This typically takes 3-10 minutes on Apple Silicon M3") print("⏰ Progress will be reported every 10 steps\n") try: # Run the training command  result = subprocess.run(cmd, capture_output=True, text=True, check=True) training_time = time.time() - start_time print("\n" + "="*60) print("🎉 Fine-tuning completed successfully!") print(f"⏱️ Total training time: {training_time:.1f} seconds") print(f"💾 Adapters saved to: {self.config.adapter_path}") # Save training metadata  metadata = { 'model_name': self.config.base_model, 'training_time_seconds': training_time, 'training_examples': train_count, 'lora_rank': self.config.lora_rank, 'lora_layers': self.config.lora_layers, 'batch_size': self.config.batch_size, 'learning_rate': self.config.learning_rate, 'max_iters': self.config.max_iters, 'timestamp': time.time(), 'command_used': ' '.join(cmd) } metadata_path = f"{self.config.adapter_path}/training_metadata.json" with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"📊 Training metadata saved to: {metadata_path}") # Parse and display training output  self.parse_training_output(result.stdout) return True, metadata except subprocess.CalledProcessError as e: print("\n❌ Fine-tuning failed!") print(f"Error code: {e.returncode}") print(f"Error output: {e.stderr}") print(f"Standard output: {e.stdout}") return False, None def parse_training_output(self, output: str): """Parse and display key information from training output""" print("\n📈 Training Progress Summary:") print("-" * 40) lines = output.split('\n') # Look for key training metrics  for line in lines: if 'Loss:' in line or 'Validation' in line: print(f" {line.strip()}") # Look for final metrics  for line in reversed(lines): if 'Loss:' in line: print(f"\n🎯 Final training loss: {line.split('Loss:')[-1].strip()}") break def verify_training_output(self): """Verify that training produced the expected files""" print("\n🔍 Verifying training output...") adapter_path = Path(self.config.adapter_path) # Check for adapter files  adapter_files = list(adapter_path.glob("*.safetensors")) + list(adapter_path.glob("*.npz")) if adapter_files: print(f"✅ Found adapter files: {[f.name for f in adapter_files]}") else: print("❌ No adapter files found") return False # Check for configuration  config_file = adapter_path / "adapter_config.json" if config_file.exists(): print(f"✅ Found adapter config: {config_file}") # Display config contents  with open(config_file, 'r') as f: config_data = json.load(f) print(f" LoRA rank: {config_data.get('r', 'unknown')}") print(f" LoRA alpha: {config_data.get('lora_alpha', 'unknown')}") else: print("⚠️ No adapter config found") # Calculate total size  total_size = sum(f.stat().st_size for f in adapter_path.rglob('*') if f.is_file()) print(f"📁 Total adapter size: {total_size / 1e6:.1f} MB") return True def main(): """Main fine-tuning execution""" print("🤖 MLX LoRA Fine-Tuning Pipeline") print("=" * 50) # Create configuration  config = FineTuningConfig() # Create fine-tuner  fine_tuner = MLXFineTuner(config) # Run fine-tuning  success, metadata = fine_tuner.run_fine_tuning() if success: # Verify output  fine_tuner.verify_training_output() print("\n✨ Fine-tuning pipeline completed successfully!") print("\n🎯 Next steps:") print(" 1. Test your fine-tuned model") print(" 2. Run evaluation to measure performance") print(" 3. Build your application interface") return metadata else: print("\n💥 Fine-tuning failed. Please check the error messages above.") return None if __name__ == "__main__": metadata = main() 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)