diff --git a/examples/machine_learning/CNN/CNN.ipynb b/examples/machine_learning/CNN/CNN.ipynb new file mode 100644 index 000000000..ebd206781 --- /dev/null +++ b/examples/machine_learning/CNN/CNN.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8a7e6b7d-5d6e-4ffc-ae58-6dd8857af672", + "metadata": {}, + "outputs": [], + "source": [ + "import fastplotlib as fpl\n", + "import torch\n", + "import numpy as np\n", + "import zmq" + ] + }, + { + "cell_type": "markdown", + "id": "bc76c29a-a47e-4a5e-9ac3-0c829b50dad8", + "metadata": {}, + "source": [ + "# Set up zmq as client subscriber" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6174ff4-d2ea-4904-a635-0804faf9c1f1", + "metadata": {}, + "outputs": [], + "source": [ + "context = zmq.Context()\n", + "sub = context.socket(zmq.SUB)\n", + "sub.setsockopt(zmq.SUBSCRIBE, b\"\")\n", + "\n", + "# keep only the most recent message\n", + "sub.setsockopt(zmq.CONFLATE, 1)\n", + "\n", + "# address must match publisher in Processor actor\n", + "sub.connect(\"tcp://127.0.0.1:5555\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cbfaf88-a474-44ec-996a-c626f71a428e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_buffer():\n", + " \"\"\"Gets the buffer from the publisher.\"\"\"\n", + " try:\n", + " b = sub.recv(zmq.NOBLOCK)\n", + " except zmq.Again:\n", + " pass\n", + " else:\n", + " return b\n", + " \n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a574e0c1-5c9c-450e-a91b-63f0ebcfc584", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the figure\n", + "figure = fpl.Figure(names=[[\"conv1 weights\"]], size=(700, 500))\n", + "\n", + "figure[0,0].axes.visible = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ef6d573-624e-4880-b343-f42d303885c2", + "metadata": {}, + "outputs": [], + "source": [ + "def update_frame(p):\n", + " \"\"\"Update the frame using data received from the socket.\"\"\"\n", + " buff = get_buffer()\n", + " if buff is not None:\n", + " # Deserialize the buffer into a NumPy array\n", + " data = np.frombuffer(buff, dtype=np.float64)\n", + "\n", + " data = data.reshape(20, 40) \n", + "\n", + " if len(p.graphics) == 0:\n", + " p.add_image(data, name=\"weights\", cmap=\"viridis\")\n", + " else:\n", + " # Update the line plot\n", + " p[\"weights\"].data = data\n", + " p[\"weights\"].cmap = gray\n", + "\n", + " p.auto_scale()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd59022f-b0fe-4f9f-939b-9c2e9f3dc4a2", + "metadata": {}, + "outputs": [], + "source": [ + "# Add the animation update function\n", + "figure[0, 0].add_animations(update_frame)\n", + "\n", + "figure.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d45053a7-465f-4d1f-9d68-f93d313db935", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/machine_learning/CNN/cnn.py b/examples/machine_learning/CNN/cnn.py new file mode 100644 index 000000000..6e06e214a --- /dev/null +++ b/examples/machine_learning/CNN/cnn.py @@ -0,0 +1,160 @@ +""" +Convolutional Neural Network Model Weights +========================================== + +Example showing how you can explore the model weights of a simple Convolutional Neural Network (CNN) +during training. +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = false + +import fastplotlib as fpl +import numpy as np +import torch +import zmq + +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +from torchvision import datasets, transforms + +from torch.optim.lr_scheduler import StepLR +import tqdm + +# set up zmq connection to notebook +context = zmq.Context() +socket = context.socket(zmq.PUB) +socket.bind("tcp://127.0.0.1:5555") + +# check if GPU with cuda is available +if torch.cuda.is_available(): + device = torch.device("cuda") +# if not, use CPU +else: + device = torch.device("cpu") +print(f"Device: {device}") + +# define a simple CNN architecture +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 32, kernel_size=5) + self.conv3 = nn.Conv2d(32,64, kernel_size=5) + self.fc1 = nn.Linear(3*3*64, 256) + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu(F.max_pool2d(self.conv3(x),2)) + x = F.dropout(x, p=0.5, training=self.training) + x = x.view(-1,3*3*64 ) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +# create model and put to device +model = CNN().to(device) +print(f"Model Architecture:\n {model}") + +# load the dataset +# transform to apply to the images +transform=transforms.Compose([ + transforms.ToTensor(), # convert to tensor + transforms.Normalize((0.1307,), (0.3081,)) # normalize with specified mean and sd + ]) + +data = datasets.MNIST('./data', train=True, download=True, + transform=transform) + +train_loader = torch.utils.data.DataLoader(data, batch_size=32, num_workers=1, shuffle=True) + +# sample visual of inputs +#fig_data = fpl.Figure(shape=(1,5), size=(900,300)) + +# Print the first few images in a row +# for j, (image, label) in enumerate(train_loader): +# for i in range(5): +# fig_data[0, i].add_image(np.asarray(image[i].squeeze()), cmap="gray") +# fig_data[0, i].set_title(f"Label: {label[i].item()}") +# fig_data[0, i].axes.visible = False +# fig_data[0, i].toolbar = False +# +# break # Exit the loop after printing 5 samples +# +# fig_data.show() + + +# send the initial weights via zmq to notebook +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + + +# train the model +def train(model, device, train_loader, optimizer, epoch): + global socket + + # make sure model is in train mode + model.train() + + correct = 0 + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + predicted = torch.max(output.data, 1)[1] + correct += (predicted == target).sum() + if batch_idx % 1000 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Accuracy:{:.3f}%'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item(), float(correct*100) / float(32 * (batch_idx + 1)))) + + +# define optimizer +optimizer = torch.optim.Adam(model.parameters() ,lr=0.001) +# define scheduler for learning rate +scheduler = StepLR(optimizer, step_size=1) + +# train the model +# for epoch in tqdm.tqdm(range(0, 5)): + +# epoch 0 +train(model, device, train_loader, optimizer, 0) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +# epoch 1 +train(model, device, train_loader, optimizer, 1) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +# epoch 2 +train(model, device, train_loader, optimizer, 2) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +#socket.close() + +# NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively +# please see our docs for using fastplotlib interactively in ipython and jupyter +if __name__ == "__main__": + print(__doc__) + fpl.loop.run()