Os tipos de dados comuns no PyTorch são os seguintes
Tipo de dados | tipo d | Tensor CPU | Tensor GPU | Tamanho / bytes |
---|---|---|---|---|
Flutuante de 32 bits | torch.float32 ou torch.float |
torch.FloatTensor |
torch.cuda.FloatTensor |
4 |
Flutuante de 64 bits | torch.float64 ou torch.double |
torch.DoubleTensor |
torch.cuda.DoubleTensor |
8 |
Flutuante de 16 bits | torch.float16 ou torch.half |
torch.HalfTensor |
torch.cuda.HalfTensor |
- |
Inteiro de 8 bits (sem sinal) | torch.uint8 |
torch.ByteTensor |
torch.cuda.ByteTensor |
1 |
Inteiro de 8 bits (assinado) | torch.int8 |
torch.CharTensor |
torch.cuda.CharTensor |
- |
Inteiro de 16 bits (assinado) | torch.int16 ou torch.short |
torch.ShortTensor |
torch.cuda.ShortTensor |
2 |
Inteiro de 32 bits (assinado) | torch.int32 ou torch.int |
torch.IntTensor |
torch.cuda.IntTensor |
4 |
Inteiro de 64 bits (assinado) | torch.int64 ou torch.long |
torch.LongTensor |
torch.cuda.LongTensor |
8 |
Os tipos de dados acima em PyTorch correspondem àqueles em numpy, e o tamanho do byte também é o mesmo
tipo de conversão de dados
Referência
[1] https://pytorch.org/docs/stable/tensors.html
[2] https://blog.csdn.net/u010099080/article/details/53411703