PyTorch中tensor[..., 2:4]的解析

1. 动机

在看YOLO v3-SPP源码时,看到tensor[..., a: b]的切片方式比较新奇,接下来进行分析:

        p = p.view(bs, self.na, self.no, self.ny, self.nx).permute(0, 1, 3, 4, 2).contiguous()  # prediction

        if self.training:
            return p

            p = p.view(m, self.no)

            p[:, :2] = (torch.sigmoid(p[:, 0:2]) + grid) * ng  # x, y
            p[:, 2:4] = torch.exp(p[:, 2:4]) * anchor_wh  # width, height
            p[:, 4:] = torch.sigmoid(p[:, 4:])
            p[:, 5:] = p[:, 5:self.no] * p[:, 4:5]
            return p

2. 分析问题

2.1 list数组使用[..., a:b]方式切片

list_simple = [1, 2, 3, 4, 5]
list_complex = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]

# 对list数组进行切片
try:  # ① list_simple
    print(f"...: {
      
      list_simple[..., 2:4]}")
except Exception as e:
    print(f"a_list_simple切片报错,错误为: {
      
      e}")

try:  # ② list_complex
    print(f"...: {
      
      list_complex[..., 2:4]}")
except Exception as e:
    print(f"a_list_complex切片报错,错误为: {
      
      e}")
    

"""
    a_list_simple切片报错,错误为: list indices must be integers or slices, not tuple
    a_list_complex切片报错,错误为: list indices must be integers or slices, not tuple

"""

很明显,Python的基础数据类型list并不支持这样的切片方式。

2.2 numpy array使用[..., a:b]方式切片

import numpy as np

numpy_simple = np.array([1, 2, 3, 4, 5])
numpy_complex = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])

print(f"numpy_simple: \n{
      
      numpy_simple}")
print(f"numpy_complex: \n{
      
      numpy_complex}")

print("\n-------------------------\n")


# 对numpy array进行切片
try:  # ① list_simple
    print(f"...切片没有报错,结果为: {
      
      numpy_simple[..., 2:4]}")
except Exception as e:
    print(f"numpy_simple切片报错,错误为: {
      
      e}")

try:  # ② list_complex
    print(f"...切片没有报错,结果为: {
      
      numpy_complex[..., 2:4]}")
except Exception as e:
    print(f"numpy_complex切片报错,错误为: {
      
      e}")
    

"""
numpy_simple: 
[1 2 3 4 5]
numpy_complex: 
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]

-------------------------

...切片没有报错,结果为: [3 4]
...切片没有报错,结果为: [[ 3]
                     [ 6]
                     [ 9]
                     [12]]
"""

说明使用[..., a:b]方式是可以对numpy array进行切片的。

2.3 PyTorch tensor使用[..., a:b]方式切片

我们直接创建一个tensor进行分析:

import torch

a = torch.rand([3, 112, 112])

print(f"...: {
      
      a[..., :2].shape}")  # ...: torch.Size([3, 112, 2])

可以看到[...,a:b]中的...表示前n-1个维度,a:b表示直接对最后一个维度进行切片

3. 总结

[..., a:b]是array/tensor特有的切片方式,表示直接对最后一个维度进行切片

猜你喜欢

转载自blog.csdn.net/weixin_44878336/article/details/124847855