In the previous sections you learned how to use gradient descent to find the minimum of a function, but to do that we needed the gradient. For our simple example, we could compute the gradient with paper and pencil. For deep learning models, that is impractical, so we rely on libraries like PyTorch that provide automatic differentiation capabilities that make it much easier.
The basic idea is that in PyTorch we create a computational graph, similar to the diagrams we used in the previous section, where relations between inputs, outputs, and connections between different functions are made explicit and kept track of so we can easily apply the chain rule automatically to compute gradients. Fortunately, switching from numpy to PyTorch is simple, and most of the time we can just replace numpy with torch. Let’s translate our neural network from above into PyTorch.
import torch def nn(x,w1,w2): l1 = x @ w1 1 l1 = torch.relu(l1) 2 l2 = l1 @ w2 return l2 w1 = torch.randn(784,200,requires_grad=True) 3 w2 = torch.randn(200,10,requires_grad=True)
This looks almost identical to the numpy version except that we use torch.relu instead of np.maximum, but they are the same function. We also added a requires_grad=True parameter to the weight matrix setup. This tells PyTorch that these are trainable parameters that we want to track gradients for, whereas x is an input, not a trainable parameter. We also got rid of the last activation function for reasons that will become clear. For this example, we will use the famous MNIST data set that contains images of handwritten digits from 0 to 9, such as the one in figure A.2.
We want to train our neural network to recognize these images and classify them as digits 0 through 9. PyTorch has a related library that lets us easily download this data set.
mnist_data = TV.datasets.MNIST("MNIST", train=True, download=False) 1 lr = 0.001 epochs = 2000 batch_size = 100 lossfn = torch.nn.CrossEntropyLoss() 2 for i in range(epochs): rid = np.random.randint(0,mnist_data.train_data.shape[0],size=batch_size)3 x = mnist_data.train_data[rid].float().flatten(start_dim=1) 4 x /= x.max() 5 pred = nn(x,w1,w2) 6 target = mnist_data.train_labels[rid] 7 loss = lossfn(pred,target) 8 loss.backward() 9 with torch.no_grad(): 10 w1 -= lr * w1.grad 11 w2 -= lr * w2.grad
You can tell that the neural network is successfully training by observing the loss function fairly steadily decreasing over training time (figure A.3). This short code snippet trains a complete neural network to successfully classify MNIST digits at around 70% accuracy. We just implemented gradient descent exactly the same way we did with our simple logarithmic function f(x) = log(x4 + x3 + 2), but PyTorch handled the gradients for us. Since the gradient of the neural network’s parameters depends on the input data, each time we run the neural network “forward” with a new random sample of images, the gradients will be different. So we run the neural network forward with a random sample of data, PyTorch keeps track of the computations that occur, and when we’re done, we call the backward() method on the last output; in this case it is generally the loss. The backward() method uses automatic differentiation to compute all gradients for all PyTorch variables that have requires_grad=True set. Then we can update the model parameters using gradient descent. We wrap the actual gradient descent part in the torch.no_grad() context because we don’t want it to keep track of these computations.
We can easily achieve greater than 95% accuracy by improving the training algorithm with a more sophisticated version of gradient descent. In listing A.4 we implemented our own version of stochastic gradient descent, the stochastic part because we are randomly taking subsets from the dataset and computing gradients based on that, which gives us noisy estimates of the true gradient given the full set of data.
PyTorch includes built-in optimizers, of which stochastic gradient descent (SGD) is one. The most popular alternative is called Adam, which is a more sophisticated version of SGD. We just need to instantiate the optimizer with the model parameters.
mnist_data = TV.datasets.MNIST("MNIST", train=True, download=False) lr = 0.001 epochs = 5000 batch_size = 500 lossfn = torch.nn.CrossEntropyLoss() 1 optim = torch.optim.Adam(params=[w1,w2],lr=lr) 2 for i in range(epochs): rid = np.random.randint(0,mnist_data.train_data.shape[0],size=batch_size) x = mnist_data.train_data[rid].float().flatten(start_dim=1) x /= x.max() pred = nn(x,w1,w2) target = mnist_data.train_labels[rid] loss = lossfn(pred,target) loss.backward() 3 optim.step() 4 optim.zero_grad() 5
You can see that the loss function in figure A.4 is much smoother now with the Adam optimizer, and it dramatically increases the accuracy of our neural network classifier.