torch DataLoader function case

test_dataloader = DataLoader(
    test_dataset,
    collate_fn=collate,
    batch_size=4,
)

DataLoaderIt is PyTorcha class provided for data loading, which is used to batch load data from a given dataset.

test_datasetis a dataset object that provides the data to load.

collate_fn=collateis a parameter used to specify how to combine the data of different samples to form a batch during the data loading process. collateis a function or callable that takes a list of samples as input and returns a combined batch of data.

batch_size=4is a parameter that specifies the number of samples for each batch. In this example, each batch will contain 4 samples.

So what the above code does is create a dataloader test_dataloadercalled that test_datasetloads data from and 4organizes the data in such a way that each batch contains samples.

Here is an example showing how to use the code:

import torch
from torch.utils.data import DataLoader

# 定义数据集类
class MyDataset:
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 自定义collate函数,用于组合样本形成批次数据
def collate(batch):
    # 这里简单地将样本列表转换为张量
    return torch.tensor(batch)

# 创建数据集对象

# test_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

test_data = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
            ]

test_dataset = MyDataset(test_data)

# 创建数据加载器
test_dataloader = DataLoader(
    test_dataset,
    collate_fn=collate,
    batch_size=4,
)

# 遍历数据加载器,输出每个批次的数据
for batch in test_dataloader:
    print(batch)

output

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])

In this example, a simple dataset class is first defined MyDataset, which contains the length of the dataset and a method to get a single sample.

Then define a custom collatefunction that takes a list of samples and simply converts it to a tensor.

Next, a dataset object is created test_datasetand using it and a custom collatefunction the data loader is created test_dataloader, with each batch containing 4 samples.

Finally, by traversing test_dataloader, you can see the data of each batch. In this example, the first two batches contain 4 samples and the last batch contains 3samples.


[torch.from_numpy(x.astype("uint8")) for x in labels]

labelsIs a list of Boolean values, explaining the meaning of the above code in detail

The code above uses list comprehensions and PyTorchfunctions torch.from_numpy()to convert a list of boolean values labels​​into PyTorchan Tensorobject.

Let's explain step by step what the code means:

torch.from_numpy(x.astype("uint8"))

x.astype("uint8")xConvert a list of boolean values ​​to an array 8of type unsigned bit NumPyintegers. This is because torch.from_numpy()the function expects the input array to be NumPyan array.
torch.from_numpy()Convert NumPyan array to a PyTorch Tensor object.
Therefore, converttorch.from_numpy(x.astype("uint8")) a list of boolean values ​​to an object.xPyTorchTensor

for x in labels:

labelsis a list of boolean values.
This list comprehension iterates over each element labelsin and xpasses each element to torch.from_numpy()for transformation.
What you end up with is a list containing the converted Tensorobjects .

Here is an example:

import numpy as np
import torch

# 布尔值列表
labels = [True, False, True, True, False]

# 将布尔值列表转换为Tensor对象的列表
tensor_list = [torch.from_numpy(np.array(x.astype("uint8"))) for x in labels]

# 打印转换后的Tensor对象列表
for tensor in tensor_list:
    print(tensor)

Output result:

tensor(1, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)
tensor(1, dtype=torch.uint8)
tensor(0, dtype=torch.uint8)

In this example we have a list of booleans labels. Using list comprehensions and torch.from_numpy()functions, we convert the list of booleans into a list PyTorchof Tensorobjects tensor_list. Finally, we iterate tensor_listover each Tensorobject in the print .

Guess you like

Origin blog.csdn.net/AdamCY888/article/details/131343554