From 44189d58a44eb9f0c29c4849e23a54e9be14fa96 Mon Sep 17 00:00:00 2001 From: Anthony Marakis Date: Thu, 7 Sep 2017 21:07:57 +0300 Subject: [PATCH 1/2] Update notebook.py --- notebook.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/notebook.py b/notebook.py index 2894a8bfb..529307ee0 100644 --- a/notebook.py +++ b/notebook.py @@ -95,12 +95,15 @@ def show_iris(i=0, j=1, k=2): # MNIST -def load_MNIST(path="aima-data/MNIST"): +def load_MNIST(path="aima-data/MNIST/Digits", fashion=False): import os, struct import array import numpy as np from collections import Counter + if fashion: + path = "aima-data/MNIST/Fashion" + plt.rcParams.update(plt.rcParamsDefault) plt.rcParams['figure.figsize'] = (10.0, 8.0) plt.rcParams['image.interpolation'] = 'nearest' From bee6e3bcb66fe4297c1b27f9e3933f50d1638481 Mon Sep 17 00:00:00 2001 From: Anthony Marakis Date: Thu, 7 Sep 2017 21:54:30 +0300 Subject: [PATCH 2/2] Update notebook.py --- notebook.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/notebook.py b/notebook.py index 529307ee0..3fe64de2d 100644 --- a/notebook.py +++ b/notebook.py @@ -146,8 +146,17 @@ def load_MNIST(path="aima-data/MNIST/Digits", fashion=False): return(train_img, train_lbl, test_img, test_lbl) -def show_MNIST(labels, images, samples=8): - classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] +digit_classes = [str(i) for i in range(10)] +fashion_classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", + "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + + +def show_MNIST(labels, images, samples=8, fashion=False): + if not fashion: + classes = digit_classes + else: + classes = fashion_classes + num_classes = len(classes) for y, cls in enumerate(classes): @@ -164,13 +173,19 @@ def show_MNIST(labels, images, samples=8): plt.show() -def show_ave_MNIST(labels, images): - classes = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"] +def show_ave_MNIST(labels, images, fashion=False): + if not fashion: + item_type = "Digit" + classes = digit_classes + else: + item_type = "Apparel" + classes = fashion_classes + num_classes = len(classes) for y, cls in enumerate(classes): idxs = np.nonzero([i == y for i in labels]) - print("Digit", y, ":", len(idxs[0]), "images.") + print(item_type, y, ":", len(idxs[0]), "images.") ave_img = np.mean(np.vstack([images[i] for i in idxs[0]]), axis = 0) #print(ave_img.shape)