图像超分之——寻找两张图差异的区域

本代码是超分或者复原任务中,想找出PSNR差距较大的区域的代码

import os
import math
import numpy as np
import cv2
import glob
from skimage import transform
from skimage import measure
from collections import OrderedDict

import matplotlib.pyplot as plt
import matplotlib.patches as patches


def bgr2ycbcr(img, only_y=True):
    '''same as matlab rgb2ycbcr
    only_y: only return Y channel
    Input:
        uint8, [0, 255]
        float, [0, 1]
    '''
    in_img_type = img.dtype
    img.astype(np.float32)
    if in_img_type != np.uint8:
        img *= 255.
    # convert
    if only_y:
        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
    else:
        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
    if in_img_type == np.uint8:
        rlt = rlt.round()
    else:
        rlt /= 255.
    return rlt.astype(in_img_type)

def calculate_psnr(img1, img2):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

def mse2psnr(mse):
    if mse == 0:
        return float('inf')
    return 20 * math.log10(1.0 / math.sqrt(mse))

def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', 
                 dpi=80, verbose=False):
    height = image.shape[0]
    width = image.shape[1]

    # resize heat map
    heat_map_resized = transform.resize(heat_map, (height, width))

    # normalize heat map
    max_value = np.max(heat_map_resized)
    min_value = np.min(heat_map_resized)
    normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)

    if display:
        # display
        plt.imshow(image)
        plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
        plt.axis(axis)
        plt.show()

    if save is not None:
        if verbose:
            print('save image: ' + save)
            
        H, W, C = image.shape
        figsize = W / float(dpi), H / float(dpi)
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        
        ax.imshow(image)
        ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)

        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
        fig.savefig(save, dpi=dpi, transparent=True)
        
def to_bin(img, lower, upper):
    return (lower < img) & (img < upper)


def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, 
                 save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
    height, width, _ = im_BSL.shape

    # resize heat map
    heatmap_resized = transform.resize(heatmap, (height, width))

    # normalize heat map
    max_value = np.max(heatmap_resized)
    min_value = np.min(heatmap_resized)
    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)

    # capture regions
    bin_map = to_bin(normalized_heatmap, thres, 1.0)
    label_map = measure.label(bin_map, connectivity=2)
    props = measure.regionprops(label_map)

    plot_im = im_BSL.copy()
    plot_im[~bin_map] = 0

    if save is not None:
        if verbose:
            print('save image: ' + save)
            
        H, W, C = im_BSL.shape
        figsize = W / float(dpi), H / float(dpi)
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        
        ax.imshow(im_BSL)
        ax.imshow(normalized_heatmap, alpha=alpha)
#         ax.imshow(plot_im, alpha=alpha)
        ax.axis(axis)

        for i in range(len(props)):
            if props[i].bbox_area >= 100:
                bbox_coord = props[i].bbox
                ax.add_patch(
                    patches.Rectangle(
                        (bbox_coord[1], bbox_coord[0]),
                        bbox_coord[3] - bbox_coord[1],
                        bbox_coord[2] - bbox_coord[0],
                        edgecolor='y',
                        linewidth = 6,
                        fill=False
                    ))
                psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
                       calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)

                h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'

                if bbox_coord[0] < 20:
                    ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r', 
                            verticalalignment='top', horizontalalignment=h_aln, fontsize=26)
                else:
                    ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',
                            verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)
        
        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
        fig.savefig(save, dpi=dpi, transparent=True)
#     plt.show()

def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, 
                 save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
    H, W, C = im_BSL.shape
    # resize heat map
    heatmap_resized = transform.resize(heatmap, (H, W))

    # normalize heat map
    max_value = np.max(heatmap_resized)
    min_value = np.min(heatmap_resized)
    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)

    # capture regions
    bin_map = to_bin(normalized_heatmap, 0.4, 1.0)
    label_map = measure.label(bin_map, connectivity=2)
    props = measure.regionprops(label_map)
    bbox_err = []

    for i in range(len(props)):
        if props[i].bbox_area >= 100:
            bbox_coord = props[i].bbox
            err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])
            psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
                   calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
            bbox_err.append((i, err, psnr))
            
    bbox_err.sort(key=lambda x:x[1], reverse=True)
    im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)
    save_dir20= '/data1/cropimage/diff6_cvpr'
    save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')
    im_diff20=im_diff*255
    cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])

    num_bbox = min(len(bbox_err), 5)

    # Plot patches
    fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))
    if axes.ndim == 1:
        axes = [axes]
    for i in range(num_bbox):
        ind, err, psnr = bbox_err[i]
        bbox_coord = props[ind].bbox

        axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
# ###################################################################################################
#         im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]

#         axes[i][0].imshow(im_GT1)
#         axes[i][1].imshow(im_BSL1)
#         axes[i][2].imshow(im_OCT1)
#         axes[i][3].imshow(im_diff1)


#         save_dir1= '/data1/cropimage/diff/im_GT1'
#         save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         im_GT1=cv2.resize(im_GT1*255,(100, 100))
#         cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])

#         save_dir2= '/data1/cropimage/diff/im_BSL1'
#         save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_BSL1=im_BSL1*255
#         im_BSL1=cv2.resize(im_BSL1*255,(100, 100))
#         cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])


#         save_dir3= '/data1/cropimage/diff/im_OCT1'
#         save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_OCT1=im_OCT1*255
#         im_OCT1=cv2.resize(im_OCT1*255,(100, 100))
#         cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])

#         save_dir4= '/data1/cropimage/diff/im_diff1'
#         save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_diff1=im_diff1*255
#         im_diff1=cv2.resize(im_diff1*255,(100, 100))
#         cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])


        axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \
                "{:+.2f}".format(psnr), color='r', fontsize=16)
        axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \
                "{}".format(bbox_coord), color='r', fontsize=16)

    fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)
    # plt.show()

folder_BSL = "/data1/results10.09/(v1)Layer_HRLR_withoutconnection_SRResNet_16B64C_alpha=0.5/DIV2K_VAL/"
folder_OCT = "/data1/results/(ture)_1X1_directshare_SRResNet_44B64C_alpha=0.5/DIV2K_VAL/"
folder_GT = '/data1/data/DIV2K_VAL/DIV2K_valid_HR/'

crop_border = 4
suffix = ''  # suffix for Gen images
test_Y = False  # True: test Y channel only; False: test RGB channels

PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]

if test_Y:
    print('Testing Y channel.')
else:
    print('Testing RGB channels.')


patch_size = 32
stride = 10

for img_path in img_list:
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.
    im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
    im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '').replace('_bicLRx4', '') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.

    H, W, C = im_OCT.shape
    H_axis = np.arange(0, H - patch_size, stride)
    W_axis = np.arange(0, W - patch_size, stride)
    err_map = np.zeros((len(H_axis), len(W_axis)))
    inv_map = np.zeros((len(H_axis), len(W_axis)))
    total_err = np.mean((im_OCT - im_BSL)**2)

    for i, h in enumerate(H_axis):
        for j, w in enumerate(W_axis):
            patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]
            patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]
            patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)

            err_map[i, j] = mse2psnr(patch_err)
            inv_map[i, j] = mse2psnr(total_err- patch_err)
    
    save_dir = '/data1/cropimage/diff_cvpr/'
    save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
    save_dir6 = '/data1/cropimage/heatdiff_cvpr/'
    save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
    #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)
    plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)
    plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)

改进版:

import os
import math
import numpy as np
import cv2
import glob
from skimage import transform
from skimage import measure
from collections import OrderedDict

import matplotlib.pyplot as plt
import matplotlib.patches as patches


def bgr2ycbcr(img, only_y=True):
    '''same as matlab rgb2ycbcr
    only_y: only return Y channel
    Input:
        uint8, [0, 255]
        float, [0, 1]
    '''
    in_img_type = img.dtype
    img.astype(np.float32)
    if in_img_type != np.uint8:
        img *= 255.
    # convert
    if only_y:
        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
    else:
        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
    if in_img_type == np.uint8:
        rlt = rlt.round()
    else:
        rlt /= 255.
    return rlt.astype(in_img_type)

def calculate_psnr(img1, img2):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

def mse2psnr(mse):
    if mse == 0:
        return float('inf')
    return 20 * math.log10(1.0 / math.sqrt(mse))

def plot_heatmap(image, heat_map, alpha=0.5, display=False, save=None, cmap='viridis', axis='on', 
                 dpi=80, verbose=False):
    height = image.shape[0]
    width = image.shape[1]

    # resize heat map
    heat_map_resized = transform.resize(heat_map, (height, width))

    # normalize heat map
    max_value = np.max(heat_map_resized)
    min_value = np.min(heat_map_resized)
    normalized_heat_map = (heat_map_resized - min_value) / (max_value - min_value)

    if display:
        # display
        plt.imshow(image)
        plt.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)
        plt.axis(axis)
        plt.show()

    if save is not None:
        if verbose:
            print('save image: ' + save)
            
        H, W, C = image.shape
        figsize = W / float(dpi), H / float(dpi)
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        
        ax.imshow(image)
        ax.imshow(255 * normalized_heat_map, alpha=alpha, cmap=cmap)

        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
        fig.savefig(save, dpi=dpi, transparent=True)
        
def to_bin(img, lower, upper):
    return (lower < img) & (img < upper)


def plot_diffmap(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, 
                 save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
    height, width, _ = im_BSL.shape

    # resize heat map
    heatmap_resized = transform.resize(heatmap, (height, width))

    # normalize heat map
    max_value = np.max(heatmap_resized)
    min_value = np.min(heatmap_resized)
    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)

    # capture regions
    bin_map = to_bin(normalized_heatmap, thres, 1.0)
    label_map = measure.label(bin_map, connectivity=2)
    props = measure.regionprops(label_map)

    plot_im = im_BSL.copy()
    plot_im[~bin_map] = 0

    if save is not None:
        if verbose:
            print('save image: ' + save)
            
        H, W, C = im_BSL.shape
        figsize = W / float(dpi), H / float(dpi)
        fig = plt.figure(figsize=figsize)
        ax = fig.add_axes([0, 0, 1, 1])
        ax.axis('off')
        
        ax.imshow(im_BSL)
        ax.imshow(normalized_heatmap, alpha=alpha)
#         ax.imshow(plot_im, alpha=alpha)
        ax.axis(axis)

        for i in range(len(props)):
            if props[i].bbox_area >= 100:
                bbox_coord = props[i].bbox
                ax.add_patch(
                    patches.Rectangle(
                        (bbox_coord[1], bbox_coord[0]),
                        bbox_coord[3] - bbox_coord[1],
                        bbox_coord[2] - bbox_coord[0],
                        edgecolor='y',
                        linewidth = 6,
                        fill=False
                    ))
                psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
                       calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                      im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)

                h_aln = 'right' if W - bbox_coord[1] < 50 else 'left'

                if bbox_coord[0] < 20:
                    ax.text(bbox_coord[1], bbox_coord[2], "{:+.2f}".format(psnr), color='r', 
                            verticalalignment='top', horizontalalignment=h_aln, fontsize=26)
                else:
                    ax.text(bbox_coord[1], bbox_coord[0], "{:+.2f}".format(psnr), color='r',
                            verticalalignment='bottom', horizontalalignment=h_aln, fontsize=26)
        
        ax.set(xlim=[0, W], ylim=[H, 0], aspect=1)
        fig.savefig(save, dpi=dpi, transparent=True)
#     plt.show()

def plot_diff_patch(im_BSL, im_OCT, im_GT, heatmap, thres=0.4, alpha=0.5, display=False, 
                 save=None, cmap='viridis', axis='on', dpi=80, verbose=False):
    H, W, C = im_BSL.shape
    # resize heat map
    heatmap_resized = transform.resize(heatmap, (H, W))

    # normalize heat map
    max_value = np.max(heatmap_resized)
    min_value = np.min(heatmap_resized)
    normalized_heatmap = (heatmap_resized - min_value) / (max_value - min_value)

    # capture regions
    bin_map = to_bin(normalized_heatmap, 0.4, 1.0)
    label_map = measure.label(bin_map, connectivity=2)
    props = measure.regionprops(label_map)
    bbox_err = []

    for i in range(len(props)):
        if props[i].bbox_area >= 100:
            bbox_coord = props[i].bbox
            err = np.mean(normalized_heatmap[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3]])
            psnr = calculate_psnr(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.) - \
                   calculate_psnr(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255., \
                                  im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]*255.)
            bbox_err.append((i, err, psnr))
            
    bbox_err.sort(key=lambda x:x[1], reverse=True)
    im_diff = np.clip(im_OCT - im_BSL + 0.5, 0.0, 1.0)
    save_dir20= '/data1/cropimage/diff6_cvpr'
    save_path20 = os.path.join(save_dir20, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+'.png')
    im_diff20=im_diff*255
    cv2.imwrite(save_path20,im_diff20[:, :, [2, 1, 0]])

    num_bbox = min(len(bbox_err), 5)

    # Plot patches
    fig, axes = plt.subplots(nrows=num_bbox, ncols=4, figsize=(15,15))
    if axes.ndim == 1:
        axes = [axes]
    for i in range(num_bbox):
        ind, err, psnr = bbox_err[i]
        bbox_coord = props[ind].bbox

        axes[i][0].imshow(im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][1].imshow(im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][2].imshow(im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
        axes[i][3].imshow(im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :])
# ###################################################################################################
#         im_GT1=im_GT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_BSL1=im_BSL[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_OCT1=im_OCT[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]
#         im_diff1=im_diff[bbox_coord[0]:bbox_coord[2], bbox_coord[1]:bbox_coord[3], :]

#         axes[i][0].imshow(im_GT1)
#         axes[i][1].imshow(im_BSL1)
#         axes[i][2].imshow(im_OCT1)
#         axes[i][3].imshow(im_diff1)


#         save_dir1= '/data1/cropimage/diff/im_GT1'
#         save_path1 = os.path.join(save_dir1, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         im_GT1=cv2.resize(im_GT1*255,(100, 100))
#         cv2.imwrite(save_path1,im_GT1[:, :, [2, 1, 0]])

#         save_dir2= '/data1/cropimage/diff/im_BSL1'
#         save_path2 = os.path.join(save_dir2, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_BSL1=im_BSL1*255
#         im_BSL1=cv2.resize(im_BSL1*255,(100, 100))
#         cv2.imwrite(save_path2,im_BSL1[:, :, [2, 1, 0]])


#         save_dir3= '/data1/cropimage/diff/im_OCT1'
#         save_path3 = os.path.join(save_dir3, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_OCT1=im_OCT1*255
#         im_OCT1=cv2.resize(im_OCT1*255,(100, 100))
#         cv2.imwrite(save_path3,im_OCT1[:, :, [2, 1, 0]])

#         save_dir4= '/data1/cropimage/diff/im_diff1'
#         save_path4 = os.path.join(save_dir4, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '')+str(i+1) + '.png')
#         #im_diff1=im_diff1*255
#         im_diff1=cv2.resize(im_diff1*255,(100, 100))
#         cv2.imwrite(save_path4,im_diff1[:, :, [2, 1, 0]])


        axes[i][3].text(bbox_coord[3]-bbox_coord[1], bbox_coord[2]-bbox_coord[0], \
                "{:+.2f}".format(psnr), color='r', fontsize=16)
        axes[i][3].text(bbox_coord[3]-bbox_coord[1], 0, \
                "{}".format(bbox_coord), color='r', fontsize=16)

    fig.savefig(save_path, dpi=300, bbox_inches='tight', transparent=False)
    # plt.show()

folder_BSL = "/data1/results10.09/(multi_scale)_SRResNet_16B64C/DIV2K_VAL0.8/"
folder_OCT ="/data1/results/(multi)1X1_directshare_SRResNet_48B64C_alpha=0.5/DIV2K_VAL0.8/"
folder_GT =   "/data1/data/multiscale_dataset/DIV2K_valid_HR_0.8/"

crop_border = 4
suffix = ''  # suffix for Gen images
test_Y = False  # True: test Y channel only; False: test RGB channels

PSNR_all = []
SSIM_all = []
img_list = sorted(glob.glob(folder_OCT + '/*'))[:100]

if test_Y:
    print('Testing Y channel.')
else:
    print('Testing RGB channels.')


patch_size = 32
stride = 10

for img_path in img_list:
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    im_OCT = cv2.imread(img_path)[:, :, [2, 1, 0]] / 255.
    im_BSL = cv2.imread(os.path.join(folder_BSL, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', 'SRResNet_16B64C') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.
    im_GT = cv2.imread(os.path.join(folder_GT, base_name.replace('_bicLRx4', '_bicLRx0.6') + suffix + '.png'))[:, :, [2, 1, 0]] / 255.

    H, W, C = im_OCT.shape
    H_axis = np.arange(0, H - patch_size, stride)
    W_axis = np.arange(0, W - patch_size, stride)
    err_map = np.zeros((len(H_axis), len(W_axis)))
    inv_map = np.zeros((len(H_axis), len(W_axis)))
    total_err = np.mean((im_OCT - im_BSL)**2)

    for i, h in enumerate(H_axis):
        for j, w in enumerate(W_axis):
            patch_OCT = im_OCT[h:h+patch_size, w:w+patch_size, :]
            patch_BSL = im_BSL[h:h+patch_size, w:w+patch_size, :]
            patch_err = np.sum((patch_OCT - patch_BSL)**2) / (H*W*C)

            err_map[i, j] = mse2psnr(patch_err)
            inv_map[i, j] = mse2psnr(total_err- patch_err)


    save_dir = '/data1/cropimage/heatmap/diff_DIV2K_VAL0.8/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    save_path = os.path.join(save_dir, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
    save_dir6 = '/data1/cropimage/heatmap/heatdiff_DIV2K_VAL0.8/'
    if not os.path.exists(save_dir6):
        os.mkdir(save_dir6)
    save_path6 = os.path.join(save_dir6, base_name.replace('Layer_HRLRadd_SRResNet_16B64C_alpha=0.5', '') + '.png')
    #plot_heatmap(im_BSL, inv_map, alpha=0.7, save=save_path, axis='off', display=False)
    plot_diffmap(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path6, axis='off', display=False)
    plot_diff_patch(im_BSL, im_OCT, im_GT, inv_map, alpha=0.5, save=save_path, axis='off', display=False)
发布了208 篇原创文章 · 获赞 198 · 访问量 23万+

猜你喜欢

转载自blog.csdn.net/gwplovekimi/article/details/103069899