This project implements a federated learning framework using PyTorch on the CIFAR-10 dataset. The goal is to simulate a scenario where multiple clients (e.g., devices in a distributed system) collaboratively train a global neural network model without sharing their local data. This approach enhances data privacy and security while leveraging decentralized data sources.
- Federated Learning: A machine learning technique where multiple decentralized devices collaboratively train a model without sharing raw data.
- PyTorch: An open-source machine learning library used for training the neural network model.
- CIFAR-10: A widely-used dataset for image classification tasks.
- Multithreading and Parallel Processing: Techniques to handle concurrent training of client models.
- Logging and Exception Handling: To ensure robust and traceable code execution.
- Python 3.8 or higher
- PyTorch
- Torchvision
- NumPy
- Matplotlib
- TQDM
-
Clone the repository:
git clone https://github.com/yourusername/federated-learning-pytorch.git cd federated-learning-pytorch -
Create a virtual environment:
python3 -m venv venv source venv/bin/activate -
Install the required packages:
pip install -r requirements.txt
-
Download and preprocess the CIFAR-10 dataset.
-
Train the federated learning model:
python src/main.py --num_clients 10 --epochs 20 --max_workers 10 --batch_size 4
-
Evaluate the trained model on the test dataset.
src/
├── __init__.py
├── main.py
├── data.py
├── model.py
├── federated.py
└── utils.py
requirements.txt
README.md
The main script to orchestrate the federated learning process.
- Arguments:
--num_clients: Number of clients to simulate.--epochs: Number of epochs for training.--max_workers: Maximum number of parallel workers.--batch_size: Batch size for DataLoader.
Handles data downloading, preprocessing, and splitting.
- Classes and Methods:
DataLoaderWrapper.download_dataset(): Downloads and transforms the CIFAR-10 dataset.DataLoaderWrapper.split_dataset(dataset, num_clients): Splits the dataset into subsets for each client.
Defines the neural network model used for classification.
- Classes:
Net: A convolutional neural network with two convolutional layers and three fully connected layers.
Contains the logic for federated learning, including local model training and federated averaging.
- Functions:
FederatedLearning.train_local_model(client_id, client_loader, net, epochs, global_progress): Trains a local model on a client's data.FederatedLearning.federated_averaging(global_model, client_models): Averages the model parameters from all clients.FederatedLearning.train_client(client_id, client_loader, global_model_state_dict, results, epochs, global_progress): Trains a client's model in parallel.FederatedLearning.train_federated(trainset, num_clients, epochs, path, max_workers, batch_size): Orchestrates federated learning across all clients.FederatedLearning.check(testloader, path): Evaluates the global model on the test set.
Provides utility functions for logging and exception handling.
- Functions:
setup_logging(): Sets up logging to file and stdout.handle_exception(exc_type, exc_value, exc_traceback): Handles uncaught exceptions and logs them.capture_warnings(): Captures warnings and logs them.
-
Prepare the environment:
python3 -m venv venv source venv/bin/activate pip install -r requirements.txt -
Run the main script to start the federated learning process:
python src/main.py --num_clients 10 --epochs 20 --max_workers 10 --batch_size 4
-
Monitor the training progress through logging output and progress bars.
-
Evaluate the trained model using the provided test dataset.