1122

由YouTube8M的视频模型到音频模型转化

  youtube8M的接口的参数较为容易设置,首先文件夹的train.py文件

import json
import os
import time

import eval_util
import export_model
import losses
import frame_level_models
import video_level_models
import readers
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow import app
from tensorflow import flags
from tensorflow import gfile
from tensorflow import logging
from tensorflow.python.client import device_lib
import utils

  这些包含了引入的文件夹的其他写好的py文件和需要用到的库,如果没有的话pip安装即可。

1.模型的保存地址设置:

flags.DEFINE_string("train_dir", "/tmp/yt8m_model/",
                      "The directory to save the model files in.")

  需要多提一句的是flags这个模块是用于执行py程序的外部参数设置的交互,具体可以参见tf.app.flags/tf.flags

2.数据集的存储地址指定:

flags.DEFINE_string(
      "train_data_pattern", "E:/Audio_project/audioset/audioset_v1_embeddings/bal_train/*.tfrecord",
      "File glob for the training dataset. If the files refer to Frame Level "
      "features (i.e. tensorflow.SequenceExample), then set --reader_type "
      "format. The (Sequence)Examples are expected to have 'rgb' byte array "
      "sequence feature as well as a 'labels' int64 context feature.")

  看代码应该知道,这里介绍设置数据集的指向地址,注意地址的分隔符是'' /  ” , 其实“\\”也是可以的。

3.特征的名字(这一部分需要特别注意):

flags.DEFINE_string("feature_names", "audio_embedding", "Name of the feature "
                      "to use for training.")

   注意要设置成“audio_embedding”

4.帧特征设置

flags.DEFINE_bool(
      "frame_features", True,
      "If set, then --train_data_pattern must be frame-level features. "
      "Otherwise, --train_data_pattern must be aggregated video-level "
      "features. The model must also be set appropriately (i.e. to read 3D "
      "batches VS 4D batches.")

  

猜你喜欢

转载自www.cnblogs.com/ChenKe-cheng/p/10004084.html