La cantidad de datos es demasiado grande para cargarlos en la memoria para entrenarlos al mismo tiempo.

pregunta:

Se refiere principalmente al campo médico, especialmente a la segmentación 3D. Ya que puede haber cientos de dicoms para cada dato. Al dividirlo con un parche, llevará mucho tiempo recargar un conjunto de datos tan grande cada vez. Cargue los datos y las etiquetas en la memoria al mismo tiempo, lo que requiere de decenas a cientos de gigabytes de memoria. Obviamente, una memoria tan grande no existe en general.

Solución:

El uso de iteradores de Python puede resolver perfectamente el problema de que los datos no se pueden cargar al mismo tiempo.

Un iterador es un objeto que recuerda dónde atravesar.

Se accede a los objetos iteradores desde el primer elemento de la colección hasta que se haya accedido a todos los elementos. Los iteradores sólo pueden avanzar y no retroceder.

Los iteradores tienen dos métodos básicos: iter() y next().

Reescribir el conjunto de datos en Pytorch.

Idea básica: dividir los datos en n partes, cargar una parte para el entrenamiento cada vez y cuando se complete el entrenamiento. Elimina la memoria y carga una nueva copia de datos.

import numpy as np
from torch.utils.data import DataLoader,Dataset


class train_test_dataset(Dataset):
    def __init__(self,data,fen):
        self.data = data
        self.index = iter(np.arange(self.data.shape[0]))
        self.jj = 0
        self.fen = fen


    def datagen(self):
        print("datagen")
        data2 = self.data[self.m]
        return data2




    def __getitem__(self, item):
        hh = self.data_gen1[item]
        return hh


    def __len__(self):
        if self.jj<self.fen-1:
            print("len")
            self.m = next(self.index)
            self.data_gen1 = self.datagen()
            #文件加载要在shuffle前完成
            np.random.shuffle(self.data_gen1)


            print(self.data_gen1.shape[0])
            self.jj+=1
            return self.data_gen1.shape[0]
        else:
            self.jj = 0
            print("len")
            self.m = next(self.index)
            self.data_gen1 = self.datagen()
            # 文件加载要在shuffle前完成
            np.random.shuffle(self.data_gen1)


            print(self.data_gen1.shape[0])
            self.index = iter(np.arange(self.data.shape[0]))
            return self.data_gen1.shape[0]


if __name__ == "__main__":
    data = np.arange(80).reshape(4,20)
    fen = 4
    test_data = train_test_dataset(data,fen)
    datalader = DataLoader(test_data, batch_size=3,)
    for epoch in range(51):
        for i in range(fen):#分为4份
            for _,step in enumerate(datalader):
                print(step)

Ejemplo real (no sé cuántos casos cargar, solo cargue todos los casos usados ​​por primera vez. Elimínelo más tarde)

def load_data(self):
        #加载部分数据并且返回(将数据分为10份


        if self.m<len(self.every_num_lis)-1:
            temp_volume_file = self.volume_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]
            temp_pad_list = self.pad_list[self.every_num_lis[self.m]:self.every_num_lis[self.m+1]]


        else:#当取到最后一轮数据的时候
            temp_volume_file = self.volume_file[self.every_num_lis[self.m]:]
            temp_segmentation_file = self.segmentation_file[self.every_num_lis[self.m]:]
            temp_startpoint_list = self.startpoint_list[self.every_num_lis[self.m]:]
            temp_endpoint_list = self.endpoint_list[self.every_num_lis[self.m]:]
            temp_pad_list = self.pad_list[self.every_num_lis[self.m]:]


        #清空之前的数据
        global_key = []
        for key, value in globals().items():
            if "vvolume_" in key or "ssegmentation_" in key:
                global_key.append(key)
        for i in global_key:
            del globals()[i]




        name_cv = "hh"
        name2 = 0
        filename_label = []
        for i in range(len(temp_volume_file)):
            if name_cv != temp_volume_file[i]:
                name_cv = temp_volume_file[i]
                name2 += 1
                print(name2)
                globals()["vvolume_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(name_cv, sitk.sitkInt16))
                globals()["ssegmentation_" + str(name2)] = sitk.GetArrayFromImage(sitk.ReadImage(temp_segmentation_file[i], sitk.sitkUInt8))
            filename_label.append(name2)


        #shuffle
        state = np.random.get_state()
        np.random.shuffle(filename_label)
        np.random.set_state(state)
        np.random.shuffle(temp_startpoint_list)
        np.random.set_state(state)
        np.random.shuffle(temp_endpoint_list)
        np.random.set_state(state)
        np.random.shuffle(temp_pad_list)
        return filename_label,temp_startpoint_list,temp_endpoint_list,temp_pad_list

Supongo que te gusta

Origin blog.csdn.net/weixin_41202834/article/details/121173754
Recomendado
Clasificación