计算机视觉图像分类【NO.1】VGG16训练自己的数据集

以下计算机视觉任务之图像分类任务,目前计算机视觉有三大基本任务,图像分类、图像检测、图像分割。此前博客主要对图像检测任务进行研究和经验分享,后续将对图像分类任务进行研究。想要了解的朋友请持续关注。

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os, PIL
import numpy as np


# In[2]:


import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


# In[3]:


import os,PIL,random,pathlib


# In[4]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[6] for path in data_paths]
classeNames


# In[5]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[5] for path in data_paths]
classeNames


# In[6]:


data_dir = 'E:/chaye/Tea_Leaf_Disease/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[3] for path in data_paths]
classeNames


# In[7]:


# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder('E:/chaye/Tea_Leaf_Disease/',transform=train_transforms)
total_data



猜你喜欢

转载自blog.csdn.net/m0_70388905/article/details/130461908
今日推荐