卷积-李沐老师

import torch
from torch import nn
from d2l import torch as d2l

def corr2d(X,K):
    """计算二维互相关运算"""
    h,w=K.shape
    Y=torch.zeros((X.shape[0]-h+1,X.shape[1]-w+1))
    for i in range(Y.shape[0]):
        for j in range(Y.shape[1]):
            Y[i,j]=(X[i:i+h,j:j+w]*K).sum()
    return Y

X=torch.tensor([[0.0,1.0,2.0],[3.0,4.0,5.0],[6.0,7.0,8.0]])
K=torch.tensor([[0.0,1.0],[2.0,3.0]])
Q=corr2d(X,K)
print(Q)

输出:

tensor([[19., 25.],
        [37., 43.]])
class Conv2D(nn.Module):
    def __init__(self,kernel_size):
        super().__init__()
        self.weight=nn.Parameter(torch.rand(kernel_size))
        self.bias=nn.Parameter(torch.zeros(1))

    def forward(self,x):
        return corr2d(x,self.weight)+self.bias

X=torch.ones(size=(6,8))
X[:,2:6]=0
print(X)

输出:

tensor([[1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.],
        [1., 1., 0., 0., 0., 0., 1., 1.]])
K=torch.tensor([[1.0,-1.0]])
Y=corr2d(X,K)
print(Y)
print(corr2d(X.t(),K))

输出:

tensor([[ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  0.,  0., -1.,  0.]])
tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
#给定输入X,输出Y,学习K.
#输出通道为1,输入通道为1.
conv2d=nn.Conv2d(1,1,kernel_size=(1,2),bias=False)
X=X.reshape((1,1,6,8))
Y=Y.reshape((1,1,6,7))

for i in range(20):
    Y_hat=conv2d(X)
    l=(Y_hat-Y)**2
    conv2d.zero_grad()
    l.sum().backward()
    #学习率3e-2.
    conv2d.weight.data[:]-=3e-2*conv2d.weight.grad
    if(i+1)%2==0:
        print(f'batch {i+1},loss {l.sum():.3f}')

print(conv2d.weight.data.reshape((1,2)))

 输出:

batch 2,loss 8.712
batch 4,loss 1.614
batch 6,loss 0.333
batch 8,loss 0.082
batch 10,loss 0.024
batch 12,loss 0.008
batch 14,loss 0.003
batch 16,loss 0.001
batch 18,loss 0.001
batch 20,loss 0.000
tensor([[ 0.9984, -1.0013]])

猜你喜欢

转载自blog.csdn.net/qq_45828494/article/details/126652420