RuntimeError: cada elemento en la lista del lote debe ser del mismo tamaño (使用collate_fn解决)

Durante el entrenamiento de PyTorch, se informó un error al cargar el conjunto de datos:

RuntimeError: cada elemento en la lista del lote debe ser del mismo tamaño

El motivo de este error es que los tamaños de datos en el mismo minilote son diferentes.

El tipo de retorno de getitem es tensor, tupla, lista

Aquí primero tomamos un ejemplo simple de clasificación de imágenes para explicar los problemas, principios y soluciones. La clase del conjunto de datos y el código para leer datos deben ser aproximadamente los siguientes:

from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
	def __init__(self):
		# 初始化..., 把image路径和label信息都写入到 total_data_list 中
		self.total_data_list = ......
	def __len__(self):
		return len(self.total_data_list)
	def __getitem__(self, index):
		image, label = self.total_data_list[index]
		return image, label

dataset = ImageDataset()
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True)

# 两种读取数据的方式
# 1. sample获取的是整个batch的数据,即:下面的[[img_0, ...],[lbl_0, ...]]
for index, sample in enumerate(dataloader):
	....
# 2. 直接从原数据中unpack,得到image是下面的[img_0, ...], label是下面的[lbl_0, ...]
for image, data in dataloader:
	....

Cada vez que se extrae un mini lote de datos del cargador de datos, el formato predeterminado es: (el tamaño del lote es 4)

[ [img_0, img_1, img_2, img_3], [lbl_0, lbl_1, lbl_2, lbl_3] ]

Cada vez que se saca un lote de datos, se devuelve una lista. Hay varias listas en la lista y la longitud de cada lista es el tamaño del lote. El error mencionado al principio es que entre varias listas, hay tamaños desiguales de elementos. en las listas.

Pero lo que __getitem__()devolvimos es una tupla (imagen, etiqueta) ¿Qué hizo el cargador de datos después de leer los 4 datos?
Al principio, el cargador de datos lee directamente 4 datos, a saber

[ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]

El cargador de datos primero expande cada elemento en la lista de datos obtenida en una dimensión antes de la primera dimensión, es decir: la
imagen original de 3x640x512 se convierte en una
etiqueta de 1x3x640x512, que originalmente era un número real, y se convierte en un tensor unidimensional.

Luego, conecte la imagen y la etiqueta a lo largo de la dimensión del tamaño del lote para formar dos tensores de 3x3x640x512 y 3x1, y luego combínelos en una lista para formar el formato de datos predeterminado anterior.

El código es el siguiente, donde x es [ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]:

collate_fn=lambda x:(
	torch.cat(
   		[x[i][j].unsqueeze(0) for i in range(len(x))], 0
   	) for j in range(len(x[0]))
)

Tenga en cuenta que se usa aquí collate_fn. Este es un parámetro del DataLoader que usamos principalmente para resolver este error. Indica cómo el DataLoader toma muestras. Puede usar este parámetro para definir su propia función para lograr una función de muestreo más precisa.

Si no se especifica collate_fn al definir DataLoader, es el método de muestreo predeterminado, lo que equivale a especificar collate_fn como el código anterior. También se puede personalizar, como por ejemplo:

collate_fn=lambda x:x
collate_fn=lambda x:torch.utils.data.dataloader.default_collate(list(filter(lambda x: x is not None, batch))) # 限制x不为空

El tipo de devolución de getitem es Dict

Lo anterior es el caso de devolver tensor, lista o tupla en __getitem__. Si los datos devueltos son datos de tipo dict, como por ejemplo:

class ImageDataset(Dataset):
	def __init__(self):
		# 初始化..., 把image路径和label信息都写入到 total_data_list 中
		self.total_data_list = ......
	def __len__(self):
		return len(self.total_data_list)
	def __getitem__(self, index):
		image, label = self.total_data_list[index]
		return {
    
    
			"image": image,
			"label": label
		}

En el cargador de datos, el método para procesar el valor en el dictado de muestra es el mismo que el descrito anteriormente: primero, se expanden las dimensiones y luego se combinan los datos del tamaño del lote. La diferencia es que el nuevo valor procesado, incluidos los datos del tamaño del lote, se almacena correspondiente a la clave original, y lo que finalmente se devuelve no es una lista, sino un tipo de datos dict.

{
	'image': [img_0, img_1, img_2, img_3], 
	'label': [lbl_0, lbl_1, lbl_2, lbl_3]
}

De hecho, es muy similar, excepto que la parte del valor se extrae del dict de acuerdo con la clave y luego se vuelve a colocar en la posición correspondiente en el dict después del procesamiento.

Para presentar formalmente el problema que encontré, el formato devuelto por __getitem__ es el siguiente:

return {
    
    
	
    "images": images,          # List[Tensor]: [N][3,Hi,Wi], N is number of images
    "intrinsics": intrinsics,  # Tensor: [N,3,3]
    "extrinsics": extrinsics,  # Tensor: [N,4,4]
    "depth_min": depth_min,    # Tensor: [1]
    "depth_max": depth_max,    # Tensor: [1]
    "depth_gt": depth_gt,      # Tensor: [1,H0,W0] if exists
    "mask": mask,              # Tensor: [1,H0,W0] if exists
    "filename": os.path.join(scan, "{:0>8}".format(view_ids[0]) + "{}")
}

N es igual a 1 imagen de referencia + N-1 imagen de origen. Cuando N = 6, los datos devueltos por getitem son los siguientes:

{
	'images' : [torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640]), torch.Size([3, 480, 640])]
	'intrinsics' : torch.Size([6, 3, 3])
	'extrinsics' : torch.Size([6, 4, 4])
	'depth_min' : torch.Size([1])
	'depth_max' : torch.Size([1])
	'depth_gt' : torch.Size([1, 480, 640])
	'mask' : torch.Size([1, 480, 640])
	'filename' : '57f8d9bbe73f6760f10e916a/00000182{}'
}

En circunstancias normales, cuando el tamaño del lote = 4, los 4 datos anteriores se combinarán de acuerdo con las reglas introducidas anteriormente y se generará un lote de datos con el siguiente formato:

{
	'images' : [torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640]), torch.Size([4, 3, 480, 640])]
	'intrinsics' : torch.Size([4, 6, 3, 3])
	'extrinsics' : torch.Size([4, 6, 4, 4])
	'depth_min' : torch.Size([4])
	'depth_max' : torch.Size([4])
	'depth_gt' : torch.Size([4, 1, 480, 640])
	'mask' : torch.Size([4, 1, 480, 640])
	'filename' : ['57f8d9bbe73f6760f10e916a/00000182{}', '57f8d9bbe73f6760f10e916a/00000204{}', '57f8d9bbe73f6760f10e916a/00000235{}', '57f8d9bbe73f6760f10e916a/00000170{}']
}

El número de imágenes de referencia está determinado, solo hay una.
En el pairs.txt de los datos de entrenamiento, bajo la referencia especificada, habrá m imágenes src, ordenadas de mayor a menor según la puntuación. Durante el entrenamiento, debe especificar la cantidad de imágenes src que se utilizarán, que es la valor de N-1. Cuando m >= N-1, es decir, definitivamente puedes sacar las imágenes src del conjunto N-1, puedes tomar las N-1 con la puntuación más grande, o puedes tomar N -1 aleatoriamente por el bien de un entrenamiento sólido.
Sin embargo, si el número de imágenes fuente de una determinada referencia es m < N-1, es decir, como máximo m, al recuperar datos según el método mini-batch, puede suceder que entre los cuatro datos recuperados, el image El número total (ref+src) parece ser 6,6,6,5, lo que causa problemas al fusionar los cuatro datos, porque los dos datos de [1x6xCxHxW]y no se pueden [1x5xCxHxW]combinar. Entonces ocurrió el error anterior:RuntimeError: each element in list of batch should be of equal size

solución

Solución:
1. Establezca el tamaño del lote en 1. El primer método no debe considerarse solución, solo se puede decir que evita errores.
2. Personalice la función collate_fn y cambie el método de muestreo predeterminado

En el segundo método, el más sencillo y directo:

dataloader = DataLoader(dataset, batch_size=4, collate_fn=lambda x:x)

Es decir, se devuelve el formato original sin procesar: [ (img_0, lbl_0), (img_1, lbl_1), (img_2, lbl_2), (img_3, lbl_3) ]
si es un dict, el formato devuelto es:[ {'image':img_0, 'label': lbl_0}, {'image':img_1, 'label': lbl_1}, {'image':img_2, 'label': lbl_2}, {'image':img_3, 'label': lbl_3} ]

Luego, después de leer los datos del cargador de datos en el siguiente paso, simplemente cambie la forma en que se procesan los datos.

Pero si no está preparado para cambiar el código posterior (por ejemplo, utilizando el código de otras personas, que ya ha sido escrito y es muy completo), entonces debe implementar una lógica más compleja en el collate_fn personalizado para que los datos sean normales.

Mi solución actual es eliminar los datos que tienen un tamaño diferente de otros datos en el mini lote y devolver los datos restantes.
Específicamente para este ejemplo, eliminé los datos con m < N-1 y no los usé directamente, porque en pairs.txt, dichos datos son muy pequeños y tienen poco impacto en el entrenamiento. El método específico es obtener la longitud de la imagen en el dict para ver si es menor que N. El código es el siguiente:

# src_nums + 1 为N的值,1是指ref图像
src_nums = input_args.num_views
def is_data_ok(data):
    if len(data['images']) < src_nums + 1:
        return False
    else:
        return True

# batch是原始的数据 [ {'image':img_0, 'label': lbl_0}, {'image':img_1, 'label': lbl_1}, {'image':img_2, 'label': lbl_2}, {'image':img_3, 'label': lbl_3} ]
def collate_fn(batch):
	# 对batch中的每个数据进行判断,is_data_ok(batch[i])为true则保留,否则丢弃
    batch_new = list(filter(is_data_ok, batch))
    # default_collate(batch_new): 对于剩下的数据batch_new,按照DataLoader默认的方式进行处理(即上面介绍的扩展维度、合并等)
    return  torch.utils.data.dataloader.default_collate(batch_new)

# 设置drop_last 为 False:假设过滤掉了1个数据,此时返回的数据只有batchsize-1个,drop_last=True会将这样的数据丢弃不要
train_loader = DataLoader(train_dataset, input_args.batch_size, shuffle=True, num_workers=8, drop_last=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, input_args.batch_size, shuffle=False, num_workers=4, drop_last=False, collate_fn=collate_fn)

Creo que un método más científico es conservar estos datos y luego combinarlos cuando encuentre datos del mismo tamaño. El código aún no se ha escrito. Lo actualizaré más tarde. Los amigos que sepan cómo implementarlo pueden dar sugerencias.

referencia

solución de error:
http://www.manongjc.com/detail/25-hkwccoiwsencpam.html
https://www.cnblogs.com/vase/p/15354331.html
https://blog.csdn.net/weixin_44799217/article /detalles/115137820

collate_fn:
https://zhuanlan.zhihu.com/p/361830892
https://github.com/pytorch/pytorch/issues/57429
https://github.com/pytorch/pytorch/issues/67419
https://pytorch .org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#generate-data-batch-and-iterator

drop_last:
https://blog.csdn.net/xijuezhu8128/article/details/107954141

lambda:
https://blog.csdn.net/zagfai/article/details/8972618

Supongo que te gusta

Origin blog.csdn.net/qq_41340996/article/details/123156330
Recomendado
Clasificación