Common data types on PyTorch are as follows
Data type | dtype | CPU tensor | GPU tensor | Size/bytes |
---|---|---|---|---|
32-bit floating | torch.float32 or torch.float |
torch.FloatTensor |
torch.cuda.FloatTensor |
4 |
64-bit floating | torch.float64 or torch.double |
torch.DoubleTensor |
torch.cuda.DoubleTensor |
8 |
16-bit floating | torch.float16 or torch.half |
torch.HalfTensor |
torch.cuda.HalfTensor |
- |
8-bit integer (unsigned) | torch.uint8 |
torch.ByteTensor |
torch.cuda.ByteTensor |
1 |
8-bit integer (signed) | torch.int8 |
torch.CharTensor |
torch.cuda.CharTensor |
- |
16-bit integer (signed) | torch.int16 or torch.short |
torch.ShortTensor |
torch.cuda.ShortTensor |
2 |
32-bit integer (signed) | torch.int32 or torch.int |
torch.IntTensor |
torch.cuda.IntTensor |
4 |
64-bit integer (signed) | torch.int64 or torch.long |
torch.LongTensor |
torch.cuda.LongTensor |
8 |
The above data types in PyTorch correspond to those in numpy, and the byte size is also the same
data type conversion
Reference
[1] https://pytorch.org/docs/stable/tensors.html
[2] https://blog.csdn.net/u010099080/article/details/53411703