Thanks to visit codestin.com
Credit goes to www.scribd.com

0% found this document useful (0 votes)
17 views3 pages

Paper Code 1

The document outlines a federated learning framework using a convolutional neural network (CNN) for training on the MNIST dataset across multiple devices. It implements a round-robin selection for device participation in training, where each device trains a representation and classifier, followed by averaging the classifier states. The performance is evaluated over multiple rounds, and results are plotted to show accuracy improvements over time.

Uploaded by

Sravan Kumar
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
17 views3 pages

Paper Code 1

The document outlines a federated learning framework using a convolutional neural network (CNN) for training on the MNIST dataset across multiple devices. It implements a round-robin selection for device participation in training, where each device trains a representation and classifier, followed by averaging the classifier states. The performance is evaluated over multiple rounds, and results are plotted to show accuracy improvements over time.

Uploaded by

Sravan Kumar
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 3

import torch

import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# Configuration
NUM_DEVICES = 50
EPOCHS = 2
ROUNDS = 100
BATCH_SIZE = 64
LEARNING_RATE = 0.01
DEVICE_PER_ROUND = 10 # For round robin

# Energy, Bandwidth, SNR initialization


device_energy = np.random.uniform(30, 100, NUM_DEVICES)
device_bandwidth = np.random.uniform(1.0, 5.0, NUM_DEVICES)
device_snr = np.random.uniform(10, 40, NUM_DEVICES)

# Dataset loading (split among devices)


transform = transforms.Compose([transforms.ToTensor()])
full_dataset = datasets.MNIST('./data', train=True, download=True,
transform=transform)
datasets_split = torch.utils.data.random_split(full_dataset,
[int(len(full_dataset)/NUM_DEVICES)]*NUM_DEVICES)

# CNN Model with separable representation and classifier (for FedRep)


class CNN_FedRep(nn.Module):
def __init__(self):
super(CNN_FedRep, self).__init__()
self.representation = nn.Sequential(
nn.Conv2d(1, 10, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten()
)
self.classifier = nn.Linear(1440, 10)

def forward(self, x):


rep = self.representation(x)
return self.classifier(rep)

# Initialize models and optimizers


models = [CNN_FedRep() for _ in range(NUM_DEVICES)]
optimizers_rep = [optim.SGD(model.representation.parameters(), lr=LEARNING_RATE)
for model in models]
optimizers_cls = [optim.SGD(model.classifier.parameters(), lr=LEARNING_RATE) for
model in models]
loss_fn = nn.CrossEntropyLoss()

# Training function for FedRep


def train_rep(model, data_loader, optimizer_rep, optimizer_cls=None,
train_classifier=False):
model.train()
for epoch in range(EPOCHS):
for data, target in data_loader:
if train_classifier:
optimizer_cls.zero_grad()
optimizer_rep.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer_rep.step()
if train_classifier:
optimizer_cls.step()
return model.classifier.state_dict()

# Model averaging (FedRep: only classifier aggregated)


def average_classifiers(classifier_states):
new_state = {}
for key in classifier_states[0].keys():
new_state[key] = sum([state[key] for state in classifier_states]) /
len(classifier_states)
return new_state

# Evaluation
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total

# Test loader
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False,
download=True, transform=transform), batch_size=1000)

# === Federated Training ===


accuracy_round_robin = []
accuracy_fedrep = []

for round_idx in range(ROUNDS):


print(f"\n--- Round {round_idx+1} ---")

# Round Robin Selection


start = (round_idx * DEVICE_PER_ROUND) % NUM_DEVICES
selected_devices = list(range(start, start + DEVICE_PER_ROUND))
print(f"Selected Devices (Round Robin): {selected_devices}")

# === FedRep Training ===


classifier_states = []
for i in selected_devices:
data_loader = torch.utils.data.DataLoader(datasets_split[i],
batch_size=BATCH_SIZE, shuffle=True)
state_dict = train_rep(models[i], data_loader, optimizers_rep[i],
optimizers_cls[i], train_classifier=True)
classifier_states.append(state_dict)

if classifier_states:
avg_classifier = average_classifiers(classifier_states)
for model in models:
model.classifier.load_state_dict(avg_classifier)
acc_fedrep = evaluate_model(models[0], test_loader)
accuracy_fedrep.append(acc_fedrep)
print(f"FedRep Accuracy: {acc_fedrep*100:.2f}%")

# === Device Info ===


print("\n--- Device Initialization Info ---")
for i in range(NUM_DEVICES):
print(f"Device {i}: Energy = {device_energy[i]:.2f}, Bandwidth =
{device_bandwidth[i]:.2f}, SNR = {device_snr[i]:.2f}")

# === Plotting ===


plt.plot(range(1, ROUNDS+1), [round(a*100, 2) for a in accuracy_fedrep],
label='FedRep', marker='o')
plt.xlabel("Rounds")
plt.ylabel("Accuracy (%)")
plt.title("FedRep Accuracy (Round Robin Scheduling)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

You might also like