This project demonstrates a distributed machine learning system that leverages model and data parallelism across multiple machines. It includes two primary implementations: a CNN-based inference system using model parallelism and a Linear Regression system using data parallelism with Dask.
The architecture is designed to run over a local area network. A master node coordinates tasks, distributes models or data, and aggregates results from multiple worker nodes. Communication is handled using UDP for worker registration and TCP for data exchange.
- Distributed CNN model using model parallelism
- Linear Regression training using Dask and data parallelism
- REST API built with Flask for initiating training and making predictions
- Lightweight and runs on low-spec machines over LAN
- Python 3.10
- PyTorch
- Dask
- Flask
- Scikit-learn
- NumPy, Pandas
- TCP/UDP socket communication
- Pickle for serialization
master.py
– Controls model/data distribution, API serverworker.py
– Handles assigned training or inference tasksflask_server.py
– Hosts Flask routes for/train
and/predict
utils/
– Utility functions for serialization and configurationmodel.pkl
– Sample CNN model (PyTorch)dataset.csv
– Input data for regressionrequirements.txt
– List of dependenciesbankend_81.py
- Linear Regression Distributed Model
- Use
/train
API endpoint to start distributed training. - Use
/predict
API endpoint to send input data and get predictions.