import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
# Configuration
NUM_DEVICES = 50
EPOCHS = 2
ROUNDS = 100
BATCH_SIZE = 64
LEARNING_RATE = 0.01
DEVICE_PER_ROUND = 10 # For round robin
# Energy, Bandwidth, SNR initialization
device_energy = np.random.uniform(30, 100, NUM_DEVICES)
device_bandwidth = np.random.uniform(1.0, 5.0, NUM_DEVICES)
device_snr = np.random.uniform(10, 40, NUM_DEVICES)
# Dataset loading (split among devices)
transform = transforms.Compose([transforms.ToTensor()])
full_dataset = datasets.MNIST('./data', train=True, download=True,
transform=transform)
datasets_split = torch.utils.data.random_split(full_dataset,
[int(len(full_dataset)/NUM_DEVICES)]*NUM_DEVICES)
# CNN Model with separable representation and classifier (for FedRep)
class CNN_FedRep(nn.Module):
def __init__(self):
super(CNN_FedRep, self).__init__()
self.representation = nn.Sequential(
nn.Conv2d(1, 10, 5),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten()
)
self.classifier = nn.Linear(1440, 10)
def forward(self, x):
rep = self.representation(x)
return self.classifier(rep)
# Initialize models and optimizers
models = [CNN_FedRep() for _ in range(NUM_DEVICES)]
optimizers_rep = [optim.SGD(model.representation.parameters(), lr=LEARNING_RATE)
for model in models]
optimizers_cls = [optim.SGD(model.classifier.parameters(), lr=LEARNING_RATE) for
model in models]
loss_fn = nn.CrossEntropyLoss()
# Training function for FedRep
def train_rep(model, data_loader, optimizer_rep, optimizer_cls=None,
train_classifier=False):
model.train()
for epoch in range(EPOCHS):
for data, target in data_loader:
if train_classifier:
optimizer_cls.zero_grad()
optimizer_rep.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer_rep.step()
if train_classifier:
optimizer_cls.step()
return model.classifier.state_dict()
# Model averaging (FedRep: only classifier aggregated)
def average_classifiers(classifier_states):
new_state = {}
for key in classifier_states[0].keys():
new_state[key] = sum([state[key] for state in classifier_states]) /
len(classifier_states)
return new_state
# Evaluation
def evaluate_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
# Test loader
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False,
download=True, transform=transform), batch_size=1000)
# === Federated Training ===
accuracy_round_robin = []
accuracy_fedrep = []
for round_idx in range(ROUNDS):
print(f"\n--- Round {round_idx+1} ---")
# Round Robin Selection
start = (round_idx * DEVICE_PER_ROUND) % NUM_DEVICES
selected_devices = list(range(start, start + DEVICE_PER_ROUND))
print(f"Selected Devices (Round Robin): {selected_devices}")
# === FedRep Training ===
classifier_states = []
for i in selected_devices:
data_loader = torch.utils.data.DataLoader(datasets_split[i],
batch_size=BATCH_SIZE, shuffle=True)
state_dict = train_rep(models[i], data_loader, optimizers_rep[i],
optimizers_cls[i], train_classifier=True)
classifier_states.append(state_dict)
if classifier_states:
avg_classifier = average_classifiers(classifier_states)
for model in models:
model.classifier.load_state_dict(avg_classifier)
acc_fedrep = evaluate_model(models[0], test_loader)
accuracy_fedrep.append(acc_fedrep)
print(f"FedRep Accuracy: {acc_fedrep*100:.2f}%")
# === Device Info ===
print("\n--- Device Initialization Info ---")
for i in range(NUM_DEVICES):
print(f"Device {i}: Energy = {device_energy[i]:.2f}, Bandwidth =
{device_bandwidth[i]:.2f}, SNR = {device_snr[i]:.2f}")
# === Plotting ===
plt.plot(range(1, ROUNDS+1), [round(a*100, 2) for a in accuracy_fedrep],
label='FedRep', marker='o')
plt.xlabel("Rounds")
plt.ylabel("Accuracy (%)")
plt.title("FedRep Accuracy (Round Robin Scheduling)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()