This project implements an Early Exit Convolutional Neural Network (EE-CNN) for efficient image classification on the CIFAR-10 dataset. The model uses reinforcement learning to dynamically decide when to exit the network early, optimizing the trade-off between computational efficiency and accuracy.
The system consists of two main components:
-
Early Exit CNN: A deep neural network with multiple exit points, allowing predictions at different depths:
- 4 exit points with increasing complexity
- Enhanced feature extraction at each stage
- Batch normalization and dropout for regularization
- Progressive increase in channel dimensions (64→128→256→512)
-
DQN Agent: A reinforcement learning agent that learns when to exit:
- Makes exit decisions based on confidence scores
- Balances accuracy and computational efficiency
- Uses experience replay for stable training
- Implements epsilon-greedy exploration
Our model achieves strong performance metrics:
- Overall Accuracy: 88.77%
- Compute Savings: 39.9%
- Effectiveness Score: 35.38
Exit Point Distribution:
- Exit 1: 4.8%
- Exit 2: 37.5%
- Exit 3: 30.3%
- Exit 4: 27.4%
Per-Class Performance:
- Best Classes: Ship (94.3%), Car (94.2%), Truck (93.7%)
- Most Challenging: Dog (78.4%), Cat (80.3%), Bird (80.0%)
Per-Class Exit Performance:
First Exit (≈67% accuracy):
- Strong: Truck (81.4%), Horse (71.7%), Plane (75.4%)
- Weak: Bird (40.7%), Cat (42.1%)
Second Exit (≈78% accuracy):
- Strong: Ship (91.3%), Truck (92.4%), Car (90.8%)
- Weak: Bird (71.5%), Cat (70.9%)
Third Exit (≈88% accuracy):
- Strong: Ship (94.5%), Car (94.5%), Plane (92.5%)
- Weak: Dog (79.7%), Cat (79.3%)
Final Exit (≈89% accuracy):
- Strong: Car (94.6%), Ship (94.4%), Truck (93.7%)
- Weak: Dog (78.4%), Bird (80.1%)
Confidence Statistics:
- Exit 1: Mean=0.673, Std=0.211
- Exit 2: Mean=0.807, Std=0.206
- Exit 3: Mean=0.838, Std=0.184
- Exit 4: Mean=0.706, Std=0.181
Below are some visualizations of the result from the experiments:
Per Class Performance:
Training Progress: 
src/ ├── models/ # Model architectures │ ├── early_exit_cnn.py # CNN implementation │ ├── dqn_agent.py # DQN agent │ └── environment.py # Training environment ├── training/ # Training implementations │ ├── train_cnn.py # CNN training │ └── train_rl.py # RL training ├── evaluation/ # Evaluation code │ └── evaluate.py # Evaluation metrics ├── inference/ # Inference implementation │ └── inference.py # Inference code └── visualization/ # Visualization tools └── visualize.py # Plotting functions - Clone the repository:
git clone https://github.com/Shikha-code36/early-exit-cnn.git cd early-exit-cnn- Install dependencies:
pip install -r requirements.txt- Train the CNN model:
from src.training.train_cnn import pretrain_cnn from src.data.data_loader import load_cifar10_data # Load data train_loader, test_loader = load_cifar10_data(batch_size=128) # Train model losses, accuracies = pretrain_cnn(model, train_loader, num_epochs=50)- Train the RL agent:
from src.training.train_rl import train_rl_agent rewards, exit_counts = train_rl_agent( model, agent, env, train_loader, num_episodes=5000 )Run inference on new images:
from src.inference.inference import EarlyExitInference # Initialize inference inferencer = EarlyExitInference(model_path='models/') # Process image result = inferencer.process_image("path/to/image.jpg") print(f"Prediction: {result['class']}") print(f"Confidence: {result['confidence']:.2f}") print(f"Exit Point: {result['exit_point']}")Evaluate model performance:
from src.evaluation.evaluate import evaluate_model metrics = evaluate_model(model, agent, test_loader) print(f"Accuracy: {metrics['accuracy']:.2f}%") print(f"Compute Saved: {metrics['compute_saved']:.2f}%")The Early Exit CNN employs a progressive architecture:
-
First Exit (64 channels):
- Basic feature extraction
- Early exit for simple cases
- 68.4% accuracy for easy classes
-
Second Exit (128 channels):
- Intermediate processing
- Improved feature representation
- 86.8% accuracy for moderate cases
-
Third Exit (256 channels):
- Advanced feature processing
- Enhanced classification capability
- 92.9% accuracy for complex cases
-
Final Exit (512 channels):
- Deep feature extraction
- Comprehensive classification
- 92.2% accuracy for challenging cases
The project includes several visualization tools:
-
Training Progress:
- Loss curves
- Accuracy per exit point
- Exit distribution
-
Analysis Tools:
- Confidence distributions
- Class-wise exit patterns
- Performance heatmaps
Example visualization code:
from src.visualization.visualize import plot_training_metrics plot_training_metrics( train_losses=losses, accuracies_per_exit=accuracies, exit_distributions=exit_dist )Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use this code in your research, please cite:
@misc{early-exit-cnn-2024, author = {Shikha Pandey}, title = {Early Exit CNN with RL-based Decision Making}, year = {2025}, publisher = {GitHub}, url = {https://github.com/Shikha-code36/early-exit-cnn} }- CIFAR-10 dataset
- PyTorch team for the deep learning framework
- Reinforcement learning community for DQN implementations