【AI研习社分类相关竞赛】美丽城市--垃圾分类识别

美丽城市–垃圾分类识别

0. 利用使用tensroflow-slim训练自己的图像分类

见往期博客分享:使用tensroflow-slim训练自己的图像分类

1. 下载官方数据集进行处理
  • 数据集处理代码如下,所有路径均为绝对路径
import csv
import os
import cv2

filepath = '/*/classgabbage/data/train.csv'
file_pathname = '/*/classgabbage/data/train'

def read_path(file_pathname, special_filename, clas):
    #遍历该目录下的所有图片文件
    for filename in os.listdir(file_pathname):
        if filename == special_filename:
            img = cv2.imread(file_pathname+'/'+filename)
            #####save figure
            cv2.imwrite('/*/classgabbage/data/class'+"/"+clas+"/"+filename,img)       



with open(filepath) as f:
    reader = csv.reader(f)
    for row in reader:
        #print(reader.line_num, row)
        #print(row[0])
        if row[1] == 'cardboard':
            clas = 'cardboard'
            special_filename = row[0]
            read_path(file_pathname, special_filename, clas)
            print("finish" + row[0])
2. 依照第0步骤,完成所有流程
  • 在此选用的模型是Inception-ResNet-v2;
  • 训练脚本为:
CUDA_VISIBLE_DEVICES=0,2 python3 train_image_classifier.py \
 --train_dir=/*/slim/class_gabbage/train_eval/training \
 --dataset_dir=/*/slim/class_gabbage/data \
 --dataset_name=gabbage \
 --dataset_split_name=train \
 --model_name=inception_resnet_v2 \
 --checkpoint_path=/*/models-master/research/slim/class_gabbage/inception_resnet_v2_2016_08_30.ckpt \
 --checkpoint_exclude_scopes=InceptionResnetV2/AuxLogits,InceptionResnetV2/Logits \
 --trainable_scopes=InceptionResnetV2/AuxLogits,InceptionResnetV2/Logits \
 --max_number_of_steps=5000 \
 --learning_rate=0.004 \
 --save_interval_secs=60 \
 --save_summaries_secs=60 \
 --log_every_n_steps=2 \
 --train_image_size=300 \
 --num_epochs_per_decay=200 \
 --batch_size=32 \
 --clone_on_cpu=False \
 --num_clones=2  \
 --optimizer=rmsprop 

  • validation脚本为:
CUDA_VISIBLE_DEVICES=1 python3 eval_image_classifier.py \
  --checkpoint_path=/*/class_gabbage/train_eval/training \
  --eval_dir=/*/slim/class_gabbage/train_eval/eval \
  --dataset_name=gabbage \
  --dataset_split_name=validation \
  --dataset_dir=/*/research/slim/class_gabbage/data \
  --model_name=inception_resnet_v2

剩余冻成模型等步骤均在第0步骤的链接中有所体现

3. 最后得到结果csv文件进行比赛提交
import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
from IPython import display
import csv
import os


dataset_dir='/*/research/slim/class_gabbage/data'
model_dir ='/*/models-master/research/slim/class_gabbage/model/inception_resnet_v2.pb'
file_pathname = '/*/workspace/classgabbage/data/test1'
# from IPython import display

f= open('/*/models-master/research/slim/result1.csv','a+',newline='')

gd = tf.GraphDef.FromString(open(model_dir, 'rb').read())
inp, predictions = tf.import_graph_def(gd,  return_elements = ['input:0','InceptionResnetV2/Logits/Predictions:0'])

with tf.Session(graph=inp.graph):
    for filename in os.listdir(file_pathname):
        print(filename)
        img = cv2.imread(file_pathname+'/'+filename)
        width = 300
        height = 300
        dim = (width, height)

        # resize image to [-1,1] Maps pixel values to the range [-1, 1]
        resized = (cv2.resize(img, dim)).astype(np.float) / 128 - 1
        image_np_expanded = np.expand_dims(resized, axis=0)
        x = predictions.eval(feed_dict={
    
    inp: image_np_expanded})

        label_map = dataset_utils.read_label_file(dataset_dir)
        print("Top 1 Prediction: ",label_map[x.argmax()])
        writer = csv.writer(f)
        row=[filename[:-4],label_map[x.argmax()]]
        writer.writerow(row)

猜你喜欢

转载自blog.csdn.net/qq_43348528/article/details/108033513
今日推荐