NSFW Image Category

NSFW means Not Safe for Work (“Not Safe (or Suitable) For Work;”). In this article, I'll show you how to create an image classification model that detects NSFW images.

data set

Due to the nature of the dataset, we cannot get all the images from some dataset's websites (such as Kaggle, etc.).

But we found a github library dedicated to grabbing this type of image, so we can use it directly. After cloning the project you can run the code below to create folders and download each image into its specific folder.

 folders = ['drawings','hentai','neutral','porn','sexy']
 urls = ['urls_drawings.txt','urls_hentai.txt','urls_neutral.txt','urls_porn.txt','urls_sexy.txt']
 names = ['d','h','n','p','s']
 
 for i,j,k in zip(folders,urls,names):
     try:
         #Specify the path of the  folder that has to be made
         folder_path = os.path.join('your directory',i)
         os.mkdir(folder_path)
     except:
         pass
     #setup the path of url text file
     url_path = os.path.join('Datasets_Urls',j)
     my_file = open(url_path, "r")
     data = my_file.read()
     #create a list with all urls
     data_into_list = data.split("\n")
     my_file.close()
     icount = 0
     for ii in data_into_list:
         try:
             #create a unique image names for each images
             image_name = 'image'+str(icount)+str(k)+'.png'
             image_path = os.path.join(folder_path,image_name)
             #download it using the library
             urllib.request.urlretrieve(ii, image_path)
             icount+=1
         except Exception as e:
             pass
         #this below code is done to make the count of the image same for all the data 
         #you can use a big number if you are building a more complex model or if you have a good system
         if icount == 2000:
             break

Here the folder variable represents the name of the class, the urls variable is used to get the URL text file (it can be changed according to the text file name), and the name variable is used to create a unique name for each image.

The above code will download 2000 images for each class, you can edit the last "if" condition to change the number of downloaded images.

data preparation

The folder we download may contain other types of files, so first we must delete the unwanted types of files.

 image_exts = ['jpeg','.jpg','bmp','png']
 path_list = ['drawings','hentai','neutral','porn','sexy']
 cwd = os.getcwd()
 def remove_other_images(path_list):
     for ii in path_list:
         data_dir = os.path.join(cwd,'DataSet',ii)
         for image in os.listdir(os.path.join(data_dir)):
             image_path = os.path.join(data_dir,image_class,image)
             try:
                 img = cv2.imread(image_path)
                 tip = imghdr.what(image_path)
                 if tip not in image_exts:
                     print('Image not in ext list {}'.format(image_path))
                     os.remove(image_path)
             except Exception as e:
                 print("Issue with image {}".format(image_path))
 remove_other_images(path_list)

The above code removes images whose extensions are not in the specified format.

Also images may contain many duplicate images, so we have to remove duplicate images from each folder.

 cwd = os.getcwd()
 path_list = ['drawings','hentai','neutral','porn','sexy']
 def remove_dup_images(path_list):
     for ii in path_list:
         os.chdir(os.path.join(cwd,'DataSet',ii))
         filelist = os.listdir()
         duplicates = []
         hash_keys = dict()
         for index, filename in enumerate(filelist):
             if os.path.isfile(filename):
                 with open(filename,'rb') as f:
                     filehash = hashlib.md5(f.read()).hexdigest()
                 if filehash not in hash_keys:
                     hash_keys[filehash] = index
                 else:
                     duplicates.append((index,hash_keys[filehash]))
             
         for index in duplicates:
             os.remove(filelist[index[0]])
             print('{} duplicates removed from {}'.format(len(duplicates),ii))
 remove_dup_images(path_list)

Here we use hashlib.md5 encoding to find duplicate images in each class.

Md5 creates a unique hash for each image, and if the hash is duplicated (duplicate image), then we add the duplicate image to a list and delete it later.

Because the TensorFlow framework is used, it is necessary to judge whether it is supported by TensorFlow, so we add a judgment here:

 import tensorflow as tf
 
 os.chdir('{data-set} directory')
 cwd = os.getcwd()
 
 for ii in path_list:
     os.chdir(os.path.join(cwd,ii))
     filelist = os.listdir()
     for image_file in filelist:
         with open(image_file, 'rb') as f:
             image_data = f.read()
 
         # Check the file format
         _, ext = os.path.splitext(image_file)
         if ext.lower() not in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
             print('Unsupported image format:', ext)
             os.remove(os.path.join(cwd,ii,image_file))            
         else:
             # Decode the image
             try:
                 image = tf.image.decode_image(image_data)
             except:
                 print(image_file)
                 print("unspported")
                 os.remove(os.path.join(cwd,ii,image_file))

The above is all the work of data preparation. After cleaning the data, we can split the data. For example split create a training, validation and testing folder and manually add images in the folder, we will use 80% for training, 10% for validation and 10% for testing.

Model

First import tensorflow

 import tensorflow as tf
 import os
 import numpy as np
 import matplotlib.pyplot as plt
 from sklearn.utils import shuffle
 import hashlib
 from imageio import imread
 import numpy as np
 from tensorflow.keras.preprocessing.image import ImageDataGenerator
 from tensorflow.keras.applications.vgg16 import VGG16
 from tensorflow.keras.applications.vgg16 import preprocess_input
 from tensorflow.keras.layers import Flatten,Dense,Input
 from tensorflow.keras.models import Model,Sequential
 from keras import optimizers

For images, the default size is set to 224,224.

 IMAGE_SIZE = [224,224]

You can use the ImageDataGenerator library for data augmentation. Data augmentation, also known as data augmentation, is to increase the size of the dataset. ImageDataGenerator creates new images based on the given parameters and uses them for training (note: when using ImageDataGenerator, the original data will not be used for training).

 train_datagen = ImageDataGenerator(
         rescale=1./255,
         preprocessing_function=preprocess_input,
         rotation_range=40,
         width_shift_range=0.2,
         height_shift_range=0.2,
         shear_range=0.2,
         zoom_range=0.2,
         horizontal_flip=True,
         fill_mode='nearest')

The same is true for the test set:

 test_datagen = ImageDataGenerator(rescale=1./255)

For demonstration, we directly use the VGG model

vgg = VGG16(input_shape=IMAGE_SIZE+[3],weights='imagenet',include_top=False

Then freeze the previous layer:

for layer in vgg.layers:
    layer.trainable = False

Finally we add our own category header:

x = Flatten()(vgg.output)
prediction = Dense(5,activation='softmax')(x)
model = Model(inputs=vgg.input, outputs=prediction)
model.summary()

The model looks like this:

train

Take a look at our training set:

train_set = train_datagen.flow_from_directory('DataSet/train',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

validation set

val_set = train_datagen.flow_from_directory('DataSet/validation',
                                              target_size=(224,224),
                                              batch_size=32,
                                              class_mode='sparse')

Use the 'sparse_categorical_crossentropy' loss, which encodes labels as integers instead of one-hot encoding.

from tensorflow.keras.metrics import MeanSquaredError
from tensorflow.keras.metrics import CategoricalAccuracy
adam = optimizers.Adam()
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=adam,
              metrics=['accuracy',MeanSquaredError(name='val_loss'),CategoricalAccuracy(name='val_accuracy')])

Then you can train:

from datetime import datetime
from keras.callbacks import ModelCheckpoint

log_dir = 'vg_log'

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir)

start = datetime.now()

history = model.fit_generator(train_set,
                              validation_data=val_set,
                              epochs=100,
                              steps_per_epoch=len(train_set)// batch_size,
                              validation_steps=len(val_set)//batch_size,
                              callbacks=[tensorboard_callback],
                             verbose=1)

duration = datetime.now() - start
print("Time taken for training is ",duration)

The model was trained 100 times. A verification accuracy of 80% was obtained. f1 score is 93%

predict

The function below will take a list of images and make predictions based on that list.

import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
def print_classes(images,model):
    classes = ['Drawing','Hentai','Neutral','Porn','Sexual']
    fig, ax = plt.subplots(ncols=len(images), figsize=(20,20))
    for idx,img in enumerate(images):
        img = mpimg.imread(img)
        resize = tf.image.resize(img,(224,224))
        result = model.predict(np.expand_dims(resize/255,0))
        result = np.argmax(result)
        if classes[result] == 'Porn':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Sexual':
            img = gaussian_filter(img, sigma=6)
        elif classes[result] == 'Hentai':
            img = gaussian_filter(img, sigma=6)
        ax[idx].imshow(img)
        ax[idx].title.set_text(classes[result])

li = ['test1.jpeg','test2.jpeg','test3.jpeg','test4.jpeg','test5.jpeg']
print_classes(li,model)

It is still possible to see the result.

Finally, the source code for this article:

https://avoid.overfit.cn/post/8f681841d02e4a8db7bcf77926e123f1

Author: Nikhil Thalappalli

Guess you like

Origin blog.csdn.net/m0_46510245/article/details/130789130