论文复现代码Reversible Watermarking Algorithm Using Sorting and Prediction

论文简述

见我文章

论文简述附复现代码Reversible Watermarking Algorithm Using Sorting and Prediction比较经典的PEE算法_zrc007007的博客-CSDN博客

复现代码

注:

代码不可商用!

代码不可商用!

代码不可商用!

from skimage import io
import matplotlib.pyplot as plt
import numpy as np


class Test:
    def __init__(self):
        self.data = np.array(
            [[251, 252, 253, 253],
             [252, 254, 252, 251],
             [252, 252, 252, 252],
             [252, 253, 252, 255]]
        )


def int2bin(x, length):
    string = ''
    while x != 0:
        string = str(x % 2) + string
        x = x // 2
    if len(string) > length:
        raise Exception("The length of x is longer actually!")
    elif len(string) < length:
        return '0' * (length - len(string)) + string
    else:
        return string


def countMiu(x1, x2, x3, x4):
    delta = []
    miu = 0
    x1 = float(x1)
    x2 = float(x2)
    x3 = float(x3)
    x4 = float(x4)
    if x4 != -1:
        delta.append(abs(x1 - x2))
        delta.append(abs(x2 - x3))
        delta.append(abs(x3 - x4))
        delta.append(abs(x4 - x1))
        deltaAvg = (delta[0] + delta[1] + delta[2] + delta[3]) / 4
        for i in range(4):
            miu += (delta[i] - deltaAvg) ** 2
        return miu / 4
    elif x3 != -1:
        delta.append(abs(x1 - x2))
        delta.append(abs(x2 - x3))
        delta.append(abs(x3 - x1))
        deltaAvg = (delta[0] + delta[1] + delta[2]) / 3
        for i in range(3):
            miu += (delta[i] - deltaAvg) ** 2
        return miu / 3
    else:
        return 10 ** 5


def countPrdMiu(img, row, column):
    if row == 0 and column == 0:
        prd = (int(img[0][1]) + int(img[1][0])) // 2
        miu = countMiu(int(img[0][1]), int(img[1][0]), -1, -1)
    elif row == (img.shape[0] - 1) and column == (img.shape[1] - 1):
        prd = (int(img[img.shape[0] - 2][img.shape[0] - 1]) + int(img[img.shape[0] - 1][img.shape[1] - 2])) // 2
        miu = countMiu(int(img[img.shape[0] - 2][img.shape[0] - 1]), int(img[img.shape[0] - 1][img.shape[1] - 2]), -1,
                       -1)
    elif row == 0:
        prd = (int(img[0][column - 1]) + int(img[0][column + 1]) + int(img[1][column])) // 3
        miu = countMiu(int(img[0][column - 1]), int(img[0][column + 1]), int(img[1][column]), -1)
    elif row == img.shape[0] - 1:
        prd = (int(img[row][column - 1]) + int(img[row][column + 1]) + int(img[row - 1][column])) // 3
        miu = countMiu(int(img[row][column - 1]), int(img[row][column + 1]), int(img[row - 1][column]), -1)
    elif column == 0:
        prd = (int(img[row - 1][0]) + int(img[row + 1][0]) + int(img[row][1])) // 3
        miu = countMiu(int(img[row - 1][0]), int(img[row + 1][0]), int(img[row][1]), -1)
    elif column == img.shape[1] - 1:
        prd = (int(img[row - 1][column]) + int(img[row + 1][column]) + int(img[row][column - 1])) // 3
        miu = countMiu(int(img[row - 1][column]), int(img[row + 1][column]), int(img[row][column - 1]), -1)
    else:
        prd = (int(img[row - 1][column]) + int(img[row + 1][column])
               + int(img[row][column - 1]) + int(img[row][column + 1])) // 4
        miu = countMiu(int(img[row - 1][column]), int(img[row + 1][column]),
                       int(img[row][column - 1]), int(img[row][column + 1]))
    return prd, miu


def HistogramShift(origin, prd, Tn, Tp, b):
    d = origin - prd
    if Tn <= d <= Tp:
        D = 2 * d + b
    elif d > Tp:
        D = d + Tp + 1
    else:
        D = d + Tn
    return D


def judgeTwiceFlow(origin, prd, Tn, Tp):
    if origin < prd:
        origin2 = prd + HistogramShift(origin, prd, Tn, Tp, 0)
        if origin2 < 0:
            return 1
        else:
            if prd + HistogramShift(origin2, prd, Tn, Tp, 0) < 0:
                return 0
            else:
                return -1
    else:
        origin2 = prd + HistogramShift(origin, prd, Tn, Tp, 1)
        if origin2 > 255:
            return 1
        else:
            if prd + HistogramShift(origin2, prd, Tn, Tp, 1) > 255:
                return 0
            else:
                return -1


def embed(img, row, column, origin, prd, Tn, Tp, b):
    img[row][column] = prd + HistogramShift(origin, prd, Tn, Tp, b)


def countPosition(img, place):
    row = place // img.shape[1]
    if img.shape[1] % 2 == 1:
        column = place % img.shape[1]
    elif row % 2 == 0:
        column = place % img.shape[1]
    else:
        column = place % img.shape[1] + 1
    return row, column


def embedOnePattern(place, Tn, Tp, P, processed):
    encrypted = processed
    Slsb = ''
    head = int2bin(-Tn, 7) + int2bin(Tp, 7) + int2bin(len(P), 20)
    row_dic = []
    row_dic_list = []
    row_dic_list_sorted = []
    for z in range(encrypted.shape[0]):
        row_dic.append({})
        row_dic_list.append(z)
        row_dic_list_sorted.append(z)
    start = 1
    locationMap = ''
    needDot = True
    while place < img.shape[0] * img.shape[1]:
        row, column = countPosition(img, place)
        if (place + 2) // img.shape[1] == row:
            prd, miu = countPrdMiu(img, row, column)
            origin = img[row][column]
            row_dic[row][row, column, origin, prd] = miu
        else:
            prd, miu = countPrdMiu(img, row, column)
            origin = img[row][column]
            row_dic[row][row, column, origin, prd] = miu
            row_dic_list[row] = list(zip(row_dic[row].values(), row_dic[row].keys()))
            row_dic_list_sorted[row] = sorted(row_dic_list[row])
            # 没有放在外面是因为接着判断为换行,在处理这个行尾的时候,用这个固定的row,直接sorting完这一行
            for i in range(len(row_dic_list_sorted[row])):
                # 标记出位图
                if start == 1 and judgeTwiceFlow(row_dic_list_sorted[row][i][1][2],
                                                 row_dic_list_sorted[row][i][1][3], Tn, Tp) >= 0:
                    continue
                else:
                    start = 0
                    if judgeTwiceFlow(row_dic_list_sorted[row][i][1][2],
                                      row_dic_list_sorted[row][i][1][3], Tn, Tp) != -1:
                        locationMap += str(judgeTwiceFlow(row_dic_list_sorted[row][i][1][2],
                                                          row_dic_list_sorted[row][i][1][3], Tn, Tp))
        place += 2
    start = 1
    embedded_message = locationMap
    # 利用之前算出来的sorting顺序,直接进行嵌入embed
    for x in range(img.shape[0]):
        for y in range(len(row_dic_list_sorted[x])):
            if x * img.shape[1] + y < 34:
                encrypted[row_dic_list_sorted[x][y][1][0]][row_dic_list_sorted[x][y][1][1]] = \
                    img[row_dic_list_sorted[x][y][1][0]][row_dic_list_sorted[x][y][1][1]] // 2 * 2 + int(head[0])
                Slsb += str(img[row_dic_list_sorted[x][y][1][0]][row_dic_list_sorted[x][y][1][1]] % 2)
                head = head[1:]
            else:
                if start == 1:
                    embedded_message = locationMap + Slsb + P
                    start = 0
                d = row_dic_list_sorted[x][y][1][2] - row_dic_list_sorted[x][y][1][3]
                if Tn <= d <= Tp and len(embedded_message) > 0:
                    embed(encrypted, row_dic_list_sorted[x][y][1][0], row_dic_list_sorted[x][y][1][1],
                          row_dic_list_sorted[x][y][1][2], row_dic_list_sorted[x][y][1][3], Tn, Tp,
                          int(embedded_message[0]))
                    embedded_message = embedded_message[1:]
                else:
                    embed(encrypted, row_dic_list_sorted[x][y][1][0], row_dic_list_sorted[x][y][1][1],
                          row_dic_list_sorted[x][y][1][2], row_dic_list_sorted[x][y][1][3], Tn, Tp,
                          0)  # 超过界限就嵌入0
    if len(embedded_message) > len(P):
        raise Exception("Didn't finish embedding lacation map and Slsb!")
    if len(embedded_message) == 0:
        print("No need to embed Dot set!")
        needDot = False
    return encrypted, embedded_message, needDot


def encode(img, Tn, Tp, P):
    '''
    Encode the original imagine.
    :param img: original imagine
    :param Tn: start threshold
    :param Tp: end threshold
    :param P: payload
    :return: The encrypted imagine and the bool value sign whether need Dot pattern encoding.
    '''
    if img.shape[0] * img.shape[1] < 34 * 4:
        raise Exception("The imagine is too small!")
    if Tn >= 0:
        raise Exception("Tn need to be negative!")
    if Tp < 0:
        raise Exception("Tp need to be positive!")
    encrypted = np.array(img)
    encrypted, embedded_message, needDot = embedOnePattern(0, Tn, Tp, P, encrypted)
    if needDot is True:
        encrypted, embedded_message, needDot = embedOnePattern(1, Tn, Tp, P, encrypted)
        if len(embedded_message) > 0:
            raise Exception("Length of P,", len(P), ", is much more than this picture can be embedded!")
        else:
            return encrypted, True
    else:
        return encrypted, False


def Slsb2int(string):
    if len(string) != 34:
        raise Exception("Given string's length isn't 34!")
    T = [0, 0]
    P = 0
    cut = [string[0:7], string[7:14], string[14: 34]]
    for i in range(2):
        for j in range(7):
            T[i] += int(cut[i][-1]) * 2 ** j
            cut[i] = cut[i][0:-1]
    for j in range(20):
        P += int(cut[2][-1]) * 2 ** j
        cut[2] = cut[2][0:-1]
    return T[0], T[1], P


def countPattenNum(isOdd, place, img):
    if isOdd is False:
        return img.shape[0] * img.shape[1] // 2
    if isOdd is True:
        if place == 0:
            return img.shape[0] * img.shape[1] // 2 + 1
        else:
            return img.shape[0] * img.shape[1] // 2


def judgeOnceFlow(origin, prd, Tn, Tp):
    if origin < prd:
        if prd + HistogramShift(origin, prd, Tn, Tp, 0) < 0:
            return True
        else:
            return False
    else:
        if prd + HistogramShift(origin, prd, Tn, Tp, 1) > 255:
            return True
        else:
            return False


def extractPixel(origin, prd):
    result = (origin - prd) % 2
    return result


def decodeShift(origin, prd, Tn, Tp):
    D = origin - prd
    if 2 * Tn <= D <= 2 * Tp + 1:
        d = D // 2
    elif D > 2 * Tp + 1:
        d = D - Tp - 1
    else:
        d = D - Tn
    return d


def restoreImg(needResumed, message, bitNum, row_dic_list_sorted, Tn, Tp):
    processed = needResumed
    start = 1
    bitPosition = 0
    Slsb = message[bitNum:bitNum + 34]
    for i in range(processed.shape[0]):
        for j in range(len(row_dic_list_sorted[i])):
            if i * processed.shape[0] + j < 34:
                processed[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] = \
                    processed[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] - \
                    processed[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] % 2 + int(Slsb[0])
                Slsb = Slsb[1:]
            else:
                if start == 1 and judgeOnceFlow(row_dic_list_sorted[i][j][1][2],
                                                row_dic_list_sorted[i][j][1][3], Tn, Tp) is True:
                    continue
                else:
                    start = 0
                    if judgeOnceFlow(row_dic_list_sorted[i][j][1][2],
                                     row_dic_list_sorted[i][j][1][3], Tn, Tp) is False:
                        processed[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] \
                            = row_dic_list_sorted[i][j][1][3] \
                              + decodeShift(row_dic_list_sorted[i][j][1][2], row_dic_list_sorted[i][j][1][3], Tn, Tp)
                    else:
                        if message[bitPosition] == 0:
                            processed[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] \
                                = row_dic_list_sorted[i][j][1][3] \
                                  + decodeShift(row_dic_list_sorted[i][j][1][2], row_dic_list_sorted[i][j][1][3], Tn, Tp)
                        else:
                            pass
                        bitPosition += 1
    return processed


def decodeOnePattern(place, encrypted):
    resumed = encrypted
    odd = False
    if encrypted.shape[0] % 2 != 0 and encrypted.shape[1] % 2 != 0:
        odd = True
    head = ''
    start = 1
    startRow = 0
    message = ''
    bitPosition = 0
    row_dic = []
    row_dic_list = []
    row_dic_list_sorted = []
    for z in range(encrypted.shape[0]):
        row_dic.append({})
        row_dic_list.append(z)
        row_dic_list_sorted.append(z)
    for i in range(countPattenNum(odd, place, encrypted)):
        row, column = countPosition(encrypted, place)
        if (place + 2) // encrypted.shape[1] == row:
            prd, miu = countPrdMiu(encrypted, row, column)
            modified = encrypted[row][column]
            row_dic[row][row, column, modified, prd] = miu
        else:
            prd, miu = countPrdMiu(encrypted, row, column)
            modified = encrypted[row][column]
            row_dic[row][row, column, modified, prd] = miu
            row_dic_list[row] = list(zip(row_dic[row].values(), row_dic[row].keys()))
            row_dic_list_sorted[row] = sorted(row_dic_list[row])
        place += 2
    place = 0

    needBreak = False
    for i in range(encrypted.shape[0]):
        for j in range(len(row_dic_list_sorted[i])):
            if i * encrypted.shape[1] + j >= 34:
                needBreak = True
            if needBreak is True:
                break
            head += str(encrypted[row_dic_list_sorted[i][j][1][0]][row_dic_list_sorted[i][j][1][1]] % 2)
        if needBreak is True:
            break
    Tn, Tp, P = Slsb2int(head)
    Tn = -Tn
    if Tn >= 0:
        raise Exception("The extracted Tn isn't negative!")
    if Tp < 0:
        raise Exception("The extracted Tp isn't positive!")

    for i in range(encrypted.shape[0]):
        for j in range(len(row_dic_list_sorted[i])):
            if i * encrypted.shape[1] + j < 34:
                pass
            else:
                if start == 1 and judgeOnceFlow(row_dic_list_sorted[i][j][1][2],
                                                row_dic_list_sorted[i][j][1][3], Tn, Tp) is True:
                    continue
                else:
                    start = 0
                    if judgeOnceFlow(row_dic_list_sorted[i][j][1][2],
                                     row_dic_list_sorted[i][j][1][3], Tn, Tp) is False:
                        D = row_dic_list_sorted[i][j][1][2] - row_dic_list_sorted[i][j][1][3]
                        if 2 * Tn <= D <= 2 * Tp + 1:
                            message += str(extractPixel(row_dic_list_sorted[i][j][1][2],
                                                        row_dic_list_sorted[i][j][1][3]))

                    else:
                        if message[bitPosition] == 0:
                            D = row_dic_list_sorted[i][j][1][2] - row_dic_list_sorted[i][j][1][3]
                            if 2 * Tn <= D <= 2 * Tp + 1:
                                message += str(extractPixel(row_dic_list_sorted[i][j][1][2],
                                                            row_dic_list_sorted[i][j][1][3]))
                            bitPosition += 1
                        elif message[bitPosition] == 1:
                            bitPosition += 1
    resumed = restoreImg(resumed, message, bitPosition, row_dic_list_sorted, Tn, Tp)
    info = message[bitPosition + 34:bitPosition + 34 + P]
    return resumed, info
    # 还有第二次嵌入要处理一下,到时候写


def decode(encrypted, needDot):
    '''
    Decode the encrypted imagine.
    :param encrypted: the encrypted imagine
    :param needDot: whether you need decoding Dot set or not
    :return: The original imagine and the secret information.
    '''
    if img.shape[0] * img.shape[1] < 34 * 4:
        raise Exception("The imagine is too small!")
    if needDot is False:
        imagine, info = decodeOnePattern(0, encrypted)
    else:
        encrypted1, info1 = decodeOnePattern(1, encrypted)
        imagine, info2 = decodeOnePattern(0, encrypted1)
        info = info1 + info2
    return imagine, info


if __name__ == "__main__":
    img = io.imread("../img/lena_gray_512.tif")  # 在这里放你要处理的图像 place your imagine here
    # img = Test().data
    # print(img)
    encrypted, needDot = encode(img, -1, 0, '100000')
    print("need Dot:", needDot)
    resumedImg, info = decode(encrypted, needDot)

    print("The secret info is:", info)
    print("img == resumedImg:")
    print(img == resumedImg)

    plt.set_cmap(cmap="gray")
    plt.subplot(131)
    plt.imshow(img)
    plt.title("Original imagery")
    plt.subplot(132)
    plt.imshow(encrypted)
    plt.title("Encrypted imagery")
    plt.subplot(133)
    plt.imshow(resumedImg)
    plt.title("Resumed imagery")
    plt.show()

结果如下:

代码只写了处理嵌入Cross集的,如果嵌入信息超过Cross集容量上限而要用到Dot集嵌入的我还没写完,看以后有机会补充没有。 

猜你喜欢

转载自blog.csdn.net/m0_46948660/article/details/128724346#comments_27369665