Real-time bokeh with deep learning

A real-time bokeh solution based on deep learning requires multiple steps. Here we provide a complete example using TensorFlow and OpenCV. To simplify matters, we will use U-Net for image segmentation. This example is divided into the following sections:

1. Data preprocessing
2. Build and train U-Net model
3. Apply the model for real-time background blur

### Part 1: Data Preprocessing

This example assumes that you already have a dataset containing images and corresponding foreground (subject) segmentation masks. You can start with existing datasets such as [COCO dataset](https://cocodataset.org/). The following code loads the image data into memory and splits it into training, validation, and test sets:

import os
import numpy as np
import cv2
from sklearn.model_selection import train_test_split

def load_data(image_dir, mask_dir, image_size=(256, 256)):
    image_files = os.listdir(image_dir)
    mask_files = os.listdir(mask_dir)
    
    images = []
    masks = []
    
    for img_file, mask_file in zip(image_files, mask_files):
        img = cv2.imread(os.path.join(image_dir, img_file))
        mask = cv2.imread(os.path.join(mask_dir, mask_file), cv2.IMREAD_GRAYSCALE)
        
        img = cv2.resize(img, image_size)
        mask = cv2.resize(mask, image_size)
        
        images.append(img)
        masks.append(mask)
    
    images = np.array(images, dtype=np.float32) / 255.0
    masks = np.array(masks, dtype=np.float32) / 255.0
    masks = np.expand_dims(masks, axis=-1)
    
    return images, masks

images, masks = load_data('path/to/image/dir', 'path/to/mask/dir')

X_train, X_test, y_train, y_test = train_test_split(images, masks, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42)

### Part 2: Building and training the U-Net model

Use TensorFlow to build the U-Net model and train on the training data:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, concatenate, UpSampling2D

def build_unet(input_shape=(256, 256, 3)):
    inputs = tf.keras.Input(input_shape)

    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    drop1 = Dropout(0.5)(pool1)

    # 添加更多卷积层和上采样层,组成完整的 U-Net 结构

    up_last = UpSampling2D(size=(2, 2))(drop_last)
    merge_last = concatenate([conv1, up_last], axis=3)
    conv_last = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge_last)
    conv_last = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv_last)
    output = Conv2D(1, 1, activation='sigmoid')(conv_last)

    model = tf.keras.Model(inputs=inputs, outputs=output)
    
    return model

unet = build_unet()

unet.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
             loss='binary_crossentropy',
             metrics=['accuracy'])

unet.fit(X_train, y_train, batch_size=16, epochs=50, validation_data=(X_val, y_val))

### Part 3: Real-time background blur

Now, we apply the trained U-Net model to the images captured by the camera in real time to achieve the background blur effect:

def blur_background(frame, model, blur_intensity=21):
    original_shape = frame.shape[:2]
    resized_frame = cv2.resize(frame, (256, 256))

    mask = model.predict(np.expand_dims(resized_frame, axis=0))[0]
    mask = (mask > 0.5).astype(np.uint8)
    mask = cv2.resize(mask, original_shape[::-1])

    blurred_frame = cv2.GaussianBlur(frame, (blur_intensity, blur_intensity), 0)
    result = np.where(mask[..., np.newaxis] == 1, frame, blurred_frame)

    return result

cap = cv2.VideoCapture(0)

while True:
    ret, frame = cap.read()
    if not ret:
        break

    result_frame = blur_background(frame, unet)
    cv2.imshow('Real-time background blur', result_frame)

    key = cv2.waitKey(1)
    if key == 27:  # Press ESC to exit
        break

cap.release()
cv2.destroyAllWindows()

This example should enable a real-time bokeh effect. However, to improve performance, optimizations may be required, such as using smaller models, using hardware accelerators for inference, etc.

Guess you like

Origin blog.csdn.net/a871923942/article/details/130945686