Pytorch implements bird species classification and recognition (including training code and bird data set)

Pytorch implements bird species classification and recognition (including training code and bird data set)

Table of contents

Pytorch implements bird species classification and recognition (including training code and bird data set)

1 Introduction

2. Bird Dataset

(1)Bird-Dataset26

(2) Custom data set

3. Bird classification and recognition model training

(1) Project installation

(2) Prepare Train and Test data

(3) Configuration file: ​config.yaml​

(4) Start training

(5) Visualize the training process

(6) Some optimization suggestions

(7) Some running error handling methods:

4. Test results of bird classification and recognition model

5. Project source code download


1 Introduction

This project will use deep learning methods to build a training and testing project for bird classification and recognition, and implement a simple bird image classification and recognition system . The project collected 26 bird species Bird-Dataset26, with about 20,000+ image data; in the bird data set Bird-Dataset26, based on ResNet18 bird classification recognition, the accuracy of the training set is about 98%, and the accuracy of the test set At about 95%, the backbone network can support common models such as googlenet, resnet[18,34,50], inception_v3, mobilenet_v2.

Model input size Test accuracy
mobilenet_v2 224×224 95.0000%
googlenet 224×224 96.1538%
resnet18 224×224 95.9615%

[Respect the originality, please indicate the source when reprinting] https://blog.csdn.net/guyuealian/article/details/132588031

2. Bird Dataset

(1)Bird-Dataset26

The project collected a data set of multiple bird species, named Bird-Dataset26. This data set collected a total of 26 bird species, including common eagles, peacocks and other bird species. The total data exceeds 20,000 pictures, with an average of Each category has about 700+ pictures; the data is divided into train and test. The training set train has a total of 20,000+ bird images, and the test set test has a total of 500+ bird images, which can meet the requirements of deep learning bird species classification and recognition. need.

Bird-Dataset26, part of the data was crawled from the Internet, and there are some wrong pictures. Although I have already cleaned some of them, I still recommend that you clean the data set again before training, otherwise it will affect the accuracy of the model's recognition. Bird image data can be retrieved here: China Bird Watching Record Center 

 The 26 species of birds included in Bird-Dataset26, the category names are:

Eight-Colored Thrush
white wagtail
white breast jadeite
white-breasted bitter bird
Spotted Green Takuboku
Red-necked Crane
Ruddy Shelduck
Red-breasted Woodpecker
Jungle Thrush
Hoopoe
Crested McChicken
gray wagtail
gray hornbill
Myna
house crow
Peacock
Blue Breasted Buddha, Dharma and Monk
Green-throated Bee-eater
cattle egret
common kingfisher
Common Leaf Warbler
Common Suzaku
Wattled McChicken
Mountain Wagtail
eagle
Brown-bellied tree magpie

(2) Custom data set

If you need to add new category data, or need a custom data set for training, you can proceed as follows:

  • Train and Test datasets require images of the same category to be placed in the same folder; and the subdirectory folder is named as the category name, such as

  • Class file: one list per line: ​class_name.txt​
    (Last line, please enter one more line)
A
B
C
D

  • Modify the data path of the configuration file: ​config.yaml​
train_data: # 可添加多个数据集
  - 'data/dataset/train1' 
  - 'data/dataset/train2'
test_data: 'data/dataset/test'
class_name: 'data/dataset/class_name.txt'
...
...

3. Bird classification and recognition model training

This project uses the Bird-Dataset26 bird dataset as training and testing samples.

(1) Project installation

The basic framework structure of the entire project is as follows:

.
├── classifier                 # 训练模型相关工具
├── configs                    # 训练配置文件
├── data                       # 训练数据
├── libs           
├── demo.py              # 模型推理demo
├── README.md            # 项目工程说明文档
├── requirements.txt     # 项目相关依赖包
└── train.py             # 训练文件

   The project depends on the python package, please refer to requirements.txt, use pip to install:

numpy==1.16.3
matplotlib==3.1.0
Pillow==6.0.0
easydict==1.9
opencv-contrib-python==4.5.2.52
opencv-python==4.5.1.48
pandas==1.1.5
PyYAML==5.3.1
scikit-image==0.17.2
scikit-learn==0.24.0
scipy==1.5.4
seaborn==0.11.2
tensorboard==2.5.0
tensorboardX==2.1
torch==1.7.1+cu110
torchvision==0.8.2+cu110
tqdm==4.55.1
xmltodict==0.12.0
basetrainer
pybaseutils==0.6.5

 Please refer to the project installation tutorial ( for beginners, please read the following tutorial first and configure the development environment ):

(2) Prepare Train and Test data

Download the bird species classification data set, Train and Test data sets. It is required that pictures of the same category be placed in the same folder; and the subdirectory folder should be named the category name.

Data enhancement methods mainly use: random cropping, random flipping, random rotation, color transformation and other processing methods

import numbers
import random
import PIL.Image as Image
import numpy as np
from torchvision import transforms


def image_transform(input_size, rgb_mean=[0.5, 0.5, 0.5], rgb_std=[0.5, 0.5, 0.5], trans_type="train"):
    """
    不推荐使用:RandomResizedCrop(input_size), # bug:目标容易被crop掉
    :param input_size: [w,h]
    :param rgb_mean:
    :param rgb_std:
    :param trans_type:
    :return::
    """
    if trans_type == "train":
        transform = transforms.Compose([
            transforms.Resize([int(128 * input_size[1] / 112), int(128 * input_size[0] / 112)]),
            transforms.RandomHorizontalFlip(),  # 随机左右翻转
            # transforms.RandomVerticalFlip(), # 随机上下翻转
            transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
            transforms.RandomRotation(degrees=5),
            transforms.RandomCrop([input_size[1], input_size[0]]),
            transforms.ToTensor(),
            transforms.Normalize(mean=rgb_mean, std=rgb_std),
        ])
    elif trans_type == "val" or trans_type == "test":
        transform = transforms.Compose([
            transforms.Resize([input_size[1], input_size[0]]),
            # transforms.CenterCrop([input_size[1], input_size[0]]),
            # transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=rgb_mean, std=rgb_std),
        ])
    else:
        raise Exception("transform_type ERROR:{}".format(trans_type))
    return transform

Modify the configuration file data path: ​config.yaml​

# 训练数据集,可支持多个数据集
train_data:
  - '/path/to/Bird-Dataset26/train'
# 测试数据集
test_data: '/path/to/Bird-Dataset26/test'
# 类别文件
class_name: '/path/to/Bird-Dataset26/class_name.txt'

(3) Configuration file: ​config.yaml​

  • Currently supported backbones include: googlenet, resnet[18,34,50], inception_v3, mobilenet_v2, etc. Other backbones can be customized and added.
  • Training parameters can be set through the (configs/config.yaml) configuration file

 Configuration file: ​config.yaml​The description is as follows:

# 训练数据集,可支持多个数据集
train_data:
  - '/path/to/Bird-Dataset26/train'
# 测试数据集
test_data: '/path/to/Bird-Dataset26/test'
# 类别文件
class_name: '/path/to/Bird-Dataset26/class_name.txt'
train_transform: "train"       # 训练使用的数据增强方法
test_transform: "val"          # 测试使用的数据增强方法
work_dir: "work_space/"        # 保存输出模型的目录
net_type: "resnet18"           # 骨干网络,支持:resnet18/50,mobilenet_v2,googlenet,inception_v3
width_mult: 1.0
input_size: [ 224,224 ]        # 模型输入大小
rgb_mean: [ 0.5, 0.5, 0.5 ]    # for normalize inputs to [-1, 1],Sequence of means for each channel.
rgb_std: [ 0.5, 0.5, 0.5 ]     # for normalize,Sequence of standard deviations for each channel.
batch_size: 32
lr: 0.01                       # 初始学习率
optim_type: "SGD"              # 选择优化器,SGD,Adam
loss_type: "CrossEntropyLoss"  # 选择损失函数:支持CrossEntropyLoss,LabelSmoothing
momentum: 0.9                  # SGD momentum
num_epochs: 100                # 训练循环次数
num_warn_up: 3                 # warn-up次数
num_workers: 8                 # 加载数据工作进程数
weight_decay: 0.0005           # weight_decay,默认5e-4
scheduler: "multi-step"        # 学习率调整策略
milestones: [ 20,50,80 ]       # 下调学习率方式
gpu_id: [ 0 ]                  # GPU ID
log_freq: 50                   # LOG打印频率
progress: True                 # 是否显示进度条
pretrained: False              # 是否使用pretrained模型
finetune: False                # 是否进行finetune

parameter type Reference illustrate
train_data str, list - Training data file, can support multiple files
test_data str, list - Test data file, can support multiple files
class_name str - class file
work_dir str work_space Training output workspace
net_type str resnet18
backbone type,{resnet18/50,mobilenet_v2,googlenet,inception_v3}
input_size list [128,128] Model input size [W,H]
batch_size int 32 batch size
lr float 0.1 Initial learning rate size
optim_type str SGD optimizer, {SGD,Adam}
loss_type str CELoss loss function
scheduler str multi-step Learning rate adjustment strategy, {multi-step, cosine}
milestones list [30,80,100] For nodes that reduce the learning rate, only scheduler=multi-step is effective.
momentum float 0.9 SGD momentum factor
num_epochs int 120 Number of rounds
num_warn_up int 3 The number of warn_up
num_workers int 12 Number of threads enabled by DataLoader
weight_decay float 5e-4 weight decay factor
gpu_id list [ 0 ] Specify the GPU card number for training. You can specify multiple
log_freq in 20 Frequency of displaying LOG information
fine tune str model.pth finetune model
progress bool True Whether to display a progress bar
distributed bool False Whether to use distributed training

(4) Start training

The entire training code is very simple to operate. Users only need to put the same category of data in the same directory and fill in the corresponding data path to start training.

python train.py -c configs/config.yaml 

After the training is completed, on the bird species data set Bird-Dataset26, the Accuracy of the training set is around 98%, and the Accuracy of the test set is around 95%. The backbone network can support googlenet, resnet[18,34,50], inception_v3 , mobilenet_v2 and other commonly used models, users can choose model training by themselves.

Model input size Test accuracy
mobilenet_v2 224×224 95.0000%
googlenet 224×224 96.1538%
resnet18 224×224 95.9615%

(5) Visualize the training process

The training process visualization tool is to use Tensorboard. To use it, enter in the terminal:
# 基本方法
tensorboard --logdir=path/to/log/
# 例如
tensorboard --logdir=data/pretrained/mobilenet_v2_1.0_224_224_CrossEntropyLoss_20230828_172209_6476/log

Visualization 

(6) Some optimization suggestions

If you want to further improve accuracy, you can try:

  1. The most important thing : Clean the data set, bird species data set Bird-Dataset26. Most of the data is crawled from the Internet, and there are some wrong pictures. Although I have already cleaned part of it, I still recommend that you clean it again before training. Data set, otherwise it will affect the recognition accuracy of the model.
  2. Use different backbone models, such as resnet50 or deeper models with larger parameters.
  3. Added data enhancement: Already supported: random cropping, random flipping, random rotation, color transformation and other data enhancement methods. You can try more complex data enhancement methods such as mixup, CutMix , etc.
  4. Sample balancing: It is recommended to perform sample balancing processing to avoid long tail problems
  5. Adjust hyperparameters: such as learning rate adjustment strategies, optimizers (SGD, Adam, etc.)
  6. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

(7) 一些运行错误处理方法:

  • 项目不要出现含有中文字符的目录文件或路径,否则会出现很多异常!!!!!!!!

  • cannot import name 'load_state_dict_from_url' 

由于一些版本升级,会导致部分接口函数不能使用,请确保版本对应

torch==1.7.1

torchvision==0.8.2

或者将对应python文件将

from torchvision.models.resnet import model_urls, load_state_dict_from_url

修改为:

from torch.hub import load_state_dict_from_url
model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

4. 鸟类分类识别模型测试效果

 demo.py文件用于推理和测试模型的效果,填写好配置文件,模型文件以及测试图片即可运行测试了

def get_parser():
    # 配置文件
    config_file = "data/pretrained/mobilenet_v2_1.0_224_224_CrossEntropyLoss_20230828_172209_6476/config.yaml"
    # 模型文件
    model_file = "data/pretrained/mobilenet_v2_1.0_224_224_CrossEntropyLoss_20230828_172209_6476/model/best_model_063_95.0000.pth"
    # 待测试图片目录
    image_dir = "data/test_images"
    parser = argparse.ArgumentParser(description="Inference Argument")
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    parser.add_argument("--image_dir", help="image file or directory", default=image_dir, type=str)
    return parser
#!/usr/bin/env bash
# Usage:
# python demo.py  -c "path/to/config.yaml" -m "path/to/model.pth" --image_dir "path/to/image_dir"

python demo.py -c data/pretrained/mobilenet_v2_1.0_224_224_CrossEntropyLoss_20230828_172209_6476/config.yaml -m data/pretrained/mobilenet_v2_1.0_224_224_CrossEntropyLoss_20230828_172209_6476/model/best_model_063_95.0000.pth --image_dir data/test_images

运行测试结果: 

pred_index:['灰犀鸟'],pred_score:[0.5273883]

pred_index:['家鸦'],pred_score:[0.9989742]

pred_index:['鹰'],pred_score:[0.9795395]

pred_index:['孔雀'],pred_score:[0.9997749]


5.项目源码下载

 【源码下载】Pytorch实现鸟类品种分类识别(含训练代码和鸟类数据集)

整套项目源码内容包含:

  • Bird-Dataset26鸟类数据集: 该数据集包含 26 种不同种类的鸟类品种,总数超过2万张图像,可满足深度学习鸟类分类识别的需求
  • 项目支持自定义数据集进行训练
  • 项目模型训练,支持的backbone骨干网络模型有:googlenet,resnet[18,34,50],inception_v3,mobilenet_v2等, 其他backbone可以自定义添加
  • 项目提供已经训练好的模型,无需重新训练,即可运行demo.py测试图片

Guess you like

Origin blog.csdn.net/guyuealian/article/details/132588031