python 获取torch tensor随机10个位置

    tensor = torch.rand((8, 3, 384, 640))  # 生成一个形状为(5,5,5,5)的随机Tensor作为示例

    # 获取tensor的大小
    n1, n2, n3, n4 = tensor.size()

    # 产生10个随机索引
    rand_indices = torch.randint(0, min(n1 * n2 * n3 * n4, 10), (10,))
    print(rand_indices)
    # 根据随机索引获取元素
    rand_elements = tensor.view(-1)[rand_indices]

    print(rand_elements)

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/131553205
今日推荐