ピトーチ (3)

1. 古典的なネットワーク アーキテクチャの画像分類モデル

データ前処理部分:

  • データ拡張
  • データの前処理
  • DataLoader モジュールはバッチ データを直接読み取ります

ネットワークモジュールの設定:

  • 事前トレーニングされたモデルをロードします。torchvision には多くの古典的なネットワーク アーキテクチャがあり、直接呼び出すことができます。
  • 他の人がトレーニングしたタスクは私たちのタスクとまったく同じではないことに注意してください。最後のヘッド層 (通常は最後に完全に接続された層) を変更し、それを独自のタスクに変更する必要があります。
  • 継続中は、最初からすべてをトレーニングすることも、タスクの最後の層だけをトレーニングすることもできます。これは、最初の数層はすべて特徴抽出用であり、本質的なタスクの目標は同じであるためです。

ネットワーク モデルの保存とテスト:

  • モデルは選択的に保存できます。たとえば、現在の効果が検証セットで良好であれば、それを保存します。
  • 実際のテスト用にモデルを読み取る

2.転移学習

他の人がトレーニングしたモデルを使用して自分のモデルをトレーニングします

注: 両方のオブジェクトは可能な限り類似しています

転移学習 Web サイト:ローカルで開始 | PyTorch

3. 花の画像分類例

未完成の

#数据读取与预处理操作
data_dir = './a/'
# 训练集
train_dir = data_dir + '/train'
#验证集
valid_ir = data_dir + '/valid'

#制作数据源
data_transfroms = {
    'train':transforms.Compose([transforms.RandomRotation(45), #随机旋转(-45~45)
    transforms.CenterCrop(224), #从中心开始裁剪
    transforms.RandomHorizontalFlip(p = 0.5), #随机水平翻转
    transforms.RandomVerticalFlip(p = 0.5), #随机垂直翻转
    transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue = 0.1),
    transforms.RandomGrayscale(p = 0.025), #概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

#batch数据制作
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transfroms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size = batch_size,shuffle = True) for x in ['train','valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes


#读取标签对应的实际名字
with open('cat_to_name.json','r') as f:
    cat_to_name = json.load(f)

#加载model中提供的模型,并且直接用训练好的权重当做初始化参数
model_name = 'resnet'
#是否用人家训练好的特征来做
feature_extract = True

#是否用GPU来训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('cuda is not available. Training on CPU')
else:
    print('cuda is available. Training on GPU')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameter():
            param.requires_grad = False

model_ft = models.resnet152()

おすすめ

転載: blog.csdn.net/weixin_64443786/article/details/132007292