[A] Keras keras principle source of neural network running in-depth analysis

Copyright notice: reproduced please indicate the source and marked "AI algorithms Wuhan study" https://blog.csdn.net/qq_36931982/article/details/90312852

model.fit(X_train,y_train,batch_size=BATCH_SIZE,nb_epoch=1,validation_data=(X_val,y_val))

These are the codes keras be fit model training, its real implementation process what is it?

More final call is training.Model.fit () method, mainly in the fit method steps:

  1. The legality of processing model parameters, verification of data related to preparations
  2. Input data is ready to model functions and related training

More than ready to work after the best and delegate the follow-up work entrusted to training_arrays.fit_loop () method, apart from data processing, preparation, training is the main code during this cycle, it is critical:

callbacks.set_model(callback_model)
    callbacks.set_params({
        'batch_size': batch_size,
        'epochs': epochs,
        'steps': steps_per_epoch,
        'samples': num_train_samples,
        'verbose': verbose,
        'do_validation': do_validation,
        'metrics': callback_metrics or [],
    })
callbacks.on_train_begin() 
for epoch in range(initial_epoch, nb_epoch):
        # 记录本回epoch的历史信息
        callbacks.on_epoch_begin(epoch)
        # 按照batch批次打混索引
        if shuffle == 'batch':
            index_array = batch_shuffle(index_array, batch_size)
        elif shuffle:
            np.random.shuffle(index_array)
        # 得到一个批次的索引
        batches = make_batches(nb_train_sample, batch_size)
        epoch_logs = {}
        #........
        #省略逻辑见下 部分 
        #........
        callbacks.on_epoch_end(epoch, epoch_logs)
        if callback_model.stop_training:
            break

callbacks.on_train_end()

Above {for epoch in} code logic mainly circulated for each epoch, wherein the core processing codes for each batch is shown below: =

            for batch_index, (batch_start, batch_end) in enumerate(batches):
                batch_ids = index_array[batch_start:batch_end]
                try:
                    if isinstance(ins[-1], float):
                        # Do not slice the training phase flag.
                        ins_batch = slice_arrays(
                            ins[:-1], batch_ids) + [ins[-1]]
                    else:
                        ins_batch = slice_arrays(ins, batch_ids)
                except TypeError:
                    raise TypeError('TypeError while preparing batch. '
                                    'If using HDF5 input data, '
                                    'pass shuffle="batch".')
                batch_logs = {}
                batch_logs['batch'] = batch_index
                batch_logs['size'] = len(batch_ids)
                
                #回调:每个batch的开始处:logs包含size,即当前batch的样本数
                callbacks.on_batch_begin(batch_index, batch_logs)
                for i in indices_for_conversion_to_dense:
                    ins_batch[i] = ins_batch[i].toarray()

                outs = f(ins_batch)
                outs = to_list(outs)
                for l, o in zip(out_labels, outs):
                    batch_logs[l] = o
                #回调:batch结束:logs包含loss,若启用accuracy则还包含acc
                callbacks.on_batch_end(batch_index, batch_logs)
                if callback_model.stop_training:
                    break

                if batch_index == len(batches) - 1:  # Last batch.
                    if do_validation:
                        val_outs = test_loop(model, val_f, val_ins,
                                             batch_size=batch_size,
                                             verbose=0)
                        val_outs = to_list(val_outs)
                        # Same labels assumed.
                        for l, o in zip(out_labels, val_outs):
                            epoch_logs['val_' + l] = o

[1] callback function callback

The above is the entire fit_loop () function call to the code, which code is the key point there is a callback function:

  1. on_epoch_begin: call at the beginning of each epoch
  2. on_epoch_end: called at the end of each epoch
  3. on_batch_begin: call at the beginning of each batch
  4. on_batch_end: called at the end of each batch
  5. on_train_begin: Called when the start of training
  6. on_train_end: called at the end of training

Which is the main callback function on_batch_end callback function, e.gkeras.callbacks.BaseLogger

The statistics batch of training inside and values ​​acc loss is recognized in totals, multiplied by batch_size.

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size

        for k, v in logs.items():
            if k in self.stateful_metrics:
                self.totals[k] = v
            else:
                if k in self.totals:
                    self.totals[k] += v * batch_size
                else:
                    self.totals[k] = v * batch_size

Where the callback function on_epoch_end, e.gkeras.callbacks.BaseLogger

on_epoch_end function of this class, the implementation of this epoch averaged training data loss and acc.

    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            for k in self.params['metrics']:
                if k in self.totals:
                    # Make value available to next callbacks.
                    if k in self.stateful_metrics:
                        logs[k] = self.totals[k]
                    else:
                        logs[k] = self.totals[k] / self.seen

supplement:

keras.callbacks.ModelCheckpoint

Model data into the file will be saved in the on_epoch_end

keras.callbacks.History

Every major record epoch result of the training, and the results include loss of value acc

keras.callbacks.ProgbarLogger

This function is implemented inside the intermediate state data output training information, schedule information relates.

【2、outs = f(ins_batch)】

Where the function f () is passed as a parameter to enter, we debug with debug, we found directly entered the Keras backend processing, so in line with keras is based on the premise that the second package tf done, but this is to call different backend function engine.

After the data check section proceeds to operation tensorflow_backend.Function._call genuine tf, wherein Function class is to provide a number of tools in operation Tensorflow FIG.

    def _call(self, inputs):
        if not isinstance(inputs, (list, tuple)):
            raise TypeError('`inputs` should be a list or tuple.')

        session = get_session()
        feed_arrays = []
        array_vals = []
        feed_symbols = []
        symbol_vals = []
        #数据处理转换
        for tensor, value in zip(self.inputs, inputs):
            if value is None:
                continue
            if is_tensor(value):
                # Case: feeding symbolic tensor.
                feed_symbols.append(tensor)
                symbol_vals.append(value)
            else:
                feed_arrays.append(tensor)
                # We need to do array conversion and type casting
                # at this level, since
                # `callable_fn` only supports exact matches.
                array_vals.append(
                    np.asarray(value,
                               dtype=tf.as_dtype(tensor.dtype).as_numpy_dtype))
        if self.feed_dict:
            for key in sorted(self.feed_dict.keys()):
                array_vals.append(
                    np.asarray(self.feed_dict[key],
                               dtype=tf.as_dtype(key.dtype).as_numpy_dtype))

        # Refresh callable if anything has changed.
        if (self._callable_fn is None or
                feed_arrays != self._feed_arrays or
                symbol_vals != self._symbol_vals or
                feed_symbols != self._feed_symbols or
                session != self._session):
            #生成一个可以调用的graph
            self._make_callable(feed_arrays,
                                feed_symbols,
                                symbol_vals,
                                session)
        #运行graph
        if self.run_metadata:
            fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
        else:
            fetched = self._callable_fn(*array_vals)
        #返回结果
        return fetched[:len(self.outputs)]

to sum up:

1, Keras tf is calculated call is divided batch operation, the end of each batch keras may be returned corresponding memory operation.

 

Guess you like

Origin blog.csdn.net/qq_36931982/article/details/90312852