Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
201 views
in Technique[技术] by (71.8m points)

python - Why do we need to call zero_grad() in PyTorch?

The method zero_grad() needs to be called during training. But the documentation is not very helpful

|  zero_grad(self)
|      Sets gradients of all model parameters to zero.

Why do we need to call this method?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

In PyTorch, for every mini-batch during the training phase, we need to explicitly set the gradients to zero before starting to do backpropragation (i.e., updation of Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes. This is convenient while training RNNs. So, the default action has been set to accumulate (i.e. sum) the gradients on every loss.backward() call.

Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. Else the gradient would point in some other direction than the intended direction towards the minimum (or maximum, in case of maximization objectives).

Here is a simple example:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternatively, if you're doing a vanilla gradient descent, then:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Note:

  • The accumulation (i.e., sum) of gradients happen when .backward() is called on the loss tensor.
  • As of v1.7.0, there's an option of resetting the gradients with None optimizer.zero_grad(set_to_none=True) instead of filling it with a tensor of zeroes. The docs claim that this setting will result in lower memory requirements and a slight improvement in performance but it might be error-prone, if not handled carefully.

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...