【deep_thoughts】16_PyTorch中进行卷积残差模块算子融合


视频链接: 16、PyTorch中进行卷积残差模块算子融合_哔哩哔哩_bilibili

Conv2D官方API:Conv2d — PyTorch 2.0 documentation

原始论文:RepVGG: Making VGG-style ConvNets Great Again (arxiv.org)

理论

​ 论文中提及如何将一个训练时的多分支模块转换为单一的 3 × 3 3\times 3 3×3卷积,从而达到加速的目的。如下图所示:

21hvi.png

代码验证

​ 视频介绍的代码中并没有考虑BN层。

原生写法

​ 对应上图中(A)的第1幅小图:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time


in_c = 2
out_c = 2
k = 3  # kernel_size
w = h = 9

x = torch.ones(1, in_c, h, w)  # 输入图片 [batch_size,channels,h,w]

conv_2d = nn.Conv2d(in_c, out_c, k, padding="same")  # in_channels=out_channels,不然输入输出不一致,后面无法相加
conv_2d_pointwise = nn.Conv2d(in_c, out_c, 1)  
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
print(conv_2d_pointwise.weight.size())  # 打印3*3 conv层参数维度
print(conv_2d_pointwise.bias.size())
print(conv_2d_pointwise.weight.size())  # 打印1*1 conv层参数维度
print(conv_2d_pointwise.bias.size())

​ 参数维度的输出结果如下:

torch.Size([2, 2, 3, 3])  # 3*3 conv weight [out_c, in_c, k, k]
torch.Size([2])   		  # 3*3 conv bias [out_c]
torch.Size([2, 2, 1, 1])  # 1*1 conv 
torch.Size([2])

算法融合

1.改造

​ 代码中需要使用torch.nn.functional.pad,查看官网上的例子就差不多明白了。

pad官网API:torch.nn.functional.pad — PyTorch 2.0 documentation

​ 对应上图中(A)的第2幅小图,卷积核参数对应上图(B)的第2幅小图。

​ 首先是 1 × 1 1\times 1 1×1的卷积变成 3 × 3 3\times 3 3×3卷积

# 原来 conv_2d_pointwise.weight.size() 为[2,2,1,1]
# 需要将其从 1*1 卷积转变成 3*3 卷积,所以 weight.size 需要变为 [2,2,3,3]
# 使用F.pad将最后的h和w维填充0,使之从[2,2,1,1]->[2,2,3,3]
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1])  # pad last dim by (1, 1) and 2nd to last by (1, 1)
conv_2d_for_pointwise = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)  # 修改参数
conv_2d_for_pointwise.bias = conv_2d_pointwise.bias

​ 接着是将恒等映射转变成 3 × 3 3 \times 3 3×3的卷积,所以我们不需要考虑相邻像素点以及通道之间的关联性。因此,对于这样一个卷积层,它的weight.size()首先肯定是2*2*3*3的大小。

# 不考虑相邻点和通道之间的关联性
# 只考虑单个通道的影响
zeros = torch.unsqueeze(torch.zeros(k, k), 0)  # [1,3,3]
# 只考虑一个点的影响
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)  # [1,3,3]

​ stars和zeros的效果如下图:
在这里插入图片描述

# 在第0维进行拼接,再升维 [1,2,3,3]
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)  
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)

identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)  # [2,2,3,3]
identity_to_conv_bias = torch.zeros([out_c])

conv_2d_for_identity = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)

result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)
# print(result2)
print(torch.all(torch.isclose(result1, result2)))  # 判断是否相等

​ 因为是两个浮点矩阵,不能直接用torch.equal去比较,只能通过torch.isclose方法比较result1和result2。由于两个均为张量,再加torch.all 将它们统一判断一下,输出结果如下:

tensor(True)

2.融合

​ 最后将这3个 3 × 3 3 \times 3 3×3的卷积融合起来。对应上图中(A)的第3幅小图,根据(B)中的参数示意图,将所有卷积层的权重和偏置各自相加。

conv_2d_for_fusion = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data +
                                         conv_2d_for_pointwise.weight.data +
                                         conv_2d_for_identity.weight.data)  # 所有参数相加
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data +
                                       conv_2d_for_pointwise.bias.data +
                                       conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)
# print(result3)
print(torch.all(torch.isclose(result3, result2)))  # 判断是否相等

​ 判断result3和result2是否相等,输出结果如下:

tensor(True)

对比耗时

​ 导入time库,使用time.time()来计算不同方法之间的耗时。全部代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

in_c = 2
out_c = 2
k = 3  # kernel_size
w = h = 9

x = torch.ones(1, in_c, h, w)  # 输入图片 [batch_size,channels,h,w]

# res_block = 3*3 conv + 1*1 conv + input

# 方法1.原生写法
t1 = time.time()
conv_2d = nn.Conv2d(in_c, out_c, k, padding="same")  # in_channels=out_channels,不然输入输出不一致
conv_2d_pointwise = nn.Conv2d(in_c, out_c, 1)
result1 = conv_2d(x) + conv_2d_pointwise(x) + x
t2 = time.time()
print(conv_2d.weight.size())
print(conv_2d.bias.size())
print(conv_2d_pointwise.weight.size())
print(conv_2d_pointwise.bias.size())
print(t2-t1)

# 方法2.算法融合
# 把 point-wise卷积 和 x 本身都写成 3*3 的卷积
# 最终把三个卷积写成一个卷积
# 1) 改造
pointwise_to_conv_weight = F.pad(conv_2d_pointwise.weight, [1, 1, 1, 1, 0, 0, 0, 0])  # 从里到外pad,上面一行,下面一行,左边一列,右边一列各自pad
conv_2d_for_pointwise = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_pointwise.weight = nn.Parameter(pointwise_to_conv_weight)
conv_2d_for_pointwise.bias = conv_2d_pointwise.bias

# 不考虑相邻点和通道之间的关联性
# 只考虑单个通道的影响
zeros = torch.unsqueeze(torch.zeros(k, k), 0)
# 只考虑一个点的影响
stars = torch.unsqueeze(F.pad(torch.ones(1, 1), [1, 1, 1, 1]), 0)
stars_zeros = torch.unsqueeze(torch.cat([stars, zeros], 0), 0)
zeros_stars = torch.unsqueeze(torch.cat([zeros, stars], 0), 0)
identity_to_conv_weight = torch.cat([stars_zeros, zeros_stars], 0)
identity_to_conv_bias = torch.zeros([out_c])

conv_2d_for_identity = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_identity.weight = nn.Parameter(identity_to_conv_weight)
conv_2d_for_identity.bias = nn.Parameter(identity_to_conv_bias)

result2 = conv_2d(x) + conv_2d_for_pointwise(x) + conv_2d_for_identity(x)
# print(result2)
print(torch.all(torch.isclose(result1, result2)))  # 判断是否相等

# 2) 融合
t3 = time.time()
conv_2d_for_fusion = nn.Conv2d(in_c, out_c, k, padding="same")
conv_2d_for_fusion.weight = nn.Parameter(conv_2d.weight.data +
                                         conv_2d_for_pointwise.weight.data +
                                         conv_2d_for_identity.weight.data)  # 参数相加起来
conv_2d_for_fusion.bias = nn.Parameter(conv_2d.bias.data +
                                       conv_2d_for_pointwise.bias.data +
                                       conv_2d_for_identity.bias.data)
result3 = conv_2d_for_fusion(x)
t4 = time.time()
print("原生写法耗时: ", t2 - t1, "\n算子融合写法耗时: ", t4 - t3)
# print(result3)

print(torch.all(torch.isclose(result3, result2)))  # 判断是否相等

​ 输出结果为:

原生写法耗时:  0.0029582977294921875 
算子融合写法耗时:  0.0009975433349609375

​ 可以看到,使用算子融合确实耗时更少。

猜你喜欢

转载自blog.csdn.net/qq_45670134/article/details/129883225
今日推荐