基于PaddleClas的天气以及时间多标签分类比赛

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第3天,点击查看活动详情

一、基于PaddleClas的天气以及时间多标签分类比赛

天气以及时间分类 比赛地址: www.datafountain.cn/competition…

1.赛题背景

在自动驾驶场景中,天气和时间(黎明、早上、下午、黄昏、夜晚)会对传感器的精度造成影响,比如雨天和夜晚会对视觉传感器的精度造成很大的影响。此赛题旨在对拍摄的照片天气和时间进行分类,从而在不同的天气和时间使用不同的自动驾驶策略。

2.赛题任务

此赛题的数据集由云测数据提供。比赛数据集中包含3000张真实场景下行车记录仪采集的图片,其中训练集包含2600张带有天气和时间类别标签的图片,测试集包含400张不带有标签的图片。参赛者需基于Oneflow框架在训练集上进行训练,对测试集中照片的天气和时间进行分类。

3.数据简介

本赛题的数据集包含2600张人工标注的天气和时间标签。

  • 天气类别:多云、晴天、雨天、雪天和雾天5个类别
  • 时间:黎明、早上、下午、黄昏、夜晚5个类别

下午 多云

早上 雨天

4.数据说明

数据集包含anno和image两个文件夹,anno文件夹中包含2600个标签json文件,image文件夹中包含3000张行车记录仪拍摄的JPEG编码照片。图片标签将字典以json格式序列化进行保存:

列名 取值范围 作用
Period 黎明、早上、下午、黄昏、夜晚 图片拍摄时间
Weather 多云、晴天、雨天、雪天、雾天 图片天气

5.提交要求

参赛者使用Oneflow框架对数据集进行训练后对测试集图片进行推理后, 1.将测试集图片的目标检测和识别结果以与训练集格式保持一致的json文件序列化保存,并上传至参赛平台由参赛平台自动评测返回结果。 2.在提交时的备注附上自己的模型github仓库链接

6.提交示例

{ “annotations”: [ { “filename”: “test_images\00008.jpg”, “period”: “Morning”, “weather”: “Cloudy” }] }

7.解题思路

总体上看,该任务可以分为2个:一个是预测时间、一个是预测天气,具体如下:

  • 预测时间、天气数据标签列表生成
  • 数据集划分
  • 数据均衡(数据很不均衡)
  • 分别预测
  • 合并预测结果

二、数据集准备

1.数据下载

# 直接下载,速度超快
!wget https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/train_dataset.zip
!wget https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/test_dataset.zip
!wget https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/submit_example.json
复制代码
--2022-06-20 17:34:33--  https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/train_dataset.zip
Resolving awscdn.datafountain.cn (awscdn.datafountain.cn)... 210.51.40.10, 210.51.40.12, 210.51.40.28, ...
Connecting to awscdn.datafountain.cn (awscdn.datafountain.cn)|210.51.40.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 419324853 (400M) [application/octet-stream]
Saving to: ‘train_dataset.zip’

train_dataset.zip   100%[===================>] 399.90M  20.2MB/s    in 32s     

2022-06-20 17:35:05 (12.7 MB/s) - ‘train_dataset.zip’ saved [419324853/419324853]

--2022-06-20 17:35:06--  https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/test_dataset.zip
Resolving awscdn.datafountain.cn (awscdn.datafountain.cn)... 58.254.138.137, 58.254.138.145, 58.254.138.151, ...
Connecting to awscdn.datafountain.cn (awscdn.datafountain.cn)|58.254.138.137|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 62269247 (59M) [application/octet-stream]
Saving to: ‘test_dataset.zip’

test_dataset.zip    100%[===================>]  59.38M  3.13MB/s    in 9.7s    

2022-06-20 17:35:16 (6.13 MB/s) - ‘test_dataset.zip’ saved [62269247/62269247]

--2022-06-20 17:35:16--  https://awscdn.datafountain.cn/cometition_data2/Files/BDCI2021/555/submit_example.json
Resolving awscdn.datafountain.cn (awscdn.datafountain.cn)... 58.254.138.137, 58.254.138.145, 58.254.138.151, ...
Connecting to awscdn.datafountain.cn (awscdn.datafountain.cn)|58.254.138.137|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 593 [application/octet-stream]
Saving to: ‘submit_example.json’

submit_example.json 100%[===================>]     593  --.-KB/s    in 0s      

2022-06-20 17:35:17 (102 MB/s) - ‘submit_example.json’ saved [593/593]
复制代码

2.数据解压缩

# 解压缩数据集
!unzip -qoa test_dataset.zip 
!unzip -qoa train_dataset.zip
复制代码

3.list文件制作

  • Period 为 黎明、早上、下午、黄昏、夜晚
  • weather 为“Morning”, “weather”: “Cloudy
# 标签修改
%cd ~
import json
import os

train = {}
with open('train.json', 'r') as f:
    train = json.load(f)

period_dic = {'Dawn': 0, 'Dusk': 1, 'Morning': 2, 'Afternoon': 3}
weather_dic =  {'Cloudy': 0, 'Rainy': 1, 'Sunny': 2}

f_period=open('train.txt','w')
for item in train["annotations"]:
    period_list=[0,0,0,0]
    weather_list=[0,0,0]
    period_list[period_dic[item['period']] ] =1
    weather_list[weather_dic[item['weather']]] =1
    file_name=os.path.join(item['filename'].split('\\')[0], item['filename'].split('\\')[1])
    f_period.write(file_name +'\t'+ str(period_list[0]) +',' + str(period_list[1]) +','+ str(period_list[2]) +','+ str(period_list[3]) +','+ str(weather_list[0])+','+ str(weather_list[1])+','+ str(weather_list[2])+'\n')
f_period.close()
print("写入train.txt完成!!!")
复制代码
/home/aistudio
写入train.txt完成!!!
复制代码
!head train.txt
复制代码
train_images/00001.jpg	0,0,1,0,1,0,0
train_images/00002.jpg	0,0,0,1,1,0,0
train_images/00003.jpg	0,0,1,0,1,0,0
train_images/00004.jpg	0,0,1,0,0,0,1
train_images/00005.jpg	0,0,0,1,1,0,0
train_images/00006.jpg	0,0,0,1,1,0,0
train_images/00007.jpg	1,0,0,0,1,0,0
train_images/00009.jpg	0,0,0,1,1,0,0
train_images/00010.jpg	0,0,1,0,1,0,0
train_images/00011.jpg	0,0,1,0,0,1,0
复制代码

4.数据集划分

按8:2划分train和val

# 训练集、测试集划分
import pandas as pd
import os
from sklearn.model_selection import train_test_split

def split_dataset(data_file):
    # 展示不同的调用方式
    data = pd.read_csv(data_file, header=None, sep=',')
    train_dataset, eval_dataset = train_test_split(data, test_size=0.2, random_state=42)
    print(f'train dataset len: {train_dataset.size}')
    print(f'eval dataset len: {eval_dataset.size}')
    train_filename='train_' + data_file.split('.')[0]+'.txt'
    eval_filename='eval_' + data_file.split('.')[0]+'.txt'
    train_dataset.to_csv(train_filename, index=None, header=None, sep=',')
    eval_dataset.to_csv(eval_filename, index=None, header=None, sep=',')
    

data_file='train.txt'
split_dataset(data_file)
复制代码
train dataset len: 14560
eval dataset len: 3640
复制代码
!head train_train.txt
复制代码
train_images/00679.jpg	0,0,1,0,1,0,0
train_images/00054.jpg	0,0,1,0,0,0,1
train_images/02043.jpg	0,0,1,0,0,0,1
train_images/01120.jpg	0,0,1,0,1,0,0
train_images/02552.jpg	0,0,0,1,1,0,0
train_images/01980.jpg	0,0,1,0,1,0,0
train_images/00030.jpg	0,0,1,0,1,0,0
train_images/01219.jpg	0,0,1,0,0,0,1
train_images/02007.jpg	0,0,0,1,1,0,0
train_images/01971.jpg	0,0,0,1,1,0,0
复制代码
!head eval_train.txt
复制代码
train_images/01834.jpg	0,0,1,0,1,0,0
train_images/00235.jpg	0,0,1,0,0,0,1
train_images/00283.jpg	0,0,0,1,1,0,0
train_images/02453.jpg	0,0,1,0,0,0,1
train_images/01828.jpg	0,0,1,0,0,0,1
train_images/02236.jpg	0,0,1,0,0,1,0
train_images/01917.jpg	0,0,0,1,1,0,0
train_images/00566.jpg	0,0,1,0,0,0,1
train_images/00914.jpg	0,0,0,1,1,0,0
train_images/02394.jpg	0,0,0,1,0,0,1
复制代码

三、环境准备

飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,助力使用者训练出更好的视觉模型和应用落地。此次计划使用端到端的PaddleClas图像分类套件来快速完成分类。此次使用PaddleClas框架完成比赛。

1.PaddleClas下载

# git 下载PaddleClas
!git clone https://gitee.com/paddlepaddle/PaddleClas.git --depth=1
复制代码
Cloning into 'PaddleClas'...
remote: Enumerating objects: 2019, done.
remote: Counting objects: 100% (2019/2019), done.
remote: Compressing objects: 100% (1256/1256), done.
remote: Total 2019 (delta 1001), reused 1333 (delta 725), pack-reused 0
Receiving objects: 100% (2019/2019), 86.17 MiB | 2.21 MiB/s, done.
Resolving deltas: 100% (1001/1001), done.
Checking connectivity... done.
复制代码

2.PaddleClas安装

# 安装
%cd ~/PaddleClas/
!pip install -U pip --user
!pip install -r requirements.txt
!pip install -e ./
%cd ~
复制代码
/home/aistudio/PaddleClas
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: pip in /home/aistudio/.data/webide/pip/lib/python3.7/site-packages (22.1.2)
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (0.7.2)
Requirement already satisfied: ujson in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (5.2.0)
Collecting opencv-python==4.4.0.46
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/1b/2d/62eba161d3d713e1720504de1c25d439b02c85159804d9ecead10be5d87e/opencv_python-4.4.0.46-cp37-cp37m-manylinux2014_x86_64.whl (49.5 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.5/49.5 MB 2.4 MB/s eta 0:00:0000:0100:01
[?25hRequirement already satisfied: pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (7.1.2)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (4.36.1)
Requirement already satisfied: PyYAML in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 6)) (5.1.2)
Requirement already satisfied: visualdl>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (2.2.0)
Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 8)) (1.6.3)
Requirement already satisfied: scikit-learn>=0.21.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 9)) (0.24.2)
Requirement already satisfied: gast==0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 10)) (0.3.3)
Collecting faiss-cpu==1.7.1.post2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/4c/d6/072a9d18430b8c68c99ffb49fe14fbf89c62f71dcd4f5f692c7691447a14/faiss_cpu-1.7.1.post2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.4/8.4 MB 2.6 MB/s eta 0:00:0000:0100:01
[?25hRequirement already satisfied: easydict in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 12)) (1.9)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from opencv-python==4.4.0.46->-r requirements.txt (line 3)) (1.20.3)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (1.1.5)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (3.14.0)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (2.2.3)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (1.21.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (0.8.53)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (2.22.0)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (0.7.1.1)
Requirement already satisfied: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (1.16.0)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (4.0.1)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (1.0.0)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->-r requirements.txt (line 7)) (1.1.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.0->-r requirements.txt (line 9)) (2.1.0)
Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.0->-r requirements.txt (line 9)) (0.14.1)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (0.6.1)
Requirement already satisfied: pycodestyle<2.9.0,>=2.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.8.0)
Requirement already satisfied: importlib-metadata<4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (4.2.0)
Requirement already satisfied: pyflakes<2.5.0,>=2.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.4.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->-r requirements.txt (line 7)) (3.0.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->-r requirements.txt (line 7)) (7.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->-r requirements.txt (line 7)) (0.16.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.1.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.2.0->-r requirements.txt (line 7)) (2019.3)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.2.0->-r requirements.txt (line 7)) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.2.0->-r requirements.txt (line 7)) (0.18.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->-r requirements.txt (line 7)) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.1.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.8.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->-r requirements.txt (line 7)) (3.0.8)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.3.4)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (0.10.0)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.3.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (16.7.9)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.4.10)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.0.1)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->-r requirements.txt (line 7)) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->-r requirements.txt (line 7)) (1.25.6)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->-r requirements.txt (line 7)) (2019.9.11)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.8)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (3.8.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl>=2.2.0->-r requirements.txt (line 7)) (4.2.0)
Requirement already satisfied: MarkupSafe>=2.0.0rc2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl>=2.2.0->-r requirements.txt (line 7)) (2.0.1)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl>=2.2.0->-r requirements.txt (line 7)) (56.2.0)
Installing collected packages: faiss-cpu, opencv-python
  Attempting uninstall: opencv-python
    Found existing installation: opencv-python 4.1.1.26
    Uninstalling opencv-python-4.1.1.26:
      Successfully uninstalled opencv-python-4.1.1.26
Successfully installed faiss-cpu-1.7.1.post2 opencv-python-4.4.0.46
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Obtaining file:///home/aistudio/PaddleClas
  Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: prettytable in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (0.7.2)
Requirement already satisfied: ujson in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (5.2.0)
Requirement already satisfied: opencv-python==4.4.0.46 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (4.4.0.46)
Requirement already satisfied: pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (7.1.2)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (4.36.1)
Requirement already satisfied: PyYAML in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (5.1.2)
Requirement already satisfied: visualdl>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (2.2.0)
Requirement already satisfied: scipy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (1.6.3)
Requirement already satisfied: scikit-learn>=0.21.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (0.24.2)
Requirement already satisfied: gast==0.3.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (0.3.3)
Requirement already satisfied: faiss-cpu==1.7.1.post2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (1.7.1.post2)
Requirement already satisfied: easydict in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddleclas==0.0.0) (1.9)
Requirement already satisfied: numpy>=1.14.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from opencv-python==4.4.0.46->paddleclas==0.0.0) (1.20.3)
Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.0->paddleclas==0.0.0) (0.14.1)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.0->paddleclas==0.0.0) (2.1.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (0.8.53)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (1.0.0)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (1.21.0)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (1.1.1)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (3.14.0)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (0.7.1.1)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (1.1.5)
Requirement already satisfied: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (1.16.0)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (4.0.1)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (2.2.3)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.2.0->paddleclas==0.0.0) (2.22.0)
Requirement already satisfied: pycodestyle<2.9.0,>=2.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (2.8.0)
Requirement already satisfied: pyflakes<2.5.0,>=2.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (2.4.0)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (0.6.1)
Requirement already satisfied: importlib-metadata<4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (4.2.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->paddleclas==0.0.0) (3.0.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->paddleclas==0.0.0) (7.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->paddleclas==0.0.0) (1.1.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl>=2.2.0->paddleclas==0.0.0) (0.16.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.2.0->paddleclas==0.0.0) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.2.0->paddleclas==0.0.0) (2019.3)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.2.0->paddleclas==0.0.0) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.2.0->paddleclas==0.0.0) (0.18.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->paddleclas==0.0.0) (2.8.2)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->paddleclas==0.0.0) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->paddleclas==0.0.0) (3.0.8)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl>=2.2.0->paddleclas==0.0.0) (1.1.0)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (1.3.0)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (1.4.10)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (1.3.4)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (2.0.1)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (0.10.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.2.0->paddleclas==0.0.0) (16.7.9)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->paddleclas==0.0.0) (2019.9.11)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->paddleclas==0.0.0) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->paddleclas==0.0.0) (1.25.6)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.2.0->paddleclas==0.0.0) (2.8)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (3.8.0)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl>=2.2.0->paddleclas==0.0.0) (4.2.0)
Requirement already satisfied: MarkupSafe>=2.0.0rc2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl>=2.2.0->paddleclas==0.0.0) (2.0.1)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl>=2.2.0->paddleclas==0.0.0) (56.2.0)
Installing collected packages: paddleclas
  Running setup.py develop for paddleclas
Successfully installed paddleclas-0.0.0
/home/aistudio
复制代码

四、模型训练 and 评估

1.配置

PaddleClas/ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml 为基础进行修改

# global configs
Global:
  checkpoints: null
  pretrained_model: null
  output_dir: ./output/
  device: gpu
  save_interval: 1
  eval_during_train: True
  eval_interval: 1
  epochs: 10
  print_batch_step: 10
  use_visualdl: False
  # used for static mode and model export
  image_shape: [3, 224, 224]
  save_inference_dir: ./inference
  use_multilabel: True
# model architecture
Arch:
  name: MobileNetV1
  class_num: 7
  pretrained: True

# loss function config for traing/eval process
Loss:
  Train:
    - MultiLabelLoss:
        weight: 1.0
  Eval:
    - MultiLabelLoss:
        weight: 1.0


Optimizer:
  name: Momentum
  momentum: 0.9
  lr:
    name: Cosine
    learning_rate: 0.1
  regularizer:
    name: 'L2'
    coeff: 0.00004


# data loader for train and eval
DataLoader:
  Train:
    dataset:
      name: MultiLabelDataset
      image_root: ../
      cls_label_path: ../train_train.txt
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - RandCropImage:
            size: 224
        - RandFlipImage:
            flip_code: 1
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''

    sampler:
      name: DistributedBatchSampler
      batch_size: 64
      drop_last: False
      shuffle: True
    loader:
      num_workers: 1
      use_shared_memory: True

  Eval:
    dataset: 
      name: MultiLabelDataset
      image_root: ../
      cls_label_path: ../eval_train.txt
      transform_ops:
        - DecodeImage:
            to_rgb: True
            channel_first: False
        - ResizeImage:
            resize_short: 256
        - CropImage:
            size: 224
        - NormalizeImage:
            scale: 1.0/255.0
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
            order: ''
    sampler:
      name: DistributedBatchSampler
      batch_size: 256
      drop_last: False
      shuffle: False
    loader:
      num_workers: 1
      use_shared_memory: True

Infer:
  infer_imgs: /home/aistudio/test_images
  batch_size: 10
  transforms:
    - DecodeImage:
        to_rgb: True
        channel_first: False
    - ResizeImage:
        resize_short: 256
    - CropImage:
        size: 224
    - NormalizeImage:
        scale: 1.0/255.0
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]
        order: ''
    - ToCHWImage:
  PostProcess:
    name: MultiLabelTopk
    topk: 5
    class_id_map_file: None

Metric:
  Train:
    - HammingDistance:
    - AccuracyScore:
  Eval:
    - HammingDistance:
    - AccuracyScore:

复制代码
# 覆盖配置
%cd ~
!cp -f ./MobileNetV1_multilabel.yaml  PaddleClas/ppcls/configs/quick_start/professional/MobileNetV1_multilabel.yaml
复制代码
/home/aistudio
复制代码

2.修bug

对PaddleClas/ppcls/data/dataloader/multilabel_dataset.py进行修改

from __future__ import print_function

import numpy as np
import os
import cv2

from ppcls.data.preprocess import transform
from ppcls.utils import logger

from .common_dataset import CommonDataset


class MultiLabelDataset(CommonDataset):
    def _load_anno(self, label_ratio=False):
        assert os.path.exists(self._cls_path)
        assert os.path.exists(self._img_root)
        self.images = []
        self.labels = []
        with open(self._cls_path) as fd:
            lines = fd.readlines()
            for l in lines:
                l = l.strip().split("\t")
                self.images.append(os.path.join(self._img_root, l[0]))

                labels = l[1].split(',')
                labels = [np.int64(i) for i in labels]

                self.labels.append(labels)
                assert os.path.exists(self.images[-1])
        
        self.label_ratio=label_ratio
        if label_ratio:
            return np.array(self.labels).mean(0).astype("float32")

    def __getitem__(self, idx):
        try:
            with open(self.images[idx], 'rb') as f:
                img = f.read()
            if self._transform_ops:
                img = transform(img, self._transform_ops)
            img = img.transpose((2, 0, 1))
            label = np.array(self.labels[idx]).astype("float32")
            # if self.label_ratio is not None:
            if self.label_ratio:
                return (img, np.array([label, self.label_ratio]))
            else:
                return (img, label)

        except Exception as ex:
            logger.error("Exception occured when parse line: {} with msg: {}".
                         format(self.images[idx], ex))
            rnd_idx = np.random.randint(self.__len__())
            return self.__getitem__(rnd_idx)
复制代码
%cd ~
!cp multilabel_dataset.py PaddleClas/ppcls/data/dataloader/multilabel_dataset.py
复制代码
/home/aistudio
复制代码

3.开始训练

# 开始训练
%cd ~/PaddleClas/

!python3 tools/train.py \
    -c ../MobileNetV1_multilabel.yaml \
    -o Arch.pretrained=True \
    -o Global.device=gpu
复制代码
# 模型评估
%cd ~/PaddleClas/

!python  tools/eval.py \
        -c  ../MobileNetV1_multilabel.yaml \
        -o Global.pretrained_model=./output/MobileNetV1/best_model
复制代码
/home/aistudio/PaddleClas
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
[2022/06/20 22:27:08] ppcls INFO: 
===========================================================
==        PaddleClas is powered by PaddlePaddle !        ==
===========================================================
==                                                       ==
==   For more info please go to the following website.   ==
==                                                       ==
==       https://github.com/PaddlePaddle/PaddleClas      ==
===========================================================

[2022/06/20 22:27:08] ppcls INFO: Arch : 
[2022/06/20 22:27:08] ppcls INFO:     class_num : 7
[2022/06/20 22:27:08] ppcls INFO:     name : MobileNetV1
[2022/06/20 22:27:08] ppcls INFO:     pretrained : True
[2022/06/20 22:27:08] ppcls INFO: DataLoader : 
[2022/06/20 22:27:08] ppcls INFO:     Eval : 
[2022/06/20 22:27:08] ppcls INFO:         dataset : 
[2022/06/20 22:27:08] ppcls INFO:             cls_label_path : ../eval_train.txt
[2022/06/20 22:27:08] ppcls INFO:             image_root : ../
[2022/06/20 22:27:08] ppcls INFO:             name : MultiLabelDataset
[2022/06/20 22:27:08] ppcls INFO:             transform_ops : 
[2022/06/20 22:27:08] ppcls INFO:                 DecodeImage : 
[2022/06/20 22:27:08] ppcls INFO:                     channel_first : False
[2022/06/20 22:27:08] ppcls INFO:                     to_rgb : True
[2022/06/20 22:27:08] ppcls INFO:                 ResizeImage : 
[2022/06/20 22:27:08] ppcls INFO:                     resize_short : 256
[2022/06/20 22:27:08] ppcls INFO:                 CropImage : 
[2022/06/20 22:27:08] ppcls INFO:                     size : 224
[2022/06/20 22:27:08] ppcls INFO:                 NormalizeImage : 
[2022/06/20 22:27:08] ppcls INFO:                     mean : [0.485, 0.456, 0.406]
[2022/06/20 22:27:08] ppcls INFO:                     order : 
[2022/06/20 22:27:08] ppcls INFO:                     scale : 1.0/255.0
[2022/06/20 22:27:08] ppcls INFO:                     std : [0.229, 0.224, 0.225]
[2022/06/20 22:27:08] ppcls INFO:         loader : 
[2022/06/20 22:27:08] ppcls INFO:             num_workers : 1
[2022/06/20 22:27:08] ppcls INFO:             use_shared_memory : True
[2022/06/20 22:27:08] ppcls INFO:         sampler : 
[2022/06/20 22:27:08] ppcls INFO:             batch_size : 256
[2022/06/20 22:27:08] ppcls INFO:             drop_last : False
[2022/06/20 22:27:08] ppcls INFO:             name : DistributedBatchSampler
[2022/06/20 22:27:08] ppcls INFO:             shuffle : False
[2022/06/20 22:27:08] ppcls INFO:     Train : 
[2022/06/20 22:27:08] ppcls INFO:         dataset : 
[2022/06/20 22:27:08] ppcls INFO:             cls_label_path : ../train_train.txt
[2022/06/20 22:27:08] ppcls INFO:             image_root : ../
[2022/06/20 22:27:08] ppcls INFO:             name : MultiLabelDataset
[2022/06/20 22:27:08] ppcls INFO:             transform_ops : 
[2022/06/20 22:27:08] ppcls INFO:                 DecodeImage : 
[2022/06/20 22:27:08] ppcls INFO:                     channel_first : False
[2022/06/20 22:27:08] ppcls INFO:                     to_rgb : True
[2022/06/20 22:27:08] ppcls INFO:                 RandCropImage : 
[2022/06/20 22:27:08] ppcls INFO:                     size : 224
[2022/06/20 22:27:08] ppcls INFO:                 RandFlipImage : 
[2022/06/20 22:27:08] ppcls INFO:                     flip_code : 1
[2022/06/20 22:27:08] ppcls INFO:                 NormalizeImage : 
[2022/06/20 22:27:08] ppcls INFO:                     mean : [0.485, 0.456, 0.406]
[2022/06/20 22:27:08] ppcls INFO:                     order : 
[2022/06/20 22:27:08] ppcls INFO:                     scale : 1.0/255.0
[2022/06/20 22:27:08] ppcls INFO:                     std : [0.229, 0.224, 0.225]
[2022/06/20 22:27:08] ppcls INFO:         loader : 
[2022/06/20 22:27:08] ppcls INFO:             num_workers : 1
[2022/06/20 22:27:08] ppcls INFO:             use_shared_memory : True
[2022/06/20 22:27:08] ppcls INFO:         sampler : 
[2022/06/20 22:27:08] ppcls INFO:             batch_size : 64
[2022/06/20 22:27:08] ppcls INFO:             drop_last : False
[2022/06/20 22:27:08] ppcls INFO:             name : DistributedBatchSampler
[2022/06/20 22:27:08] ppcls INFO:             shuffle : True
[2022/06/20 22:27:08] ppcls INFO: Global : 
[2022/06/20 22:27:08] ppcls INFO:     checkpoints : None
[2022/06/20 22:27:08] ppcls INFO:     device : gpu
[2022/06/20 22:27:08] ppcls INFO:     epochs : 10
[2022/06/20 22:27:08] ppcls INFO:     eval_during_train : True
[2022/06/20 22:27:08] ppcls INFO:     eval_interval : 1
[2022/06/20 22:27:08] ppcls INFO:     image_shape : [3, 224, 224]
[2022/06/20 22:27:08] ppcls INFO:     output_dir : ./output/
[2022/06/20 22:27:08] ppcls INFO:     pretrained_model : ./output/MobileNetV1/best_model
[2022/06/20 22:27:08] ppcls INFO:     print_batch_step : 10
[2022/06/20 22:27:08] ppcls INFO:     save_inference_dir : ./inference
[2022/06/20 22:27:08] ppcls INFO:     save_interval : 1
[2022/06/20 22:27:08] ppcls INFO:     use_multilabel : True
[2022/06/20 22:27:08] ppcls INFO:     use_visualdl : False
[2022/06/20 22:27:08] ppcls INFO: Infer : 
[2022/06/20 22:27:08] ppcls INFO:     PostProcess : 
[2022/06/20 22:27:08] ppcls INFO:         class_id_map_file : None
[2022/06/20 22:27:08] ppcls INFO:         name : MultiLabelTopk
[2022/06/20 22:27:08] ppcls INFO:         topk : 5
[2022/06/20 22:27:08] ppcls INFO:     batch_size : 10
[2022/06/20 22:27:08] ppcls INFO:     infer_imgs : ./deploy/images/0517_2715693311.jpg
[2022/06/20 22:27:08] ppcls INFO:     transforms : 
[2022/06/20 22:27:08] ppcls INFO:         DecodeImage : 
[2022/06/20 22:27:08] ppcls INFO:             channel_first : False
[2022/06/20 22:27:08] ppcls INFO:             to_rgb : True
[2022/06/20 22:27:08] ppcls INFO:         ResizeImage : 
[2022/06/20 22:27:08] ppcls INFO:             resize_short : 256
[2022/06/20 22:27:08] ppcls INFO:         CropImage : 
[2022/06/20 22:27:08] ppcls INFO:             size : 224
[2022/06/20 22:27:08] ppcls INFO:         NormalizeImage : 
[2022/06/20 22:27:08] ppcls INFO:             mean : [0.485, 0.456, 0.406]
[2022/06/20 22:27:08] ppcls INFO:             order : 
[2022/06/20 22:27:08] ppcls INFO:             scale : 1.0/255.0
[2022/06/20 22:27:08] ppcls INFO:             std : [0.229, 0.224, 0.225]
[2022/06/20 22:27:08] ppcls INFO:         ToCHWImage : None
[2022/06/20 22:27:08] ppcls INFO: Loss : 
[2022/06/20 22:27:08] ppcls INFO:     Eval : 
[2022/06/20 22:27:08] ppcls INFO:         MultiLabelLoss : 
[2022/06/20 22:27:08] ppcls INFO:             weight : 1.0
[2022/06/20 22:27:08] ppcls INFO:     Train : 
[2022/06/20 22:27:08] ppcls INFO:         MultiLabelLoss : 
[2022/06/20 22:27:08] ppcls INFO:             weight : 1.0
[2022/06/20 22:27:08] ppcls INFO: Metric : 
[2022/06/20 22:27:08] ppcls INFO:     Eval : 
[2022/06/20 22:27:08] ppcls INFO:         HammingDistance : None
[2022/06/20 22:27:08] ppcls INFO:         AccuracyScore : None
[2022/06/20 22:27:08] ppcls INFO:     Train : 
[2022/06/20 22:27:08] ppcls INFO:         HammingDistance : None
[2022/06/20 22:27:08] ppcls INFO:         AccuracyScore : None
[2022/06/20 22:27:08] ppcls INFO: Optimizer : 
[2022/06/20 22:27:08] ppcls INFO:     lr : 
[2022/06/20 22:27:08] ppcls INFO:         learning_rate : 0.1
[2022/06/20 22:27:08] ppcls INFO:         name : Cosine
[2022/06/20 22:27:08] ppcls INFO:     momentum : 0.9
[2022/06/20 22:27:08] ppcls INFO:     name : Momentum
[2022/06/20 22:27:08] ppcls INFO:     regularizer : 
[2022/06/20 22:27:08] ppcls INFO:         coeff : 4e-05
[2022/06/20 22:27:08] ppcls INFO:         name : L2
[2022/06/20 22:27:08] ppcls INFO: train with paddle 2.2.2 and device CUDAPlace(0)
W0620 22:27:08.304502  6225 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0620 22:27:08.310420  6225 device_context.cc:465] device: 0, cuDNN Version: 7.6.
[2022/06/20 22:27:11] ppcls INFO: unique_endpoints {''}
[2022/06/20 22:27:11] ppcls INFO: Found /home/aistudio/.paddleclas/weights/MobileNetV1_pretrained.pdparams
[2022/06/20 22:27:20] ppcls INFO: [Eval][Epoch 0][Iter: 0/3]MultiLabelLoss: 0.29409, loss: 0.29409, HammingDistance: 0.12612, AccuracyScore: 0.87388, batch_cost: 9.16829s, reader_cost: 9.08039, ips: 27.92233 images/sec
[2022/06/20 22:27:28] ppcls INFO: [Eval][Epoch 0][Avg]MultiLabelLoss: 0.28746, loss: 0.28746, HammingDistance: 0.12060, AccuracyScore: 0.87940
复制代码

五、预测

1.修改topk.py

编辑 PaddleClas/ppcls/data/postprocess/topk.py,输出所有分类预测可能。

    def __call__(self, x, file_names=None, multilabel=False):
        if isinstance(x, dict):
            x = x['logits']
        assert isinstance(x, paddle.Tensor)
        if file_names is not None:
            assert x.shape[0] == len(file_names)
        x = F.softmax(x, axis=-1) if not multilabel else F.sigmoid(x)
        x = x.numpy()
        y = []
        for idx, probs in enumerate(x):
            index = probs.argsort(axis=0)[-self.topk:][::-1].astype(
                "int32") if not multilabel else np.where(
                    probs >= 0.0)[0].astype("int32") ## >=0.0
            clas_id_list = []
            score_list = []
            label_name_list = []
            for i in index:
                clas_id_list.append(i.item())
                score_list.append(probs[i].item())
                if self.class_id_map is not None:
                    label_name_list.append(self.class_id_map[i.item()])
            result = {
                "class_ids": clas_id_list,
                "scores": np.around(
                    score_list, decimals=5).tolist(),
            }
            if file_names is not None:
                result["file_name"] = file_names[idx]
            if label_name_list is not None:
                result["label_names"] = label_name_list          
            
            y.append(result)
        return y

复制代码

PaddleClas/ppcls/engine/engine.py

!cp -f ~/topk.py ~/PaddleClas/ppcls/data/postprocess/topk.py
复制代码

2.预测结果保存

修改 PaddleClas/ppcls/engine/engine.py ,增加文件输出,以及输出格式

 @paddle.no_grad()
    def infer(self):
        assert self.mode == "infer" and self.eval_mode == "classification"
        total_trainer = dist.get_world_size()
        local_rank = dist.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        batch_size = self.config["Infer"]["batch_size"]
        self.model.eval()
        batch_data = []
        image_file_list = []

        ## open file
        with open('myresult.csv',mode='w') as ff:
            period_dic = { 0:'Dawn', 1:'Dusk', 2:'Morning', 3:'Afternoon'}
            weather_dic =  {0:'Cloudy', 1:'Rainy', 2:'Sunny'}

            
            for idx, image_file in enumerate(image_list):
                with open(image_file, 'rb') as f:
                    x = f.read()
                for process in self.preprocess_func:
                    x = process(x)
                batch_data.append(x)
                image_file_list.append(image_file)
                if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                    batch_tensor = paddle.to_tensor(batch_data)

                    if self.amp and self.amp_eval:
                        with paddle.amp.auto_cast(
                                custom_black_list={
                                    "flatten_contiguous_range", "greater_than"
                                },
                                level=self.amp_level):
                            out = self.model(batch_tensor)
                    else:
                        out = self.model(batch_tensor)
                    
                    if isinstance(out, list):
                        out = out[0]
                    if isinstance(out, dict) and "Student" in out:
                        out = out["Student"]
                    if isinstance(out, dict) and "logits" in out:
                        out = out["logits"]
                    if isinstance(out, dict) and "output" in out:
                        out = out["output"]
                    result = self.postprocess_func(out, image_file_list)
                    # print(len(result))
                    for i in range(len(result)):
                        # print(result)
                        # print(result[i]['scores'])
                        period=period_dic[result[i]['scores'][0:4].index(max(result[i]['scores'][0:4]))]
                        weather=weather_dic[result[i]['scores'][4:7].index(max(result[i]['scores'][4:7]))]
                        print(result[i]['file_name'], period, weather)
                        ff.writelines(result[i]['file_name']+','+ period+','+ weather+'\n')
                    batch_data.clear()
                    image_file_list.clear()
复制代码
!cp -f  ~/engine.py ~/PaddleClas/ppcls/engine/engine.py
复制代码

3.开始预测

修改 ../MobileNetV1_multilabel.yaml ,明确预测目录。

Infer:
  infer_imgs: /home/aistudio/test_images
  batch_size: 10
  transforms:
    - DecodeImage:
        to_rgb: True
        channel_first: False
    - ResizeImage:
        resize_short: 256
    - CropImage:
        size: 224
    - NormalizeImage:
        scale: 1.0/255.0
        mean: [0.485, 0.456, 0.406]
        std: [0.229, 0.224, 0.225]
        order: ''
    - ToCHWImage:
  PostProcess:
    name: MultiLabelTopk
    topk: 5
    class_id_map_file: None
复制代码
# 开始预测
%cd ~/PaddleClas/

!python3 tools/infer.py \
    -c ../MobileNetV1_multilabel.yaml \
    -o Arch.pretrained="./output/MobileNetV1/best_model"
复制代码
!head myresult.csv
复制代码
/home/aistudio/test_images/00008.jpg,Morning,Cloudy
/home/aistudio/test_images/00018.jpg,Morning,Cloudy
/home/aistudio/test_images/00020.jpg,Afternoon,Cloudy
/home/aistudio/test_images/00022.jpg,Morning,Sunny
/home/aistudio/test_images/00035.jpg,Afternoon,Cloudy
/home/aistudio/test_images/00065.jpg,Morning,Rainy
/home/aistudio/test_images/00074.jpg,Morning,Cloudy
/home/aistudio/test_images/00079.jpg,Morning,Cloudy
/home/aistudio/test_images/00080.jpg,Morning,Cloudy
/home/aistudio/test_images/00082.jpg,Afternoon,Cloudy
复制代码

六、提交

1.格式转换

import pandas as pd
import json
data= pd.read_csv('myresult.csv', header=None, sep=',')
annotations_list=[]
for i in range(len(data)):
    temp={}
    temp["filename"]="test_images"+"\\"+data.loc[i][0].split('/')[-1]
    temp["period"]=data.loc[i][1]
    temp["weather"]=data.loc[i][2]
    annotations_list.append(temp)
myresult={}
myresult["annotations"]=annotations_list

with open('result.json','w') as f:
    json.dump(myresult, f)
    print("结果生成完毕")
复制代码
结果生成完毕
复制代码

2.提交并获取成绩

下载result.json并提交,即可获得成绩

3.建议

该项目仅跑了10 epoch,建议可增大epoch,以获取更好的图像分类效果。

猜你喜欢

转载自juejin.im/post/7125787896669601829