Conjunto de datos ❀ modificación del contenido del conjunto de datos cifar10

 Leer datos del conjunto de datos Cifar10

file = 'data\cifar-10-batches-py\data_batch_1'
def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict

# 第几张图片
line_number = 0
# 显示测试集图片
dict = unpickle(file)
data = dict.get("data")
label = dict.get("labels")
# 由于在cifar10的data中,图片数据存储为了一个一维数组的形式
# 但其本身是一个rgb三通道的32x32图片,因此我们将其整型如下,并存入r,g,b中
image_m = np.reshape(data[line_number], (3, 32, 32))
image_label = label[line_number]
r = image_m[0, :, :]
g = image_m[1, :, :]
b = image_m[2, :, :]
# 通过将r,g,b合并来输出原图
img32 = np.array(cv.merge([r, g, b]))

plt.figure()
plt.imshow(img32)
plt.show()

Crea los datos que deseas escribir.

Aquí solo hice una imagen img32_compress, que también es una imagen de tres canales, y convertí cada canal en una matriz unidimensional y la almacené en temp, porque en cifar10, la imagen se almacena unidimensionalmente.

    temp_r = np.reshape(img32_compress[:, :, 0], (1024, )).tolist()
    temp_g = np.reshape(img32_compress[:, :, 1], (1024, )).tolist()
    temp_b = np.reshape(img32_compress[:, :, 2], (1024, )).tolist()

 Luego almacenamos estas matrices unidimensionales en las posiciones correspondientes de las imágenes que queremos modificar. args.line_number se refiere al número de filas de datos en el diccionario del conjunto de datos cifar10. En datos, cada fila representa una imagen y el número de filas es el número de filas.

    dict.get("data")[args.line_number,0:1024] = temp_r
    dict.get("data")[args.line_number,1024:2048] = temp_g
    dict.get("data")[args.line_number,2048:3072] = temp_b

 Finalmente, al guardar en un archivo binario, el tipo np.array no es compatible, por lo que necesitamos usar .tolist() para cambiar los datos de np.array a list.  

Finalmente, use pickle.dump para escribir la nueva imagen que creamos en el archivo original f1

dict['data'] = dict['data'].tolist()

f1 = open(file, 'wb+')
pickle.dump(dict, f1)
# f1.write(json.dumps(dict).encode())
f1.close()

Supongo que te gusta

Origin blog.csdn.net/qq_42395917/article/details/127742465
Recomendado
Clasificación