基于字典学习的图像去噪研究与实践

机器学习在图像处理中有非常多的应用,运用机器学习(包括现在非常流行的深度学习)技术,很多传统的图像处理问题都会取得相当不错的效果。今天我们就以机器学习中的字典学习(Dictionary Learning)为例,来展示其在图像去噪方面的应用。文中代码采用Python写成,其中使用了Scikit-learn包中提供的API,读者可以从【2】中获得演示用的完整代码(Jupyter notebook)。

一、什么是字典学习?

字典学习 (aka Sparse dictionary learning) is a branch of signal processing and machine learning. 特别地,我们也称其为a representation learning method. 字典学习 aims at finding a sparse representation of the input data (also known as sparse coding or 字典) in the form of a linear combination of basic elements as well as those basic elements themselves。These elements are called atoms and they compose a dictionary。字典中,some training data admits a sparse representation。The sparser the representation, the better the dictionary。如下图所示,现在我们有一组自然图像,我们希望学到一个字典,从而原图中的每一个小块都可以表示成字典中少数几个atoms之线性组合的形式。

二、字典学习应用于图像去噪的原理

首先来看看图像去噪问题的基本模型,如下图所示,对于一幅噪声图像y,它应该等于原图像x 假设噪声w,而在这个关系中,只有y是已知的,去噪的过程就是在此情况下推测x的过程。

在此基础上,我们通常把去噪问题看成是一个带约束条件的能量最小化问题,即对下面这个公式进行最下好从而求出未知项x。最小化公式中的第一项表示x和y要尽量接近,否则本来一幅猫噪声的图像,降噪之后变成了狗,这样的结果显然不是我们所期望的。第二项则表示对x的一个约束条件,否则如果没有这一项,那么x就变成了噪声图像y,那去噪也就失去了意义。这又变成了一个最大后验估计(Maximum-a-Posteriori (MAP) Estimation) 问题,关于MAP的概念你可以参考【3】。

关于第二项到底应该符合什么标准,不同学者都提出了各自的观点。其中一派就认为,x应该满足“稀疏”这个条件,因为这也是自然图像中所普遍存在的一个现实。

如果所x是一个m维的信号,D=[d1,..., dn]是一组标准基向量(basis vectors),其大小为m×p,也就是我们所说的字典。D is "adapted" to y if it can represent it with a few basis vectors, that is, there exists a sparse vector a,这里a是一个p维向量,such that y≈Da。这里a 就是所谓的稀疏编码。

为什稀疏对于去噪是好的?显然,a dictionary can be good for representing a class of signals, but not for representing white Gaussian noise。直觉上,也即是说,用字典来近似表达y的时候,就会略掉噪声。

于是乎,Elad and Aharon在2006年便提出了用字典学习来进行去噪的技术方案。首先,我们从y中提取所有的overlapping的矩形窗口(例如8×8),然后求解如下矩阵分解(matrix factorization)问题

基于这一基本原理,字典学习还可以应用于图像的inpaint(也就是图像修复),例如下面是修复前和修复后的图像对比

再来看看局部效果,可见毫无违和感

三、Scikit-learn实现的基于字典学习的图像去噪实例

首先引入所需的各种包,并读入一张Lena图像。

 
  1. print(__doc__)

  2.  
  3. from PIL import Image

  4.  
  5. import matplotlib.pyplot as plt

  6. import numpy as np

  7. import scipy as sp

  8.  
  9. from sklearn.decomposition import MiniBatchDictionaryLearning

  10. from sklearn.feature_extraction.image import extract_patches_2d

  11. from sklearn.feature_extraction.image import reconstruct_from_patches_2d

  12.  
  13. %matplotlib inline

  14.  
  15. from keras.preprocessing.image import load_img

  16. # load an image from file

  17. image = load_img('lena_gray_256.tif')

  18.  
  19. from keras.preprocessing.image import img_to_array

  20. # convert the image pixels to a numpy array

  21. image = img_to_array(image)

  22.  
  23. image = image[:,:,0]

  24.  
  25. print("original shape", image.shape)

接下来将原本取值为0~255之间整数值的像素灰度,转换到0~1之间的浮点数。你也可以用imshow将图像显示出来看看效果。

 
  1. image = image.astype('float32')

  2. image/=255

  3.  
  4. plt.imshow(image, cmap='gray')

然后是向其中加入高斯噪声,我们把含噪声的图像放在本文最后以便直观地对比去噪效果。

 
  1. noise = np.random.normal(loc=0, scale=0.05, size=image.shape)

  2. x_test_noisy1 = image + noise

  3. x_test_noisy1 = np.clip(x_test_noisy1, 0., 1.)

  4.  
  5. plt.imshow(x_test_noisy1, cmap='Greys_r')

接下来需要使用Scikit-learn中的Extract_pathches_2d来从图像中提取patches。通常来说,你可以从一组或者一个图像数据集中提取patches,并由此来训练字典。你也可以从一张图像中提取patches,注意它们都是overlapping的。你还可以将二者结合起来。这里我们的处理方式是从原图(为加噪声)中提取patches。你可能会疑惑,因为原图在实际应用中是未知的。后面我们也会演示从噪声图像中提取patches并去噪的效果。

 
  1. # Extract all reference patches from the original image

  2. print('Extracting reference patches...')

  3. patch_size = (5, 5)

  4. data = extract_patches_2d(image, patch_size)

  5. print(data.shape)

这里data的shape是一个三维向量[x, y, z],x表示patches的数量,y和z是每个小patch的宽和高。但在进行运算时,我们需要把每个小矩形“拉直”,于是有

 
  1. data = data.reshape(data.shape[0], -1)

  2. print(data.shape)

  3.  
  4. data -= np.mean(data, axis=0)

  5. data /= np.std(data, axis=0)

接下来使用MiniBatchDictionaryLearning函数和fit函数,就可以学到一个字典V。

 
  1. # #############################################################################

  2. # Learn the dictionary from reference patches

  3. print('Learning the dictionary...')

  4. dico = MiniBatchDictionaryLearning(n_components=144, alpha=1, n_iter=500)

  5. V = dico.fit(data).components_

你也可以查看一下字典V的形状,或者把它显示出来看看具体内容。这里我们设计的字典中一共有144个atoms,所以分12排来打印,每排12个atoms。

 
  1. print(V.shape)

  2.  
  3. plt.figure(figsize=(4.2, 4))

  4. for i, comp in enumerate(V[:144]):

  5. plt.subplot(12, 12, i + 1)

  6. plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,

  7. interpolation='nearest')

  8. plt.xticks(())

  9. plt.yticks(())

  10. plt.suptitle('Dictionary learned from patches\n', fontsize=16)

  11. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

字典如下图所示。

接下来要做的是从噪声图像中提取同样大小的patches,然后从字典中找到一组最逼近这个小patch的一组原子的线性组合,并用这个组合来重构噪声图像中的小patch。

 
  1. # #############################################################################

  2. # Extract noisy patches and reconstruct them using the dictionary

  3.  
  4. print('Extracting noisy patches... ')

  5. data = extract_patches_2d(x_test_noisy1, patch_size)

  6. data = data.reshape(data.shape[0], -1)

  7. intercept = np.mean(data, axis=0)

  8. data -= intercept

这里所使用的算法是OMP。函数recontruct_from_patches_2d用于重构图像。方法transform的作用是Encode the data as a sparse combination of the dictionary atoms

 
  1. print('Orthogonal Matching Pursuit\n2 atoms' + '...')

  2. reconstructions = x_test_noisy1.copy()

  3.  
  4. dico.set_params(transform_algorithm='omp', **{'transform_n_nonzero_coefs': 2})

  5. code = dico.transform(data)

  6. patches = np.dot(code, V)

  7.  
  8. patches += intercept

  9. patches = patches.reshape(len(data), *patch_size)

  10.  
  11. reconstructions = reconstruct_from_patches_2d(patches, (256, 256))

  12.  
  13. plt.imshow(reconstructions, cmap='Greys_r')

去噪后的图像由文末对比结果中的左下角图像给出。

四、从噪声图像中学习字典

现实中,噪声图像的原图像是无法获取的,否则就不需要去噪了。但我们可以直接从噪声图中学习字典。于是修改之前的代码,我们从直接从噪声图像中获取用于字典学习的patches。

 
  1. # Extract all reference patches from noisy image

  2. print('Extracting reference patches...')

  3. patch_size = (5, 5)

  4. data = extract_patches_2d(x_test_noisy1, patch_size)

  5. print(data.shape)

同样进行一些预处理和标准化之类的工作。

 
  1. data = data.reshape(data.shape[0], -1)

  2. print(data.shape)

  3.  
  4. data -= np.mean(data, axis=0)

  5. data /= np.std(data, axis=0)

解下来进行字典学习,并将学到的字典显示出来。

 
  1. # #############################################################################

  2. # Learn the dictionary from reference patches

  3. print('Learning the dictionary...')

  4. dico = MiniBatchDictionaryLearning(n_components=144, alpha=1, n_iter=500)

  5. V = dico.fit(data).components_

  6.  
  7. print(V.shape)

  8.  
  9. plt.figure(figsize=(4.2, 4))

  10. for i, comp in enumerate(V[:144]):

  11. plt.subplot(12, 12, i + 1)

  12. plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,

  13. interpolation='nearest')

  14. plt.xticks(())

  15. plt.yticks(())

  16. plt.suptitle('Dictionary learned from patches\n', fontsize=16)

  17. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

下图便是我们从噪声图像中学习到的字典。

再次使用之前的代码。其中,set_params的作用是Set the parameters of this estimator。其中,transform_n_nonzero_coefs 是 Number of nonzero coefficients to target in each column of the solution.  也就是稀疏解的稀疏程度。默认情况下,它等于 0.1 * n_features(features的数量)。该参数 is only used by algorithm=’lars’ and algorithm=’omp’ and is overridden by alpha in the omp case。

 
  1. # #############################################################################

  2. # Extract noisy patches and reconstruct them using the dictionary

  3.  
  4. print('Extracting noisy patches... ')

  5. data = extract_patches_2d(x_test_noisy1, patch_size)

  6. data = data.reshape(data.shape[0], -1)

  7. intercept = np.mean(data, axis=0)

  8. data -= intercept

  9.  
  10. print('Orthogonal Matching Pursuit\n2 atoms' + '...')

  11. reconstructions_frm_noise = x_test_noisy1.copy()

  12.  
  13. dico.set_params(transform_algorithm='omp', **{'transform_n_nonzero_coefs': 2})

  14. code = dico.transform(data)

  15. patches = np.dot(code, V)

  16.  
  17. patches += intercept

  18. patches = patches.reshape(len(data), *patch_size)

  19.  
  20. reconstructions_frm_noise = reconstruct_from_patches_2d(patches, (256, 256))

  21.  
  22. plt.imshow(reconstructions_frm_noise, cmap='Greys_r')

重构的去噪效果如下图中的右下角图像所示。

最后,顺便补充一句,你可以使用下面的代码来保存已经得到的图像结果。

 
  1. imgs = (reconstructions_frm_noise * 255).astype(np.uint8)

  2. Image.fromarray(imgs).save('lena_denoise_from_noise.png')

读者还可以参考Scikit-learn官方文档中给出的例子【1】,以了解其他参数或算法的使用。

猜你喜欢

转载自blog.csdn.net/qq_30263737/article/details/81087988