Pytorch gradient update issues and possible solution
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
lr = 1e-1 | |
n_epochs = 1000 | |
torch.manual_seed(42) | |
a = torch.randn(1, requires_grad=True, dtype=torch.float, device=device) | |
b = torch.randn(1, requires_grad=True, dtype=torch.float, device=device) | |
for epoch in range(n_epochs): | |
yhat = a + b * x_train_tensor | |
error = y_train_tensor - yhat | |
loss = (error ** 2).mean() | |
# No more manual computation of gradients! | |
# a_grad = -2 * error.mean() | |
# b_grad = -2 * (x_tensor * error).mean() | |
# We just tell PyTorch to work its way BACKWARDS from the specified loss! | |
loss.backward() | |
# Let's check the computed gradients... | |
print(a.grad) | |
print(b.grad) | |
# What about UPDATING the parameters? Not so fast... | |
# FIRST ATTEMPT | |
# AttributeError: 'NoneType' object has no attribute 'zero_' | |
# a = a - lr * a.grad | |
# b = b - lr * b.grad | |
# print(a) | |
# SECOND ATTEMPT | |
# RuntimeError: a leaf Variable that requires grad has been used in an in-place operation. | |
# a -= lr * a.grad | |
# b -= lr * b.grad | |
# THIRD ATTEMPT | |
# We need to use NO_GRAD to keep the update out of the gradient computation | |
# Why is that? It boils down to the DYNAMIC GRAPH that PyTorch uses... | |
with torch.no_grad(): | |
a -= lr * a.grad | |
b -= lr * b.grad | |
# PyTorch is "clingy" to its computed gradients, we need to tell it to let it go... | |
a.grad.zero_() | |
b.grad.zero_() | |
print(a, b) |
Comments