在使用autograd.grad()时,需要注意它计算的是某一标量相对于一组参数的梯度。如果只计算某个参数的梯度,需要把标量的grad_fn设置为None,否则会计算所有参数的梯度。具体实现如下:
import torch
# 定义模型参数
x = torch.tensor([2.0], requires_grad=True)
w = torch.tensor([3.0], requires_grad=True)
# 计算标量
y = x * w
# 计算某个参数的梯度
grad_w = torch.autograd.grad(y, w, retain_graph=True)
# 由于y与x相关,所以也需要计算x的梯度,否则会报错
grad_x = torch.autograd.grad(y, x, retain_graph=True)
print(grad_w) # 输出为tensor([2.])
print(grad_x) # 输出为tensor([3.])
在计算某个参数的梯度时,需要将其他参数的grad_fn设置为None,保证只计算该参数的梯度。