torch.view函数用法

view

一、手动调整size

view( )相当于reshape、resize,对Tensor的形状进行调整。
例:

import torch
x1 = torch.arange(0,16)
print("x1:",x1)
#a1: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
--------------------------------------------------------------------------------------------    
x2 = x1.view(8, 2)
x3 = x1.view(2, 8)
x4 = x1.view(4, 4)
print("x2:",x2)
print("x3:",x3)
print("x4:",x4)
x2: tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])
x3: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
x4: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

二、自动调整size (参数-1)

例:
view中一个参数指定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。

import torch
x1 = torch.arange(0,16)
print(x1)
#a1: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
------------------------------------------------------------------------------------------------------   
x2 = x1.view(-1, 16)
x3 = x1.view(-1, 8)
x4 = x1.view(-1, 4)
x5 = x1.view(-1, 2)
x6 = x1.view(4*4, -1)
x7 = x1.view(1*4, -1)
x8 = x1.view(2*4, -1)		#-1自动调整,8行有几列自动调整

print(x2)
print(x3)
print(x4)
print(x5)
print(x6)
print(x7)
print(x8)

x2: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
x3: tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
x4: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
x5: tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])
x6: tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15]])
x7: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
x8: tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])

分类器就是一个简单的nn.Linear()结构,输入输出都是一维的值,x = x.view(x.size(0), -1) 是为了将多维度的tensor展平成一维。

x = x.view(x.size(0), -1) 
print(x.size(), '*'*100)
print(x, '*'*100)

x4.size():        torch.Size([1, 2048]) 
x4:tensor([[0.3893, 0.5719, 0.5537,  ..., 0.3605, 0.4108, 0.3296]],device='cuda:0')  # 拉平了

猜你喜欢

转载自blog.csdn.net/wahahaha116/article/details/126103893