aboutsummaryrefslogtreecommitdiff
path: root/Fundamentals_of_Deep_Learning/01_mnist.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'Fundamentals_of_Deep_Learning/01_mnist.ipynb')
-rw-r--r--Fundamentals_of_Deep_Learning/01_mnist.ipynb2223
1 files changed, 2223 insertions, 0 deletions
diff --git a/Fundamentals_of_Deep_Learning/01_mnist.ipynb b/Fundamentals_of_Deep_Learning/01_mnist.ipynb
new file mode 100644
index 0000000..3d25d94
--- /dev/null
+++ b/Fundamentals_of_Deep_Learning/01_mnist.ipynb
@@ -0,0 +1,2223 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hr61FqXw2zfz"
+ },
+ "source": [
+ "<center><a href=\"https://www.nvidia.com/dli\"> <img src=\"images/DLI_Header.png\" alt=\"Header\" style=\"width: 400px;\"/> </a></center>"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dLa1EZq42zf0"
+ },
+ "source": [
+ "# 1. Image Classification with the MNIST Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "B-nDnmoC2zf1"
+ },
+ "source": [
+ "In this section we will do the \"Hello World\" of deep learning: training a deep learning model to correctly classify hand-written digits."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZWd9Mjxm2zf1"
+ },
+ "source": [
+ "## 1.1 Objectives"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qmkmljVQ2zf1"
+ },
+ "source": [
+ "* Understand how deep learning can solve problems traditional programming methods cannot\n",
+ "* Learn about the [MNIST handwritten digits dataset](http://yann.lecun.com/exdb/mnist/)\n",
+ "* Use the [torchvision](https://pytorch.org/vision/stable/index.html) to load the MNIST dataset and prepare it for training\n",
+ "* Create a simple neural network to perform image classification\n",
+ "* Train the neural network using the prepped MNIST dataset\n",
+ "* Observe the performance of the trained neural network"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's begin by loading the libraries used in this notebook:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "siboVxRJ5pes"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import Dataset, DataLoader\n",
+ "from torch.optim import Adam\n",
+ "\n",
+ "# Visualization tools\n",
+ "import torchvision\n",
+ "import torchvision.transforms.v2 as transforms\n",
+ "import torchvision.transforms.functional as F\n",
+ "import matplotlib.pyplot as plt"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FQJHCNcy5_Mk"
+ },
+ "source": [
+ "In PyTorch, we can use our GPU in our operations by setting the [device](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) to `cuda`. The function `torch.cuda.is_available()` will confirm PyTorch can recognize the GPU."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 4,
+ "status": "ok",
+ "timestamp": 1714980671199,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "AgBjq0uu6CYm",
+ "outputId": "9a881df2-6d26-4c95-d44b-1730d4d5b150"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "torch.cuda.is_available()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XC3FFt-Z2zf1"
+ },
+ "source": [
+ "### 1.1.1 The Problem: Image Classification"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ygLT32E-2zf1"
+ },
+ "source": [
+ "In traditional programming, the programmer is able to articulate rules and conditions in their code that their program can then use to act in the correct way. This approach continues to work exceptionally well for a huge variety of problems.\n",
+ "\n",
+ "Image classification, which asks a program to correctly classify an image it has never seen before into its correct class, is near impossible to solve with traditional programming techniques. How could a programmer possibly define the rules and conditions to correctly classify a huge variety of images, especially taking into account images that they have never seen?"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ch148qeW2zf2"
+ },
+ "source": [
+ "### 1.1.2 The Solution: Deep Learning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cdR-fe3Y2zf2"
+ },
+ "source": [
+ "Deep learning excels at pattern recognition by trial and error. By training a deep neural network with sufficient data, and providing the network with feedback on its performance via training, the network can identify, though a huge amount of iteration, its own set of conditions by which it can act in the correct way."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "g8e6Tasm2zf2"
+ },
+ "source": [
+ "## 1.2 The MNIST Dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HUopNy9V2zf2"
+ },
+ "source": [
+ "In the history of deep learning, the accurate image classification of the [MNIST dataset](http://yann.lecun.com/exdb/mnist/), a collection of 70,000 grayscale images of handwritten digits from 0 to 9, was a major development. While today the problem is considered trivial, doing image classification with MNIST has become a kind of \"Hello World\" for deep learning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0fjGsFR_2zf2"
+ },
+ "source": [
+ "Here are 40 of the images included in the MNIST dataset:"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6U1GdBU72zf2"
+ },
+ "source": [
+ "<img src=\"./images/mnist1.png\" style=\"width: 600px;\">"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4ksf2B052zf2"
+ },
+ "source": [
+ "### 1.2.1 Training and Validation Data and Labels"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8d-FSGUU2zf2"
+ },
+ "source": [
+ "When working with images for deep learning, we need both the images themselves, usually denoted as `X`, and also, correct [labels](https://developers.google.com/machine-learning/glossary#label) for these images, usually denoted as `Y`. Furthermore, we need `X` and `Y` values both for *training* the model, and then, a separate set of `X` and `Y` values for *validating* the performance of the model after it has been trained.\n",
+ "\n",
+ "We can imagine these `X` and `Y` pairs as a set of flash cards. A student can train with one set of flashcards, and to validate the student learned the correct concepts, a teacher might quiz the student with a different set of flash cards. \n",
+ "\n",
+ "Therefore, we need 4 segments of data for the MNIST dataset:\n",
+ "\n",
+ "1. `x_train`: Images used for training the neural network\n",
+ "2. `y_train`: Correct labels for the `x_train` images, used to evaluate the model's predictions during training\n",
+ "3. `x_valid`: Images set aside for validating the performance of the model after it has been trained\n",
+ "4. `y_valid`: Correct labels for the `x_valid` images, used to evaluate the model's predictions after it has been trained\n",
+ "\n",
+ "The process of preparing data for analysis is called [Data Engineering](https://medium.com/@rchang/a-beginners-guide-to-data-engineering-part-i-4227c5c457d7). To learn more about the differences between training data and validation data (as well as test data), check out [this article](https://machinelearningmastery.com/difference-test-validation-datasets/) by Jason Brownlee."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "exXV58Gn2zf2"
+ },
+ "source": [
+ "### 1.2.2 Loading the Data Into Memory (with TorchVision)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Rp-63AGt2zf2"
+ },
+ "source": [
+ "There are many [deep learning frameworks](https://developer.nvidia.com/deep-learning-frameworks), each with their own merits. In this workshop we will be working with [PyTorch 2](https://pytorch.org/get-started/pytorch-2.0/), and specifically with the [Sequential API](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html). The Sequential API has many useful built in functions designed for constructing neural networks. It is also a legitimate choice for deep learning in a professional setting due to its [readability](https://blog.pragmaticengineer.com/readable-code/) and efficiency, though it is not alone in this regard, and it is worth investigating a variety of frameworks when beginning a deep learning project.\n",
+ "\n",
+ "We will also use the [TorchVision](https://pytorch.org/vision/stable/index.html) library. One of the many helpful features that it provides are modules containing helper methods for [many common datasets](https://pytorch.org/vision/main/datasets.html), including MNIST.\n",
+ "\n",
+ "We will begin by loading both the `train` and `valid` datasets for [MNIST](https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html#torchvision.datasets.MNIST)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 4382,
+ "status": "ok",
+ "timestamp": 1714980675579,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "If_AfbCG2zf3",
+ "outputId": "8d7266b2-7b9f-4f9d-a0cd-ae38ed0840f9"
+ },
+ "outputs": [],
+ "source": [
+ "train_set = torchvision.datasets.MNIST(\"./data/\", train=True, download=True)\n",
+ "valid_set = torchvision.datasets.MNIST(\"./data/\", train=False, download=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We stated above that the MNIST dataset contained 70,000 grayscale images of handwritten digits. By executing the following cells, we can see that TorchVision has partitioned 60,000 of these [PIL Images](https://pillow.readthedocs.io/en/stable/reference/Image.html) for training, and 10,000 for validation (after training)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 156,
+ "status": "ok",
+ "timestamp": 1714980675725,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "tv9vHMDZNNs3",
+ "outputId": "6fe55662-e37a-427c-e055-08e5dcdff18b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Dataset MNIST\n",
+ " Number of datapoints: 60000\n",
+ " Root location: ./data/\n",
+ " Split: Train"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "train_set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 3,
+ "status": "ok",
+ "timestamp": 1714980675726,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "vEKzOpJpNPIS",
+ "outputId": "39715875-6a29-4429-fdfd-cbd95221fdd9"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Dataset MNIST\n",
+ " Number of datapoints: 10000\n",
+ " Root location: ./data/\n",
+ " Split: Test"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "valid_set"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "34dQ8RKFNf4z"
+ },
+ "source": [
+ "*Note*: The `Split` for `valid_set` is stated as `Test`, but we will be using the data for validation in our hands-on exercises. To learn more about the difference between `Train`, `Valid`, and `Test` datasets, please view [this article](https://kili-technology.com/training-data/training-validation-and-test-sets-how-to-split-machine-learning-data) by Kili."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-C_XCsFl2zf3"
+ },
+ "source": [
+ "### 1.2.3 Exploring the MNIST Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ybZesgdF2zf3"
+ },
+ "source": [
+ "Let's take the first x, y pair from `train_set` and review the data structures:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "PKVmYF9V2zf3"
+ },
+ "outputs": [],
+ "source": [
+ "x_0, y_0 = train_set[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 45
+ },
+ "executionInfo": {
+ "elapsed": 268,
+ "status": "ok",
+ "timestamp": 1714980688936,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "BQ7Gt-za2zf3",
+ "outputId": "70fa206b-5bf6-4d7b-bbcd-ce71aad28ff2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAACzBVBJJwAO9dnp/wm8damu6Dw5dRjGf9IKw/+hkVPffCnWNJa7XVNV0Kxa1hErrNe/M2cnYqgElsAHpjkc1wlAODkV694W8c654t8M6n4TuvEctrrFw0cun3c0/lq+3AMJcDK5AyOeTkd+fPvGFn4gsvEtzF4m89tUG1ZJJjuMgUBVYN/EMKOe9YVXtK0bUtdvVs9LsZ7y4YgbIULYycZPoPc8V6lpfwh0/w7p66z8RdXj0y2z8llC4aWQ+mRn8lz9RXPfE3x1pvi46TYaPZTQadpMJghluWDSyrhQM9SMBe5Oc5NcBV7Tda1XRZJJNK1O8sXkG12tZ2iLD0JUjNQ3l9eahN517dT3MvTfNIXb16n6mq9Ff/2Q==",
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA90lEQVR4AWNgGMyAWUhIqK5jvdSy/9/rQe5kgTlWjs3KRiAYxHsyKfDzxYMgFiOIAALDvfwQBsO/pK8Mz97fhPLAlNDtvyBwbNv3j8jCUHbAnOy/f89yM2jPwiLJwMc4628UqgQTnPvp/0eGFAQXLg5lcO/764YuhuArf3y4IAfmfoQwlBX44e/fckkMYaiA7q6/f6dJ45IViP3zdzcuSQaGn39/OkBl4WEL4euFmLIwXDuETav6lKfAIPy1DYucRNFdUPCe9MOUE3e6CpI6FogZSEKrwbFyOIATQ5v5mkcgXV9auVGlwK4NDGRguL75b88HVDla8QBFF16ADQA8sQAAAABJRU5ErkJggg==",
+ "text/plain": [
+ "<PIL.Image.Image image mode=L size=28x28>"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 186
+ },
+ "executionInfo": {
+ "elapsed": 253,
+ "status": "ok",
+ "timestamp": 1714980690340,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "XT8XStMyNB45",
+ "outputId": "a576e2ec-0847-43cc-efc8-6236d6e7bdc6"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PIL.Image.Image"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(x_0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CgfAqJKmVvuH"
+ },
+ "source": [
+ "Is this a 5 or a poorly written 3? We can view the corresponding label to be sure."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 156,
+ "status": "ok",
+ "timestamp": 1714980692683,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "QSFkrqaQ2zf4",
+ "outputId": "1722c462-ba0f-487e-dca2-5c5ccb02442e"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "5"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 167,
+ "status": "ok",
+ "timestamp": 1714980693942,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "Ne6enO2sNE55",
+ "outputId": "01e9b3a7-bf8d-4d08-d8d8-0fded6885a67"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "int"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(y_0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "826TKJQaOOVs"
+ },
+ "source": [
+ "## 1.3 Tensors"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DZ_zoaub2zf4"
+ },
+ "source": [
+ "If a vector is a 1-dimensional array, and a matrix is a 2-dimensional array, a tensor is an n-dimensional array representing any number of dimensions. Most modern neural network frameworks are powerful tensor processing tools.\n",
+ "\n",
+ "One example of a 3-dimensional tensor could be pixels on a computer screen. The different dimensions would be width, height, and color channel. Video games use matrix mathematics to calculate pixel values in a similar way to how neural networks calculate tensors. This is why GPUs are effective tensor processesing machines.\n",
+ "\n",
+ "Let's convert our images into tensors so we can later process them with a neural network. TorchVision has a useful function to convert [PIL Images](https://pillow.readthedocs.io/en/stable/reference/Image.html) into tensors with the [ToTensor](https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html) class:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "id": "uWnZxkiqUDp6"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/torchvision/transforms/v2/_deprecated.py:41: UserWarning: The transform `ToTensor()` is deprecated and will be removed in a future release. Instead, please use `v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])`.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "trans = transforms.Compose([transforms.ToTensor()])\n",
+ "x_0_tensor = trans(x_0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TewgOfaYXbJo"
+ },
+ "source": [
+ "[PyTorch tensors](https://pytorch.org/docs/stable/tensors.html#torch.Tensor) have a number of useful properies and methods. We can verify the data type:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 152,
+ "status": "ok",
+ "timestamp": 1714980703279,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "F-jyN9_N2zf4",
+ "outputId": "ea7a13bb-0f4b-4792-df28-7a21812acecd"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.float32"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.dtype"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yIDUGa4LYzYb"
+ },
+ "source": [
+ "We can verify the minimum and maximum values. PIL Images have a potential integer range of [0, 255], but the [ToTensor](https://pytorch.org/vision/main/generated/torchvision.transforms.ToTensor.html) class converts it to a float range of [0.0, 1.0]."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 156,
+ "status": "ok",
+ "timestamp": 1714980704869,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "XTTjlZVS2zf4",
+ "outputId": "28be071a-3d2c-496e-ae07-e788ae6a284b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(0.)"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.min()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 210,
+ "status": "ok",
+ "timestamp": 1714980706159,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "H_denh2i2zf4",
+ "outputId": "98b3f276-cc73-46dc-cb21-0005d2756a28"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor(1.)"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9A9EL-34ZpSS"
+ },
+ "source": [
+ "We can also view the size of each dimension. PyTorch uses a `C x H x W` convention, which means the first dimension is color channel, the second is height, and the third is width.\n",
+ "\n",
+ "Since these images are black and white, there is only 1 color channel. The images are square being 28 pixels tall and wide:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 2,
+ "status": "ok",
+ "timestamp": 1714980707121,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "cIxHEBDRZk-K",
+ "outputId": "e6fbc8d8-1bdb-42c1-8a30-6c022daa7d65"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 28, 28])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.size()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "cXZWU123bB9W"
+ },
+ "source": [
+ "We can also look at the values directly:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 2,
+ "status": "ok",
+ "timestamp": 1714980708660,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "CLU8kY332zf4",
+ "outputId": "7e8fbef2-d1a9-4910-a1a9-caad1cec750c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0706, 0.0706, 0.0706,\n",
+ " 0.4941, 0.5333, 0.6863, 0.1020, 0.6510, 1.0000, 0.9686, 0.4980,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.1176, 0.1412, 0.3686, 0.6039, 0.6667, 0.9922, 0.9922, 0.9922,\n",
+ " 0.9922, 0.9922, 0.8824, 0.6745, 0.9922, 0.9490, 0.7647, 0.2510,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1922,\n",
+ " 0.9333, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922,\n",
+ " 0.9922, 0.9843, 0.3647, 0.3216, 0.3216, 0.2196, 0.1529, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706,\n",
+ " 0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765, 0.7137,\n",
+ " 0.9686, 0.9451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.3137, 0.6118, 0.4196, 0.9922, 0.9922, 0.8039, 0.0431, 0.0000,\n",
+ " 0.1686, 0.6039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0549, 0.0039, 0.6039, 0.9922, 0.3529, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.5451, 0.9922, 0.7451, 0.0078, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0431, 0.7451, 0.9922, 0.2745, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.1373, 0.9451, 0.8824, 0.6275,\n",
+ " 0.4235, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3176, 0.9412, 0.9922,\n",
+ " 0.9922, 0.4667, 0.0980, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.7294,\n",
+ " 0.9922, 0.9922, 0.5882, 0.1059, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,\n",
+ " 0.3647, 0.9882, 0.9922, 0.7333, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.9765, 0.9922, 0.9765, 0.2510, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1804, 0.5098,\n",
+ " 0.7176, 0.9922, 0.9922, 0.8118, 0.0078, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.1529, 0.5804, 0.8980, 0.9922,\n",
+ " 0.9922, 0.9922, 0.9804, 0.7137, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0941, 0.4471, 0.8667, 0.9922, 0.9922, 0.9922,\n",
+ " 0.9922, 0.7882, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0902, 0.2588, 0.8353, 0.9922, 0.9922, 0.9922, 0.9922, 0.7765,\n",
+ " 0.3176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.6706,\n",
+ " 0.8588, 0.9922, 0.9922, 0.9922, 0.9922, 0.7647, 0.3137, 0.0353,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.2157, 0.6745, 0.8863, 0.9922,\n",
+ " 0.9922, 0.9922, 0.9922, 0.9569, 0.5216, 0.0431, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.5333, 0.9922, 0.9922, 0.9922,\n",
+ " 0.8314, 0.5294, 0.5176, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000],\n",
+ " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
+ " 0.0000, 0.0000, 0.0000, 0.0000]]])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "By default, a tensor is processed with a [CPU](https://www.arm.com/glossary/cpu)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 145,
+ "status": "ok",
+ "timestamp": 1714980711055,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "MUDWAzz386wC",
+ "outputId": "e4610927-74da-4a7d-cac8-3836661aaf2c"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cpu')"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To move it to a GPU, we can use the `.cuda` method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 208,
+ "status": "ok",
+ "timestamp": 1714980712295,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "PtVbrIK_8bwz",
+ "outputId": "7e9256b6-8cd1-4a14-e656-114298564a04"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda', index=0)"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_gpu = x_0_tensor.cuda()\n",
+ "x_0_gpu.device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The `.cuda` method will fail if a GPU is not recognized by PyTorch. In order to make our code flexible, we can send our tensor `to` the `device` we identified at the start of this notebook. This way, our code will run much faster if a GPU is available, but the code will not break if there is no available GPU."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 144,
+ "status": "ok",
+ "timestamp": 1714980713216,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "_qSz_fcP8swy",
+ "outputId": "db53c12e-8809-4023-fd9b-bdd5c1b39b4b"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda', index=0)"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "x_0_tensor.to(device).device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lBqQfpwo2zf4"
+ },
+ "source": [
+ "Sometimes, it can be hard to interpret so many numbers. Thankfully, TorchVision can convert `C x H x W` tensors back into a PIL image with the [to_pil_image](https://pytorch.org/vision/main/generated/torchvision.transforms.functional.to_pil_image.html) function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 447
+ },
+ "executionInfo": {
+ "elapsed": 394,
+ "status": "ok",
+ "timestamp": 1714980717832,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "fsqamRxJ2zf4",
+ "outputId": "75700cd2-85ad-4866-efc4-f8ac35863237"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<matplotlib.image.AxesImage at 0x7f24e90dac20>"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbe0lEQVR4nO3df2xV9f3H8dflR6+I7e1KbW8rPyygsIlgxqDrVMRRKd1G5McWdS7BzWhwrRGYuNRM0W2uDqczbEz5Y4GxCSjJgEEWNi22ZLNgQBgxbg0l3VpGWyZb7y2FFmw/3z+I98uVFjyXe/u+vTwfySeh955378fjtU9vezn1OeecAADoZ4OsNwAAuDIRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYGKI9QY+qaenR8eOHVN6erp8Pp/1dgAAHjnn1N7ervz8fA0a1PfrnKQL0LFjxzRq1CjrbQAALlNTU5NGjhzZ5/1J9y249PR06y0AAOLgUl/PExag1atX6/rrr9dVV12lwsJCvfvuu59qjm+7AUBquNTX84QE6PXXX9eyZcu0YsUKvffee5oyZYpKSkp0/PjxRDwcAGAgcgkwffp0V1ZWFvm4u7vb5efnu8rKykvOhkIhJ4nFYrFYA3yFQqGLfr2P+yugM2fOaP/+/SouLo7cNmjQIBUXF6u2tvaC47u6uhQOh6MWACD1xT1AH374obq7u5Wbmxt1e25urlpaWi44vrKyUoFAILJ4BxwAXBnM3wVXUVGhUCgUWU1NTdZbAgD0g7j/PaDs7GwNHjxYra2tUbe3trYqGAxecLzf75ff74/3NgAASS7ur4DS0tI0depUVVVVRW7r6elRVVWVioqK4v1wAIABKiFXQli2bJkWLVqkL3zhC5o+fbpefvlldXR06Nvf/nYiHg4AMAAlJED33HOP/vOf/+jpp59WS0uLbrnlFu3cufOCNyYAAK5cPuecs97E+cLhsAKBgPU2AACXKRQKKSMjo8/7zd8FBwC4MhEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmhlhvAEgmgwcP9jwTCAQSsJP4KC8vj2nu6quv9jwzYcIEzzNlZWWeZ372s595nrnvvvs8z0hSZ2en55nnn3/e88yzzz7reSYV8AoIAGCCAAEATMQ9QM8884x8Pl/UmjhxYrwfBgAwwCXkZ0A33XST3nrrrf9/kCH8qAkAEC0hZRgyZIiCwWAiPjUAIEUk5GdAhw8fVn5+vsaOHav7779fjY2NfR7b1dWlcDgctQAAqS/uASosLNS6deu0c+dOvfLKK2poaNDtt9+u9vb2Xo+vrKxUIBCIrFGjRsV7SwCAJBT3AJWWluob3/iGJk+erJKSEv3xj39UW1ub3njjjV6Pr6ioUCgUiqympqZ4bwkAkIQS/u6AzMxM3Xjjjaqvr+/1fr/fL7/fn+htAACSTML/HtDJkyd15MgR5eXlJfqhAAADSNwD9Pjjj6umpkb//Oc/9c4772j+/PkaPHhwzJfCAACkprh/C+7o0aO67777dOLECV177bW67bbbtGfPHl177bXxfigAwAAW9wBt2rQp3p8SSWr06NGeZ9LS0jzPfOlLX/I8c9ttt3mekc79zNKrhQsXxvRYqebo0aOeZ1atWuV5Zv78+Z5n+noX7qX87W9/8zxTU1MT02NdibgWHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgwuecc9abOF84HFYgELDexhXllltuiWlu165dnmf4dzsw9PT0eJ75zne+43nm5MmTnmdi0dzcHNPc//73P88zdXV1MT1WKgqFQsrIyOjzfl4BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMQQ6w3AXmNjY0xzJ06c8DzD1bDP2bt3r+eZtrY2zzN33nmn5xlJOnPmjOeZ3/72tzE9Fq5cvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMVLov//9b0xzy5cv9zzzta99zfPMgQMHPM+sWrXK80ysDh486Hnmrrvu8jzT0dHheeamm27yPCNJjz32WExzgBe8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATPicc856E+cLh8MKBALW20CCZGRkeJ5pb2/3PLNmzRrPM5L04IMPep751re+5Xlm48aNnmeAgSYUCl30v3leAQEATBAgAIAJzwHavXu35s6dq/z8fPl8Pm3dujXqfuecnn76aeXl5WnYsGEqLi7W4cOH47VfAECK8Bygjo4OTZkyRatXr+71/pUrV2rVqlV69dVXtXfvXg0fPlwlJSXq7Oy87M0CAFKH59+IWlpaqtLS0l7vc87p5Zdf1g9+8APdfffdkqT169crNzdXW7du1b333nt5uwUApIy4/gyooaFBLS0tKi4ujtwWCARUWFio2traXme6uroUDoejFgAg9cU1QC0tLZKk3NzcqNtzc3Mj931SZWWlAoFAZI0aNSqeWwIAJCnzd8FVVFQoFApFVlNTk/WWAAD9IK4BCgaDkqTW1tao21tbWyP3fZLf71dGRkbUAgCkvrgGqKCgQMFgUFVVVZHbwuGw9u7dq6Kiong+FABggPP8LriTJ0+qvr4+8nFDQ4MOHjyorKwsjR49WkuWLNGPf/xj3XDDDSooKNBTTz2l/Px8zZs3L577BgAMcJ4DtG/fPt15552Rj5ctWyZJWrRokdatW6cnnnhCHR0devjhh9XW1qbbbrtNO3fu1FVXXRW/XQMABjwuRoqU9MILL8Q09/H/UHlRU1Pjeeb8v6rwafX09HieASxxMVIAQFIiQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACa6GjZQ0fPjwmOa2b9/ueeaOO+7wPFNaWup55s9//rPnGcASV8MGACQlAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEFyMFzjNu3DjPM++9957nmba2Ns8zb7/9tueZffv2eZ6RpNWrV3ueSbIvJUgCXIwUAJCUCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATXIwUuEzz58/3PLN27VrPM+np6Z5nYvXkk096nlm/fr3nmebmZs8zGDi4GCkAICkRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GClgYNKkSZ5nXnrpJc8zs2bN8jwTqzVr1nieee655zzP/Pvf//Y8AxtcjBQAkJQIEADAhOcA7d69W3PnzlV+fr58Pp+2bt0adf8DDzwgn88XtebMmROv/QIAUoTnAHV0dGjKlClavXp1n8fMmTNHzc3NkbVx48bL2iQAIPUM8TpQWlqq0tLSix7j9/sVDAZj3hQAIPUl5GdA1dXVysnJ0YQJE/TII4/oxIkTfR7b1dWlcDgctQAAqS/uAZozZ47Wr1+vqqoq/fSnP1VNTY1KS0vV3d3d6/GVlZUKBAKRNWrUqHhvCQCQhDx/C+5S7r333sifb775Zk2ePFnjxo1TdXV1r38noaKiQsuWLYt8HA6HiRAAXAES/jbssWPHKjs7W/X19b3e7/f7lZGREbUAAKkv4QE6evSoTpw4oby8vEQ/FABgAPH8LbiTJ09GvZppaGjQwYMHlZWVpaysLD377LNauHChgsGgjhw5oieeeELjx49XSUlJXDcOABjYPAdo3759uvPOOyMff/zzm0WLFumVV17RoUOH9Jvf/EZtbW3Kz8/X7Nmz9aMf/Uh+vz9+uwYADHhcjBQYIDIzMz3PzJ07N6bHWrt2recZn8/neWbXrl2eZ+666y7PM7DBxUgBAEmJAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgaNoALdHV1eZ4ZMsTzb3fRRx995Hkmlt8tVl1d7XkGl4+rYQMAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYML71QMBXLbJkyd7nvn617/ueWbatGmeZ6TYLiwaiw8++MDzzO7duxOwE1jgFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKLkQLnmTBhgueZ8vJyzzMLFizwPBMMBj3P9Kfu7m7PM83NzZ5nenp6PM8gOfEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVIkfRiuQjnfffdF9NjxXJh0euvvz6mx0pm+/bt8zzz3HPPeZ75wx/+4HkGqYNXQAAAEwQIAGDCU4AqKys1bdo0paenKycnR/PmzVNdXV3UMZ2dnSorK9OIESN0zTXXaOHChWptbY3rpgEAA5+nANXU1KisrEx79uzRm2++qbNnz2r27Nnq6OiIHLN06VJt375dmzdvVk1NjY4dOxbTL98CAKQ2T29C2LlzZ9TH69atU05Ojvbv368ZM2YoFArp17/+tTZs2KAvf/nLkqS1a9fqs5/9rPbs2aMvfvGL8ds5AGBAu6yfAYVCIUlSVlaWJGn//v06e/asiouLI8dMnDhRo0ePVm1tba+fo6urS+FwOGoBAFJfzAHq6enRkiVLdOutt2rSpEmSpJaWFqWlpSkzMzPq2NzcXLW0tPT6eSorKxUIBCJr1KhRsW4JADCAxBygsrIyvf/++9q0adNlbaCiokKhUCiympqaLuvzAQAGhpj+Imp5ebl27Nih3bt3a+TIkZHbg8Ggzpw5o7a2tqhXQa2trX3+ZUK/3y+/3x/LNgAAA5inV0DOOZWXl2vLli3atWuXCgoKou6fOnWqhg4dqqqqqshtdXV1amxsVFFRUXx2DABICZ5eAZWVlWnDhg3atm2b0tPTIz/XCQQCGjZsmAKBgB588EEtW7ZMWVlZysjI0KOPPqqioiLeAQcAiOIpQK+88ookaebMmVG3r127Vg888IAk6ec//7kGDRqkhQsXqqurSyUlJfrVr34Vl80CAFKHzznnrDdxvnA4rEAgYL0NfAq5ubmeZz73uc95nvnlL3/peWbixImeZ5Ld3r17Pc+88MILMT3Wtm3bPM/09PTE9FhIXaFQSBkZGX3ez7XgAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCKm34iK5JWVleV5Zs2aNTE91i233OJ5ZuzYsTE9VjJ75513PM+8+OKLnmf+9Kc/eZ45ffq05xmgv/AKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVI+0lhYaHnmeXLl3uemT59uueZ6667zvNMsjt16lRMc6tWrfI885Of/MTzTEdHh+cZINXwCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHFSPvJ/Pnz+2WmP33wwQeeZ3bs2OF55qOPPvI88+KLL3qekaS2traY5gB4xysgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCEzznnrDdxvnA4rEAgYL0NAMBlCoVCysjI6PN+XgEBAEwQIACACU8Bqqys1LRp05Senq6cnBzNmzdPdXV1UcfMnDlTPp8vai1evDiumwYADHyeAlRTU6OysjLt2bNHb775ps6ePavZs2ero6Mj6riHHnpIzc3NkbVy5cq4bhoAMPB5+o2oO3fujPp43bp1ysnJ0f79+zVjxozI7VdffbWCwWB8dggASEmX9TOgUCgkScrKyoq6/bXXXlN2drYmTZqkiooKnTp1qs/P0dXVpXA4HLUAAFcAF6Pu7m731a9+1d16661Rt69Zs8bt3LnTHTp0yP3ud79z1113nZs/f36fn2fFihVOEovFYrFSbIVCoYt2JOYALV682I0ZM8Y1NTVd9LiqqionydXX1/d6f2dnpwuFQpHV1NRkftJYLBaLdfnrUgHy9DOgj5WXl2vHjh3avXu3Ro4cedFjCwsLJUn19fUaN27cBff7/X75/f5YtgEAGMA8Bcg5p0cffVRbtmxRdXW1CgoKLjlz8OBBSVJeXl5MGwQApCZPASorK9OGDRu0bds2paenq6WlRZIUCAQ0bNgwHTlyRBs2bNBXvvIVjRgxQocOHdLSpUs1Y8YMTZ48OSH/AACAAcrLz33Ux/f51q5d65xzrrGx0c2YMcNlZWU5v9/vxo8f75YvX37J7wOeLxQKmX/fksVisViXvy71tZ+LkQIAEoKLkQIAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwETSBcg5Z70FAEAcXOrredIFqL293XoLAIA4uNTXc59LspccPT09OnbsmNLT0+Xz+aLuC4fDGjVqlJqampSRkWG0Q3uch3M4D+dwHs7hPJyTDOfBOaf29nbl5+dr0KC+X+cM6cc9fSqDBg3SyJEjL3pMRkbGFf0E+xjn4RzOwzmch3M4D+dYn4dAIHDJY5LuW3AAgCsDAQIAmBhQAfL7/VqxYoX8fr/1VkxxHs7hPJzDeTiH83DOQDoPSfcmBADAlWFAvQICAKQOAgQAMEGAAAAmCBAAwMSACdDq1at1/fXX66qrrlJhYaHeffdd6y31u2eeeUY+ny9qTZw40XpbCbd7927NnTtX+fn58vl82rp1a9T9zjk9/fTTysvL07Bhw1RcXKzDhw/bbDaBLnUeHnjggQueH3PmzLHZbIJUVlZq2rRpSk9PV05OjubNm6e6urqoYzo7O1VWVqYRI0bommuu0cKFC9Xa2mq048T4NOdh5syZFzwfFi9ebLTj3g2IAL3++utatmyZVqxYoffee09TpkxRSUmJjh8/br21fnfTTTepubk5sv7yl79YbynhOjo6NGXKFK1evbrX+1euXKlVq1bp1Vdf1d69ezV8+HCVlJSos7Ozn3eaWJc6D5I0Z86cqOfHxo0b+3GHiVdTU6OysjLt2bNHb775ps6ePavZs2ero6MjcszSpUu1fft2bd68WTU1NTp27JgWLFhguOv4+zTnQZIeeuihqOfDypUrjXbcBzcATJ8+3ZWVlUU+7u7udvn5+a6ystJwV/1vxYoVbsqUKdbbMCXJbdmyJfJxT0+PCwaD7oUXXojc1tbW5vx+v9u4caPBDvvHJ8+Dc84tWrTI3X333Sb7sXL8+HEnydXU1Djnzv27Hzp0qNu8eXPkmL///e9OkqutrbXaZsJ98jw459wdd9zhHnvsMbtNfQpJ/wrozJkz2r9/v4qLiyO3DRo0SMXFxaqtrTXcmY3Dhw8rPz9fY8eO1f3336/GxkbrLZlqaGhQS0tL1PMjEAiosLDwinx+VFdXKycnRxMmTNAjjzyiEydOWG8poUKhkCQpKytLkrR//36dPXs26vkwceJEjR49OqWfD588Dx977bXXlJ2drUmTJqmiokKnTp2y2F6fku5ipJ/04Ycfqru7W7m5uVG35+bm6h//+IfRrmwUFhZq3bp1mjBhgpqbm/Xss8/q9ttv1/vvv6/09HTr7ZloaWmRpF6fHx/fd6WYM2eOFixYoIKCAh05ckRPPvmkSktLVVtbq8GDB1tvL+56enq0ZMkS3XrrrZo0aZKkc8+HtLQ0ZWZmRh2bys+H3s6DJH3zm9/UmDFjlJ+fr0OHDun73/++6urq9Pvf/95wt9GSPkD4f6WlpZE/T548WYWFhRozZozeeOMNPfjgg4Y7QzK49957I3+++eabNXnyZI0bN07V1dWaNWuW4c4So6ysTO+///4V8XPQi+nrPDz88MORP998883Ky8vTrFmzdOTIEY0bN66/t9mrpP8WXHZ2tgYPHnzBu1haW1sVDAaNdpUcMjMzdeONN6q+vt56K2Y+fg7w/LjQ2LFjlZ2dnZLPj/Lycu3YsUNvv/121K9vCQaDOnPmjNra2qKOT9XnQ1/noTeFhYWSlFTPh6QPUFpamqZOnaqqqqrIbT09PaqqqlJRUZHhzuydPHlSR44cUV5envVWzBQUFCgYDEY9P8LhsPbu3XvFPz+OHj2qEydOpNTzwzmn8vJybdmyRbt27VJBQUHU/VOnTtXQoUOjng91dXVqbGxMqefDpc5Dbw4ePChJyfV8sH4XxKexadMm5/f73bp169wHH3zgHn74YZeZmelaWlqst9avvve977nq6mrX0NDg/vrXv7ri4mKXnZ3tjh8/br21hGpvb3cHDhxwBw4ccJLcSy+95A4cOOD+9a9/Oeece/75511mZqbbtm2bO3TokLv77rtdQUGBO336tPHO4+ti56G9vd09/vjjrra21jU0NLi33nrLff7zn3c33HCD6+zstN563DzyyCMuEAi46upq19zcHFmnTp2KHLN48WI3evRot2vXLrdv3z5XVFTkioqKDHcdf5c6D/X19e6HP/yh27dvn2toaHDbtm1zY8eOdTNmzDDeebQBESDnnPvFL37hRo8e7dLS0tz06dPdnj17rLfU7+655x6Xl5fn0tLS3HXXXefuueceV19fb72thHv77bedpAvWokWLnHPn3or91FNPudzcXOf3+92sWbNcXV2d7aYT4GLn4dSpU2727Nnu2muvdUOHDnVjxoxxDz30UMr9T1pv//yS3Nq1ayPHnD592n33u991n/nMZ9zVV1/t5s+f75qbm+02nQCXOg+NjY1uxowZLisry/n9fjd+/Hi3fPlyFwqFbDf+Cfw6BgCAiaT/GRAAIDURIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACb+Dwuo74MxItlsAAAAAElFTkSuQmCC",
+ "text/plain": [
+ "<Figure size 640x480 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "image = F.to_pil_image(x_0_tensor)\n",
+ "plt.imshow(image, cmap='gray')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1.4 Preparing the Data for Training"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Earlier, we created a `trans` variable to convert an image to a tensor. [Transforms](https://pytorch.org/vision/stable/transforms.html) are a group of torchvision functions that can be used to transform a dataset."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4.1 Transforms"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The [Compose](https://pytorch.org/vision/stable/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose) fuction combines a list of transforms. We will learn more about transforms in a later notebook, but have copied the `trans` definition below as an introduction."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "trans = transforms.Compose([transforms.ToTensor()])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Before, we only applied `trans` to one value. There are multiple ways we can apply our list of transforms to a dataset. One such way is to set it to a dataset's `transform` variable. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "id": "ue8Qs8RU0nF7"
+ },
+ "outputs": [],
+ "source": [
+ "train_set.transform = trans\n",
+ "valid_set.transform = trans"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4.2 DataLoaders"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If our dataset is a deck of flash cards, a [DataLoader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#preparing-your-data-for-training-with-dataloaders) defines how we pull cards from the deck to train an AI model. We could show our models the entire dataset at once. Not only does this take a lot of computational resources, but [research shows](https://arxiv.org/pdf/1804.07612) using a smaller batch of data is more efficient for model training.\n",
+ "\n",
+ "For example, if our `batch_size` is 32, we will train our model by shuffling the deck and drawing 32 cards. We do not need to shuffle for validation as the model is not learning, but we will still use a `batch_size` to prevent memory errors.\n",
+ "\n",
+ "The batch size is something the model developer decides, and the best value will depend on the problem being solved. Research shows 32 or 64 is sufficient for many machine learning problems and is the default in some machine learning frameworks, so we will use 32 here."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "id": "vhfKSYKCpHxf"
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = 32\n",
+ "\n",
+ "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
+ "valid_loader = DataLoader(valid_set, batch_size=batch_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1.5 Creating the Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It's time to build the model! Neural networks are composed of layers where each layer performs a mathematical operation on the data it receives before passing it to the next layer. To start, we will create a \"Hello World\" level model made from 4 components:\n",
+ "\n",
+ "1. A [Flatten](https://pytorch.org/docs/stable/generated/torch.nn.Flatten.html) used to convert n-dimensional data into a vector. \n",
+ "2. An input layer, the first layer of neurons\n",
+ "3. A hidden layer, another layor of neurons \"hidden\" between the input and output\n",
+ "4. An output layer, the last set of neurons which returns the final prediction from the model\n",
+ "\n",
+ "More information about these layers is available in [this blog post](https://medium.com/@sarita_68521/basic-understanding-of-neural-network-structure-eecc8f149a23) by Sarita.\n",
+ "\n",
+ "Let's create a `layers` variable to hold our list of layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[]"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "layers = []\n",
+ "layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5.1 Flattening the Image"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "When we looked at the shape of our data above, we saw the images had 3 dimensions: `C x H x W`. To flatten an image means to combine all of these images into 1 dimension. Let's say we have a tensor like the one below. Try running the code cell to see what it looks like before and after being flattened."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]])"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_matrix = torch.tensor(\n",
+ " [[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]]\n",
+ ")\n",
+ "test_matrix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]])"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.Flatten()(test_matrix)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Nothing happened? That's because neural networks expect to recieve a batch of data. Currently, the Flatten layer sees three vectors as opposed to one 2d matrix. To fix this, we can \"batch\" our data by adding an extra pair of brackets. Since `test_matrix` is now a tensor, we can do that with the shorthand below. `None` adds a new dimension where `:` selects all the data in a tensor."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]]])"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch_test_matrix = test_matrix[None, :]\n",
+ "batch_test_matrix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9]])"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.Flatten()(batch_test_matrix)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Order matters! This is what happens when we do it the other way:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ " [7, 8, 9]])"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nn.Flatten()(test_matrix[:, None])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we've gotten the hang of the `Flatten` layer, let's add it to our list of `layers`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Flatten(start_dim=1, end_dim=-1)]"
+ ]
+ },
+ "execution_count": 30,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "layers = [\n",
+ " nn.Flatten()\n",
+ "]\n",
+ "layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5.2 The Input Layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Our first layer of neurons connects our flattened image to the rest of our model. To do that, we will use a [Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) layer. This layer will be *densely connected*, meaning that each neuron in it, and its weights, will affect every neuron in the next layer.\n",
+ "\n",
+ "In order to create these weights, Pytorch needs to know the size of our inputs and how many neurons we want to create. \n",
+ "Since we've flattened our images, the size of our inputs is the number of channels, number of pixels vertically, and number of pixels horizontally multiplied together."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_size = 1 * 28 * 28"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Choosing the correct number of neurons is what puts the \"science\" in \"data science\" as it is a matter of capturing the statistical complexity of the dataset. For now, we will use `512` neurons. Try playing around with this value later to see how it affects training and to start developing a sense for what this number means.\n",
+ "\n",
+ "We will learn more about activation functions later, but for now, we will use the [relu](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html) activation function, which in short, will help our network to learn how to make more sophisticated guesses about data than if it were required to make guesses based on some strictly linear function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Flatten(start_dim=1, end_dim=-1),\n",
+ " Linear(in_features=784, out_features=512, bias=True),\n",
+ " ReLU()]"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "layers = [\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(input_size, 512), # Input\n",
+ " nn.ReLU(), # Activation for input\n",
+ "]\n",
+ "layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5.3 The Hidden Layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now we will add an additional densely connected linear layer. We will cover why adding another set of neurons can help improve learning in the next lesson. Just like how the input layer needed to know the shape of the data that was being passed to it, a hidden layer's [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) needs to know the shape of the data being passed to it. Each neuron in the previous layer will compute one number, so the number of inputs into the hidden layer is the same as the number of neurons in the previous later."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Flatten(start_dim=1, end_dim=-1),\n",
+ " Linear(in_features=784, out_features=512, bias=True),\n",
+ " ReLU(),\n",
+ " Linear(in_features=512, out_features=512, bias=True),\n",
+ " ReLU()]"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "layers = [\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(input_size, 512), # Input\n",
+ " nn.ReLU(), # Activation for input\n",
+ " nn.Linear(512, 512), # Hidden\n",
+ " nn.ReLU() # Activation for hidden\n",
+ "]\n",
+ "layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5.4 The Output Layer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, we will add an output layer. In this case, since the network is to make a guess about a single image belonging to 1 of 10 possible categories, there will be 10 outputs. Each output is assigned a neuron. The larger the value of the output neuron compared to the other neurons, the more the model predicts the input image belongs to the output neuron's assigned class.\n",
+ "\n",
+ "We will not assign the `relu` function to the output layer. Instead, we will apply a `loss function` covered in the next section."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Flatten(start_dim=1, end_dim=-1),\n",
+ " Linear(in_features=784, out_features=512, bias=True),\n",
+ " ReLU(),\n",
+ " Linear(in_features=512, out_features=512, bias=True),\n",
+ " ReLU(),\n",
+ " Linear(in_features=512, out_features=10, bias=True)]"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "n_classes = 10\n",
+ "\n",
+ "layers = [\n",
+ " nn.Flatten(),\n",
+ " nn.Linear(input_size, 512), # Input\n",
+ " nn.ReLU(), # Activation for input\n",
+ " nn.Linear(512, 512), # Hidden\n",
+ " nn.ReLU(), # Activation for hidden\n",
+ " nn.Linear(512, n_classes) # Output\n",
+ "]\n",
+ "layers"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.5.5 Compiling the Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "A [Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) model expects a sequence of arguments, not a list, so we can use the [* operator](https://docs.python.org/3/reference/expressions.html#expression-lists) to unpack our list of layers into a sequence. We can print the model to verify these layers loaded correctly."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "id": "7lmQ-G_FuThy"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Sequential(\n",
+ " (0): Flatten(start_dim=1, end_dim=-1)\n",
+ " (1): Linear(in_features=784, out_features=512, bias=True)\n",
+ " (2): ReLU()\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): ReLU()\n",
+ " (5): Linear(in_features=512, out_features=10, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model = nn.Sequential(*layers)\n",
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "hGSKooDz1DH5"
+ },
+ "source": [
+ "Much like tensors, when the model is first initialized, it will be processed on a CPU. To have it process with a GPU, we can use `to(device)`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 154,
+ "status": "ok",
+ "timestamp": 1714980725791,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "ynuW56udz7C1",
+ "outputId": "25675910-e29a-4f3d-99c8-2254d5277e38"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Sequential(\n",
+ " (0): Flatten(start_dim=1, end_dim=-1)\n",
+ " (1): Linear(in_features=784, out_features=512, bias=True)\n",
+ " (2): ReLU()\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): ReLU()\n",
+ " (5): Linear(in_features=512, out_features=10, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To check which device a model is on, we can check which device the model parameters are on. Check out this [stack overflow](https://stackoverflow.com/questions/58926054/how-to-get-the-device-type-of-a-pytorch-module-conveniently) post for more information."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 3,
+ "status": "ok",
+ "timestamp": 1714980916077,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "VlCPs5KDz9Ii",
+ "outputId": "67e73ac9-139e-4ad4-9ff6-d344abd267c3"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "device(type='cuda', index=0)"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "next(model.parameters()).device"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) introduced the ability to compile a model for faster performance. Learn more about it [here](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = torch.compile(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1.6 Training the Model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now that we have prepared training and validation data, and a model, it's time to train our model with our training data, and verify it with its validation data.\n",
+ "\n",
+ "\"Training a model with data\" is often also called \"fitting a model to data.\" Put another way, it highlights that the shape of the model changes over time to more accurately understand the data that it is being given."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6.1 Loss and Optimization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Just like how teachers grade students, we need to provide the model a function in which to grade its answers. This is called a `loss function`. We will use a type of loss function called [CrossEntropy](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) which is designed to grade if a model predicted the correct category from a group of categories."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "id": "jG8Oxrz0z_y2"
+ },
+ "outputs": [],
+ "source": [
+ "loss_function = nn.CrossEntropyLoss()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we select an `optimizer` for our model. If the `loss_function` provides a grade, the optimizer tells the model how to learn from this grade to do better next time."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "optimizer = Adam(model.parameters())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6.2 Calculating Accuracy"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "While the results of the loss function are effective in helping our model learn, the values can be difficult to interpret for humans. This is why data scientists often include other metrics like accuracy.\n",
+ "\n",
+ "In order to accurately calculate accuracy, we should compare the number of correct classifications compared to the total number of predictions made. Since we're showing data to the model in batches, our accuracy can be calculated along with these batches.\n",
+ "\n",
+ "First, the total number of predictions is the same size as our dataset. Let's assign the size of our datasets to `N` where `n` is synonymous with the `batch size`. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_N = len(train_loader.dataset)\n",
+ "valid_N = len(valid_loader.dataset)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Next, we'll make a function to calculate the accuracy for each batch. The result is a fraction of the total accuracy, so we can add the accuracy of each batch together to get the total."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_batch_accuracy(output, y, N):\n",
+ " pred = output.argmax(dim=1, keepdim=True)\n",
+ " correct = pred.eq(y.view_as(pred)).sum().item()\n",
+ " return correct / N"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6.3 The Train Function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here is where everything comes together. Below is the function we've defined to train our model based on the training data. We will walk through each line of code in more detail later, but take a moment to review how it is structured. Can you recognize the variables we created earlier?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8iPLK1V53w3R"
+ },
+ "outputs": [],
+ "source": [
+ "def train():\n",
+ " loss = 0\n",
+ " accuracy = 0\n",
+ "\n",
+ " model.train()\n",
+ " for x, y in train_loader:\n",
+ " x, y = x.to(device), y.to(device)\n",
+ " output = model(x)\n",
+ " optimizer.zero_grad()\n",
+ " batch_loss = loss_function(output, y)\n",
+ " batch_loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " loss += batch_loss.item()\n",
+ " accuracy += get_batch_accuracy(output, y, train_N)\n",
+ " print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6.4 The Validate Function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Similarly, this is the code for validating the model with data it did not train on. Can you spot some differences with the `train` function?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {
+ "id": "8WlmJPR5yZJ5"
+ },
+ "outputs": [],
+ "source": [
+ "def validate():\n",
+ " loss = 0\n",
+ " accuracy = 0\n",
+ "\n",
+ " model.eval()\n",
+ " with torch.no_grad():\n",
+ " for x, y in valid_loader:\n",
+ " x, y = x.to(device), y.to(device)\n",
+ " output = model(x)\n",
+ "\n",
+ " loss += loss_function(output, y).item()\n",
+ " accuracy += get_batch_accuracy(output, y, valid_N)\n",
+ " print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.6.5 The Training Loop"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To see how the model is progressing, we will alternated between training and validation. Just like how it might take a student a few times going through their deck of flash cards to learn all the concepts, the model will go through the training data multiple times to get a better and better understanding.\n",
+ "\n",
+ "An `epoch` is one complete pass through the entire dataset. Let's train and validate the model for 5 `epochs` to see how it learns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 57767,
+ "status": "ok",
+ "timestamp": 1714814236517,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "SElV0J_aw-bW",
+ "outputId": "452e3897-934b-4823-8a75-43a20551f06f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch: 0\n",
+ "Train - Loss: 378.8628 Accuracy: 0.9391\n",
+ "Valid - Loss: 36.4331 Accuracy: 0.9633\n",
+ "Epoch: 1\n",
+ "Train - Loss: 157.7399 Accuracy: 0.9739\n",
+ "Valid - Loss: 29.6794 Accuracy: 0.9720\n",
+ "Epoch: 2\n",
+ "Train - Loss: 111.7519 Accuracy: 0.9813\n",
+ "Valid - Loss: 21.7729 Accuracy: 0.9798\n",
+ "Epoch: 3\n",
+ "Train - Loss: 82.9790 Accuracy: 0.9856\n",
+ "Valid - Loss: 23.1030 Accuracy: 0.9792\n",
+ "Epoch: 4\n",
+ "Train - Loss: 66.1132 Accuracy: 0.9889\n",
+ "Valid - Loss: 24.2721 Accuracy: 0.9790\n"
+ ]
+ }
+ ],
+ "source": [
+ "epochs = 5\n",
+ "\n",
+ "for epoch in range(epochs):\n",
+ " print('Epoch: {}'.format(epoch))\n",
+ " train()\n",
+ " validate()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We're already close to 100%! Let's see if it's true by testing it on our original sample. We can use our model like a function:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 315,
+ "status": "ok",
+ "timestamp": 1714815283682,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "FPKjh3TN1_Sx",
+ "outputId": "729b2cfa-edcf-44f4-ebcc-c63538081e5d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[-29.8087, -5.8030, -12.9307, 8.3254, -24.2137, 14.2001, -19.0615,\n",
+ " -12.6502, -15.9811, -10.2585]], device='cuda:0',\n",
+ " grad_fn=<CompiledFunctionBackward>)"
+ ]
+ },
+ "execution_count": 46,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "prediction = model(x_0_gpu)\n",
+ "prediction"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There should be ten numbers, each corresponding to a different output neuron. Thanks to how the data is structured, the index of each number matches the corresponding handwritten number. The 0th index is a prediction for a handwritten 0, the 1st index is a prediction for a handwritten 1, and so on.\n",
+ "\n",
+ "We can use the `argmax` function to find the index of the highest value."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "executionInfo": {
+ "elapsed": 1101,
+ "status": "ok",
+ "timestamp": 1714815292555,
+ "user": {
+ "displayName": "Danielle Detering US",
+ "userId": "15432464718872067879"
+ },
+ "user_tz": 420
+ },
+ "id": "XrmW1TrN2OOr",
+ "outputId": "d0d8dc74-e88c-47db-f249-6e83d4f0dc44"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[5]], device='cuda:0')"
+ ]
+ },
+ "execution_count": 47,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "prediction.argmax(dim=1, keepdim=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Did it get it right?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "5"
+ ]
+ },
+ "execution_count": 48,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "I8O7pADY2zgL"
+ },
+ "source": [
+ "## 1.7 Summary"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "p-YXMUiN2zgL"
+ },
+ "source": [
+ "The model did quite well! The accuracy quickly reached close to 100%, as did the validation accuracy. We now have a model that can be used to accurately detect and classify hand-written images.\n",
+ "\n",
+ "The next step would be to use this model to classify new not-yet-seen handwritten images. This is called [inference](https://blogs.nvidia.com/blog/2016/08/22/difference-deep-learning-training-inference-ai/). We'll explore the process of inference in a later exercise."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Yu5NNukp2zgL"
+ },
+ "source": [
+ "It's worth taking a moment to appreciate what we've done here. Historically, the expert systems that were built to do this kind of task were extremely complicated, and people spent their careers building them (check out the references on the [official MNIST page](http://yann.lecun.com/exdb/mnist/) and the years milestones were reached).\n",
+ "\n",
+ "MNIST is not only useful for its historical influence on Computer Vision, but it's also a great [benchmark](http://www.cs.toronto.edu/~serailhydra/publications/tbd-iiswc18.pdf) and debugging tool. Having trouble getting a fancy new machine learning architecture working? Check it against MNIST. If it can't learn on this dataset, chances are it won't learn on more complicated images and datasets."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "lV4QiOtg2zgL"
+ },
+ "source": [
+ "### 1.7.1 Clear the Memory"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "o3iQrnws2zgL"
+ },
+ "source": [
+ "Before moving on, please execute the following cell to clear up the GPU memory. This is required to move on to the next notebook."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "JCvpuSx-2zgM"
+ },
+ "outputs": [],
+ "source": [
+ "import IPython\n",
+ "app = IPython.Application.instance()\n",
+ "app.kernel.do_shutdown(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a7ybpLwy2zgM"
+ },
+ "source": [
+ "### 1.7.2 Next"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_oCUyjG72zgM"
+ },
+ "source": [
+ "In this section you learned how to build and train a simple neural network for image classification. In the next section, you will be asked to build your own neural network and perform data preparation to solve a different image classification problem."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zqcwglnk2zgM"
+ },
+ "source": [
+ "<center><a href=\"https://www.nvidia.com/dli\"> <img src=\"images/DLI_Header.png\" alt=\"Header\" style=\"width: 400px;\"/> </a></center>"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}