Importation et prétraitement des ensembles de données de formation et de validation. Importation et prétraitement prédictifs de plusieurs images.
Organisé en fonctions pouvant être appelées directement, voir le code suivant.
import os
import sys
import json
import PIL.Image as Image
from torch.utils.data import Dataset
import torch
from torchvision import transforms, datasets
from torchvision import transforms
def predata(batch_size,path,method): # batch_size,数据集路径,测试集或验证集 [后两个都是字符串类型】
#图片处理
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
#数据集路径
# 加载dataset
assert os.path.exists(path), "{} path does not exist.".format(path)
train_dataset = datasets.ImageFolder(root=os.path.join(path, method), # 加载train数据集
transform=data_transform[method])
#数据量
train_num = len(train_dataset)
#只有训练时才生成类别的json文件
if method =="train":
# 得到分类名称对应的索引 # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx # ----大概是已经通过数据集已经分类好的文件名确定的图片类别
cla_dict = dict((val, key) for key, val in flower_list.items()) # 将key,value值反过来,已达到通过索引找到分类的目的
# write dict into json file
json_str = json.dumps(cla_dict, indent=4) # 编码成json格式
with open('class_indices.json', 'w') as json_file: # 写进去
json_file.write(json_str)
# nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
nw = 0
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
return train_num, train_loader
class MyData(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.img_path = os.listdir(self.data_dir)
def __getitem__(self, idx): #这样用 image = dataset.__getitem__(i) #获取下标为i的图像
img_name = self.img_path[idx]
img_item_path = os.path.join(self.data_dir, img_name)
img = Image.open(img_item_path)
return img
def __len__(self):
return len(self.img_path)
def read_data(data_dir):# 测试图片所在文件夹的路径(绝对路径)
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
dataset = MyData(data_dir) # 创建对象
sum = dataset.__len__() # 获取数据集的总长度
image_list=[]
for i in range(0, sum): # 遍历每张图片
image = dataset.__getitem__(i) # 获取下标为 i 的图像
image = data_transform(image)
image_list.append(image)
batch_image = torch.stack(image_list,dim=0)
return batch_image
Exemples de prédiction d'images multiples :
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model_v2 import MobileNetV2
from utils import read_data
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
json_path = './class_indices.json'#加载label
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = MobileNetV2(num_classes=5).to(device)
# load model weights
model_weight_path = "./MobileNetV2.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
# 预测数据集路径
test_path = "E:\deep-learning-for-image-processing-master\data_set\\flower_data\\test"
with torch.no_grad():
# predict class
#output = model(img.to(device)).cpu()
batch_image = read_data(test_path).to(device) # (25,3,224,224)
output = model(batch_image).cpu() #(25,5)
print("output:",output)
print("output.shape:",output.shape)
#output = torch.squeeze(output) #压缩batch维度
predict = torch.softmax(output, dim=1)#从scores变成概率
print("predict:",predict)
predict_cla = torch.argmax(predict,dim=1).numpy()
print("predict_cla:",predict_cla)
for i in range(len(predict_cla)):
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla[i])],
predict[i][predict_cla[i]].numpy())
print(print_res)
if __name__ == '__main__':
main()
sortir:
output: tensor([[-0.8376, -1.4547, -3.1795, 1.9097, -2.8242],
[-3.0098, -2.2684, -0.4529, -1.6731, 1.9120],
[ 1.7103, -0.2653, -3.0748, -0.6628, -3.2830],
[ 1.2928, -1.5638, -1.8410, -0.9810, -1.4947],
[ 1.7677, -0.7175, -2.3304, -1.0074, -1.3919],
[-1.9958, -3.8992, 0.0284, -1.7568, 0.8421],
[-0.1854, 0.6881, -2.1250, -1.3285, -2.4432],
[-0.7020, 1.9533, -1.7579, -1.8751, -2.7175],
[-3.0090, -3.2183, 0.7938, -1.1972, -0.3066],
[-2.4225, -2.0883, -1.5710, 2.1600, -2.2080],
[-0.9034, -1.3276, -2.8479, 1.5856, -1.3967],
[-1.7756, 2.1748, -3.1034, -0.7712, -2.7011],
[-3.7374, -2.9883, 2.6747, -2.3358, -0.1361],
[-2.7635, -2.7453, 1.8817, -2.3586, 0.4804],
[-2.3963, -2.7185, -0.4796, -1.0483, 1.9857],
[-1.6062, 2.8672, -2.6138, -1.1941, -1.6695],
[-1.7888, 3.5478, -3.5287, -2.1418, -2.6027],
[-2.6447, -1.8472, -0.7186, -1.6111, 1.7937],
[-0.1753, -0.0084, -2.3816, -0.2522, -1.8641],
[ 3.3103, -1.1389, -3.4989, -0.6747, -3.5854],
[-1.7156, -0.7643, -2.4082, 1.9063, -1.4952],
[-1.0604, -2.9224, -0.0667, -0.6217, -0.9351],
[-1.6808, -2.4863, 0.1389, -0.9924, 0.0727],
[-1.9966, -3.1275, 0.6565, -3.1978, 1.8574],
[ 1.0177, -0.5006, -1.8477, -1.2302, -1.9340]])
output.shape: torch.Size([25, 5])
predict: tensor([[5.7556e-02, 3.1053e-02, 5.5337e-03, 8.9796e-01, 7.8949e-03],
[6.3671e-03, 1.3364e-02, 8.2114e-02, 2.4238e-02, 8.7392e-01],
[8.0192e-01, 1.1121e-01, 6.6985e-03, 7.4732e-02, 5.4397e-03],
[7.9019e-01, 4.5409e-02, 3.4414e-02, 8.1326e-02, 4.8659e-02],
[8.3008e-01, 6.9153e-02, 1.3784e-02, 5.1752e-02, 3.5233e-02],
[3.6944e-02, 5.5067e-03, 2.7965e-01, 4.6917e-02, 6.3098e-01],
[2.5236e-01, 6.0451e-01, 3.6280e-02, 8.0462e-02, 2.6393e-02],
[6.2422e-02, 8.8823e-01, 2.1716e-02, 1.9314e-02, 8.3183e-03],
[1.4777e-02, 1.1987e-02, 6.6238e-01, 9.0457e-02, 2.2040e-01],
[9.6397e-03, 1.3465e-02, 2.2587e-02, 9.4236e-01, 1.1946e-02],
[6.9166e-02, 4.5257e-02, 9.8945e-03, 8.3345e-01, 4.2236e-02],
[1.7747e-02, 9.2206e-01, 4.7038e-03, 4.8454e-02, 7.0336e-03],
[1.5315e-03, 3.2391e-03, 9.3288e-01, 6.2206e-03, 5.6124e-02],
[7.5058e-03, 7.6429e-03, 7.8122e-01, 1.1251e-02, 1.9238e-01],
[1.0826e-02, 7.8443e-03, 7.3597e-02, 4.1676e-02, 8.6606e-01],
[1.0932e-02, 9.5831e-01, 3.9915e-03, 1.6508e-02, 1.0261e-02],
[4.7594e-03, 9.8895e-01, 8.3542e-04, 3.3437e-03, 2.1090e-03],
[1.0253e-02, 2.2760e-02, 7.0362e-02, 2.8823e-02, 8.6780e-01],
[2.9389e-01, 3.4729e-01, 3.2363e-02, 2.7216e-01, 5.4294e-02],
[9.6862e-01, 1.1322e-02, 1.0689e-03, 1.8009e-02, 9.8037e-04],
[2.3394e-02, 6.0568e-02, 1.1704e-02, 8.7517e-01, 2.9164e-02],
[1.5289e-01, 2.3754e-02, 4.1299e-01, 2.3707e-01, 1.7330e-01],
[6.5011e-02, 2.9051e-02, 4.0112e-01, 1.2940e-01, 3.7542e-01],
[1.5872e-02, 5.1225e-03, 2.2535e-01, 4.7749e-03, 7.4888e-01],
[6.9739e-01, 1.5279e-01, 3.9725e-02, 7.3661e-02, 3.6438e-02]])
predict_cla: [3 4 0 0 0 4 1 1 2 3 3 1 2 2 4 1 1 4 1 0 3 2 2 4 0]
class: sunflowers prob: 0.898
class: tulips prob: 0.874
class: daisy prob: 0.802
class: daisy prob: 0.79
class: daisy prob: 0.83
class: tulips prob: 0.631
class: dandelion prob: 0.605
class: dandelion prob: 0.888
class: roses prob: 0.662
class: sunflowers prob: 0.942
class: sunflowers prob: 0.833
class: dandelion prob: 0.922
class: roses prob: 0.933
class: roses prob: 0.781
class: tulips prob: 0.866
class: dandelion prob: 0.958
class: dandelion prob: 0.989
class: tulips prob: 0.868
class: dandelion prob: 0.347
class: daisy prob: 0.969
class: sunflowers prob: 0.875
class: roses prob: 0.413
class: roses prob: 0.401
class: tulips prob: 0.749
class: daisy prob: 0.697