pytorch-张量-张量的操作

张量的操作:

import torch

# 改变张量的形状大小
a = torch.arange(12.0).reshape(3, 4)
print(a)

# 使用torch.reshape()函数来修改张量的形状和大小
print(torch.reshape(input=a, shape=(2, -1)))

# 改变张量形状的resize_()函数
print(a.resize_(2, 6))

# resize_as_()方法,复制其他张量的形状与尺寸
b = torch.arange(10.0, 19.0).reshape(3, 3)
print(b)
print(a.resize_as_(b))

# torch.unsqueeze()函数可以在张量的指定维度插入新的维度得到维度提升的张量
# torch.squeeze()函数可以在移除指定或者所以维度大小为1的维度,从而得到维度减小的新张量
a = torch.arange(12.0).reshape(2, 6)
# 函数在指定维度插入尺寸为1的新张量
b = torch.unsqueeze(a, dim=0)
print(a.shape, "\n", b.shape)
c = b.unsqueeze(dim=3)
print(c.shape)
d = torch.squeeze(c, dim=3)
print(d.shape)

# 可以使用.expand()方法对张量的维度进行拓展,从而张量的形状大小进行修改
# .expand_as()方法则会将张量根据另一个张量的形状大小进行拓展,得到新的张量
a = torch.arange(3)
b = a.expand(3, -1)
print(a, "\n", b)
c = torch.arange(6).reshape(2,3)
b = a.expand_as(c)
print(c, "\n", b)

# 使用张量的.repeat()方法,可以将张量看作一个整体,然后根据指定的形状进行重复填充,得到新的张量
print(b, b.shape)
d = b.repeat(1, 3, 4)
print(d, d.shape)

# 获取张量中的元素
# 利用张量中切片和索引提取元素的方法
a = torch.arange(12).reshape(1, 3, 4)
print(a)
print(a[0])
print(a[0, 0:2, :])
print(a[0, -1, -4:-1])

# 可按需将索引设置为相应的布尔值,然后提取为真条件下的内容
b = -a
print(torch.where(a > 5, a, b))
print(a[a > 5])

# torch.tril()函数可以获取张量的下三角部分内容,上三角部分的元素填充为0
# torch.triu()函数可以获取张量的上三角部分内容,下三角部分的元素填充为0
# torch.diag()函数可以获取矩阵张量的对角线元素,或者提供一个向量生成的一个矩阵张量
print(torch.tril(a, diagonal=0))
print(torch.tril(a, diagonal=1))
print(torch.triu(a, diagonal=0))
print(torch.triu(a, diagonal=1))

# 获取矩阵张量的对角线元素,input需要一个二维的张量
c = a.reshape(3, 4)
print(torch.diag(c, diagonal=0))
print(torch.diag(c, diagonal=1))

# 提供对角线元素生成矩阵张量
print(torch.diag(torch.tensor([1, 2, 3])))

# 拼接和拆分
# 将多个张量拼接为一个张量,将一个大的张量拆分为几个小的张量,其中torch.cat()函数,可以将多个张量在指定的维度进行拼接,得到新的张量
a = torch.arange(6.0).reshape(2, 3)
b = torch.linspace(0, 10, 6).reshape(2, 3)
# 在维度为0处拼接张量
c = torch.cat((a, b), dim=0)
print(c)
# 在维度为1处拼接张量
d = torch.cat((a, b), dim=1)
print(d)

# 在维度为1处连接3个张量
e = torch.cat((a[:, 1:2], a, b), dim=1)
print(e)

# torch.stack()函数,也将多个张量按照指定的维度进行拼接
f = torch.stack((a, b), dim=0)
print(f)
print(f.shape)

g = torch.stack((a, b), dim=2)
print(g)
print(g.shape)

# torch.chunk()函数可以将张量分割为特定数量的块
# torch.split()函数在将张量分割为特定数量的块时,可以指定每个块的大小
print(torch.chunk(e, 2, dim=0))

d1, d2 = torch.chunk(d, 2, dim=1)
print(d, "\n", d1, "\n", d2)

# 如果沿给定的维度dim的张量大小不能整除,则最后一个块将最小
e1, e2, e3, e4= torch.chunk(e, 4, dim=1)
print(e, "\n", e1, "\n", e2, "\n", e3, "\n", e4)

# 将张量切分成块,指定每个块的大小
d1, d2, d3 = torch.split(d, [1, 2, 3], dim=1)
print(d, "\n", d1, "\n", d2, "\n", d3)

猜你喜欢

转载自blog.csdn.net/weixin_45955630/article/details/111668025