【报错】paddle.io.TensorDataset报错 __init__() takes 2 positional arguments but 3 were given

问题原因

一般是用户习惯了torch的传参,paddle的TensorDataset接收一个列表
所以像下面这么写

dataset = TensorDataset([data_X, data_Y])

扩展知识

torch.utils.data.TensorDataset的参数

  • torch.utils.data.TensorDataset是PyTorch中的一个数据集类,用于将一个或多个张量组合为一个数据集。它的使用方式如下:
dataset = TensorDataset(*tensors)

其中,*tensors是一个或多个张量。如果传递多个张量,它们的第一个维度必须相同,即它们必须具有相同的样本数。torch.utils.data.TensorDataset会将这些张量的对应元素组合为一个样本,并将这些样本组合成一个数据集对象。

以下是torch.utils.data.TensorDataset方法的参数说明:

  • *tensors:一个或多个张量,用于创建数据集对象。如果传递多个张量,它们的第一个维度必须相同,即它们必须具有相同的样本数。

该方法会根据传递的张量创建一个数据集对象,它的每个元素都是一个由这些张量的对应元素组成的元组。例如,如果传递两个张量xy,其中x的形状为(n, d)y的形状为(n, 1),那么数据集中的每个元素都是一个形状为(d + 1,)的一维张量,其中前d个元素来自x,最后一个元素来自y

该方法的返回值是一个torch.utils.data.TensorDataset对象,可以用于训练和评估模型。

Person:

  • 为什么有多个tensor输入的时候,paddle里的TensorDataset在参数里需要加[],而torch的TensorDatase不用[]

和paddle的差别

在PyTorch中,可以将多个张量直接传递给TensorDataset的构造函数,而在PaddlePaddle中,需要将多个张量放在一个列表中传递给TensorDataset的构造函数。

具体来说,在PyTorch中,TensorDataset的构造函数定义为:

def __init__(self, *tensors: Tensor) -> None:

可以看到,参数tensors使用了可变参数的语法,即可以接受任意个数的Tensor类型参数。因此,在使用时可以直接传递多个张量,不需要将它们放在列表中。

而在PaddlePaddle中,TensorDataset的构造函数定义为:


def __init__(self, tensors: List[Tensor], use_double_buffer: bool = False):

可以看到,参数tensors是一个列表类型,要求传递一个Tensor类型的列表。因此,在使用时需要将多个张量放在列表中传递给构造函数,即使只传递一个张量也需要放在列表中。

这是两个框架在API设计上的不同。

猜你喜欢

转载自blog.csdn.net/crazyjinks/article/details/130590491
今日推荐