【项目实战】WaveNet 代码解析 —— train.py 【更新中】

WaveNet 代码解析 —— train.py

  简介

       本项目一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
       
       WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
       
       由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
       
       本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
       
       本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址
       

  代码解析

    全局变量解析

       以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。

		BATCH_SIZE = 1								# 一批训练集中,样本音频的数量
		DATA_DIRECTORY = './VCTK-Corpus'			# 下载的VCTK数据集的路径
		LOGDIR_ROOT = './logdir'					# 训练日志的路径
		CHECKPOINT_EVERY = 50						# 保存训练模型的检查点数量
		NUM_STEPS = int(1e5)						# 训练的总次数
		LEARNING_RATE = 1e-3						# 学习率
		WAVENET_PARAMS = './wavenet_params.json'	# WaveNet 模型的相关参数路径
		STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())				# 当前日期格式化
		SAMPLE_SIZE = 100000						# 样本数量大小
		L2_REGULARIZATION_STRENGTH = 0				# L2正则化中的系数
		SILENCE_THRESHOLD = 0.3						# 音量阈值大小
		EPSILON = 0.001								# 精度设置
		MOMENTUM = 0.9								# 优化器动量
		MAX_TO_KEEP = 5								# 保存的最大检查点数量
		METADATA = False							# 高级调试信息存储标志

    函数解析

      main()

        下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。

	def main():
	    # 解析命令行功能参数
	    args = get_arguments()
	
	    try:
	        # 验证并整理与目录有关的参数
	        directories = validate_directories(args)
	    except ValueError as e:
	        print("Some arguments are wrong:")
	        print(str(e))
	        return
	
	    # 将整理好的文件路径赋给相应变量
	    logdir = directories['logdir']
	    restore_from = directories['restore_from']
	
	    # 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练
	    is_overwritten_training = logdir != restore_from
	
	    # 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dict
	    with open(args.wavenet_params, 'r') as f:
	        wavenet_params = json.load(f)
	
	    # 创建线程协调器,多线程协调器相关知识可参考文章地址如下:
	    # https://blog.csdn.net/weixin_42721167/article/details/112795491
	    coord = tf.train.Coordinator()
	
	    # 从VCTK数据集中加载原始波形
	    with tf.name_scope('create_inputs'):
	        # 允许通过指定接近零的阈值跳过静默修剪
	        silence_threshold = args.silence_threshold if args.silence_threshold > \
	                                                      EPSILON else None
	        gc_enabled = args.gc_channels is not None
	        # 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中
	        reader = AudioReader(
	            args.data_dir,
	            coord,
	            sample_rate=wavenet_params['sample_rate'],
	            gc_enabled=gc_enabled,
	            receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
	                                                                   wavenet_params["dilations"],
	                                                                   wavenet_params["scalar_input"],
	                                                                   wavenet_params["initial_filter_width"]),
	            sample_size=args.sample_size,
	            silence_threshold=silence_threshold)
	        # 准备好的音频出队列
	        audio_batch = reader.dequeue(args.batch_size)
	        if gc_enabled:
	            gc_id_batch = reader.dequeue_gc(args.batch_size)
	        else:
	            gc_id_batch = None
	
	    # 创建 WaveNet 网络
	    net = WaveNetModel(
	        batch_size=args.batch_size,
	        dilations=wavenet_params["dilations"],
	        filter_width=wavenet_params["filter_width"],
	        residual_channels=wavenet_params["residual_channels"],
	        dilation_channels=wavenet_params["dilation_channels"],
	        skip_channels=wavenet_params["skip_channels"],
	        quantization_channels=wavenet_params["quantization_channels"],
	        use_biases=wavenet_params["use_biases"],
	        scalar_input=wavenet_params["scalar_input"],
	        initial_filter_width=wavenet_params["initial_filter_width"],
	        histograms=args.histograms,
	        global_condition_channels=args.gc_channels,
	        global_condition_cardinality=reader.gc_category_cardinality)
	
	    # 验证 l2 正则化系数
	    if args.l2_regularization_strength == 0:
	        args.l2_regularization_strength = None
	    
	    # 创建一个 WaveNet 网络并返回自动编码损耗
	    loss = net.loss(input_batch=audio_batch,
	                    global_condition_batch=gc_id_batch,
	                    l2_regularization_strength=args.l2_regularization_strength)
	    
	    # 创建对应的优化器
	    optimizer = optimizer_factory[args.optimizer](
	                    learning_rate=args.learning_rate,
	                    momentum=args.momentum)
	    
	    # 返回使用 trainable=True 创建的所有变量
	    trainable = tf.trainable_variables()
	    optim = optimizer.minimize(loss, var_list=trainable)
	
	    # 设置TensorBoard的日志记录
	    writer = tf.summary.FileWriter(logdir)
	    writer.add_graph(tf.get_default_graph())
	    # 收集关于训练的元信息
	    run_metadata = tf.RunMetadata()
	    summaries = tf.summary.merge_all()
	
	    # 建立会话
	    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
	    # 初始化变量
	    init = tf.global_variables_initializer()
	    sess.run(init)
	
	    # 存储模型检查点的保护程序
	    # 在创建这个 Saver 对象的时候, max_to_keep 参数表示要保留的最近检查点文件的最大数量,创建新文件时,将删除旧文件
	    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=args.max_checkpoints)
	
	    try:
	        # 恢复训练模型,获取训练步数
	        saved_global_step = load(saver, sess, restore_from)
	        if is_overwritten_training or saved_global_step is None:
	            # 第一个训练步骤将是 saved_global_step + 1,因此我们在这里输入-1表示新的或覆盖的训练
	            saved_global_step = -1
	
	    except:
	        print("Something went wrong while restoring checkpoint. "
	              "We will terminate training to avoid accidentally overwriting "
	              "the previous model.")
	        raise
	
	    # 开启入队线程启动器,详细介绍可参考这篇博客:
	    # https://blog.csdn.net/weixin_42721167/article/details/112795491
	    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
	    reader.start_threads(sess)
	
	    step = None
	    last_saved_step = saved_global_step
	    try:
	        # 从恢复模型的节点处开始训练
	        for step in range(saved_global_step + 1, args.num_steps):
	            # 获取当前时间
	            start_time = time.time()
	            # 当存储标志为 true 且训练次数为50的倍数时存储调试信息
	            if args.store_metadata and step % 50 == 0:
	                # 缓慢运行,存储额外的调试信息
	                print('Storing metadata')
	                
	                # RunOptions提供配置参数,供SessionRun调用时使用
	                run_options = tf.RunOptions(
	                    trace_level=tf.RunOptions.FULL_TRACE)
	                # 计算日志与自动编码的损失
	                summary, loss_value, _ = sess.run(
	                    [summaries, loss, optim],
	                    options=run_options,
	                    run_metadata=run_metadata)
	                
	                # 调用train_writer的add_summary方法将训练过程以及训练步数保存 
	                writer.add_summary(summary, step)
	                # 记录CPU/内存使用情况
	                writer.add_run_metadata(run_metadata,
	                                        'step_{:04d}'.format(step))
	                # Tensorflow的Timeline模块是用于描述张量图一个工具,可以记录在会话中每个操作执行时间和资源分配及消耗的情况
	                tl = timeline.Timeline(run_metadata.step_stats)
	                # 加载文件路径,打开文件,写入日志
	                timeline_path = os.path.join(logdir, 'timeline.trace')
	                with open(timeline_path, 'w') as f:
	                    f.write(tl.generate_chrome_trace_format(show_memory=True))
	            else:
	                # 在不保存模型的训练步数里,保存训练日志到 Tensorboard
	                summary, loss_value, _ = sess.run([summaries, loss, optim])
	                writer.add_summary(summary, step)
	
	            # 计算并打印训练一次的时间与结果
	            duration = time.time() - start_time
	            print('step {:d} - loss = {:.3f}, ({:.3f} sec/step)'
	                  .format(step, loss_value, duration))
	
	            # 每隔输入的检查点间隔存储一次训练模型
	            if step % args.checkpoint_every == 0:
	                save(saver, sess, logdir, step)
	                last_saved_step = step
	
	    except KeyboardInterrupt:
	        # 在 ctrl+C 显示之后引入一个换行符,这样保存消息就在它自己的行上了
	        print()
	    finally:
	        # 若训练到了更多步
	        if step > last_saved_step:
	            save(saver, sess, logdir, step)
	        coord.request_stop()
	        coord.join(threads)

       

      get_arguments()

        下面这段代码主要是获取命令行参数。
        运用 python 中的 argparse 模块对我们输入的命令行进行解析。

	def get_arguments():
	    def _str_to_bool(s):
	        """ 将string转换为bool """
	        """ 传入的字符串被限制为'true'或'false' """
	        if s.lower() not in ['true', 'false']:
	            raise ValueError('Argument needs to be a '
	                             'boolean, got {}'.format(s))
	                             
	        return {
    
    'true': True, 'false': False}[s.lower()]
	
	    # 创建解析器,解析的功能参数作为 WaveNet 的实例
	    parser = argparse.ArgumentParser(description='WaveNet example network')
	    # 添加可选功能参数: --batch_size; 该参数含义为: 一次要处理的 wav 文件数量
	    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,
	                        help='How many wav files to process at once. Default: ' + str(BATCH_SIZE) + '.')
	    # 添加可选功能参数: --data_dir; 该参数含义为: VCTK数据集的文件路径
	    parser.add_argument('--data_dir', type=str, default=DATA_DIRECTORY,
	                        help='The directory containing the VCTK corpus.')
	    # 添加可选功能参数: --store_metadata; 该参数含义为: 高级调试信息存储标志
	    parser.add_argument('--store_metadata', type=bool, default=METADATA,
	                        help='Whether to store advanced debugging information '
	                        '(execution time, memory consumption) for use with '
	                        'TensorBoard. Default: ' + str(METADATA) + '.')
	    # 添加可选功能参数: --logdir; 该参数含义为: 存储 TensorBoard 日志信息的文件路径;
	    # 需要注意: 该参数不能与'--logdir_root'或'--restore_from'一起使用
	    parser.add_argument('--logdir', type=str, default=None,
	                        help='Directory in which to store the logging '
	                        'information for TensorBoard. '
	                        'If the model already exists, it will restore '
	                        'the state and will continue training. '
	                        'Cannot use with --logdir_root and --restore_from.')
	    # 添加可选功能参数: --logdir_root; 该参数含义为: 放置日志输出和生成模型的文件路径,存放在带有日期的子目录下
	    # 需要注意: 该参数不能与'--logdir'一起使用
	    parser.add_argument('--logdir_root', type=str, default=None,
	                        help='Root directory to place the logging '
	                        'output and generated model. These are stored '
	                        'under the dated subdirectory of --logdir_root. '
	                        'Cannot use with --logdir.')
	    # 添加可选功能参数: --restore_from; 该参数含义为: 恢复模型的目录,能创建带有日期的子目录
	    # 需要注意: 该参数不能与'--logdir'一起使用
	    parser.add_argument('--restore_from', type=str, default=None,
	                        help='Directory in which to restore the model from. '
	                        'This creates the new model under the dated directory '
	                        'in --logdir_root. '
	                        'Cannot use with --logdir.')
	    # 添加可选功能参数: --checkpoint_every; 该参数含义为: 存放训练模型的检查点间隔
	    parser.add_argument('--checkpoint_every', type=int,
	                        default=CHECKPOINT_EVERY,
	                        help='How many steps to save each checkpoint after. Default: ' + str(CHECKPOINT_EVERY) + '.')
	    # 添加可选功能参数: --num_steps; 该参数含义为: 训练的次数
	    parser.add_argument('--num_steps', type=int, default=NUM_STEPS,
	                        help='Number of training steps. Default: ' + str(NUM_STEPS) + '.')
	    # 添加可选功能参数: --learning_rate; 该参数含义为: 训练的学习率
	    parser.add_argument('--learning_rate', type=float, default=LEARNING_RATE,
	                        help='Learning rate for training. Default: ' + str(LEARNING_RATE) + '.')
	    # 添加可选功能参数: --wavenet_params; 该参数含义为: WaveNet 模型的相关参数
	    parser.add_argument('--wavenet_params', type=str, default=WAVENET_PARAMS,
	                        help='JSON file with the network parameters. Default: ' + WAVENET_PARAMS + '.')
	    # 添加可选功能参数: --sample_size; 该参数含义为: 使用的样本数量
	    parser.add_argument('--sample_size', type=int, default=SAMPLE_SIZE,
	                        help='Concatenate and cut audio samples to this many '
	                        'samples. Default: ' + str(SAMPLE_SIZE) + '.')
	    # 添加可选功能参数: --l2_regularization_strength; 该参数含义为: L2正则化的系数
	    parser.add_argument('--l2_regularization_strength', type=float,
	                        default=L2_REGULARIZATION_STRENGTH,
	                        help='Coefficient in the L2 regularization. '
	                        'Default: False')
	    # 添加可选功能参数: --silence_threshold; 该参数含义为: 音量阈值限制
	    parser.add_argument('--silence_threshold', type=float,
	                        default=SILENCE_THRESHOLD,
	                        help='Volume threshold below which to trim the start '
	                        'and the end from the training set samples. Default: ' + str(SILENCE_THRESHOLD) + '.')
	    # 添加可选功能参数: --optimizer; 该参数含义为: 优化器选择
	    parser.add_argument('--optimizer', type=str, default='adam',
	                        choices=optimizer_factory.keys(),
	                        help='Select the optimizer specified by this option. Default: adam.')
	    # 添加可选功能参数: --momentum; 该参数含义为: 优化器动量大小
	    parser.add_argument('--momentum', type=float,
	                        default=MOMENTUM, help='Specify the momentum to be '
	                        'used by sgd or rmsprop optimizer. Ignored by the '
	                        'adam optimizer. Default: ' + str(MOMENTUM) + '.')
	    # 添加可选功能参数: --histograms; 该参数含义为: 直方图汇总存储标志
	    parser.add_argument('--histograms', type=_str_to_bool, default=False,
	                        help='Whether to store histogram summaries. Default: False')
	    # 添加可选功能参数: --gc_channels; 该参数含义为: 全局条件通道数量
	    parser.add_argument('--gc_channels', type=int, default=None,
	                        help='Number of global condition channels. Default: None. Expecting: Int')
	    # 添加可选功能参数: --max_checkpoints; 该参数含义为: 最大训练模型保存检查点数
	    parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP,
	                        help='Maximum amount of checkpoints that will be kept alive. Default: '
	                             + str(MAX_TO_KEEP) + '.')
	    # 把parser中设置的所有"add_argument"给返回到args子类实例中并返回
	    return parser.parse_args()

     

      validate_directories(args)

        下面这段代码主要工作是:验证当前的几个目录是否冲突,将输入的目录参数规范化。
       

	def validate_directories(args):
	    """ 验证和整理与目录相关的参数 """
	
	    # 验证接断
	    # logdir 与 logdir_root 参数不能同时存在
	    if args.logdir and args.logdir_root:
	        raise ValueError("--logdir and --logdir_root cannot be "
	                         "specified at the same time.")
	
	    # logdir 与 restore_from 参数不能同时存在
	    if args.logdir and args.restore_from:
	        raise ValueError(
	            "--logdir and --restore_from cannot be specified at the same "
	            "time. This is to keep your previous model from unexpected "
	            "overwrites.\n"
	            "Use --logdir_root to specify the root of the directory which "
	            "will be automatically created with current date and time, or use "
	            "only --logdir to just continue the training from the last "
	            "checkpoint.")
	
	    # 整理阶段
	    # 为 logdir_root 参数赋予给定的值或是默认值
	    logdir_root = args.logdir_root
	    if logdir_root is None:
	        logdir_root = LOGDIR_ROOT
	
	    # 为 logdir 参数赋予给定的值或是 logdir_root 参数的默认值
	    logdir = args.logdir
	    if logdir is None:
	        logdir = get_default_logdir(logdir_root)
	        print('Using default logdir: {}'.format(logdir))
	
	    # 为 restore_from 参数赋予给定的值或是 logdir 参数的值
	    restore_from = args.restore_from
	    if restore_from is None:
	        # args.logdir and args.restore_from are exclusive,
	        # so it is guaranteed the logdir here is newly created.
	        restore_from = logdir
	
	    # 将验证并整理好的目录参数打包返回
	    return {
    
    
	        'logdir': logdir,
	        'logdir_root': args.logdir_root,
	        'restore_from': restore_from
	    }

     

      get_default_logdir(logdir_root)

        下面这段代码主要工作是:在给定的日志目录下,创建训练文件夹,再创建以带有当前日期时间的文件路径,并将该路径返回

扫描二维码关注公众号,回复: 12416104 查看本文章
	def get_default_logdir(logdir_root):
	    # 使用路径拼接函数 os.path.join() 在给定的目录下创建'train'目录
	    # 进而创建以当前日期时间为名的子目录,格式为:{0:%Y-%m-%dT%H-%M-%S}
	    logdir = os.path.join(logdir_root, 'train', STARTED_DATESTRING)
	    return logdir

       

      save(saver, sess, logdir, step)

        这段代码主要工作是:将给定的训练结果、会话以及检查点保存到指定的文件路径下

	def save(saver, sess, logdir, step):
	    # 设置保存的模型文件名,将文件路径进行拼接
	    model_name = 'model.ckpt'
	    checkpoint_path = os.path.join(logdir, model_name)
	    print('Storing checkpoint to {} ...'.format(logdir), end="")
	    
	    # 刷新缓冲区,保证正常输出
	    sys.stdout.flush()
	
	    # 若文件不存在则先创造文件
	    if not os.path.exists(logdir):
	        os.makedirs(logdir)
	
	    # 保存模型
	    saver.save(sess, checkpoint_path, global_step=step)
	    print(' Done.')

       

      load(saver, sess, logdir)

        这段代码主要工作是:将指定路径下的模型训练结果恢复到当前会话

	def load(saver, sess, logdir):
	    print("Trying to restore saved checkpoints from {} ...".format(logdir),
	          end="")
	
	    # 从指定路径下返回训练模型以及检查点
	    ckpt = tf.train.get_checkpoint_state(logdir)
	    if ckpt:
	        print("  Checkpoint found: {}".format(ckpt.model_checkpoint_path))
	        # 找到模型,获取检查点
	        global_step = int(ckpt.model_checkpoint_path
	                          .split('/')[-1]
	                          .split('-')[-1])
	        print("  Global step was: {}".format(global_step))
	        print("  Restoring...", end="")
	        
	        # 恢复最新检查点训练情况
	        saver.restore(sess, ckpt.model_checkpoint_path)
	        print(" Done.")
	        # 返回检查点
	        return global_step
	    else:
	        # 未找到模型,返回空值
	        print(" No checkpoint found.")
	        return None

       
       
       本文还在持续更新中!
       欢迎各位大佬交流讨论!

猜你喜欢

转载自blog.csdn.net/weixin_42721167/article/details/112907874