Use keras for image segmentation

Let’s take a look at the effect first, although it’s not very good

 Directly on the code:

import tensorflow as tf
import matplotlib.pyplot as plt
import  os
import time
import numpy as np
import io
import PIL
from IPython.display import clear_output
import cv2
import sys
sys.path.append("/opt/LIP/examples")
from tensorflow_examples.models.pix2pix import pix2pix

IMG_WIDTH = 128
IMG_WIDTH = 128
IM_PATH='/opt/LIP/images/'
MS_PATH='/opt/LIP/masks/'
OUTPUT_CHANNELS = 20
EPOCHS = 20
BATCH_SIZE=256

def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()


def load_input(image_file):
    print(image_file)
    #print(str(image_file))
    img=tf.io.read_file(image_file)
    img=tf.image.decode_jpeg(img, channels=3)
    image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
    image=(image / 127.5) - 1#normalizing the images to [-1, 1]
    #image=image /255.0
    #image=image.reshape()
    #image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
    return image
def load_mask(image_file):
    img=tf.io.read_file(image_file)
    img=tf.image.decode_png(img, channels=1)
    image=tf.image.resize(img, [IMG_WIDTH, IMG_WIDTH])
    #image -= 1
    #image=image.reshape(1,IMG_WIDTH,IMG_WIDTH,3)
    #image=tf.reshape(image,[1,IMG_WIDTH,IMG_WIDTH,3])
    return image
def load(image_file,mask_file):
    _in=load_input(image_file)
    _mask=load_mask(mask_file)
    return _in,_mask

train_image_path = os.path.join(IM_PATH+'train/')
train_mask_path = os.path.join(MS_PATH+'train/')
train_images = os.listdir(train_image_path)
train_masks = os.listdir(train_mask_path)
train_images.sort()
train_masks.sort()
train_ls_images=[]
train_ls_masks=[]
for i in train_images:
    train_ls_images.append(IM_PATH+'train/'+i)
for j in train_masks:
    train_ls_masks.append(MS_PATH+'train/'+j)
train_images = tf.constant(train_ls_images)
train_labels = tf.constant(train_ls_masks)
train_data = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_data = train_data.map(load, num_parallel_calls=4)
train_batched_data = train_data.batch(BATCH_SIZE)

val_image_path = os.path.join(IM_PATH+'val/')
val_mask_path = os.path.join(MS_PATH+'val/')
val_images = os.listdir(val_image_path)
val_masks = os.listdir(val_mask_path)
val_images.sort()
val_masks.sort()
val_ls_images=[]
val_ls_masks=[]
for i in val_images:
    val_ls_images.append(IM_PATH+'val/'+i)
for j in val_masks:
    val_ls_masks.append(MS_PATH+'val/'+j)
val_images = tf.constant(val_ls_images)
val_labels = tf.constant(val_ls_masks)
val_data = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_data = val_data.map(load, num_parallel_calls=4)
val_batched_data = val_data.batch(BATCH_SIZE)


base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)
# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
layers = [base_model.get_layer(name).output for name in layer_names]
# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=layers)
down_stack.trainable = False
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]

def unet_model(output_channels):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])
    x = inputs
    # Downsampling through the model
    skips = down_stack(x)
    x = skips[-1]
    skips = reversed(skips[:-1])
    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])
    last = tf.keras.layers.Conv2DTranspose(
        output_channels, 3, strides=2,
        padding='same')  #64x64 -> 128x128
    x = last(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

model = unet_model(OUTPUT_CHANNELS)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
tf.keras.utils.plot_model(model, show_shapes=True)


for image, mask in train_data.take(1):
    sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])

show_predictions()


class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

model_history = model.fit(train_batched_data, epochs=EPOCHS,
                          #validation_data=val_batched_data,
                          callbacks=[DisplayCallback()]
                         )

Description:

The preparation work required before executing this code:

(1) Download the data set. This time the data set uses the LIP data set, but this time I adjusted the downloaded data location of the data set to facilitate data reading

(2) Download the tensorflow_examples file and place it in the specified path (I placed it in /opt/LIP/examples)

(3) Be sure to use the tf-2.3 version, the following versions are prone to errors, especially those below 2.0

Finally saved the model file, the size is about 60 megabytes

The main reference for this code: https://tensorflow.google.cn/tutorials/images/segmentation

Accuracy:

The effect is not very good, but it can still be used. After all, it has been trained for 20 rounds

For the model file MobileNetV2, if the code is slow to pull and download, you can download it in advance and put it under this path: /root/.keras/models/

You can look at:

 Because the path of this automatic download is downloaded to /root/.keras/models/here

Guess you like

Origin blog.csdn.net/zhou_438/article/details/108802058
Recommended