数据集导入预处理和多张图片预测

训练和验证数据集的导入和预处理。预测多张图片导入和预处理。
整理成函数可以直接调用,见以下代码。

import os
import sys
import json
import PIL.Image as Image
from torch.utils.data import Dataset
import torch
from torchvision import transforms, datasets
from torchvision import transforms

def predata(batch_size,path,method): # batch_size,数据集路径,测试集或验证集 [后两个都是字符串类型】
    #图片处理
    data_transform = {
    
    
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    #数据集路径

    # 加载dataset
    assert os.path.exists(path), "{} path does not exist.".format(path)
    train_dataset = datasets.ImageFolder(root=os.path.join(path, method),  # 加载train数据集
                                         transform=data_transform[method])
    #数据量
    train_num = len(train_dataset)

    #只有训练时才生成类别的json文件
    if method =="train":
        # 得到分类名称对应的索引  # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
        flower_list = train_dataset.class_to_idx  # ----大概是已经通过数据集已经分类好的文件名确定的图片类别
        cla_dict = dict((val, key) for key, val in flower_list.items())  # 将key,value值反过来,已达到通过索引找到分类的目的
        # write dict into json file
        json_str = json.dumps(cla_dict, indent=4)  # 编码成json格式
        with open('class_indices.json', 'w') as json_file:  # 写进去
            json_file.write(json_str)

    # nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    nw = 0
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
    return train_num, train_loader

class MyData(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.img_path = os.listdir(self.data_dir)

    def __getitem__(self, idx):  #这样用 image = dataset.__getitem__(i)  #获取下标为i的图像
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.data_dir, img_name)
        img = Image.open(img_item_path)
        return img

    def __len__(self):
        return len(self.img_path)

def read_data(data_dir):# 测试图片所在文件夹的路径(绝对路径)
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    dataset = MyData(data_dir)  # 创建对象
    sum = dataset.__len__()  # 获取数据集的总长度
    image_list=[]
    for i in range(0, sum):  # 遍历每张图片
        image = dataset.__getitem__(i)  # 获取下标为 i 的图像
        image = data_transform(image)
        image_list.append(image)
    batch_image = torch.stack(image_list,dim=0)
    return batch_image

多张图片预测实例:

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model_v2 import MobileNetV2
from utils import read_data

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    json_path = './class_indices.json'#加载label
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = MobileNetV2(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./MobileNetV2.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    # 预测数据集路径
    test_path = "E:\deep-learning-for-image-processing-master\data_set\\flower_data\\test"
    with torch.no_grad():
        # predict class
        #output = model(img.to(device)).cpu()
        batch_image = read_data(test_path).to(device) # (25,3,224,224)
        output = model(batch_image).cpu() #(25,5)
        print("output:",output)
        print("output.shape:",output.shape)
        #output = torch.squeeze(output) #压缩batch维度
        predict = torch.softmax(output, dim=1)#从scores变成概率
        print("predict:",predict)
        predict_cla = torch.argmax(predict,dim=1).numpy()

    print("predict_cla:",predict_cla)

    for i in range(len(predict_cla)):
        print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla[i])],
                                                 predict[i][predict_cla[i]].numpy())
        print(print_res)



if __name__ == '__main__':
    main()

输出:

output: tensor([[-0.8376, -1.4547, -3.1795,  1.9097, -2.8242],
        [-3.0098, -2.2684, -0.4529, -1.6731,  1.9120],
        [ 1.7103, -0.2653, -3.0748, -0.6628, -3.2830],
        [ 1.2928, -1.5638, -1.8410, -0.9810, -1.4947],
        [ 1.7677, -0.7175, -2.3304, -1.0074, -1.3919],
        [-1.9958, -3.8992,  0.0284, -1.7568,  0.8421],
        [-0.1854,  0.6881, -2.1250, -1.3285, -2.4432],
        [-0.7020,  1.9533, -1.7579, -1.8751, -2.7175],
        [-3.0090, -3.2183,  0.7938, -1.1972, -0.3066],
        [-2.4225, -2.0883, -1.5710,  2.1600, -2.2080],
        [-0.9034, -1.3276, -2.8479,  1.5856, -1.3967],
        [-1.7756,  2.1748, -3.1034, -0.7712, -2.7011],
        [-3.7374, -2.9883,  2.6747, -2.3358, -0.1361],
        [-2.7635, -2.7453,  1.8817, -2.3586,  0.4804],
        [-2.3963, -2.7185, -0.4796, -1.0483,  1.9857],
        [-1.6062,  2.8672, -2.6138, -1.1941, -1.6695],
        [-1.7888,  3.5478, -3.5287, -2.1418, -2.6027],
        [-2.6447, -1.8472, -0.7186, -1.6111,  1.7937],
        [-0.1753, -0.0084, -2.3816, -0.2522, -1.8641],
        [ 3.3103, -1.1389, -3.4989, -0.6747, -3.5854],
        [-1.7156, -0.7643, -2.4082,  1.9063, -1.4952],
        [-1.0604, -2.9224, -0.0667, -0.6217, -0.9351],
        [-1.6808, -2.4863,  0.1389, -0.9924,  0.0727],
        [-1.9966, -3.1275,  0.6565, -3.1978,  1.8574],
        [ 1.0177, -0.5006, -1.8477, -1.2302, -1.9340]])
output.shape: torch.Size([25, 5])
predict: tensor([[5.7556e-02, 3.1053e-02, 5.5337e-03, 8.9796e-01, 7.8949e-03],
        [6.3671e-03, 1.3364e-02, 8.2114e-02, 2.4238e-02, 8.7392e-01],
        [8.0192e-01, 1.1121e-01, 6.6985e-03, 7.4732e-02, 5.4397e-03],
        [7.9019e-01, 4.5409e-02, 3.4414e-02, 8.1326e-02, 4.8659e-02],
        [8.3008e-01, 6.9153e-02, 1.3784e-02, 5.1752e-02, 3.5233e-02],
        [3.6944e-02, 5.5067e-03, 2.7965e-01, 4.6917e-02, 6.3098e-01],
        [2.5236e-01, 6.0451e-01, 3.6280e-02, 8.0462e-02, 2.6393e-02],
        [6.2422e-02, 8.8823e-01, 2.1716e-02, 1.9314e-02, 8.3183e-03],
        [1.4777e-02, 1.1987e-02, 6.6238e-01, 9.0457e-02, 2.2040e-01],
        [9.6397e-03, 1.3465e-02, 2.2587e-02, 9.4236e-01, 1.1946e-02],
        [6.9166e-02, 4.5257e-02, 9.8945e-03, 8.3345e-01, 4.2236e-02],
        [1.7747e-02, 9.2206e-01, 4.7038e-03, 4.8454e-02, 7.0336e-03],
        [1.5315e-03, 3.2391e-03, 9.3288e-01, 6.2206e-03, 5.6124e-02],
        [7.5058e-03, 7.6429e-03, 7.8122e-01, 1.1251e-02, 1.9238e-01],
        [1.0826e-02, 7.8443e-03, 7.3597e-02, 4.1676e-02, 8.6606e-01],
        [1.0932e-02, 9.5831e-01, 3.9915e-03, 1.6508e-02, 1.0261e-02],
        [4.7594e-03, 9.8895e-01, 8.3542e-04, 3.3437e-03, 2.1090e-03],
        [1.0253e-02, 2.2760e-02, 7.0362e-02, 2.8823e-02, 8.6780e-01],
        [2.9389e-01, 3.4729e-01, 3.2363e-02, 2.7216e-01, 5.4294e-02],
        [9.6862e-01, 1.1322e-02, 1.0689e-03, 1.8009e-02, 9.8037e-04],
        [2.3394e-02, 6.0568e-02, 1.1704e-02, 8.7517e-01, 2.9164e-02],
        [1.5289e-01, 2.3754e-02, 4.1299e-01, 2.3707e-01, 1.7330e-01],
        [6.5011e-02, 2.9051e-02, 4.0112e-01, 1.2940e-01, 3.7542e-01],
        [1.5872e-02, 5.1225e-03, 2.2535e-01, 4.7749e-03, 7.4888e-01],
        [6.9739e-01, 1.5279e-01, 3.9725e-02, 7.3661e-02, 3.6438e-02]])
predict_cla: [3 4 0 0 0 4 1 1 2 3 3 1 2 2 4 1 1 4 1 0 3 2 2 4 0]
class: sunflowers   prob: 0.898
class: tulips   prob: 0.874
class: daisy   prob: 0.802
class: daisy   prob: 0.79
class: daisy   prob: 0.83
class: tulips   prob: 0.631
class: dandelion   prob: 0.605
class: dandelion   prob: 0.888
class: roses   prob: 0.662
class: sunflowers   prob: 0.942
class: sunflowers   prob: 0.833
class: dandelion   prob: 0.922
class: roses   prob: 0.933
class: roses   prob: 0.781
class: tulips   prob: 0.866
class: dandelion   prob: 0.958
class: dandelion   prob: 0.989
class: tulips   prob: 0.868
class: dandelion   prob: 0.347
class: daisy   prob: 0.969
class: sunflowers   prob: 0.875
class: roses   prob: 0.413
class: roses   prob: 0.401
class: tulips   prob: 0.749
class: daisy   prob: 0.697

猜你喜欢

转载自blog.csdn.net/weixin_44040169/article/details/128006681
今日推荐