# -*- coding:utf-8 -*-
#Author: shenying
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import sys
import scipy.io
import scipy.misc
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow
from PIL import Image
from nst_utils import *
import numpy as np
import tensorflow as tf
def compute_content_cost(a_C,a_G):
(m,n_H,n_W,n_C)=a_G.get_shape().as_list()
a_C_unrolled=tf.reshape(a_C,[n_H*n_W,n_C])
a_G_unrolled=tf.reshape(a_G,[n_H*n_W,n_C])
J_content=tf.reduce_sum(tf.square(tf.subtract(a_C_unrolled,a_G_unrolled)))/(4*n_H*n_W*n_C)
return J_content
# tf.reset_default_graph()
# with tf.Session() as test:
# tf.set_random_seed(1)
# a_C = tf.random_normal([1, 4, 4, 3], mean=1, stddev=4)
# a_G = tf.random_normal([1, 4, 4, 3], mean=1, stddev=4)
# J_content = compute_content_cost(a_C, a_G)
# print("J_content = " + str(J_content.eval()))
def gram_matrix(A):
GA=tf.matmul(A,A,transpose_a=False,transpose_b=True)
return GA
tf.reset_default_graph()
# with tf.Session() as test:
# tf.set_random_seed(1)
# A = tf.random_normal([3, 2 * 1], mean=1, stddev=4)
# GA = gram_matrix(A)
#
# print("GA = " + str(GA.eval()))
def compute_layer_style_cost(a_S,a_G):
(m,n_H,n_W,n_C)=a_G.get_shape().as_list()
a_S=tf.reshape(a_S,[n_H*n_W,n_C])
a_G=tf.reshape(a_G,[n_H*n_W,n_C])
GS=gram_matrix(tf.transpose(a_S))
GG=gram_matrix(tf.transpose(a_G))
# J_style_layer=tf.reduce_sum(tf.reduce_sum(tf.square(tf.subtract(GS,GG))))/(4*tf.to_float(tf.square(n_C*n_H*n_W)))
J_style_layer = tf.reduce_sum(tf.square(tf.subtract(GS, GG))) / (4 * tf.to_float(tf.square(n_C * n_H * n_W)))
return J_style_layer
# tf.reset_default_graph()
# with tf.Session() as test:
# tf.set_random_seed(1)
# a_S = tf.random_normal([1, 4, 4, 3], mean=1, stddev=4)
# a_G = tf.random_normal([1, 4, 4, 3], mean=1, stddev=4)
# J_style_layer = compute_layer_style_cost(a_S, a_G)
# print("J_style_layer = " + str(J_style_layer.eval()))
def compute_style_cost(model,STYLE_LAYERS):
# with tf.Session as sess:
J_style=0
for layer_name,coeff in STYLE_LAYERS:
out=model[layer_name]
a_S=sess.run(out)
a_G=out
J_style_layer=compute_layer_style_cost(a_S,a_G)
J_style+=coeff*J_style_layer
return J_style
def total_cost(J_content,J_style,alpha=10,beta=40):
J=alpha*J_content+beta*J_style
return J
# tf.reset_default_graph()
# with tf.Session() as test:
# np.random.seed(3)
# J_content = np.random.randn()
# J_style = np.random.randn()
# J = total_cost(J_content, J_style)
# print("J = " + str(J))
def model_nn(sess,input_image,num_iterations=2000):
sess.run(tf.global_variables_initializer())
sess.run(model['input'].assign(input_image))
for i in range(num_iterations):
sess.run(train_step)
generated_image=sess.run(model['input'])
if i%200==0:
Jt,Jc,Js=sess.run([J,J_content,J_style])
print("Iteration "+str(i)+":")
print("total cost ="+str(Jt))
print("content cost ="+str(Jc))
print("style cost="+str(Js))
save_image('output/'+str(i)+'.png',generated_image)
imshow(generated_image[0])
plt.show()
save_image('output/generated_image.jpg',generated_image)
imshow(generated_image[0])
plt.show()
return generated_image
if __name__=='__main__':
tf.reset_default_graph()
sess = tf.InteractiveSession()
content_image = scipy.misc.imread('images/louvre_small.jpg')
imshow(content_image)
plt.show()
content_image = reshape_and_normalize_image(content_image)
style_image = scipy.misc.imread('images/monet.jpg')
imshow(style_image)
plt.show()
style_image = reshape_and_normalize_image(style_image)
generated_image = generate_noise_image(content_image)
imshow(generated_image[0])
plt.show()
model = load_vgg_model("imagenet-vgg-verydeep-19.mat")
sess.run(model['input'].assign(content_image))
out = model['conv4_2']
a_C = sess.run(out)
a_G = out
J_content = compute_content_cost(a_C, a_G)
STYLE_LAYERS = [
('conv1_1', 0.2),
('conv2_1', 0.2),
('conv3_1', 0.2),
('conv4_1', 0.2),
('conv5_1', 0.2)]
sess.run(model['input'].assign(style_image))
J_style = compute_style_cost(model, STYLE_LAYERS)
J = total_cost(J_content, J_style, alpha=10, beta=40)
optimizer = tf.train.AdamOptimizer(3.0)
train_step = optimizer.minimize(J)
model_nn(sess, generated_image)
吴恩达 深度学习 第四课 风格转换
猜你喜欢
转载自blog.csdn.net/qq_31119155/article/details/81070033
今日推荐
周排行