机器学习在图像处理中有非常多的应用,运用机器学习(包括现在非常流行的深度学习)技术,很多传统的图像处理问题都会取得相当不错的效果。今天我们就以机器学习中的字典学习(Dictionary Learning)为例,来展示其在图像去噪方面的应用。文中代码采用Python写成,其中使用了Scikit-learn包中提供的API,读者可以从【2】中获得演示用的完整代码(Jupyter notebook)。
一、什么是字典学习?
字典学习 (aka Sparse dictionary learning) is a branch of signal processing and machine learning. 特别地,我们也称其为a representation learning method. 字典学习 aims at finding a sparse representation of the input data (also known as sparse coding or 字典) in the form of a linear combination of basic elements as well as those basic elements themselves。These elements are called atoms and they compose a dictionary。字典中,some training data admits a sparse representation。The sparser the representation, the better the dictionary。如下图所示,现在我们有一组自然图像,我们希望学到一个字典,从而原图中的每一个小块都可以表示成字典中少数几个atoms之线性组合的形式。
二、字典学习应用于图像去噪的原理
首先来看看图像去噪问题的基本模型,如下图所示,对于一幅噪声图像y,它应该等于原图像x 假设噪声w,而在这个关系中,只有y是已知的,去噪的过程就是在此情况下推测x的过程。
在此基础上,我们通常把去噪问题看成是一个带约束条件的能量最小化问题,即对下面这个公式进行最下好从而求出未知项x。最小化公式中的第一项表示x和y要尽量接近,否则本来一幅猫噪声的图像,降噪之后变成了狗,这样的结果显然不是我们所期望的。第二项则表示对x的一个约束条件,否则如果没有这一项,那么x就变成了噪声图像y,那去噪也就失去了意义。这又变成了一个最大后验估计(Maximum-a-Posteriori (MAP) Estimation) 问题,关于MAP的概念你可以参考【3】。
关于第二项到底应该符合什么标准,不同学者都提出了各自的观点。其中一派就认为,x应该满足“稀疏”这个条件,因为这也是自然图像中所普遍存在的一个现实。
如果所x是一个m维的信号,D=[d1,..., dn]是一组标准基向量(basis vectors),其大小为m×p,也就是我们所说的字典。D is "adapted" to y if it can represent it with a few basis vectors, that is, there exists a sparse vector a,这里a是一个p维向量,such that y≈Da。这里a 就是所谓的稀疏编码。
为什稀疏对于去噪是好的?显然,a dictionary can be good for representing a class of signals, but not for representing white Gaussian noise。直觉上,也即是说,用字典来近似表达y的时候,就会略掉噪声。
于是乎,Elad and Aharon在2006年便提出了用字典学习来进行去噪的技术方案。首先,我们从y中提取所有的overlapping的矩形窗口(例如8×8),然后求解如下矩阵分解(matrix factorization)问题
基于这一基本原理,字典学习还可以应用于图像的inpaint(也就是图像修复),例如下面是修复前和修复后的图像对比
再来看看局部效果,可见毫无违和感
三、Scikit-learn实现的基于字典学习的图像去噪实例
首先引入所需的各种包,并读入一张Lena图像。
-
print(__doc__)
-
from PIL import Image
-
import matplotlib.pyplot as plt
-
import numpy as np
-
import scipy as sp
-
from sklearn.decomposition import MiniBatchDictionaryLearning
-
from sklearn.feature_extraction.image import extract_patches_2d
-
from sklearn.feature_extraction.image import reconstruct_from_patches_2d
-
%matplotlib inline
-
from keras.preprocessing.image import load_img
-
# load an image from file
-
image = load_img('lena_gray_256.tif')
-
from keras.preprocessing.image import img_to_array
-
# convert the image pixels to a numpy array
-
image = img_to_array(image)
-
image = image[:,:,0]
-
print("original shape", image.shape)
接下来将原本取值为0~255之间整数值的像素灰度,转换到0~1之间的浮点数。你也可以用imshow将图像显示出来看看效果。
-
image = image.astype('float32')
-
image/=255
-
plt.imshow(image, cmap='gray')
然后是向其中加入高斯噪声,我们把含噪声的图像放在本文最后以便直观地对比去噪效果。
-
noise = np.random.normal(loc=0, scale=0.05, size=image.shape)
-
x_test_noisy1 = image + noise
-
x_test_noisy1 = np.clip(x_test_noisy1, 0., 1.)
-
plt.imshow(x_test_noisy1, cmap='Greys_r')
接下来需要使用Scikit-learn中的Extract_pathches_2d来从图像中提取patches。通常来说,你可以从一组或者一个图像数据集中提取patches,并由此来训练字典。你也可以从一张图像中提取patches,注意它们都是overlapping的。你还可以将二者结合起来。这里我们的处理方式是从原图(为加噪声)中提取patches。你可能会疑惑,因为原图在实际应用中是未知的。后面我们也会演示从噪声图像中提取patches并去噪的效果。
-
# Extract all reference patches from the original image
-
print('Extracting reference patches...')
-
patch_size = (5, 5)
-
data = extract_patches_2d(image, patch_size)
-
print(data.shape)
这里data的shape是一个三维向量[x, y, z],x表示patches的数量,y和z是每个小patch的宽和高。但在进行运算时,我们需要把每个小矩形“拉直”,于是有
-
data = data.reshape(data.shape[0], -1)
-
print(data.shape)
-
data -= np.mean(data, axis=0)
-
data /= np.std(data, axis=0)
接下来使用MiniBatchDictionaryLearning函数和fit函数,就可以学到一个字典V。
-
# #############################################################################
-
# Learn the dictionary from reference patches
-
print('Learning the dictionary...')
-
dico = MiniBatchDictionaryLearning(n_components=144, alpha=1, n_iter=500)
-
V = dico.fit(data).components_
你也可以查看一下字典V的形状,或者把它显示出来看看具体内容。这里我们设计的字典中一共有144个atoms,所以分12排来打印,每排12个atoms。
-
print(V.shape)
-
plt.figure(figsize=(4.2, 4))
-
for i, comp in enumerate(V[:144]):
-
plt.subplot(12, 12, i + 1)
-
plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,
-
interpolation='nearest')
-
plt.xticks(())
-
plt.yticks(())
-
plt.suptitle('Dictionary learned from patches\n', fontsize=16)
-
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
字典如下图所示。
接下来要做的是从噪声图像中提取同样大小的patches,然后从字典中找到一组最逼近这个小patch的一组原子的线性组合,并用这个组合来重构噪声图像中的小patch。
-
# #############################################################################
-
# Extract noisy patches and reconstruct them using the dictionary
-
print('Extracting noisy patches... ')
-
data = extract_patches_2d(x_test_noisy1, patch_size)
-
data = data.reshape(data.shape[0], -1)
-
intercept = np.mean(data, axis=0)
-
data -= intercept
这里所使用的算法是OMP。函数recontruct_from_patches_2d用于重构图像。方法transform的作用是Encode the data as a sparse combination of the dictionary atoms
-
print('Orthogonal Matching Pursuit\n2 atoms' + '...')
-
reconstructions = x_test_noisy1.copy()
-
dico.set_params(transform_algorithm='omp', **{'transform_n_nonzero_coefs': 2})
-
code = dico.transform(data)
-
patches = np.dot(code, V)
-
patches += intercept
-
patches = patches.reshape(len(data), *patch_size)
-
reconstructions = reconstruct_from_patches_2d(patches, (256, 256))
-
plt.imshow(reconstructions, cmap='Greys_r')
去噪后的图像由文末对比结果中的左下角图像给出。
四、从噪声图像中学习字典
现实中,噪声图像的原图像是无法获取的,否则就不需要去噪了。但我们可以直接从噪声图中学习字典。于是修改之前的代码,我们从直接从噪声图像中获取用于字典学习的patches。
-
# Extract all reference patches from noisy image
-
print('Extracting reference patches...')
-
patch_size = (5, 5)
-
data = extract_patches_2d(x_test_noisy1, patch_size)
-
print(data.shape)
同样进行一些预处理和标准化之类的工作。
-
data = data.reshape(data.shape[0], -1)
-
print(data.shape)
-
data -= np.mean(data, axis=0)
-
data /= np.std(data, axis=0)
解下来进行字典学习,并将学到的字典显示出来。
-
# #############################################################################
-
# Learn the dictionary from reference patches
-
print('Learning the dictionary...')
-
dico = MiniBatchDictionaryLearning(n_components=144, alpha=1, n_iter=500)
-
V = dico.fit(data).components_
-
print(V.shape)
-
plt.figure(figsize=(4.2, 4))
-
for i, comp in enumerate(V[:144]):
-
plt.subplot(12, 12, i + 1)
-
plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,
-
interpolation='nearest')
-
plt.xticks(())
-
plt.yticks(())
-
plt.suptitle('Dictionary learned from patches\n', fontsize=16)
-
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
下图便是我们从噪声图像中学习到的字典。
再次使用之前的代码。其中,set_params的作用是Set the parameters of this estimator。其中,transform_n_nonzero_coefs 是 Number of nonzero coefficients to target in each column of the solution. 也就是稀疏解的稀疏程度。默认情况下,它等于 0.1 * n_features(features的数量)。该参数 is only used by algorithm=’lars’ and algorithm=’omp’ and is overridden by alpha in the omp case。
-
# #############################################################################
-
# Extract noisy patches and reconstruct them using the dictionary
-
print('Extracting noisy patches... ')
-
data = extract_patches_2d(x_test_noisy1, patch_size)
-
data = data.reshape(data.shape[0], -1)
-
intercept = np.mean(data, axis=0)
-
data -= intercept
-
print('Orthogonal Matching Pursuit\n2 atoms' + '...')
-
reconstructions_frm_noise = x_test_noisy1.copy()
-
dico.set_params(transform_algorithm='omp', **{'transform_n_nonzero_coefs': 2})
-
code = dico.transform(data)
-
patches = np.dot(code, V)
-
patches += intercept
-
patches = patches.reshape(len(data), *patch_size)
-
reconstructions_frm_noise = reconstruct_from_patches_2d(patches, (256, 256))
-
plt.imshow(reconstructions_frm_noise, cmap='Greys_r')
重构的去噪效果如下图中的右下角图像所示。
最后,顺便补充一句,你可以使用下面的代码来保存已经得到的图像结果。
-
imgs = (reconstructions_frm_noise * 255).astype(np.uint8)
-
Image.fromarray(imgs).save('lena_denoise_from_noise.png')
读者还可以参考Scikit-learn官方文档中给出的例子【1】,以了解其他参数或算法的使用。