数据集无损处理

import numpy as np
import os,glob
from PIL import Image
import pandas as pd
import scipy.io as sio
class Dataset:
    def __init__(self):
        self.classes = ['jp2k', 'jpeg', 'wn','gblur','fastfading']
        self.cwd = os.getcwd()
        self.dataset = []
        self.label_index = 0
        self.arr = [[]]
        self.img_list = [[[[]]]]
        self.img_label = []



    def crop_concat(self,img,window,stride,labels): #this function return a data of ndarray type
        window = 120
        stride = 60
        len_iter_x = np.floor_divide((img.size[0] - window),stride) + 1
        len_iter_y = np.floor_divide((img.size[1] - window),stride) + 1
        iterx_array = np.arange(0,stride * len_iter_x,stride)
        itery_array = np.arange(0,stride * len_iter_y,stride)
        img_arr = [[[]]]
        for i in iterx_array:
            for j in itery_array:
                img1 = img.crop((i,j,i + window,j + window))
                img_1 = np.asarray(img1)
                #img_1 = img_1[np.newaxis,:,:,:]
                if img_arr == [[[]]]:
                    img_arr = [img_1]
                else:
                    img_arr = np.concatenate([img_arr,[img_1]],axis = 0)


        for i in range(img_arr.shape[0]):
            if self.img_list == [[[[]]]]:
                self.img_list = [img_arr[0]]
                self.img_list.append(labels)
                img_array = np.array(img_list).reshape([1,2])
            else:
                self.img_list = [img_arr[i]]
                self.img_list.append(labels)
                img_array1 = np.array(self.img_list).reshape([1,2])
                img_array = np.concatenate([img_array,img_array1],axis = 0)  # num of imgs: len_iter_x * len_iter_y
        self.img_list = [[[[]]]]
        img_arr  = [[[]]]
        return img_array


    def img2array(self):
        f  = open('/home/xm/data/dmos.mat','rb')
        labelset= sio.loadmat(f)
        labelset = labelset['dmos']
        labelset = labelset.reshape([982,1])
        for index,name in enumerate(self.classes):
            class_path = self.cwd + '/data/' + name + '/'
            for infile in glob.glob(class_path + '*.bmp'):
                file,ext = os.path.splitext(infile)
                img = Image.open(infile)
                if self.label_index == 0:
                    first_img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])
                    print(first_img_info.shape)
                elif self.label_index == 1:
                    img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])
                    self.img_label = np.concatenate([first_img_info,img_info],axis = 0)
                else:
                    img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])
                    print(img_info.shape)
                    self.img_label = np.concatenate([self.img_label,img_info],axis = 0)
                self.label_index += 1

        return self.img_label
haha = Dataset()
dataset = haha.img2array()
print(dataset.shape)
a = []
for i in range(dataset.shape[0]):
    if dataset[i][1] == 0:
        a.append(i)
dataset = np.delete(dataset,a,axis = 0)
np.random.shuffle(dataset)
print(dataset.shape)
show_img = dataset[100][0]
show_img = Image.fromarray(show_img,mode = 'RGB')
show_img.show()

猜你喜欢

转载自blog.csdn.net/baidu_36161077/article/details/75173155