Prétraitement de l'importation d'ensembles de données et prédiction d'images multiples

Importation et prétraitement des ensembles de données de formation et de validation. Importation et prétraitement prédictifs de plusieurs images.
Organisé en fonctions pouvant être appelées directement, voir le code suivant.

import os
import sys
import json
import PIL.Image as Image
from torch.utils.data import Dataset
import torch
from torchvision import transforms, datasets
from torchvision import transforms

def predata(batch_size,path,method): # batch_size,数据集路径,测试集或验证集 [后两个都是字符串类型】
    #图片处理
    data_transform = {
    
    
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    #数据集路径

    # 加载dataset
    assert os.path.exists(path), "{} path does not exist.".format(path)
    train_dataset = datasets.ImageFolder(root=os.path.join(path, method),  # 加载train数据集
                                         transform=data_transform[method])
    #数据量
    train_num = len(train_dataset)

    #只有训练时才生成类别的json文件
    if method =="train":
        # 得到分类名称对应的索引  # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx  # ----大概是已经通过数据集已经分类好的文件名确定的图片类别
        cla_dict = dict((val, key) for key, val in flower_list.items())  # 将key,value值反过来,已达到通过索引找到分类的目的
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)  # 编码成json格式
        with open('class_indices.json', 'w') as json_file:  # 写进去
            json_file.write(json_str)

    # nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    nw = 0
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    return train_num, train_loader

class MyData(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.img_path = os.listdir(self.data_dir)

    def __getitem__(self, idx):  #这样用 image = dataset.__getitem__(i)  #获取下标为i的图像
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.data_dir, img_name)
        img = Image.open(img_item_path)
        return img

    def __len__(self):
        return len(self.img_path)

def read_data(data_dir):# 测试图片所在文件夹的路径(绝对路径)
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    dataset = MyData(data_dir)  # 创建对象
    sum = dataset.__len__()  # 获取数据集的总长度
    image_list=[]
    for i in range(0, sum):  # 遍历每张图片
        image = dataset.__getitem__(i)  # 获取下标为 i 的图像
        image = data_transform(image)
        image_list.append(image)
    batch_image = torch.stack(image_list,dim=0)
    return batch_image

Exemples de prédiction d'images multiples :

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model_v2 import MobileNetV2
from utils import read_data

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    json_path = './class_indices.json'#加载label
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = MobileNetV2(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./MobileNetV2.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    # 预测数据集路径
    test_path = "E:\deep-learning-for-image-processing-master\data_set\\flower_data\\test"
    with torch.no_grad():
        # predict class
        #output = model(img.to(device)).cpu()
        batch_image = read_data(test_path).to(device) # (25,3,224,224)
        output = model(batch_image).cpu() #(25,5)
        print("output:",output)
        print("output.shape:",output.shape)
        #output = torch.squeeze(output) #压缩batch维度
        predict = torch.softmax(output, dim=1)#从scores变成概率
        print("predict:",predict)
        predict_cla = torch.argmax(predict,dim=1).numpy()

    print("predict_cla:",predict_cla)

    for i in range(len(predict_cla)):
        print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla[i])],
                                                 predict[i][predict_cla[i]].numpy())
        print(print_res)



if __name__ == '__main__':
    main()

sortir:

output: tensor([[-0.8376, -1.4547, -3.1795,  1.9097, -2.8242],
        [-3.0098, -2.2684, -0.4529, -1.6731,  1.9120],
        [ 1.7103, -0.2653, -3.0748, -0.6628, -3.2830],
        [ 1.2928, -1.5638, -1.8410, -0.9810, -1.4947],
        [ 1.7677, -0.7175, -2.3304, -1.0074, -1.3919],
        [-1.9958, -3.8992,  0.0284, -1.7568,  0.8421],
        [-0.1854,  0.6881, -2.1250, -1.3285, -2.4432],
        [-0.7020,  1.9533, -1.7579, -1.8751, -2.7175],
        [-3.0090, -3.2183,  0.7938, -1.1972, -0.3066],
        [-2.4225, -2.0883, -1.5710,  2.1600, -2.2080],
        [-0.9034, -1.3276, -2.8479,  1.5856, -1.3967],
        [-1.7756,  2.1748, -3.1034, -0.7712, -2.7011],
        [-3.7374, -2.9883,  2.6747, -2.3358, -0.1361],
        [-2.7635, -2.7453,  1.8817, -2.3586,  0.4804],
        [-2.3963, -2.7185, -0.4796, -1.0483,  1.9857],
        [-1.6062,  2.8672, -2.6138, -1.1941, -1.6695],
        [-1.7888,  3.5478, -3.5287, -2.1418, -2.6027],
        [-2.6447, -1.8472, -0.7186, -1.6111,  1.7937],
        [-0.1753, -0.0084, -2.3816, -0.2522, -1.8641],
        [ 3.3103, -1.1389, -3.4989, -0.6747, -3.5854],
        [-1.7156, -0.7643, -2.4082,  1.9063, -1.4952],
        [-1.0604, -2.9224, -0.0667, -0.6217, -0.9351],
        [-1.6808, -2.4863,  0.1389, -0.9924,  0.0727],
        [-1.9966, -3.1275,  0.6565, -3.1978,  1.8574],
        [ 1.0177, -0.5006, -1.8477, -1.2302, -1.9340]])
output.shape: torch.Size([25, 5])
predict: tensor([[5.7556e-02, 3.1053e-02, 5.5337e-03, 8.9796e-01, 7.8949e-03],
        [6.3671e-03, 1.3364e-02, 8.2114e-02, 2.4238e-02, 8.7392e-01],
        [8.0192e-01, 1.1121e-01, 6.6985e-03, 7.4732e-02, 5.4397e-03],
        [7.9019e-01, 4.5409e-02, 3.4414e-02, 8.1326e-02, 4.8659e-02],
        [8.3008e-01, 6.9153e-02, 1.3784e-02, 5.1752e-02, 3.5233e-02],
        [3.6944e-02, 5.5067e-03, 2.7965e-01, 4.6917e-02, 6.3098e-01],
        [2.5236e-01, 6.0451e-01, 3.6280e-02, 8.0462e-02, 2.6393e-02],
        [6.2422e-02, 8.8823e-01, 2.1716e-02, 1.9314e-02, 8.3183e-03],
        [1.4777e-02, 1.1987e-02, 6.6238e-01, 9.0457e-02, 2.2040e-01],
        [9.6397e-03, 1.3465e-02, 2.2587e-02, 9.4236e-01, 1.1946e-02],
        [6.9166e-02, 4.5257e-02, 9.8945e-03, 8.3345e-01, 4.2236e-02],
        [1.7747e-02, 9.2206e-01, 4.7038e-03, 4.8454e-02, 7.0336e-03],
        [1.5315e-03, 3.2391e-03, 9.3288e-01, 6.2206e-03, 5.6124e-02],
        [7.5058e-03, 7.6429e-03, 7.8122e-01, 1.1251e-02, 1.9238e-01],
        [1.0826e-02, 7.8443e-03, 7.3597e-02, 4.1676e-02, 8.6606e-01],
        [1.0932e-02, 9.5831e-01, 3.9915e-03, 1.6508e-02, 1.0261e-02],
        [4.7594e-03, 9.8895e-01, 8.3542e-04, 3.3437e-03, 2.1090e-03],
        [1.0253e-02, 2.2760e-02, 7.0362e-02, 2.8823e-02, 8.6780e-01],
        [2.9389e-01, 3.4729e-01, 3.2363e-02, 2.7216e-01, 5.4294e-02],
        [9.6862e-01, 1.1322e-02, 1.0689e-03, 1.8009e-02, 9.8037e-04],
        [2.3394e-02, 6.0568e-02, 1.1704e-02, 8.7517e-01, 2.9164e-02],
        [1.5289e-01, 2.3754e-02, 4.1299e-01, 2.3707e-01, 1.7330e-01],
        [6.5011e-02, 2.9051e-02, 4.0112e-01, 1.2940e-01, 3.7542e-01],
        [1.5872e-02, 5.1225e-03, 2.2535e-01, 4.7749e-03, 7.4888e-01],
        [6.9739e-01, 1.5279e-01, 3.9725e-02, 7.3661e-02, 3.6438e-02]])
predict_cla: [3 4 0 0 0 4 1 1 2 3 3 1 2 2 4 1 1 4 1 0 3 2 2 4 0]
class: sunflowers   prob: 0.898
class: tulips   prob: 0.874
class: daisy   prob: 0.802
class: daisy   prob: 0.79
class: daisy   prob: 0.83
class: tulips   prob: 0.631
class: dandelion   prob: 0.605
class: dandelion   prob: 0.888
class: roses   prob: 0.662
class: sunflowers   prob: 0.942
class: sunflowers   prob: 0.833
class: dandelion   prob: 0.922
class: roses   prob: 0.933
class: roses   prob: 0.781
class: tulips   prob: 0.866
class: dandelion   prob: 0.958
class: dandelion   prob: 0.989
class: tulips   prob: 0.868
class: dandelion   prob: 0.347
class: daisy   prob: 0.969
class: sunflowers   prob: 0.875
class: roses   prob: 0.413
class: roses   prob: 0.401
class: tulips   prob: 0.749
class: daisy   prob: 0.697

Supongo que te gusta

Origin blog.csdn.net/weixin_44040169/article/details/128006681
Recomendado
Clasificación