Requisitos: hay 23 categorías en la segmentación semántica. Algunas categorías deben eliminarse. Las categorías eliminadas no se pasarán a la pérdida para el cálculo. Sin embargo, de forma predeterminada, se utiliza codificación one-hot durante el proceso de cálculo de la pérdida. La eliminación de categorías afectará codificación one-hot.clasificación.
Lo que debe hacer en secuencia: definir la categoría ignorada como -1, establecer una tabla de mapeo de valores de etiqueta nuevos y antiguos, modificar el valor de la etiqueta a través de la tabla de mapeo, filtrar el valor -1 y luego generar un nuevo one-hot codificación.
Lo que debe modificarse: el número de categorías, valores modificados y máscaras no son en realidad para el tensor1 que creamos, sino para el tensor de etiqueta real (por ejemplo, la dimensión es 1xn). Según la tabla de mapeo y la máscara, La etiqueta de 1xn se puede procesar, lo manejaré aquí, tensor1 es para facilitar la visualización.
El código se muestra a continuación:
import torch
import torch.nn.functional as F
if __name__ == '__main__':
# define 23 classes
tensor1 = torch.arange(0,23)
print(tensor1)
del_class = [4, 6, 8]
# edit del_class to -1
for i in range(len(tensor1)):
if tensor1[i] in del_class:
tensor1[i] = -1
# edit mapping dict
value = torch.unique(tensor1)
mapping = {}
for index in range(len(value)):
if value[index] == -1:
continue
old_value = int(value[index])
new_value = index - 1
mapping[old_value] = new_value
print(tensor1)
print(mapping)
# edit tensor value according to mapping_dict
for i in range(len(tensor1)):
if int(tensor1[i]) in mapping.keys():
new_value = mapping[int(tensor1[i])]
tensor1[i] = new_value
print(tensor1)
# mask -1 value
mask = (tensor1 != -1)
# generate one-hot value
tensor1_one = F.one_hot(tensor1[mask])
print(tensor1_one)