Computer vision image classification [NO.1] VGG16 trains its own data set

The following computer vision tasks are image classification tasks. At present, there are three basic tasks in computer vision, image classification, image detection, and image segmentation. Previously, the blog mainly conducted research and experience sharing on image detection tasks, and the follow-up will conduct research on image classification tasks. Friends who want to know please continue to pay attention.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os, PIL
import numpy as np


# In[2]:


import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


# In[3]:


import os,PIL,random,pathlib


# In[4]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[6] for path in data_paths]
classeNames


# In[5]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[5] for path in data_paths]
classeNames


# In[6]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[3] for path in data_paths]
classeNames


# In[7]:


# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder('E:/chaye/Tea_Leaf_Disease/',transform=train_transforms)
total_data



Guess you like

Origin blog.csdn.net/m0_70388905/article/details/130461908