When should with torch.no_grad() be used? When should I use .requires_grad == False?

A basic torch problem, I want to write about it when I have nothing to do.
Regardless of whether you use with torch.no_grad() or .requires_grad == False, generally speaking, it will not affect the algorithm itself, but it will affect the code performance.

with torch.no_grad()

The tensor obtained by the operation below does not have grad_fn, that is, it does not have a gradient (because there is no function of the upper level), so the loss cannot be passed up from these tensors, and the parameters of the network that generates these tensors will not be updated. In the following situations, with torch.no_grad() is generally used:
Insert picture description here
here we just use the output of net2 to calculate loss, and do not want loss to update the network parameters of net2, so use with torch.no_grad(), so that loss is blocked In the process of loss.backward, net1 calculates the network parameter gradient normally. If you don't use with torch.no_grad(), it's okay. It's just a time-consuming calculation of the gradient for the parameters of net2, but only the parameter step of net1 is in the optimizer.step. In addition, this also explains why optimizer.zero_grad is required before the update. If you do not have no_grad as mentioned above, the network parameters of net2 have gradients, and there is a backward loss afterwards, which will lead to gradient superposition. Maybe this is not what we want. The result (of course, there are also such superimposed gradients, which often require retain_graph==True).

.requires_grad == False

The following example can be used:
Insert picture description here
here we only want to update net1 through loss, net2 does not want to update, can it be achieved through with torch.no_grad()? The answer is no. Once used, the loss flow is blocked. What should I do? as follows:

for p in net2.parameters():
    p.requires_grad = False

In this way, instead of calculating the gradient of net2's network weight w, it only uses its value to calculate the gradient of net1, which improves code performance.

About retain_graph == True

This blog speaks very well.

Guess you like

Origin blog.csdn.net/weixin_43145941/article/details/114757673