对抗生成网络-图像卷积-mnist数据生成

情景说明:这里演示的代码使用的mnist数据集,使用[batch_size, noise_size], 即[None, 100]的数据集经过反卷积生成[None, 28, 28, 1]的mnist图片,范围为-1, 1 

模型的结构说明:

                对于生成网络:[None, 100] - > [None, 4*4*512](全连接) -> [None, 7, 7, 256](反卷积) -> [None, 14, 14, 128](反卷积) -> [None, 28, 28, 3](反卷积) -> tf.tanh(-1, 1的变化)

                对于判别网络:[None, 28, 28, 1] -> [None, 14, 14, 128](卷积) -> [None, 7, 7, 256](卷积) -> [None, 4, 4, 512](卷积) ->[None, 512*4*4](维度变化) ->[None, 1] (全连接) -> tf.sigmoid(0, 1的变化)

生成网络参数说明: 第一层全连接,使用的w为 [100, 4*4*512], 使用的b为[4*4*512]

                               第二层反卷积,使用的w为[4, 4, 512, 256], 使用的b为[256]

                               第三层反卷积,使用的w为[3, 3, 256, 128], 使用的b为[128]

                               第四层反卷积,使用的w为[3, 3, 128, 1], 使用的b为[1]

判别网络参数说明:第一层卷积,使用的w为[3, 3, 1, 128] 使用的b是[128]

                                第二层卷积, 使用的w为[3, 3, 128, 256], 使用的b是[256]

                                第三层卷积,使用的w为[3, 3, 256, 512], 使用的b是[512]

                                第四层全连接,使用的w为[4*4*512, 1] 使用的b是[1]

代码说明:由于上述代码建立了多个函数,因此我们先进行主函数的说明

建立函数train()进行代码的训练操作

 第一步:首先是定义参数, 包括input_size(输入图片的维度), noise_size(噪声图片的维度), output_dim(生成图片的维度),batch_size(一个batch的大小)

第二步:使用get_inputs(input_size, noise_size) 获得生成的初始化的real_image 和 noise_image, 对生成的real_image进行维度的变化,将其变化为[-1, 28, 28, 1], 以用于判别网络进行判断

第三步:将生成的real_image_reshape 和 noise_image 输入到get_loss里面 获得 d_loss 和 g_loss 

         get_loss函数说明:输入的数据为real_image_reshape, noise_image, output_dim

           第一步: 调用get_generator(noise_image, output_dim, True)生成g_outputs

           第二步:调用get_discrimator(real_image, reuse=False, ) 判别真实样本的输入

           第三步: 调用get_discrimator(g_outputs, reuse=True) 判别生成样本的输入

          第四步: 构造判别网络d_loss的损失值

                      第一步:构造判别网络的真实样本的损失值, tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=, label=tf.ones_like(real_logits)))

                      第二步:构造判别网络的生成样本的损失值

                      第三步:将两种损失值进行加和操作

          第五步:构造生成网络g_loss的损失值, tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=, label=tf.ones_like))

          第六步:返回d_loss 和 g_loss 

 第四步: 调用get_opt, 输入为d_loss, g_loss, learning_rate=0.001, 进行优化操作, 获得g_opt, 和d_opt 

           第一步:获得g_var 和 d_var 进行训练时生成网络的参数和判别网络的参数

                          第一步:使用varias = tf.train_variable() 获得训练时的参数

                          第二步:使用g_var = [v for v in varias if v.name.startswith('generator')] 来获得训练过程中的生成网络的参数

                          第三步:d_var = [v for v in varias if v.name.startswith('discrimator')] 来获得训练过程中的判别网络的参数

           第二步:构造优化的操作g_opt, d_opt 

                          第一步: 使用tf.train.Adaoptimer(learning_rate).minimize(d_loss, var_list=d_var), 构造d_opt 

                          第二步: 使用tf.train.Adaoptimer(learning_rate).minimize(g_loss, var_list=g_var)  构造g_opt 

第五步:进入循环操作,迭代30次,进行训练操作,一个epoch迭代的次数为range(mnist.train.num_examples // batch_size)

            第一步:读取一个batch_size的数据image_batch,并将其转换为-1, 1的范围

                           第一步:使用mnist.train.next_batch(batch_size)获得一个batch的数据

                           第二步:使用batch_image = batch[0]获得其中的图片

                           第三步: batch_image = batch_image * 2 - 1将image的范围从0, 1变化为-1, 1

             第二步:使用np.random.uniform(-1, 1, (batch_size, noise_size)) 生成一个batch的噪音数据,大小为batch_size, noise_size 

             第三步:使用sess.run(g_opt, feed_dict) 和 d_opt更新参数

             第四步:在一个epoch结束后,执行sess.run(d_loss) 和 sess,run(g_loss)

             第五步:调用show_sample_image获得生成后的图片, 输入为sess, n_sample, input_image, output_dim, reuse , 这里的input_image表示real_image

                              第一步:使用input_image.get_shape().as_list()[-1]获得noise的维度,这里是100

                              第二步:使用np.random.uniform(-1, 1, (n_sample, input_shape)) 生成50,100的噪音数据

                              第三步:使用sess.run(get_generator(input_image, output_dim, reuse), feed_dict={input_image, sample_noise})

                              第四步:返回生成的图片sample

              第六步:调用plot_show()进行图片的展示

              第七步:打印结果

                             

                                         

                                                                             

猜你喜欢

转载自www.cnblogs.com/my-love-is-python/p/10697953.html