diff --git a/learning.ipynb b/learning.ipynb index f049810f2..b81fa3ca8 100644 --- a/learning.ipynb +++ b/learning.ipynb @@ -374,27 +374,21 @@ "source": [ "def load_MNIST(path=\"aima-data/MNIST\"):\n", " \"helper function to load MNIST data\"\n", - " train_img_file = open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\")\n", - " train_lbl_file = open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\")\n", - " test_img_file = open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\")\n", - " test_lbl_file = open(os.path.join(path, 't10k-labels-idx1-ubyte'), \"rb\")\n", + " with open(os.path.join(path, \"train-images-idx3-ubyte\"), \"rb\") as train_img_file:\n", + " magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n", + " tr_img = array.array(\"B\", train_img_file.read())\n", " \n", - " magic_nr, tr_size, tr_rows, tr_cols = struct.unpack(\">IIII\", train_img_file.read(16))\n", - " tr_img = array.array(\"B\", train_img_file.read())\n", - " train_img_file.close() \n", - " magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n", - " tr_lbl = array.array(\"b\", train_lbl_file.read())\n", - " train_lbl_file.close()\n", + " with open(os.path.join(path, \"train-labels-idx1-ubyte\"), \"rb\") as train_lbl_file:\n", + " magic_nr, tr_size = struct.unpack(\">II\", train_lbl_file.read(8))\n", + " tr_lbl = array.array(\"b\", train_lbl_file.read())\n", " \n", - " magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n", - " te_img = array.array(\"B\", test_img_file.read())\n", - " test_img_file.close()\n", - " magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n", - " te_lbl = array.array(\"b\", test_lbl_file.read())\n", - " test_lbl_file.close()\n", - "\n", - "# print(len(tr_img), len(tr_lbl), tr_size)\n", - "# print(len(te_img), len(te_lbl), te_size)\n", + " with open(os.path.join(path, \"t10k-images-idx3-ubyte\"), \"rb\") as test_img_file:\n", + " magic_nr, te_size, te_rows, te_cols = struct.unpack(\">IIII\", test_img_file.read(16))\n", + " te_img = array.array(\"B\", test_img_file.read())\n", + " \n", + " with open(os.path.join(path, \"t10k-labels-idx1-ubyte\"), \"rb\") as test_lbl_file:\n", + " magic_nr, te_size = struct.unpack(\">II\", test_lbl_file.read(8))\n", + " te_lbl = array.array(\"b\", test_lbl_file.read())\n", " \n", " train_img = np.zeros((tr_size, tr_rows*tr_cols), dtype=np.int16)\n", " train_lbl = np.zeros((tr_size,), dtype=np.int8)\n",