PyTorch Tutorial de Autograd

En PyTorch, autograd es el contenido central de todas las redes neuronales y proporciona métodos de derivación automáticos para todas las operaciones de Tensor.

Es un marco definido por la forma en que se ejecuta, lo que significa que el backprop se define por la forma en que se ejecuta el código.

 一 、 Variable

autograd.Variable es la clase principal en autograd. Envuelve un Tensor y admite casi todas las operaciones definidas en él. Una vez que se complete su cálculo, puede llamar a .backward () para calcular automáticamente todos los gradientes.

La variable tiene tres atributos: datos, graduación y creador.

Acceda al tensor original utilizando el atributo .data; el gradiente de esta Variable se centra en .grad; .creator refleja al creador e identifica si el usuario lo creó directamente usando .Variable (Ninguno).

También hay una clase que es muy importante para la implementación de la función autograd. Los números de variables y funciones están relacionados entre sí y establecen un gráfico acíclico para codificar el proceso de cálculo completo. Cada variable tiene un atributo .grad_fn que hace referencia a la función que creó la variable (excepto las variables creadas por el usuario cuyo grad_fn es None).

import torch
from torch.autograd import Variable

Crea la variable x:

x = Variable(torch.ones(2, 2), requires_grad=True)
print(x)

Resultado de salida:

Variable containing:
 1  1
 1  1
[torch.FloatTensor of size 2x2]

Operar sobre la base de x:

y = x + 2 
print(y)

Resultado de salida:

Variable containing:
 3  3
 3  3
[torch.FloatTensor of size 2x2]

查看x的grad_fn :

print(x.grad_fn)

Resultado de salida:

None

查看y的grad_fn :

print(y.grad_fn)

Resultado de salida:

<torch.autograd.function.AddConstantBackward object at 0x7f603f6ab318>

可以看到y是作为运算的结果产生的,所以y有grad_fn,而x是直接创建的,所以x没有grad_fn。

 Calcular basado en y: 

z = y * y * 3
out = z.mean()
print(z, out)

Resultado de salida:

Variable containing:
 27  27
 27  27
[torch.FloatTensor of size 2x2]
 Variable containing:
 27
[torch.FloatTensor of size 1]
 

二 、 Gradientes

如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数.

out.backward()Es equivalente aout.backward(torch.Tensor([1.0])).

out.backward()
print(x.grad)

Resultado de salida:

Variable containing:
 4.5000  4.5000
 4.5000  4.5000
[torch.FloatTensor of size 2x2]

Si tiene más elementos (vector), debe especificar un parámetro grad_output que coincida con la forma del tensor (la derivada de y proyectada a x en la dirección especificada)

x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)

 

Resultado de salida:

Variable containing:
-1296.5227
  499.0783
  778.8971
[torch.FloatTensor of size 3]x = torch.randn(3)
x = Variable(x, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)

 

Resultado de salida:

Variable containing:
-1296.5227
  499.0783
  778.8971
[torch.FloatTensor of size 3]

Sin parámetros:

 

y.backward()
print(x.grad)

 

Resultado de salida:

RuntimeError: grad can be implicitly created only for scalar outputs
None

Parámetros entrantes:

gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
y.backward(gradients)
print(x.grad)

Resultado de salida:

Variable containing:
  102.4000
 1024.0000
    0.1024
[torch.FloatTensor of size 3]

Simplemente pruebe el efecto de diferentes parámetros:

Parámetro 1: [1,1,1]

 

 

x=torch.FloatTensor([1,2,3])
x = Variable(x, requires_grad=True)
y = x * x
print(y)

gradients = torch.FloatTensor([1,1,1])
y.backward(gradients)  
print(x.grad)

 

 Resultado de salida:

Variable containing:
 1
 4
 9
[torch.FloatTensor of size 3]
Variable containing:
 2
 4
 6
[torch.FloatTensor of size 3]
 

 

 Parámetro 2: [3,2,1] 

 

x=torch.FloatTensor([1,2,3])
x = Variable(x, requires_grad=True)
y = x * x
print(y)

gradients = torch.FloatTensor([3,2,1])
y.backward(gradients)  
print(x.grad)

 

 Resultado de salida:

Variable containing:
 1
 4
 9
[torch.FloatTensor of size 3]
Variable containing:
 6
 8
 6
[torch.FloatTensor of size 3]
 

 

943 artículos originales publicados · Me gusta 136 · Visita 330,000+

Supongo que te gusta

Origin blog.csdn.net/weixin_36670529/article/details/105299062
Recomendado
Clasificación