Tensorflow2.0/Keras Eager Execution Implementation of "Kaggle Histopathologic Cancer Detection Competition"

Kaggle project address: https://www.kaggle.com/c/histopathologic-cancer-detection/overview

This article records an implementation using Tensorflow2.0/Keras Eager Execution, and the data preprocessing adopts the Tensorflow standard Dataset method:

Other implementation references:

Pyorch Implementation of Kaggle Histopathology Cancel Detection

Keras Implementation of Kaggle Histopathologic Cancer Detection

Keras/Generator Implementation of Kaggle Histopathologic Cancer Detection

# -*- coding: utf-8 -*-
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE 


# tf.enable_eager_execution()

import numpy as np
import os,sys,csv
import cv2 as cv
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import myimageutil as iu



"""
====================================================================================
<<1.初步了解掌握数据的情况>>
====================================================================================

用pandas简单处理一下CSV并画出来看一下

这里我借用了kaggle的这篇kernel里的plot的代码,有兴趣的童鞋可以读一下,
https://www.kaggle.com/qitvision/a-complete-ml-pipeline-fast-ai

"""
ROOT_PATH = 'D:/ai_data/histopathologic-cancer-detection'
CSV_PATH = 'D:/ai_data/histopathologic-cancer-detection/train_labels.csv'
TRAIN_PATH = 'D:/ai_data/histopathologic-cancer-detection/train'
TEST_PATH = 'D:/ai_data/histopathologic-cancer-detection/test'

print(">>>看一下根目录下有哪些东西:")
print(os.listdir(ROOT_PATH))

df = pd.read_csv(CSV_PATH)  #pandas里的数据集叫dataframe,和scala里的一样,我们简称df

# 接下来我们来看一下数据的情况
print(">>>这个数据集的大小:")
print(df.shape)

print(">>>这个数据集的样本分布:")
print(df['label'].value_counts())

print(">>>看一下数据:")
print(df.head())

# 这边我想说明一下,之前我们的第一篇walkthrough里是直接从csv中获得文件列表的,这边最好检查一下列表里的文件和文件夹里的是不是一一对应
print(">>>list一下训练图片文件夹里的图片:")
from glob import glob
train_file_paths = glob(TRAIN_PATH + '/*.tif')
test_file_paths  = glob(TEST_PATH + '/*.tif')
print("train_file_paths size:", len(train_file_paths)) 
print("test_file_paths size:", len(test_file_paths))

import re
def check_valid():
    assert len(train_file_paths) == len(df['id']),'图片数量不一致'
    ids_from_filepath = list(map(lambda filepath:''.join(re.findall(r'[a-z0-9]{40}',filepath)), train_file_paths))
    dif = list(set(ids_from_filepath)^set(df['id'])) #求两个list的差集,如果差集为0,那说明两个list相等
    if len(dif) == 0:
        print("文件名匹配正常")
    else:
        print("匹配异常,下列文件名有差异:")
        print(dif)
        exit()
check_valid()

# print(">>>数据没问题的话接下来看一下正负数据样例的图片:")
# iu.plotSamples(df,TRAIN_PATH) #要注意本次的图片数据是使用中间32X32像素的内容为基准进行标注的,所以画图把中间一块标注出来了,但实际分类的时候不一定要把中间裁出来

# print(">>>进入正题,我们拆分一下数据,把训练数据分成训练和测试2部分,比例为9:1")
train, val = train_test_split(train_file_paths, test_size=0.1, shuffle=True)

id_label_map = {
    
    k:v for k,v in zip(df.id.values, df.label.values)}

def get_paths_labels(pathlist):
    ids = []
    labels = []
    for item in pathlist:
        id = ''.join(re.findall(r'[a-z0-9]{40}',item))
        label = id_label_map[id]
        ids.append(item)
        labels.append(label)
    return ids,labels

train_paths,train_labels = get_paths_labels(train)
val_paths,val_labels = get_paths_labels(val)

# exit()

"""
====================================================================================
<<2.图片处理和扩增>>
====================================================================================

图片处理主要是要匹配CNN的输入大小,扩增是为了降低过拟合风险
无论是图片处理还是扩增都有太多方法了,比较常用的imageaug或者tf.image进行数据扩增,其实openCV什么都能干
imgaug堪称python里最强图片扩增工具,方法多,叠加方便,一个图像数据扩增100倍轻轻松松:
https://github.com/aleju/imgaug

使用tensorflow自带的tf.image进行augmentation,特点是能结合tf.dataset无缝使用:
http://androidkt.com/tensorflow-image-augmentation-using-tf-image/

这边我们使用imgaug进行处理,最后生成tf.dataset进行训练
"""

BATCH_SIZE = 32



#我们还是使用之前的方法读取tif文件,tensorflow本身不支持读取tif,所以只能用py_func调用外部函数来读取
def image_aug_cv(filepath,label):
    image_decoded = cv.imread(filepath.numpy().decode(), 1)

    image_resized = tf.image.resize(image_decoded, [224, 224])
    return aug_image(image_resized), label

def aug_image(image):
    return image / 255.0

def prepare_train_ds(filepaths,labels):
    global BATCH_SIZE
    paths_ds = tf.data.Dataset.from_tensor_slices(filepaths)
    labels_ds = tf.data.Dataset.from_tensor_slices(labels)
    paths_labels_ds = tf.data.Dataset.zip((paths_ds,labels_ds))
    images_labels_ds = paths_labels_ds.shuffle(buffer_size=300000)
    images_labels_ds = images_labels_ds.map(lambda filename,label : tf.py_function( func=image_aug_cv,
                                                                                    inp=[filename,label],
                                                                                    Tout=[tf.float32,tf.float32]),
                                                                                    num_parallel_calls=AUTOTUNE)
    # images_labels_ds = images_labels_ds.repeat()
    images_labels_ds = images_labels_ds.batch(BATCH_SIZE)
    images_labels_ds = images_labels_ds.prefetch(buffer_size = 200)

    return images_labels_ds


train_ds = prepare_train_ds(train_paths,np.asarray(train_labels).astype('float32').reshape((-1,1)))
val_ds = prepare_train_ds(val_paths,np.asarray(val_labels).astype('float32').reshape((-1,1)))



"""
====================================================================================
<<3.建模>>
====================================================================================

使用keras和比较新的NASnet来建立模型,方法和walkthrough里的一摸一样

"""
from tensorflow.keras.layers import concatenate, Activation, GlobalAveragePooling2D, Flatten
from tensorflow.keras.layers import Dense, Input, Dropout, MaxPooling2D, Concatenate, GlobalMaxPooling2D, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.applications.nasnet import NASNetMobile
# from tensorflow.keras.optimizers import Adam

nasnet = NASNetMobile(include_top=False, input_shape=(224, 224, 3))
x1 = GlobalMaxPooling2D()(nasnet.output)
x2 = GlobalAveragePooling2D()(nasnet.output)
x3 = Flatten()(nasnet.output)
out = Concatenate(axis=-1)([x1, x2, x3])
out = Dropout(0.5)(out)
predictions = Dense(1, activation="sigmoid",name = 'predictions')(out)
model = Model(inputs=nasnet.input, outputs=predictions)

model.trainable = True
# for layer in model.layers[:-3]:
#   layer.trainable = False

optimizer = tf.keras.optimizers.Adam(lr = 0.0001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
loss_func = tf.keras.losses.BinaryCrossentropy()

# model.summary()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')

val_loss = tf.keras.metrics.Mean(name='val_loss')
val_accuracy = tf.keras.metrics.BinaryAccuracy(name='val_accuracy')

"""
====================================================================================
<<3.训练>>
====================================================================================

这边使用官方标准的tensorflow 2.0 Eager Execution的训练方法来训练网络

"""

# @tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_func(labels, predictions)
#   print("train loss:"+str(loss.numpy()))
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  train_loss(loss)
  train_accuracy(labels, predictions)
  

# @tf.function
def val_step(images, labels):
  predictions = model(images)
  loss = loss_func(labels, predictions)
#   print("val loss:"+str(loss.numpy()))
  val_loss(loss)
  val_accuracy(labels, predictions)
  

EPOCHS = 20

import datetime

for epoch in range(EPOCHS):
  for images, labels in train_ds:
    train_step(images, labels)
    

  for val_images, val_labels in val_ds:
    val_step(val_images, val_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, val Loss: {}, val Accuracy: {}'
  print (template.format(epoch+1, 
                         train_loss.result(),
                         train_accuracy.result()*100,
                         val_loss.result(),
                         val_accuracy.result()*100))
  print(datetime.datetime.now())

Guess you like

Origin blog.csdn.net/catscanner/article/details/109177649