DCGAN生成式对抗网络--keras实现

本文针对cifar10 图集进行了DCGAN的复现。

其中库中的SpectralNormalizationKeras需添加至python环境中 该篇代码如下:

  1 from keras import backend as K
  2 from keras.engine import *
  3 from keras.legacy import interfaces
  4 from keras import activations
  5 from keras import initializers
  6 from keras import regularizers
  7 from keras import constraints
  8 from keras.utils.generic_utils import func_dump
  9 from keras.utils.generic_utils import func_load
 10 from keras.utils.generic_utils import deserialize_keras_object
 11 from keras.utils.generic_utils import has_arg
 12 from keras.utils import conv_utils
 13 from keras.legacy import interfaces
 14 from keras.layers import Dense, Conv1D, Conv2D, Conv3D, Conv2DTranspose, Embedding
 15 import tensorflow as tf
 16 
 17 class DenseSN(Dense):
 18     def build(self, input_shape):
 19         assert len(input_shape) >= 2
 20         input_dim = input_shape[-1]
 21         self.kernel = self.add_weight(shape=(input_dim, self.units),
 22                                       initializer=self.kernel_initializer,
 23                                       name='kernel',
 24                                       regularizer=self.kernel_regularizer,
 25                                       constraint=self.kernel_constraint)
 26         if self.use_bias:
 27             self.bias = self.add_weight(shape=(self.units,),
 28                                         initializer=self.bias_initializer,
 29                                         name='bias',
 30                                         regularizer=self.bias_regularizer,
 31                                         constraint=self.bias_constraint)
 32         else:
 33             self.bias = None
 34         self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
 35                                  initializer=initializers.RandomNormal(0, 1),
 36                                  name='sn',
 37                                  trainable=False)
 38         self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
 39         self.built = True
 40 
 41     def call(self, inputs, training=None):
 42         def _l2normalize(v, eps=1e-12):
 43             return v / (K.sum(v ** 2) ** 0.5 + eps)
 44         def power_iteration(W, u):
 45             _u = u
 46             _v = _l2normalize(K.dot(_u, K.transpose(W)))
 47             _u = _l2normalize(K.dot(_v, W))
 48             return _u, _v
 49         W_shape = self.kernel.shape.as_list()
 50         #Flatten the Tensor
 51         W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
 52         _u, _v = power_iteration(W_reshaped, self.u)
 53         #Calculate Sigma
 54         sigma=K.dot(_v, W_reshaped)
 55         sigma=K.dot(sigma, K.transpose(_u))
 56         #normalize it
 57         W_bar = W_reshaped / sigma
 58         #reshape weight tensor
 59         if training in {0, False}:
 60             W_bar = K.reshape(W_bar, W_shape)
 61         else:
 62             with tf.control_dependencies([self.u.assign(_u)]):
 63                  W_bar = K.reshape(W_bar, W_shape)
 64         output = K.dot(inputs, W_bar)
 65         if self.use_bias:
 66             output = K.bias_add(output, self.bias, data_format='channels_last')
 67         if self.activation is not None:
 68             output = self.activation(output)
 69         return output
 70 
 71 class _ConvSN(Layer):
 72 
 73     def __init__(self, rank,
 74                  filters,
 75                  kernel_size,
 76                  strides=1,
 77                  padding='valid',
 78                  data_format=None,
 79                  dilation_rate=1,
 80                  activation=None,
 81                  use_bias=True,
 82                  kernel_initializer='glorot_uniform',
 83                  bias_initializer='zeros',
 84                  kernel_regularizer=None,
 85                  bias_regularizer=None,
 86                  activity_regularizer=None,
 87                  kernel_constraint=None,
 88                  bias_constraint=None,
 89                  spectral_normalization=True,
 90                  **kwargs):
 91         super(_ConvSN, self).__init__(**kwargs)
 92         self.rank = rank
 93         self.filters = filters
 94         self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, 'kernel_size')
 95         self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
 96         self.padding = conv_utils.normalize_padding(padding)
 97         self.data_format = conv_utils.normalize_data_format(data_format)
 98         self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, 'dilation_rate')
 99         self.activation = activations.get(activation)
100         self.use_bias = use_bias
101         self.kernel_initializer = initializers.get(kernel_initializer)
102         self.bias_initializer = initializers.get(bias_initializer)
103         self.kernel_regularizer = regularizers.get(kernel_regularizer)
104         self.bias_regularizer = regularizers.get(bias_regularizer)
105         self.activity_regularizer = regularizers.get(activity_regularizer)
106         self.kernel_constraint = constraints.get(kernel_constraint)
107         self.bias_constraint = constraints.get(bias_constraint)
108         self.input_spec = InputSpec(ndim=self.rank + 2)
109         self.spectral_normalization = spectral_normalization
110         self.u = None
111 
112     def _l2normalize(self, v, eps=1e-12):
113         return v / (K.sum(v ** 2) ** 0.5 + eps)
114 
115     def power_iteration(self, u, W):
116         '''
117         Accroding the paper, we only need to do power iteration one time.
118         '''
119         v = self._l2normalize(K.dot(u, K.transpose(W)))
120         u = self._l2normalize(K.dot(v, W))
121         return u, v
122     def build(self, input_shape):
123         if self.data_format == 'channels_first':
124             channel_axis = 1
125         else:
126             channel_axis = -1
127         if input_shape[channel_axis] is None:
128             raise ValueError('The channel dimension of the inputs '
129                              'should be defined. Found `None`.')
130         input_dim = input_shape[channel_axis]
131         kernel_shape = self.kernel_size + (input_dim, self.filters)
132 
133         self.kernel = self.add_weight(shape=kernel_shape,
134                                       initializer=self.kernel_initializer,
135                                       name='kernel',
136                                       regularizer=self.kernel_regularizer,
137                                       constraint=self.kernel_constraint)
138 
139         #Spectral Normalization
140         if self.spectral_normalization:
141             self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
142                                      initializer=initializers.RandomNormal(0, 1),
143                                      name='sn',
144                                      trainable=False)
145 
146         if self.use_bias:
147             self.bias = self.add_weight(shape=(self.filters,),
148                                         initializer=self.bias_initializer,
149                                         name='bias',
150                                         regularizer=self.bias_regularizer,
151                                         constraint=self.bias_constraint)
152         else:
153             self.bias = None
154         # Set input spec.
155         self.input_spec = InputSpec(ndim=self.rank + 2,
156                                     axes={channel_axis: input_dim})
157         self.built = True
158 
159     def call(self, inputs):
160         def _l2normalize(v, eps=1e-12):
161             return v / (K.sum(v ** 2) ** 0.5 + eps)
162         def power_iteration(W, u):
163             _u = u
164             _v = _l2normalize(K.dot(_u, K.transpose(W)))
165             _u = _l2normalize(K.dot(_v, W))
166             return _u, _v
167 
168         if self.spectral_normalization:
169             W_shape = self.kernel.shape.as_list()
170             #Flatten the Tensor
171             W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
172             _u, _v = power_iteration(W_reshaped, self.u)
173             #Calculate Sigma
174             sigma=K.dot(_v, W_reshaped)
175             sigma=K.dot(sigma, K.transpose(_u))
176             #normalize it
177             W_bar = W_reshaped / sigma
178             #reshape weight tensor
179             if training in {0, False}:
180                 W_bar = K.reshape(W_bar, W_shape)
181             else:
182                 with tf.control_dependencies([self.u.assign(_u)]):
183                     W_bar = K.reshape(W_bar, W_shape)
184 
185             #update weitht
186             self.kernel = W_bar
187 
188         if self.rank == 1:
189             outputs = K.conv1d(
190                 inputs,
191                 self.kernel,
192                 strides=self.strides[0],
193                 padding=self.padding,
194                 data_format=self.data_format,
195                 dilation_rate=self.dilation_rate[0])
196         if self.rank == 2:
197             outputs = K.conv2d(
198                 inputs,
199                 self.kernel,
200                 strides=self.strides,
201                 padding=self.padding,
202                 data_format=self.data_format,
203                 dilation_rate=self.dilation_rate)
204         if self.rank == 3:
205             outputs = K.conv3d(
206                 inputs,
207                 self.kernel,
208                 strides=self.strides,
209                 padding=self.padding,
210                 data_format=self.data_format,
211                 dilation_rate=self.dilation_rate)
212 
213         if self.use_bias:
214             outputs = K.bias_add(
215                 outputs,
216                 self.bias,
217                 data_format=self.data_format)
218 
219         if self.activation is not None:
220             return self.activation(outputs)
221         return outputs
222 
223     def compute_output_shape(self, input_shape):
224         if self.data_format == 'channels_last':
225             space = input_shape[1:-1]
226             new_space = []
227             for i in range(len(space)):
228                 new_dim = conv_utils.conv_output_length(
229                     space[i],
230                     self.kernel_size[i],
231                     padding=self.padding,
232                     stride=self.strides[i],
233                     dilation=self.dilation_rate[i])
234                 new_space.append(new_dim)
235             return (input_shape[0],) + tuple(new_space) + (self.filters,)
236         if self.data_format == 'channels_first':
237             space = input_shape[2:]
238             new_space = []
239             for i in range(len(space)):
240                 new_dim = conv_utils.conv_output_length(
241                     space[i],
242                     self.kernel_size[i],
243                     padding=self.padding,
244                     stride=self.strides[i],
245                     dilation=self.dilation_rate[i])
246                 new_space.append(new_dim)
247             return (input_shape[0], self.filters) + tuple(new_space)
248 
249     def get_config(self):
250         config = {
251             'rank': self.rank,
252             'filters': self.filters,
253             'kernel_size': self.kernel_size,
254             'strides': self.strides,
255             'padding': self.padding,
256             'data_format': self.data_format,
257             'dilation_rate': self.dilation_rate,
258             'activation': activations.serialize(self.activation),
259             'use_bias': self.use_bias,
260             'kernel_initializer': initializers.serialize(self.kernel_initializer),
261             'bias_initializer': initializers.serialize(self.bias_initializer),
262             'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
263             'bias_regularizer': regularizers.serialize(self.bias_regularizer),
264             'activity_regularizer': regularizers.serialize(self.activity_regularizer),
265             'kernel_constraint': constraints.serialize(self.kernel_constraint),
266             'bias_constraint': constraints.serialize(self.bias_constraint)
267         }
268         base_config = super(_Conv, self).get_config()
269         return dict(list(base_config.items()) + list(config.items()))
270 
271 class ConvSN2D(Conv2D):
272 
273     def build(self, input_shape):
274         if self.data_format == 'channels_first':
275             channel_axis = 1
276         else:
277             channel_axis = -1
278         if input_shape[channel_axis] is None:
279             raise ValueError('The channel dimension of the inputs '
280                              'should be defined. Found `None`.')
281         input_dim = input_shape[channel_axis]
282         kernel_shape = self.kernel_size + (input_dim, self.filters)
283 
284         self.kernel = self.add_weight(shape=kernel_shape,
285                                       initializer=self.kernel_initializer,
286                                       name='kernel',
287                                       regularizer=self.kernel_regularizer,
288                                       constraint=self.kernel_constraint)
289 
290         if self.use_bias:
291             self.bias = self.add_weight(shape=(self.filters,),
292                                         initializer=self.bias_initializer,
293                                         name='bias',
294                                         regularizer=self.bias_regularizer,
295                                         constraint=self.bias_constraint)
296         else:
297             self.bias = None
298 
299         self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
300                          initializer=initializers.RandomNormal(0, 1),
301                          name='sn',
302                          trainable=False)
303 
304         # Set input spec.
305         self.input_spec = InputSpec(ndim=self.rank + 2,
306                                     axes={channel_axis: input_dim})
307         self.built = True
308     def call(self, inputs, training=None):
309         def _l2normalize(v, eps=1e-12):
310             return v / (K.sum(v ** 2) ** 0.5 + eps)
311         def power_iteration(W, u):
312             #Accroding the paper, we only need to do power iteration one time.
313             _u = u
314             _v = _l2normalize(K.dot(_u, K.transpose(W)))
315             _u = _l2normalize(K.dot(_v, W))
316             return _u, _v
317         #Spectral Normalization
318         W_shape = self.kernel.shape.as_list()
319         #Flatten the Tensor
320         W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
321         _u, _v = power_iteration(W_reshaped, self.u)
322         #Calculate Sigma
323         sigma=K.dot(_v, W_reshaped)
324         sigma=K.dot(sigma, K.transpose(_u))
325         #normalize it
326         W_bar = W_reshaped / sigma
327         #reshape weight tensor
328         if training in {0, False}:
329             W_bar = K.reshape(W_bar, W_shape)
330         else:
331             with tf.control_dependencies([self.u.assign(_u)]):
332                 W_bar = K.reshape(W_bar, W_shape)
333 
334         outputs = K.conv2d(
335                 inputs,
336                 W_bar,
337                 strides=self.strides,
338                 padding=self.padding,
339                 data_format=self.data_format,
340                 dilation_rate=self.dilation_rate)
341         if self.use_bias:
342             outputs = K.bias_add(
343                 outputs,
344                 self.bias,
345                 data_format=self.data_format)
346         if self.activation is not None:
347             return self.activation(outputs)
348         return outputs
349 
350 class ConvSN1D(Conv1D):
351 
352     def build(self, input_shape):
353         if self.data_format == 'channels_first':
354             channel_axis = 1
355         else:
356             channel_axis = -1
357         if input_shape[channel_axis] is None:
358             raise ValueError('The channel dimension of the inputs '
359                              'should be defined. Found `None`.')
360         input_dim = input_shape[channel_axis]
361         kernel_shape = self.kernel_size + (input_dim, self.filters)
362 
363         self.kernel = self.add_weight(shape=kernel_shape,
364                                       initializer=self.kernel_initializer,
365                                       name='kernel',
366                                       regularizer=self.kernel_regularizer,
367                                       constraint=self.kernel_constraint)
368 
369         if self.use_bias:
370             self.bias = self.add_weight(shape=(self.filters,),
371                                         initializer=self.bias_initializer,
372                                         name='bias',
373                                         regularizer=self.bias_regularizer,
374                                         constraint=self.bias_constraint)
375         else:
376             self.bias = None
377 
378         self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
379                  initializer=initializers.RandomNormal(0, 1),
380                  name='sn',
381                  trainable=False)
382         # Set input spec.
383         self.input_spec = InputSpec(ndim=self.rank + 2,
384                                     axes={channel_axis: input_dim})
385         self.built = True
386 
387     def call(self, inputs, training=None):
388         def _l2normalize(v, eps=1e-12):
389             return v / (K.sum(v ** 2) ** 0.5 + eps)
390         def power_iteration(W, u):
391             #Accroding the paper, we only need to do power iteration one time.
392             _u = u
393             _v = _l2normalize(K.dot(_u, K.transpose(W)))
394             _u = _l2normalize(K.dot(_v, W))
395             return _u, _v
396         #Spectral Normalization
397         W_shape = self.kernel.shape.as_list()
398         #Flatten the Tensor
399         W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
400         _u, _v = power_iteration(W_reshaped, self.u)
401         #Calculate Sigma
402         sigma=K.dot(_v, W_reshaped)
403         sigma=K.dot(sigma, K.transpose(_u))
404         #normalize it
405         W_bar = W_reshaped / sigma
406         #reshape weight tensor
407         if training in {0, False}:
408             W_bar = K.reshape(W_bar, W_shape)
409         else:
410             with tf.control_dependencies([self.u.assign(_u)]):
411                 W_bar = K.reshape(W_bar, W_shape)
412 
413         outputs = K.conv1d(
414                 inputs,
415                 W_bar,
416                 strides=self.strides,
417                 padding=self.padding,
418                 data_format=self.data_format,
419                 dilation_rate=self.dilation_rate)
420         if self.use_bias:
421             outputs = K.bias_add(
422                 outputs,
423                 self.bias,
424                 data_format=self.data_format)
425         if self.activation is not None:
426             return self.activation(outputs)
427         return outputs
428 
429 class ConvSN3D(Conv3D):
430     def build(self, input_shape):
431         if self.data_format == 'channels_first':
432             channel_axis = 1
433         else:
434             channel_axis = -1
435         if input_shape[channel_axis] is None:
436             raise ValueError('The channel dimension of the inputs '
437                              'should be defined. Found `None`.')
438         input_dim = input_shape[channel_axis]
439         kernel_shape = self.kernel_size + (input_dim, self.filters)
440 
441         self.kernel = self.add_weight(shape=kernel_shape,
442                                       initializer=self.kernel_initializer,
443                                       name='kernel',
444                                       regularizer=self.kernel_regularizer,
445                                       constraint=self.kernel_constraint)
446 
447         self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
448                          initializer=initializers.RandomNormal(0, 1),
449                          name='sn',
450                          trainable=False)
451 
452         if self.use_bias:
453             self.bias = self.add_weight(shape=(self.filters,),
454                                         initializer=self.bias_initializer,
455                                         name='bias',
456                                         regularizer=self.bias_regularizer,
457                                         constraint=self.bias_constraint)
458         else:
459             self.bias = None
460         # Set input spec.
461         self.input_spec = InputSpec(ndim=self.rank + 2,
462                                     axes={channel_axis: input_dim})
463         self.built = True
464 
465     def call(self, inputs, training=None):
466         def _l2normalize(v, eps=1e-12):
467             return v / (K.sum(v ** 2) ** 0.5 + eps)
468         def power_iteration(W, u):
469             #Accroding the paper, we only need to do power iteration one time.
470             _u = u
471             _v = _l2normalize(K.dot(_u, K.transpose(W)))
472             _u = _l2normalize(K.dot(_v, W))
473             return _u, _v
474         #Spectral Normalization
475         W_shape = self.kernel.shape.as_list()
476         #Flatten the Tensor
477         W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
478         _u, _v = power_iteration(W_reshaped, self.u)
479         #Calculate Sigma
480         sigma=K.dot(_v, W_reshaped)
481         sigma=K.dot(sigma, K.transpose(_u))
482         #normalize it
483         W_bar = W_reshaped / sigma
484         #reshape weight tensor
485         if training in {0, False}:
486             W_bar = K.reshape(W_bar, W_shape)
487         else:
488             with tf.control_dependencies([self.u.assign(_u)]):
489                 W_bar = K.reshape(W_bar, W_shape)
490 
491         outputs = K.conv3d(
492                 inputs,
493                 W_bar,
494                 strides=self.strides,
495                 padding=self.padding,
496                 data_format=self.data_format,
497                 dilation_rate=self.dilation_rate)
498         if self.use_bias:
499             outputs = K.bias_add(
500                 outputs,
501                 self.bias,
502                 data_format=self.data_format)
503         if self.activation is not None:
504             return self.activation(outputs)
505         return outputs
506 
507 
508 class EmbeddingSN(Embedding):
509 
510     def build(self, input_shape):
511         self.embeddings = self.add_weight(
512             shape=(self.input_dim, self.output_dim),
513             initializer=self.embeddings_initializer,
514             name='embeddings',
515             regularizer=self.embeddings_regularizer,
516             constraint=self.embeddings_constraint,
517             dtype=self.dtype)
518 
519         self.u = self.add_weight(shape=tuple([1, self.embeddings.shape.as_list()[-1]]),
520                          initializer=initializers.RandomNormal(0, 1),
521                          name='sn',
522                          trainable=False)
523 
524         self.built = True
525 
526     def call(self, inputs):
527         if K.dtype(inputs) != 'int32':
528             inputs = K.cast(inputs, 'int32')
529 
530         def _l2normalize(v, eps=1e-12):
531             return v / (K.sum(v ** 2) ** 0.5 + eps)
532         def power_iteration(W, u):
533             #Accroding the paper, we only need to do power iteration one time.
534             _u = u
535             _v = _l2normalize(K.dot(_u, K.transpose(W)))
536             _u = _l2normalize(K.dot(_v, W))
537             return _u, _v
538         W_shape = self.embeddings.shape.as_list()
539         #Flatten the Tensor
540         W_reshaped = K.reshape(self.embeddings, [-1, W_shape[-1]])
541         _u, _v = power_iteration(W_reshaped, self.u)
542         #Calculate Sigma
543         sigma=K.dot(_v, W_reshaped)
544         sigma=K.dot(sigma, K.transpose(_u))
545         #normalize it
546         W_bar = W_reshaped / sigma
547         #reshape weight tensor
548         if training in {0, False}:
549             W_bar = K.reshape(W_bar, W_shape)
550         else:
551             with tf.control_dependencies([self.u.assign(_u)]):
552                 W_bar = K.reshape(W_bar, W_shape)
553         self.embeddings = W_bar
554 
555         out = K.gather(self.embeddings, inputs)
556         return out
557 
558 class ConvSN2DTranspose(Conv2DTranspose):
559 
560     def build(self, input_shape):
561         if len(input_shape) != 4:
562             raise ValueError('Inputs should have rank ' +
563                              str(4) +
564                              '; Received input shape:', str(input_shape))
565         if self.data_format == 'channels_first':
566             channel_axis = 1
567         else:
568             channel_axis = -1
569         if input_shape[channel_axis] is None:
570             raise ValueError('The channel dimension of the inputs '
571                              'should be defined. Found `None`.')
572         input_dim = input_shape[channel_axis]
573         kernel_shape = self.kernel_size + (self.filters, input_dim)
574 
575         self.kernel = self.add_weight(shape=kernel_shape,
576                                       initializer=self.kernel_initializer,
577                                       name='kernel',
578                                       regularizer=self.kernel_regularizer,
579                                       constraint=self.kernel_constraint)
580         if self.use_bias:
581             self.bias = self.add_weight(shape=(self.filters,),
582                                         initializer=self.bias_initializer,
583                                         name='bias',
584                                         regularizer=self.bias_regularizer,
585                                         constraint=self.bias_constraint)
586         else:
587             self.bias = None
588 
589         self.u = self.add_weight(shape=tuple([1, self.kernel.shape.as_list()[-1]]),
590                          initializer=initializers.RandomNormal(0, 1),
591                          name='sn',
592                          trainable=False)
593 
594         # Set input spec.
595         self.input_spec = InputSpec(ndim=4, axes={channel_axis: input_dim})
596         self.built = True
597 
598     def call(self, inputs):
599         input_shape = K.shape(inputs)
600         batch_size = input_shape[0]
601         if self.data_format == 'channels_first':
602             h_axis, w_axis = 2, 3
603         else:
604             h_axis, w_axis = 1, 2
605 
606         height, width = input_shape[h_axis], input_shape[w_axis]
607         kernel_h, kernel_w = self.kernel_size
608         stride_h, stride_w = self.strides
609         if self.output_padding is None:
610             out_pad_h = out_pad_w = None
611         else:
612             out_pad_h, out_pad_w = self.output_padding
613 
614         # Infer the dynamic output shape:
615         out_height = conv_utils.deconv_length(height,
616                                               stride_h, kernel_h,
617                                               self.padding,
618                                               out_pad_h)
619         out_width = conv_utils.deconv_length(width,
620                                              stride_w, kernel_w,
621                                              self.padding,
622                                              out_pad_w)
623         if self.data_format == 'channels_first':
624             output_shape = (batch_size, self.filters, out_height, out_width)
625         else:
626             output_shape = (batch_size, out_height, out_width, self.filters)
627 
628         #Spectral Normalization
629         def _l2normalize(v, eps=1e-12):
630             return v / (K.sum(v ** 2) ** 0.5 + eps)
631         def power_iteration(W, u):
632             #Accroding the paper, we only need to do power iteration one time.
633             _u = u
634             _v = _l2normalize(K.dot(_u, K.transpose(W)))
635             _u = _l2normalize(K.dot(_v, W))
636             return _u, _v
637         W_shape = self.kernel.shape.as_list()
638         #Flatten the Tensor
639         W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]])
640         _u, _v = power_iteration(W_reshaped, self.u)
641         #Calculate Sigma
642         sigma=K.dot(_v, W_reshaped)
643         sigma=K.dot(sigma, K.transpose(_u))
644         #normalize it
645         W_bar = W_reshaped / sigma
646         #reshape weight tensor
647         if training in {0, False}:
648             W_bar = K.reshape(W_bar, W_shape)
649         else:
650             with tf.control_dependencies([self.u.assign(_u)]):
651                 W_bar = K.reshape(W_bar, W_shape)
652         self.kernel = W_bar
653 
654         outputs = K.conv2d_transpose(
655             inputs,
656             self.kernel,
657             output_shape,
658             self.strides,
659             padding=self.padding,
660             data_format=self.data_format)
661 
662         if self.use_bias:
663             outputs = K.bias_add(
664                 outputs,
665                 self.bias,
666                 data_format=self.data_format)
667 
668         if self.activation is not None:
669             return self.activation(outputs)
670         return outputs
View Code

完成了该部之后开始正文。

首先是导入数据集:

 1 # 导入CIFAR10数据集
 2 # 读取数据
 3 def unpickle(file):
 4     import pickle
 5     with open(file, 'rb') as fo:
 6         dict = pickle.load(fo, encoding='bytes')
 7     return dict
 8 
 9 cifar={}
10 for i in range(5):
11     cifar1=unpickle('data_batch_'+str(i+1))
12     if i==0:
13         cifar[b'data']=cifar1[b'data']
14         cifar[b'labels']=cifar1[b'labels']
15     else:
16         cifar[b'data']=np.vstack([cifar1[b'data'],cifar[b'data']])
17         cifar[b'labels']=np.hstack([cifar1[b'labels'],cifar[b'labels']])
18 target_name=unpickle('batches.meta')
19 cifar[b'label_names']=target_name[b'label_names']
20 del cifar1
21 
22 # 定义数据格式
23 blank_image= np.zeros((len(cifar[b'data']),32,32,3), np.uint8)
24 for i in range(len(cifar[b'data'])):
25     blank_image[i] = np.zeros((32,32,3), np.uint8)
26     blank_image[i][:,:,0]=cifar[b'data'][i][0:1024].reshape(32,32)
27     blank_image[i][:,:,1]=cifar[b'data'][i][1024:1024*2].reshape(32,32)
28     blank_image[i][:,:,2]=cifar[b'data'][i][1024*2:1024*3].reshape(32,32)
29 cifar[b'data']=blank_image
30 
31 cifar_test=unpickle('test_batch')
32 cifar_test[b'labels']=np.array(cifar_test[b'labels'])
33 blank_image= np.zeros((len(cifar_test[b'data']),32,32,3), np.uint8)
34 for i in range(len(cifar_test[b'data'])):
35     blank_image[i] = np.zeros((32,32,3), np.uint8)
36     blank_image[i][:,:,0]=cifar_test[b'data'][i][0:1024].reshape(32,32)
37     blank_image[i][:,:,1]=cifar_test[b'data'][i][1024:1024*2].reshape(32,32)
38     blank_image[i][:,:,2]=cifar_test[b'data'][i][1024*2:1024*3].reshape(32,32)
39 cifar_test[b'data']=blank_image
40 
41 
42 x_train=cifar[b'data']
43 x_test=cifar[b'data']
44 x_test=cifar_test[b'data']
45 y_test=cifar_test[b'labels']
46 X = np.concatenate((x_test,x_train))
View Code

以上是在cifar 10 官方网站下载的数据文件。也可以使用keras官方的cifar10导入代码:

from keras.datasets import cifar100, cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

接下来是生成器与判别器的构造:

 1 # Hyperperemeter
 2 # 生成器与判别器构造
 3 BATCHSIZE=64
 4 LEARNING_RATE = 0.0002
 5 TRAINING_RATIO = 1
 6 BETA_1 = 0.0
 7 BETA_2 = 0.9
 8 EPOCHS = 500
 9 BN_MIMENTUM = 0.1
10 BN_EPSILON  = 0.00002
11 SAVE_DIR = 'img/generated_img_CIFAR10_DCGAN/'
12 
13 GENERATE_ROW_NUM = 8
14 GENERATE_BATCHSIZE = GENERATE_ROW_NUM*GENERATE_ROW_NUM
15 
16 def BuildGenerator(summary=True):
17     model = Sequential()
18     model.add(Dense(4*4*512, kernel_initializer='glorot_uniform' , input_dim=128))
19     model.add(Reshape((4,4,512)))
20     model.add(Conv2DTranspose(256, kernel_size=4, strides=2, padding='same', activation='relu',kernel_initializer='glorot_uniform'))
21     model.add(BatchNormalization(epsilon=BN_EPSILON, momentum=BN_MIMENTUM))
22     model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu',kernel_initializer='glorot_uniform'))
23     model.add(BatchNormalization(epsilon=BN_EPSILON, momentum=BN_MIMENTUM))
24     model.add(Conv2DTranspose(64,  kernel_size=4, strides=2, padding='same', activation='relu',kernel_initializer='glorot_uniform'))
25     model.add(BatchNormalization(epsilon=BN_EPSILON, momentum=BN_MIMENTUM))
26     model.add(Conv2DTranspose(3,   kernel_size=3, strides=1, padding='same', activation='tanh'))
27     if summary:
28         print("Generator")
29         model.summary()
30     return model
31 
32 def BuildDiscriminator(summary=True, spectral_normalization=True):
33     if spectral_normalization:
34         model = Sequential()
35         model.add(ConvSN2D(64, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same', input_shape=(32,32,3) ))
36         model.add(LeakyReLU(0.1))
37         model.add(ConvSN2D(64, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
38         model.add(LeakyReLU(0.1))
39         model.add(ConvSN2D(128, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
40         model.add(LeakyReLU(0.1))
41         model.add(ConvSN2D(128, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
42         model.add(LeakyReLU(0.1))
43         model.add(ConvSN2D(256, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
44         model.add(LeakyReLU(0.1))
45         model.add(ConvSN2D(256, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
46         model.add(LeakyReLU(0.1))
47         model.add(ConvSN2D(512, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
48         model.add(LeakyReLU(0.1))
49         model.add(Flatten())
50         model.add(DenseSN(1,kernel_initializer='glorot_uniform'))
51     else:
52         model = Sequential()
53         model.add(Conv2D(64, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same', input_shape=(32,32,3) ))
54         model.add(LeakyReLU(0.1))
55         model.add(Conv2D(64, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
56         model.add(LeakyReLU(0.1))
57         model.add(Conv2D(128, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
58         model.add(LeakyReLU(0.1))
59         model.add(Conv2D(128, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
60         model.add(LeakyReLU(0.1))
61         model.add(Conv2D(256, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
62         model.add(LeakyReLU(0.1))
63         model.add(Conv2D(256, kernel_size=4, strides=2,kernel_initializer='glorot_uniform', padding='same'))
64         model.add(LeakyReLU(0.1))
65         model.add(Conv2D(512, kernel_size=3, strides=1,kernel_initializer='glorot_uniform', padding='same'))
66         model.add(LeakyReLU(0.1))
67         model.add(Flatten())
68         model.add(Dense(1,kernel_initializer='glorot_uniform'))
69     if summary:
70         print('Discriminator')
71         print('Spectral Normalization: {}'.format(spectral_normalization))
72         model.summary()
73     return model
74 
75 def wasserstein_loss(y_true, y_pred):
76     return K.mean(y_true*y_pred)
77 
78 generator = BuildGenerator()
79 discriminator = BuildDiscriminator()
View Code

然后是训练器的构造:

 1 # 生成器训练模型
 2 Noise_input_for_training_generator = Input(shape=(128,))
 3 Generated_image                    = generator(Noise_input_for_training_generator)
 4 Discriminator_output               = discriminator(Generated_image)
 5 model_for_training_generator       = Model(Noise_input_for_training_generator, Discriminator_output)
 6 print("model_for_training_generator")
 7 model_for_training_generator.summary()
 8 
 9 discriminator.trainable = False
10 model_for_training_generator.summary()
11 
12 model_for_training_generator.compile(optimizer=Adam(LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2), loss=wasserstein_loss)
13 
14 
15 # 判别器训练模型
16 Real_image                             = Input(shape=(32,32,3))
17 Noise_input_for_training_discriminator = Input(shape=(128,))
18 Fake_image                             = generator(Noise_input_for_training_discriminator)
19 Discriminator_output_for_real          = discriminator(Real_image)
20 Discriminator_output_for_fake          = discriminator(Fake_image)
21 
22 model_for_training_discriminator       = Model([Real_image,
23                                                 Noise_input_for_training_discriminator],
24                                                [Discriminator_output_for_real,
25                                                 Discriminator_output_for_fake])
26 print("model_for_training_discriminator")
27 generator.trainable = False
28 discriminator.trainable = True
29 model_for_training_discriminator.compile(optimizer=Adam(LEARNING_RATE, beta_1=BETA_1, beta_2=BETA_2), loss=[wasserstein_loss, wasserstein_loss])
30 model_for_training_discriminator.summary()
31 
32 
33 real_y = np.ones((BATCHSIZE, 1), dtype=np.float32)
34 fake_y = -real_y
35 
36 X = X/255*2-1
37 
38 plt.imshow((X[8787]+1)/2)
View Code

最后是重复训练:

 1 test_noise = np.random.randn(GENERATE_BATCHSIZE, 128)
 2 W_loss = []
 3 discriminator_loss = []
 4 generator_loss = []
 5 for epoch in range(EPOCHS):
 6     np.random.shuffle(X)
 7 
 8     print("epoch {} of {}".format(epoch+1, EPOCHS))
 9     num_batches = int(X.shape[0] // BATCHSIZE)
10 
11     print("number of batches: {}".format(int(X.shape[0] // (BATCHSIZE))))
12 
13     progress_bar = Progbar(target=int(X.shape[0] // (BATCHSIZE * TRAINING_RATIO)))
14     minibatches_size = BATCHSIZE * TRAINING_RATIO
15 
16     start_time = time()
17     for index in range(int(X.shape[0] // (BATCHSIZE * TRAINING_RATIO))):
18         progress_bar.update(index)
19         discriminator_minibatches = X[index * minibatches_size:(index + 1) * minibatches_size]
20 
21         for j in range(TRAINING_RATIO):
22             image_batch = discriminator_minibatches[j * BATCHSIZE : (j + 1) * BATCHSIZE]
23             noise = np.random.randn(BATCHSIZE, 128).astype(np.float32)
24             discriminator.trainable = True
25             generator.trainable = False
26             discriminator_loss.append(model_for_training_discriminator.train_on_batch([image_batch, noise],
27                                                                                       [real_y, fake_y]))
28         discriminator.trainable = False
29         generator.trainable = True
30         generator_loss.append(model_for_training_generator.train_on_batch(np.random.randn(BATCHSIZE, 128), real_y))
31 
32     print('\nepoch time: {}'.format(time()-start_time))
33 
34     W_real = model_for_training_generator.evaluate(test_noise, real_y)
35     print(W_real)
36     W_fake = model_for_training_generator.evaluate(test_noise, fake_y)
37     print(W_fake)
38     W_l = W_real+W_fake
39     print('wasserstein_loss: {}'.format(W_l))
40     W_loss.append(W_l)
41     #Generate image
42     generated_image = generator.predict(test_noise)
43     generated_image = (generated_image+1)/2
44     for i in range(GENERATE_ROW_NUM):
45         new = generated_image[i*GENERATE_ROW_NUM:i*GENERATE_ROW_NUM+GENERATE_ROW_NUM].reshape(32*GENERATE_ROW_NUM,32,3)
46         if i!=0:
47             old = np.concatenate((old,new),axis=1)
48         else:
49             old = new
50     print('plot generated_image')
51     plt.imsave('{}/SN_epoch_{}.png'.format(SAVE_DIR, epoch), old)
View Code

训练十轮之后生成的图片有显著的提升,结果如下:

第一轮:

 

第10轮:

猜你喜欢

转载自www.cnblogs.com/techs-wenzhe/p/11765429.html