参考文献:
https://pytorch.org/docs/stable/data.html#dataloader-collate-fn
https://blog.csdn.net/anshiquanshu/article/details/112868740
Pytorch 深層学習フレームワークを使用する場合、データセットとデータローダーは避けられず、後者は前者に依存し、データを効率的にロードするためのソリューション (マルチスレッド、バッチ トレーニングなど) を提供します。
RGB 画像を例にとると、データセットからのデータ形状は (3, H, W) ですが、データローダーからのデータ形状は (batch_size, 3, H, W) です。明らかに、もう 1 つのディメンション、つまりバッチ ディメンションがあります。これは明らかにデータローダーがデータを「スタッキング」していることです。実際、dataloader には Collate_fn と呼ばれるパラメータがあり、そのデフォルト値は None です。つまり、dataloader を使用し、collate_fn を指定しない場合、pytorch は実際にデフォルトの Collate_fn 関数を呼び出して、ここに配置される前にデータを「スタック」します。
ただし、データが可変長である場合、データを正常にスタックできません。たとえば、次のエラーが発生しました。
RuntimeError: スタックは各テンソルが等しいサイズであることを期待していますが、エントリ 0 で [2, 4] を取得し、エントリ 1 で [5, 4] を取得しました
一方のデータ長は 2、もう一方のデータ長は 5 です。当然、直接スタックすることはできません。現時点では、可変長データに直面した場合、それを埋めるためにcollate_fnをカスタマイズする必要があります。たとえば、pytorch のドキュメントには次のような一節があります。
カスタムの Collate_fn を使用して、照合順序をカスタマイズすることができます。たとえば、連続データをバッチの最大長までパディングします。
では、collate_fn をカスタマイズするにはどうすればよいでしょうか? このcollate_fnの入力と出力は何ですか? この例を見てみましょう。
def padding_collate_fn(data_batch):
batch_bbox_list = [item['bbox'] for item in data_batch]
batch_label_list = [item['label'] for item in data_batch]
batch_filename_list = [item['filename'] for item in data_batch]
padding_bbox = pad_sequence(batch_bbox_list, batch_first=True, padding_value=0)
padding_label = pad_sequence(batch_bbox_list, batch_first=True, padding_value=5)
result = dict()
result["bbox"] = padding_bbox
result["label"] = padding_label
result["filename"] = batch_filename_list
return result
まず、元のデータセットの出力は辞書ですが、上記のコードは辞書内の値を取り出して積み重ね、最終的に大きな辞書に返します。中でも、pad_sequence 関数は torch.nn.utils.rnn パッケージに含まれており、非常に便利です。
実際、バッチはデータセット[インデックス] ~ データセット[インデックス + バッチサイズ]で構成されるリストであり、これがわかれば問題は解決します。