[pytorch、学习] - 5.3 多输入通道和多输出通道

参考

5.3 多输入通道和多输出通道

前面两节里我们用到的输入和输出都是二维数组,但真实数据的维度经常更高。例如,彩色图像在高和宽2个维度外还有RGB(红、绿、蓝)3个颜色通道。假设彩色图像的高和宽分别是h和w(像素),那么它可以表示为一个3 * h * w的多维数组。我们将大小为3的这一维称为通道(channel)维。本节将介绍含多个输入通道或多个输出通道的卷积核。

5.3.1 多输入通道

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gZLOEKry-1594174473890)(attachment:image.png)]
接下来我们实现含多个输入通道的互相关运算。我们只需要对每个通道做互相关运算,然后通过add_n函数来进行累加

import torch
import torch.nn as nn
import sys
sys.path.append("..")
import d2lzh_pytorch  as d2l

def corr2d_multi_in(X, K):
    # 沿着X和K的第0维(通道维)分别计算再相加
    res = d2l.corr2d(X[0, :, :], K[0, :, :])
    print(res)
    for i in range(1, X.shape[0]):  # X.shape[0]代表多少个通道,此处为2个
        res += d2l.corr2d(X[i, :, :], K[i, :, :])
    return res
X = torch.tensor([[[0,1,2],[3,4,5],[6,7,8]],[[1,2,3], [4,5,6], [7,8,9]] ])

K = torch.tensor([[[0,1],[2,3]], [[1,2],[3,4]]])

corr2d_multi_in(X, K)

在这里插入图片描述

5.3.2 多输出通道

当输入通道有多个时,因为我们对各自通道的结果做了累加,所以不论输入通道数是多少,输出通道数总是为1。设卷积核输入通道数和输出通道数分别为c(i)和c(o),高和宽分别为k(h)和k(w)。如果希望得到含多个通道的输出,我们可以为每个输出通道分别创建形状为c(i) * k(k) * h(w)的核数组。将它们在输出通道维上连结,卷积核的形状即 c(o) * c(i) * k(h) * k(w)。在做互相关运算时,每个输出通道上的结果由卷积核在输出通道上的核数组与整个输入数组计算而来。

简单说就是,如果你想输出N个通道,你就需要创建N个 C * H * W的卷积核
下面实现一个互相关运算函数来计算多个通道的输出。

def corr2d_multi_in_out(X, K):
    # 对K的第0维遍历,每次同输入X做互相关计算。所有结果使用stack函数合并在一起
    return torch.stack([corr2d_multi_in(X, k) for k in K])

我们将核数组K同K+1(K中每个元素加一)和K+2连结在一起来构造一个输出通道数为3的卷积核

K = torch.tensor([[[0,1],[2,3]], [[1,2],[3,4]]])


# 构造3个卷积核
K = torch.stack([K, K+1, K+2])
K.shape

在这里插入图片描述
下面我们对输入数组X与核数组K做互相关运算。此时的输出含有3个通道。其中第一个通道的结果与之前输入数组X与多输入通道、单输出通道核的计算结果一致。

# 输入的规模为  2 * 3 * 3 输出的规模为 3 * (3 - 2+ 1) * (3 - 2 + 1)
corr2d_multi_in_out(X, K)

在这里插入图片描述

5.3.3 1 * 1卷积层

在这里插入图片描述

def corr2d_multi_in_out_1x1(X, K):
    c_i, h, w = X.shape
    c_o = K.shape[0]
    X = X.view(c_i, h * w)
    K = K.view(c_o, c_i)
    Y = torch.mm(K, X)  # 全连接层的矩阵乘法
    return Y.view(c_o, h, w)
X = torch.rand(3, 3, 3)
K = torch.rand(2, 3, 1, 1)

Y1 = corr2d_multi_in_out_1x1(X, K)
Y2 = corr2d_multi_in_out(X, K)

(Y1 - Y2).norm().item()  < 1e-6

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/piano9425/article/details/107199288
5.3