Python image processing [12] performing image denoising based on wavelet transform

0. Preface

The wavelet ( wavelets) transform is a general method for representing and analyzing multi-resolution images, which has applications in many different fields of image processing, such as image compression, denoising, etc. In this section we will learn how pywtto scikit-imageperform image denoising using the wavelet transform implemented by and .

1. Basics of wavelet transform

We start by introducing the basics of the wavelet transform. The following code transforms an input grayscale image (with different levels) using discrete wavelet transform ( Discrete Wavelet Transformation, ) and extracts approximate image and horizontal/vertical/diagonal details. We use the original image to generate three images after high-pass filtering, each image describes the local variation of brightness (details) in the original image. This is then low-pass filtered and scaled, resulting in an approximate image; this image is high-pass filtered to produce three smaller-sized detail maps, and low-pass filtered to produce the final approximate image:DWT

import numpy as np
import pywt
from pywt._doc_utils import wavedec2_keys, draw_2d_wp_basis
from skimage.filters import threshold_otsu
from skimage import img_as_float
import matplotlib.pylab as plt

x = pywt.data.ascent().astype(np.float32)
shape = x.shape

plt.rcParams.update({
    
    'font.size': 8})

max_lev = 3       # how many levels of decomposition to draw
label_levels = 3  # how many levels to explicitly label on the plots

fig, axes = plt.subplots(4, 2, figsize=[15, 35])
plt.subplots_adjust(0, 0, 1, 0.95, 0.05, 0.05)
for level in range(0, max_lev + 1):
    if level == 0:
        # show the original image before decomposition
        axes[0, 0].set_axis_off()
        axes[0, 1].imshow(x, cmap=plt.cm.gray)
        axes[0, 1].set_title('Image')
        axes[0, 1].set_axis_off()
        continue

    # plot subband boundaries of a standard DWT basis
    draw_2d_wp_basis(shape, wavedec2_keys(level), ax=axes[level, 0],
                     label_levels=label_levels)
    axes[level, 0].set_title('{} level\ndecomposition'.format(level))

    # compute the 2D DWT
    c = pywt.wavedec2(x, 'db2', mode='periodization', level=level)
    # normalize each coefficient array independently for better visibility
    c[0] /= np.abs(c[0]).max()
    for detail_level in range(level):
        c[detail_level + 1] = [d/np.abs(d).max() > threshold_otsu(d/np.abs(d).max()) for d in c[detail_level + 1]]
    # show the normalized coefficients
    arr, slices = pywt.coeffs_to_array(c)
    axes[level, 1].imshow(arr, cmap=plt.cm.gray)
    axes[level, 1].set_title('Coefficients\n({} level)'.format(level))
    axes[level, 1].set_axis_off()

plt.tight_layout()
plt.show()

Wavelet Transform Basics

2. Principle of wavelet transform denoising

Wavelet transform is usually used for image denoising, and the image denoising steps based on wavelet transform are as follows:

  • Select a wavelet type (for example, biorthogonal wavelet or Nlevel decomposition wavelet) to perform a discrete wavelet transform on the image using the wavelet
  • After the image is decomposed, thresholds are determined for each level ( Birgé-Massartstrategy is a common way to choose thresholds), using this procedure, Nindividual levels
  • The final step is to reconstruct the image from the modified levels using the inverse discrete wavelet transform

It should be noted that choosing to use different wavelets, levels and thresholding strategies may result in different types of filtering.

3. Use pywt to perform wavelet transform image denoising

In this section, we will add Gaussian noise to the input RGBimage and use wavelet's soft thresholding to remove the noise.

(1) First, import the required library, read the input RGBimage , and use σ = 0.25 σ=0.25p=Add Gaussian noise at 0.25 to get a noisy image:

import numpy as np
import pywt
from skimage import img_as_float
import matplotlib.pylab as plt
from skimage.io import imread

image = img_as_float(imread('3.png'))
noise_sigma = 0.25 #16.0
image += np.random.normal(0, noise_sigma, size=image.shape)

(2) We pywtcan Wavelet()apply multiple levels using the function in 2D DWT, which applies a wavelet transform and levels level=7:

wavelet = pywt.Wavelet('haar')
levels  = int(np.floor(np.log2(image.shape[0])))
print(levels)
wavelet_coeffs = pywt.wavedec2(image, wavelet, level=levels)
# 7

(3) Define a function denoise()to perform an image denoising operation, which accepts a wavelet object of a given type and the (estimated) noise standard deviation as parameters. Then, the function calculates DWTthe coefficients and uses the estimated noise to apply a soft threshold to calculate the new coefficients; finally, the image is reconstructed using the new coefficients, and the resulting returned image is:

def denoise(image, wavelet, noise_sigma):
    levels = int(np.floor(np.log2(image.shape[0])))
    wc = pywt.wavedec2(image, wavelet, level=levels)
    arr, coeff_slices = pywt.coeffs_to_array(wc)
    arr = pywt.threshold(arr, noise_sigma, mode='soft')
    nwc = pywt.array_to_coeffs(arr, coeff_slices, output_format='wavedec2')
    return pywt.waverec2(nwc, wavelet)

(4) Use different types of discrete wavelets and apply them to an image with noise, and use denoise()the function RGBto remove noise from each color channel of the noisy input image:

print(pywt.wavelist(kind='discrete'))
wlts = ['bior1.5', 'coif5', 'db6', 'dmey', 'haar', 'rbio2.8', 'sym15'] # pywt.wavelist(kind='discrete')
Denoised={
    
    }
for wlt in wlts:
    out = image.copy()
    for i in range(3):
        out[...,i] = denoise(image[...,i], wavelet=wlt, noise_sigma=3/2*noise_sigma)
    Denoised[wlt] = np.clip(out, 0, 1)
print(len(Denoised))

(5) Finally, plot all output images denoised with different wavelet types and PSNRcompare :

plt.figure(figsize=(15,8))
plt.subplots_adjust(0,0,1,0.9,0.05,0.07)
plt.subplot(241), plt.imshow(np.clip(image,0,1)), plt.axis('off'), plt.title('original image', size=8)
i = 2
for wlt in Denoised:
    plt.subplot(2,4,i), plt.imshow(Denoised[wlt]), plt.axis('off'), plt.title(wlt, size=8)
    i += 1
plt.suptitle('Image Denoising with Wavelets', size=12)
plt.show()

Image Denoising

4. Using scikit-image to perform wavelet transform image denoising

4.1 Cyclic rotation technology

Discrete wavelet transform does not have translation invariance. In order to achieve translation invariance, non-sampling wavelet transform ( , undecimated wavelet transformalso known as stationary wavelet, stationary wavelet) can be used, but it needs to increase the cost of redundancy, that is, more wavelet coefficients than the input image pixels. To achieve image denoising and approximate translation invariance using the discrete wavelet transform, we can also use another cycle-spinningtechnique called cyclic rotation ( ), which involves averaging multiple spatially shifted results:

  • Move the signal (cycle) by the amountn
  • apply denoising
  • apply reverse translation

4.2 Improving the quality of image denoising

In this section, we will see that for 2Dimage denoising, the cyclic rotation technique can greatly improve image quality, and that most of the image gain can be obtained by averaging only the offsets of andn=0 on each axis:n=1

(1) First, import all required Pythonmodules and functions:

from skimage.restoration import (denoise_wavelet, estimate_sigma)
from skimage import data, img_as_float
from skimage.util import random_noise
from skimage.metrics import peak_signal_noise_ratio
import numpy as np
from skimage import img_as_float
import matplotlib.pylab as plt
from skimage.io import imread

(2) Read the input image and add random noise to the image:

original = img_as_float(imread('3.png'))[...,:3]
sigma = 0.12
noisy = random_noise(original, var=sigma**2)

(3) Estimate the average noise standard deviation across color channels:

sigma_est = estimate_sigma(noisy, multichannel=True, average_sigmas=True)
print(f"Estimated Gaussian noise standard deviation = {
      
      sigma_est}")
# Estimated Gaussian noise standard deviation = 0.10885399509583953

As shown in the above code execution results, the resulting mean noise standard deviation estimate is slightly sigmasmaller .

(4) Perform wavelet denoising on the image using the function of skimage.restorationthe module . denoise_wavelet()Two different thresholding methods, namely Bayesshrinkand VisuShrink:

im_bayes = denoise_wavelet(noisy, multichannel=True, convert2ycbcr=True,
                           method='BayesShrink', mode='soft',
                           rescale_sigma=True)
im_visushrink = denoise_wavelet(noisy, multichannel=True, convert2ycbcr=True,
                                method='VisuShrink', mode='soft',
                                sigma=sigma_est, rescale_sigma=True)

VisushrinkAims to remove noise with a high probability, but this can result in an image that is visually too smooth. Repeat the above process with different thresholds and observe different results:

im_visushrink2 = denoise_wavelet(noisy, multichannel=True, convert2ycbcr=True,
                                 method='VisuShrink', mode='soft',
                                 sigma=sigma_est/2, rescale_sigma=True)
im_visushrink4 = denoise_wavelet(noisy, multichannel=True, convert2ycbcr=True,
                                 method='VisuShrink', mode='soft',
                                 sigma=sigma_est/4, rescale_sigma=True)

(5) Calculation PSNRas an indicator of image quality:

psnr_noisy = peak_signal_noise_ratio(original, noisy)
psnr_bayes = peak_signal_noise_ratio(original, im_bayes)
psnr_visushrink = peak_signal_noise_ratio(original, im_visushrink)
psnr_visushrink2 = peak_signal_noise_ratio(original, im_visushrink2)
psnr_visushrink4 = peak_signal_noise_ratio(original, im_visushrink4)

(6) Draw input image, output image and corresponding PSNR value:

plt.figure(figsize=(20,20))
plt.subplots_adjust(0,0,1,1,0.05,0.05)
plt.subplot(231), plt.imshow(original), plt.axis('off'), plt.title('Original', size=10)
plt.subplot(232), plt.imshow(noisy), plt.axis('off'), plt.title('Noisy\nPSNR={:0.4g}'.format(psnr_noisy), size=10)
plt.subplot(233), plt.imshow(im_bayes/im_bayes.max()), plt.axis('off'), plt.title('Wavelet denoising\n(BayesShrink)\nPSNR={:0.4f}'.format(psnr_bayes), size=10)
plt.subplot(234), plt.imshow(im_visushrink/im_visushrink.max()), plt.axis('off')
plt.title('Wavelet denoising\n' + r'(VisuShrink, $\sigma=\sigma_{est}$)' + '\nPSNR={:0.4g}'.format(psnr_visushrink), size=10)
plt.subplot(235), plt.imshow(im_visushrink2/im_visushrink2.max()), plt.axis('off')
plt.title('Wavelet denoising\n' + r'(VisuShrink, $\sigma=\sigma_{est}/2$)' + '\nPSNR={:0.4g}'.format(psnr_visushrink2), size=10)
plt.subplot(236), plt.imshow(im_visushrink4/im_visushrink4.max()), plt.axis('off')
plt.title('Wavelet denoising\n' + r'(VisuShrink, $\sigma=\sigma_{est}/4$)' + '\nPSNR={:0.4g}'.format(psnr_visushrink4), size=10)
plt.show()

Running the above code, you can get the following result image:

Improve denoising effect

summary

Discrete wavelet transform is a spectral analysis tool that discretizes the scale and translation of basic wavelets. It can simultaneously investigate the frequency domain characteristics of local time domain processes and the time domain characteristics of local frequency domain processes. For images, discrete wavelet transform can transform images into a series of wavelet coefficients and compress and store these coefficients efficiently, and can restore and express images better. In this section, we have learned the basic principles of wavelet transform, and how to use pywtand scikit-imagelibrary to implement wavelet transform image denoising.

series link

Python image processing [1] Image and video processing basics
Python image processing [2] Exploring Python image processing library
Python image processing [3] Python image processing library application
Python image processing [4] Image linear transformation
Python image processing [5] Image distortion /Unwarp
Python image processing [6] Find duplicate and similar images by hashing
Python image processing [7] Sampling, convolution and discrete Fourier transform
Python image processing [8] Blur images with low-pass filters
Python image processing [9] Python image processing using high-pass filter to perform edge detection
[10] Discrete cosine transform based image compression
Python image processing [11] Deconvolution to perform image deblurring

Guess you like

Origin blog.csdn.net/qq_30167691/article/details/130150941