torchとnumpyの次元変換

トーチ寸法変換

PIL で読み込まれるピクチャの形式は(H,W, C)
numpy で、保存されるピクチャの形式は (batch_size, H, W, C) です。
通常、畳み込みに必要なものは(batch_size,C,H, W)

したがって、次元変換が必要になります。

1. numpyでの次元変換

形状変換のためにnumpy で使用されますreshape
transposeの役割は、座標軸を変換し、問題を見る角度を切り替えることです。

n = np.random.randn(2, 3, 4)
reshape_n = n.reshape(-1, 12)
print(reshape_n.shape) # (2, 12)
transpose_n = n.transpose(1, 2, 0)
print(transpose_n.shape) # (3, 4, 12)

2. トーチ寸法変換

変形、サイズ変換、合計数量変更に使用しますtorch.view

t = torch.randn(2, 3, 4)
view_t = t.view(-1, 12)
print(view_t.shape) # torch.Size([2, 12])

torch.squeeze()/torch.unsqueeze()これは、寸法を圧縮または追加するために使用されます。
squeeze(n)n 次元の次元のみを圧縮できます1

t = torch.randn(1, 3, 4)
squeeze_t = t.squeeze(0)
print(squeeze_t.shape) # torch.Size([3, 4])

torch.unsqueeze(n)n 番目の次元の前に次元を追加します1

t = torch.randn(1, 3, 4)
unsqueeze_t = t.unsqueeze(0)
print(unsqueeze_t.shape) # torch.Size([1, 1, 3, 4])

torch.permute()次元変換。
numpy の転置と同様に、次元を並べ替えます。

t = torch.randn(1, 3, 4)
permute_t = t.permute(1, 2, 0)
print(permute_t.shape) # torch.Size([3, 4, 1])

おすすめ

転載: blog.csdn.net/m0_59967951/article/details/126532068