Problemas de sumas y restas/operaciones lógicas entre tensores con dimensiones inconsistentes en pytorch

Regla 1: Si las dimensiones de los dos tensores a sumar son inconsistentes, primero alinee el tensor con la dimensión más baja con el tensor con la dimensión más alta desde la derecha

Por ejemplo, en el siguiente código, la dimensión de b es más baja, por lo que cuando se agrega a a, la dimensión de b primero se expandirá a [1,1,5,6].

a = torch.ones([8, 4, 5, 6])
b = torch.ones([5, 6])
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

Después de la alineación, puede saltar a la regla 2.

​Regla 2: Cuando dos tensores tienen la misma dimensión, los valores de los ejes correspondientes deben ser iguales, o 1.

Al sumar, copie y expanda todos los ejes que son 1, para obtener dos tensores con exactamente las mismas dimensiones. Luego agregue las posiciones correspondientes.

Ejemplos que se pueden agregar:

1. Dado que cada eje correspondiente es igual o uno de ellos es 1, se pueden sumar. De lo contrario no se puede agregar.

a = torch.ones([8, 4, 5, 6])
b = torch.ones([1, 1, 5, 6])
c = a+b
# c = torch.Size([8, 4, 5, 6])

más:

a = torch.ones([5, 1, 1, 5])
b = torch.ones([5, 5])
c = a+b
print('c =', c.size())
# c = torch.Size([5, 1, 5, 5])

Aquí las dimensiones de a son (5,1,1,5), las dimensiones de b son (5,5), y las dimensiones de c del resultado final son (5,1,5,5), donde la suma la operación se realiza tanto en a como en b Cuando
se realiza la operación de expansión de dimensión, primero alinee b con a desde la derecha de acuerdo con la regla 1, y la dimensión de b se convierte en (1,1,5,5),
y luego de acuerdo con la regla 2 , las dimensiones de a y b son las mismas, y la operación se puede realizar. En la operación real, tanto a como b copiarán y expandirán el eje con una dimensión de 1, y las dimensiones se convertirán en (5,1, 5,5), por lo que la operación de suma se puede realizar

2. Una dimensión es todo 1, que se puede sumar

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 1, 1, 1])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

3. Las dimensiones son exactamente iguales y se pueden sumar

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([8, 4, 5, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

Ejemplos que no se pueden agregar :

1. Como 4 no es igual a 2, no se puede sumar

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 2, 1, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())

2. Como 3 no es igual a 6, no se puede sumar

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 4, 1, 3])
print('b =',b.size())
c = a+b
print('c =',c.size())

Supongo que te gusta

Origin blog.csdn.net/ytusdc/article/details/128106265
Recomendado
Clasificación