What is optimizer.zero_grad()

76    Asked by DanielBAKER in Data Science , Asked on Sep 26, 2024

What is meant by optimizer.zero_grad(). Just take an SGD as an example:

Wt+1=Wt−λgt

Which one becomes zero is it gt and not Wt for each batch right? Overall, for any optimizer, does it mean all other parameters except for Wt and Wt+1?

Answered by Daniel Cameron

“optimizer.zero_grad()” is simply a PyTorch method that resets the gradients of all model parameters before beginning a new backward pass.


  when you invoke "optimizer.zero_grad()", "gt" becomes zero.

After invoking "zero_grad()" we calculate the forward pass and then call "loss.backward()", which populates gt again.

  To update the weights "Wt+1=Wt−λgt", finally we invoke "optimizer.step()".

"zero_grad()" should be invoked to prevent "loss.backward()" from adding the new gradient values to the ones from the previous step.

Each optimizer must have a distinct method for updating the weights with the gradients.



Your Answer

Interviews

Parent Categories