优雅的维度转换_rearrange函数

前言

最近我在看一些开源项目代码的时候经常看见这样一个函数rearrange, 来进行维度转换,而不是使用permute。虽然有些时候permute 可以与rearrange替换, 但是可读性不如后者。

写这篇文章的时候看见这篇博文 非科班的他,凭借什么拿到了DeepMind的offer?, 这个故事告诉我们坚持不懈学习,并且多输出是有作用滴。



转换channel

import torch
from einops import rearrange 


# H, W, C
a = torch.randn(2, 2, 3)

# H, W, C -> C, H, W
a_permute = a.permute(2, 0, 1)
print('a_permute.shape:  ', a_permute.shape)

a_rearrange = rearrange(a, 'h w c -> c h w')
print('a_rearrange.shape:', a_rearrange.shape)

print('逐元素进行判断是否相等: ', a_permute.equal(a_rearrange))

a_permute.shape: torch.Size([3, 2, 2])
a_rearrange.shape: torch.Size([3, 2, 2])
逐元素进行判断是否相等: True



维度合并

import torch
from einops import rearrange


# B, C, H, W
a = torch.arange(9 * 2 * 2).view(1, 9, 2, 2)
# print(a)

b = rearrange(a, 'b c h w -> b c (h w)')
print(b.shape)

torch.Size([1, 9, 4])



高级用法

这里其实好像还和pixelshuffle结果不太一样,虽然维度是一样的。有空再去倒腾倒腾…(todo)

import torch
from einops import rearrange


# B, C, H, W
a = torch.arange(36).view(1, 9, 2, 2)
# print(a)

# 建议在 torch1.12.x 测试 PixleShuffle这个类
# ps = torch.nn.PixleShuffle(3)

b = rearrange(a, 'b (c h1 w2) h w -> b c (h1 h) (w2 w)', h1=3, w2=3)
# print(b)
# b_ps = ps(a)
# print('b.equal(b_ps): ', b.equal(b_ps))

c = rearrange(b, 'b c (h1 h) (w2 w) -> b (c h1 w2) h w', h1=3, w2=3)
print('a.equal(c):', a.equal(c))

a.equal©: True

猜你喜欢

转载自blog.csdn.net/weixin_43850253/article/details/126275912