test_dataloader = DataLoader(
test_dataset,
collate_fn=collate,
batch_size=4,
)
DataLoader
It is PyTorch
a class provided for data loading, which is used to batch load data from a given dataset.
test_dataset
is a dataset object that provides the data to load.
collate_fn=collate
is a parameter used to specify how to combine the data of different samples to form a batch during the data loading process. collate
is a function or callable that takes a list of samples as input and returns a combined batch of data.
batch_size=4
is a parameter that specifies the number of samples for each batch. In this example, each batch will contain 4 samples.
So what the above code does is create a dataloader test_dataloader
called that test_dataset
loads data from and 4
organizes the data in such a way that each batch contains samples.
Here is an example showing how to use the code:
import torch
from torch.utils.data import DataLoader
# 定义数据集类
class MyDataset:
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 自定义collate函数,用于组合样本形成批次数据
def collate(batch):
# 这里简单地将样本列表转换为张量
return torch.tensor(batch)
# 创建数据集对象
# test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
test_data = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
]
test_dataset = MyDataset(test_data)
# 创建数据加载器
test_dataloader = DataLoader(
test_dataset,
collate_fn=collate,
batch_size=4,
)
# 遍历数据加载器,输出每个批次的数据
for batch in test_dataloader:
print(batch)
output
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
In this example, a simple dataset class is first defined MyDataset
, which contains the length of the dataset and a method to get a single sample.
Then define a custom collate
function that takes a list of samples and simply converts it to a tensor.
Next, a dataset object is created test_dataset
and using it and a custom collate
function the data loader is created test_dataloader
, with each batch containing 4 samples.
Finally, by traversing test_dataloader
, you can see the data of each batch. In this example, the first two batches contain 4 samples and the last batch contains 3
samples.
[torch.from_numpy(x.astype("uint8")) for x in labels]
labels
Is a list of Boolean values, explaining the meaning of the above code in detail
The code above uses list comprehensions and PyTorch
functions torch.from_numpy()
to convert a list of boolean values labels
into PyTorch
an Tensor
object.
Let's explain step by step what the code means:
torch.from_numpy(x.astype("uint8"))
x.astype("uint8")
x
Convert a list of boolean values to an array 8
of type unsigned bit NumPy
integers. This is because torch.from_numpy()
the function expects the input array to be NumPy
an array.
torch.from_numpy()
Convert NumPy
an array to a PyTorch Tensor object.
Therefore, converttorch.from_numpy(x.astype("uint8"))
a list of boolean values to an object.x
PyTorch
Tensor
for x in labels:
labels
is a list of boolean values.
This list comprehension iterates over each element labels
in and x
passes each element to torch.from_numpy()
for transformation.
What you end up with is a list containing the converted Tensor
objects .
Here is an example:
import numpy as np
import torch
# 布尔值列表
labels = [True, False, True, True, False]
# 将布尔值列表转换为Tensor对象的列表
tensor_list = [torch.from_numpy(np.array(x.astype("uint8"))) for x in labels]
# 打印转换后的Tensor对象列表
for tensor in tensor_list:
print(tensor)
Output result:
tensor(1, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)
In this example we have a list of booleans labels
. Using list comprehensions and torch.from_numpy()
functions, we convert the list of booleans into a list PyTorch
of Tensor
objects tensor_list
. Finally, we iterate tensor_list
over each Tensor
object in the print .