From 683a7181f4c85a47b79e77b873fcd90274aa9389 Mon Sep 17 00:00:00 2001 From: clewis7 Date: Mon, 17 Feb 2025 12:49:34 -0500 Subject: [PATCH 1/2] CNN weights example --- examples/machine_learning/neural_net.ipynb | 647 +++++++++++++++++++++ 1 file changed, 647 insertions(+) create mode 100644 examples/machine_learning/neural_net.ipynb diff --git a/examples/machine_learning/neural_net.ipynb b/examples/machine_learning/neural_net.ipynb new file mode 100644 index 000000000..9221de94f --- /dev/null +++ b/examples/machine_learning/neural_net.ipynb @@ -0,0 +1,647 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "8a7e6b7d-5d6e-4ffc-ae58-6dd8857af672", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3d596c043e1c49bda39dea86e10aad5a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Image(value=b'version https://git-lfs.github.com/spec/...', height='55', width='300')" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Available devices:
ValidDeviceTypeBackendDriver
Intel(R) Arc(tm) Graphics (MTL)IntegratedGPUVulkanMesa 24.3.2
✅ (default) NVIDIA GeForce RTX 4060 Laptop GPUDiscreteGPUVulkan565.77
❗ limitedllvmpipe (LLVM 19.1.5, 256 bits)CPUVulkanMesa 24.3.2 (LLVM 19.1.5)
Mesa Intel(R) Arc(tm) Graphics (MTL)IntegratedGPUOpenGL4.6 (Core Profile) Mesa 24.3.2
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" + ] + } + ], + "source": [ + "import fastplotlib as fpl\n", + "import torch\n", + "\n", + "\n", + "import numpy as np # to handle matrix and data operation\n", + "#import pandas as pd # to read csv and handle dataframe\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.utils.data\n", + "from torch.autograd import Variable\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from torchvision import datasets, transforms\n", + "from torch.optim.lr_scheduler import StepLR\n", + "import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "ee5e1554-96a4-4611-b873-d6db10b68def", + "metadata": {}, + "source": [ + "## Get the device" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5630c378-f60c-40b4-9c2f-e2d44a6ec31c", + "metadata": {}, + "outputs": [], + "source": [ + "# check if GPU with cuda is available\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "# if not, use CPU\n", + "else:\n", + " device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "edf80fec-4c34-455e-a0ba-fcf78c8c798d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device" + ] + }, + { + "cell_type": "markdown", + "id": "f2690be6-9319-4038-aec6-18a15bd0196d", + "metadata": {}, + "source": [ + "## Define model architecture" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ddc53e5f-c354-4810-bf4a-cc8b9dfe7d13", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CNN(\n", + " (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))\n", + " (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))\n", + " (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n", + " (fc1): Linear(in_features=576, out_features=256, bias=True)\n", + " (fc2): Linear(in_features=256, out_features=10, bias=True)\n", + ")\n" + ] + } + ], + "source": [ + "class CNN(nn.Module):\n", + " def __init__(self):\n", + " super(CNN, self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(32, 32, kernel_size=5)\n", + " self.conv3 = nn.Conv2d(32,64, kernel_size=5)\n", + " self.fc1 = nn.Linear(3*3*64, 256)\n", + " self.fc2 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(self.conv1(x))\n", + " #x = F.dropout(x, p=0.5, training=self.training)\n", + " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = F.relu(F.max_pool2d(self.conv3(x),2))\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = x.view(-1,3*3*64 )\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)\n", + " \n", + "model = CNN().to(device)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "id": "524a5d18-49d2-4d29-b7a2-691f12463b5d", + "metadata": {}, + "source": [ + "## Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9d169693-74c3-4bbe-8c80-284db1b994c8", + "metadata": {}, + "outputs": [], + "source": [ + "# tranform to apply to images\n", + "transform=transforms.Compose([\n", + " transforms.ToTensor(), # convert to tensor\n", + " transforms.Normalize((0.1307,), (0.3081,)) # normalize with specified mean and sd\n", + " ])\n", + "\n", + "data = datasets.MNIST('../data', train=True, download=True,\n", + " transform=transform)\n", + "\n", + "train_loader = torch.utils.data.DataLoader(data, batch_size=32, num_workers=1, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "id": "f73586c4-5f2a-4cb7-833c-a2d90bedbf80", + "metadata": {}, + "source": [ + "## Sample visual of inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d224516a-591e-4ad4-9439-97582f0afa36", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6c1c9f79ad6f4b3ea6630802bc6ae015", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3bcfa7ff3a7f4bf28368298dc6aea39c", + "version_major": 2, + "version_minor": 0 + }, + "text/html": [ + "
snapshot
" + ], + "text/plain": [ + "JupyterRenderCanvas(css_height='300.0px', css_width='900.0px')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fig_data = fpl.Figure(shape=(1,5), size=(900,300))\n", + "\n", + "# Print the first few images in a row\n", + "for j, (image, label) in enumerate(train_loader):\n", + " for i in range(5):\n", + " fig_data[0, i].add_image(image[i].squeeze().numpy(), cmap=\"gray\")\n", + " fig_data[0, i].set_title(f\"Label: {label[i].item()}\")\n", + " fig_data[0, i].axes.visible = False\n", + " fig_data[0, i].toolbar = False\n", + "\n", + " break # Exit the loop after printing 5 samples\n", + "\n", + "fig_data.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "144a3f9f-5e2f-4d26-88fd-d4812e2cbdcb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "c44e6182-2d70-4142-9e85-8153f0fb47b5", + "metadata": {}, + "source": [ + "## Plot the initial weights" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "aef0a3a1-2a98-4dcf-a090-71760984ab19", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "60b2995820024a0fbec325c51062bcde", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'weight': ImageGraphic @ 0x7ff0c1ef96d0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fig_weight = fpl.Figure()\n", + "\n", + "a = model.state_dict()[\"conv1.weight\"].squeeze().reshape(20, 40)\n", + "\n", + "fig_weight[0,0].add_image(a.cpu().numpy(), \"viridis\", name=\"weight\")\n", + "\n", + "# for i, subplot in enumerate(fig_weight):\n", + "# subplot.axes.visible = False\n", + "# subplot.add_image(data=a[i].cpu().numpy(), cmap=\"viridis\", name=\"weight\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "614da99c-7cd8-43b2-99e2-6da75a94c93d", + "metadata": {}, + "outputs": [], + "source": [ + "fig_weight.show(sidecar=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "40ded155-1c97-4b03-beff-0b85d4bcfd07", + "metadata": {}, + "outputs": [], + "source": [ + "fig_weight[0,0].axes.visible = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a42de65-a8f7-4753-8d80-4b2b55e9089e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "3fab73d2-dd20-40ec-8989-11cb95b2cdb9", + "metadata": {}, + "source": [ + "## Train the model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b81aecd8-4b53-4d34-8a63-06a237c7de3f", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, device, train_loader, optimizer, epoch):\n", + " global fig_weight\n", + " model.train()\n", + " correct = 0\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = F.nll_loss(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " predicted = torch.max(output.data, 1)[1]\n", + " correct += (predicted == target).sum()\n", + " if batch_idx % 1000 == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\t Accuracy:{:.3f}%'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item(), float(correct*100) / float(32 * (batch_idx + 1))))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7da55be5-937d-44a6-abc8-95faff0ef90f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/5 [00:00 Date: Mon, 17 Feb 2025 14:35:40 -0500 Subject: [PATCH 2/2] better impl --- examples/machine_learning/CNN/CNN.ipynb | 143 +++++ examples/machine_learning/CNN/cnn.py | 160 +++++ examples/machine_learning/neural_net.ipynb | 647 --------------------- 3 files changed, 303 insertions(+), 647 deletions(-) create mode 100644 examples/machine_learning/CNN/CNN.ipynb create mode 100644 examples/machine_learning/CNN/cnn.py delete mode 100644 examples/machine_learning/neural_net.ipynb diff --git a/examples/machine_learning/CNN/CNN.ipynb b/examples/machine_learning/CNN/CNN.ipynb new file mode 100644 index 000000000..ebd206781 --- /dev/null +++ b/examples/machine_learning/CNN/CNN.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "8a7e6b7d-5d6e-4ffc-ae58-6dd8857af672", + "metadata": {}, + "outputs": [], + "source": [ + "import fastplotlib as fpl\n", + "import torch\n", + "import numpy as np\n", + "import zmq" + ] + }, + { + "cell_type": "markdown", + "id": "bc76c29a-a47e-4a5e-9ac3-0c829b50dad8", + "metadata": {}, + "source": [ + "# Set up zmq as client subscriber" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6174ff4-d2ea-4904-a635-0804faf9c1f1", + "metadata": {}, + "outputs": [], + "source": [ + "context = zmq.Context()\n", + "sub = context.socket(zmq.SUB)\n", + "sub.setsockopt(zmq.SUBSCRIBE, b\"\")\n", + "\n", + "# keep only the most recent message\n", + "sub.setsockopt(zmq.CONFLATE, 1)\n", + "\n", + "# address must match publisher in Processor actor\n", + "sub.connect(\"tcp://127.0.0.1:5555\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cbfaf88-a474-44ec-996a-c626f71a428e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_buffer():\n", + " \"\"\"Gets the buffer from the publisher.\"\"\"\n", + " try:\n", + " b = sub.recv(zmq.NOBLOCK)\n", + " except zmq.Again:\n", + " pass\n", + " else:\n", + " return b\n", + " \n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a574e0c1-5c9c-450e-a91b-63f0ebcfc584", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the figure\n", + "figure = fpl.Figure(names=[[\"conv1 weights\"]], size=(700, 500))\n", + "\n", + "figure[0,0].axes.visible = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ef6d573-624e-4880-b343-f42d303885c2", + "metadata": {}, + "outputs": [], + "source": [ + "def update_frame(p):\n", + " \"\"\"Update the frame using data received from the socket.\"\"\"\n", + " buff = get_buffer()\n", + " if buff is not None:\n", + " # Deserialize the buffer into a NumPy array\n", + " data = np.frombuffer(buff, dtype=np.float64)\n", + "\n", + " data = data.reshape(20, 40) \n", + "\n", + " if len(p.graphics) == 0:\n", + " p.add_image(data, name=\"weights\", cmap=\"viridis\")\n", + " else:\n", + " # Update the line plot\n", + " p[\"weights\"].data = data\n", + " p[\"weights\"].cmap = gray\n", + "\n", + " p.auto_scale()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd59022f-b0fe-4f9f-939b-9c2e9f3dc4a2", + "metadata": {}, + "outputs": [], + "source": [ + "# Add the animation update function\n", + "figure[0, 0].add_animations(update_frame)\n", + "\n", + "figure.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d45053a7-465f-4d1f-9d68-f93d313db935", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/machine_learning/CNN/cnn.py b/examples/machine_learning/CNN/cnn.py new file mode 100644 index 000000000..6e06e214a --- /dev/null +++ b/examples/machine_learning/CNN/cnn.py @@ -0,0 +1,160 @@ +""" +Convolutional Neural Network Model Weights +========================================== + +Example showing how you can explore the model weights of a simple Convolutional Neural Network (CNN) +during training. +""" + +# test_example = false +# sphinx_gallery_pygfx_docs = false + +import fastplotlib as fpl +import numpy as np +import torch +import zmq + +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data + +from torchvision import datasets, transforms + +from torch.optim.lr_scheduler import StepLR +import tqdm + +# set up zmq connection to notebook +context = zmq.Context() +socket = context.socket(zmq.PUB) +socket.bind("tcp://127.0.0.1:5555") + +# check if GPU with cuda is available +if torch.cuda.is_available(): + device = torch.device("cuda") +# if not, use CPU +else: + device = torch.device("cpu") +print(f"Device: {device}") + +# define a simple CNN architecture +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 32, kernel_size=5) + self.conv3 = nn.Conv2d(32,64, kernel_size=5) + self.fc1 = nn.Linear(3*3*64, 256) + self.fc2 = nn.Linear(256, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu(F.max_pool2d(self.conv3(x),2)) + x = F.dropout(x, p=0.5, training=self.training) + x = x.view(-1,3*3*64 ) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + +# create model and put to device +model = CNN().to(device) +print(f"Model Architecture:\n {model}") + +# load the dataset +# transform to apply to the images +transform=transforms.Compose([ + transforms.ToTensor(), # convert to tensor + transforms.Normalize((0.1307,), (0.3081,)) # normalize with specified mean and sd + ]) + +data = datasets.MNIST('./data', train=True, download=True, + transform=transform) + +train_loader = torch.utils.data.DataLoader(data, batch_size=32, num_workers=1, shuffle=True) + +# sample visual of inputs +#fig_data = fpl.Figure(shape=(1,5), size=(900,300)) + +# Print the first few images in a row +# for j, (image, label) in enumerate(train_loader): +# for i in range(5): +# fig_data[0, i].add_image(np.asarray(image[i].squeeze()), cmap="gray") +# fig_data[0, i].set_title(f"Label: {label[i].item()}") +# fig_data[0, i].axes.visible = False +# fig_data[0, i].toolbar = False +# +# break # Exit the loop after printing 5 samples +# +# fig_data.show() + + +# send the initial weights via zmq to notebook +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + + +# train the model +def train(model, device, train_loader, optimizer, epoch): + global socket + + # make sure model is in train mode + model.train() + + correct = 0 + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + predicted = torch.max(output.data, 1)[1] + correct += (predicted == target).sum() + if batch_idx % 1000 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Accuracy:{:.3f}%'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item(), float(correct*100) / float(32 * (batch_idx + 1)))) + + +# define optimizer +optimizer = torch.optim.Adam(model.parameters() ,lr=0.001) +# define scheduler for learning rate +scheduler = StepLR(optimizer, step_size=1) + +# train the model +# for epoch in tqdm.tqdm(range(0, 5)): + +# epoch 0 +train(model, device, train_loader, optimizer, 0) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +# epoch 1 +train(model, device, train_loader, optimizer, 1) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +# epoch 2 +train(model, device, train_loader, optimizer, 2) +scheduler.step() + +# send current model weights +weights = model.state_dict()["conv1.weight"].squeeze() +socket.send(np.asarray(weights.cpu()).ravel().astype(np.float64)) + +#socket.close() + +# NOTE: `if __name__ == "__main__"` is NOT how to use fastplotlib interactively +# please see our docs for using fastplotlib interactively in ipython and jupyter +if __name__ == "__main__": + print(__doc__) + fpl.loop.run() diff --git a/examples/machine_learning/neural_net.ipynb b/examples/machine_learning/neural_net.ipynb deleted file mode 100644 index 9221de94f..000000000 --- a/examples/machine_learning/neural_net.ipynb +++ /dev/null @@ -1,647 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "8a7e6b7d-5d6e-4ffc-ae58-6dd8857af672", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3d596c043e1c49bda39dea86e10aad5a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Image(value=b'version https://git-lfs.github.com/spec/...', height='55', width='300')" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "Available devices:
ValidDeviceTypeBackendDriver
Intel(R) Arc(tm) Graphics (MTL)IntegratedGPUVulkanMesa 24.3.2
✅ (default) NVIDIA GeForce RTX 4060 Laptop GPUDiscreteGPUVulkan565.77
❗ limitedllvmpipe (LLVM 19.1.5, 256 bits)CPUVulkanMesa 24.3.2 (LLVM 19.1.5)
Mesa Intel(R) Arc(tm) Graphics (MTL)IntegratedGPUOpenGL4.6 (Core Profile) Mesa 24.3.2
" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n", - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" - ] - } - ], - "source": [ - "import fastplotlib as fpl\n", - "import torch\n", - "\n", - "\n", - "import numpy as np # to handle matrix and data operation\n", - "#import pandas as pd # to read csv and handle dataframe\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.utils.data\n", - "from torch.autograd import Variable\n", - "\n", - "from sklearn.model_selection import train_test_split\n", - "from torchvision import datasets, transforms\n", - "from torch.optim.lr_scheduler import StepLR\n", - "import tqdm" - ] - }, - { - "cell_type": "markdown", - "id": "ee5e1554-96a4-4611-b873-d6db10b68def", - "metadata": {}, - "source": [ - "## Get the device" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "5630c378-f60c-40b4-9c2f-e2d44a6ec31c", - "metadata": {}, - "outputs": [], - "source": [ - "# check if GPU with cuda is available\n", - "if torch.cuda.is_available():\n", - " device = torch.device(\"cuda\")\n", - "# if not, use CPU\n", - "else:\n", - " device = torch.device(\"cpu\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "edf80fec-4c34-455e-a0ba-fcf78c8c798d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cuda')" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "device" - ] - }, - { - "cell_type": "markdown", - "id": "f2690be6-9319-4038-aec6-18a15bd0196d", - "metadata": {}, - "source": [ - "## Define model architecture" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "ddc53e5f-c354-4810-bf4a-cc8b9dfe7d13", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CNN(\n", - " (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))\n", - " (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))\n", - " (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))\n", - " (fc1): Linear(in_features=576, out_features=256, bias=True)\n", - " (fc2): Linear(in_features=256, out_features=10, bias=True)\n", - ")\n" - ] - } - ], - "source": [ - "class CNN(nn.Module):\n", - " def __init__(self):\n", - " super(CNN, self).__init__()\n", - " self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n", - " self.conv2 = nn.Conv2d(32, 32, kernel_size=5)\n", - " self.conv3 = nn.Conv2d(32,64, kernel_size=5)\n", - " self.fc1 = nn.Linear(3*3*64, 256)\n", - " self.fc2 = nn.Linear(256, 10)\n", - "\n", - " def forward(self, x):\n", - " x = F.relu(self.conv1(x))\n", - " #x = F.dropout(x, p=0.5, training=self.training)\n", - " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", - " x = F.dropout(x, p=0.5, training=self.training)\n", - " x = F.relu(F.max_pool2d(self.conv3(x),2))\n", - " x = F.dropout(x, p=0.5, training=self.training)\n", - " x = x.view(-1,3*3*64 )\n", - " x = F.relu(self.fc1(x))\n", - " x = F.dropout(x, training=self.training)\n", - " x = self.fc2(x)\n", - " return F.log_softmax(x, dim=1)\n", - " \n", - "model = CNN().to(device)\n", - "print(model)" - ] - }, - { - "cell_type": "markdown", - "id": "524a5d18-49d2-4d29-b7a2-691f12463b5d", - "metadata": {}, - "source": [ - "## Load the dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "9d169693-74c3-4bbe-8c80-284db1b994c8", - "metadata": {}, - "outputs": [], - "source": [ - "# tranform to apply to images\n", - "transform=transforms.Compose([\n", - " transforms.ToTensor(), # convert to tensor\n", - " transforms.Normalize((0.1307,), (0.3081,)) # normalize with specified mean and sd\n", - " ])\n", - "\n", - "data = datasets.MNIST('../data', train=True, download=True,\n", - " transform=transform)\n", - "\n", - "train_loader = torch.utils.data.DataLoader(data, batch_size=32, num_workers=1, shuffle=True)" - ] - }, - { - "cell_type": "markdown", - "id": "f73586c4-5f2a-4cb7-833c-a2d90bedbf80", - "metadata": {}, - "source": [ - "## Sample visual of inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "d224516a-591e-4ad4-9439-97582f0afa36", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6c1c9f79ad6f4b3ea6630802bc6ae015", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected skylake derivative running on mesa i915. Clears to srgb textures will use manual shader clears.\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "3bcfa7ff3a7f4bf28368298dc6aea39c", - "version_major": 2, - "version_minor": 0 - }, - "text/html": [ - "
snapshot
" - ], - "text/plain": [ - "JupyterRenderCanvas(css_height='300.0px', css_width='900.0px')" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fig_data = fpl.Figure(shape=(1,5), size=(900,300))\n", - "\n", - "# Print the first few images in a row\n", - "for j, (image, label) in enumerate(train_loader):\n", - " for i in range(5):\n", - " fig_data[0, i].add_image(image[i].squeeze().numpy(), cmap=\"gray\")\n", - " fig_data[0, i].set_title(f\"Label: {label[i].item()}\")\n", - " fig_data[0, i].axes.visible = False\n", - " fig_data[0, i].toolbar = False\n", - "\n", - " break # Exit the loop after printing 5 samples\n", - "\n", - "fig_data.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "144a3f9f-5e2f-4d26-88fd-d4812e2cbdcb", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "c44e6182-2d70-4142-9e85-8153f0fb47b5", - "metadata": {}, - "source": [ - "## Plot the initial weights" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "aef0a3a1-2a98-4dcf-a090-71760984ab19", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "60b2995820024a0fbec325c51062bcde", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "RFBOutputContext()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "'weight': ImageGraphic @ 0x7ff0c1ef96d0" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fig_weight = fpl.Figure()\n", - "\n", - "a = model.state_dict()[\"conv1.weight\"].squeeze().reshape(20, 40)\n", - "\n", - "fig_weight[0,0].add_image(a.cpu().numpy(), \"viridis\", name=\"weight\")\n", - "\n", - "# for i, subplot in enumerate(fig_weight):\n", - "# subplot.axes.visible = False\n", - "# subplot.add_image(data=a[i].cpu().numpy(), cmap=\"viridis\", name=\"weight\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "614da99c-7cd8-43b2-99e2-6da75a94c93d", - "metadata": {}, - "outputs": [], - "source": [ - "fig_weight.show(sidecar=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "40ded155-1c97-4b03-beff-0b85d4bcfd07", - "metadata": {}, - "outputs": [], - "source": [ - "fig_weight[0,0].axes.visible = False" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7a42de65-a8f7-4753-8d80-4b2b55e9089e", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "id": "3fab73d2-dd20-40ec-8989-11cb95b2cdb9", - "metadata": {}, - "source": [ - "## Train the model" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "b81aecd8-4b53-4d34-8a63-06a237c7de3f", - "metadata": {}, - "outputs": [], - "source": [ - "def train(model, device, train_loader, optimizer, epoch):\n", - " global fig_weight\n", - " model.train()\n", - " correct = 0\n", - " for batch_idx, (data, target) in enumerate(train_loader):\n", - " data, target = data.to(device), target.to(device)\n", - " optimizer.zero_grad()\n", - " output = model(data)\n", - " loss = F.nll_loss(output, target)\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " predicted = torch.max(output.data, 1)[1]\n", - " correct += (predicted == target).sum()\n", - " if batch_idx % 1000 == 0:\n", - " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\t Accuracy:{:.3f}%'.format(\n", - " epoch, batch_idx * len(data), len(train_loader.dataset),\n", - " 100. * batch_idx / len(train_loader), loss.item(), float(correct*100) / float(32 * (batch_idx + 1))))" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "7da55be5-937d-44a6-abc8-95faff0ef90f", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/5 [00:00