以horovd的HorovodAllreduceOp为例,学习如何在tensorflow上添加一个新的操作OP

参考: http://www.tensorfly.cn/tfdoc/how_tos/adding_an_op.html

添加新的OP需要3步(下述所有代码在here):

1. 定义 Op 的接口

// 1. 定义 Op 的接口
//    REGISTER_OP()向 TensorFlow 系统注册来定义 Op 的接口,该OP就是HorovodAllreduceOp.
//    在注册时, 指定 Op 的名称: REGISTER_OP("HorovodAllreduce")
//                     输入(类型和名称): Input("tensor: T")
//                     输出(类型和名称): Output("sum: T")
//                     和所需要任何 属性的文档说明Doc(R"doc(...)doc");
//
//    该 Op 接受一个 T 类型 tensor 作为输入, T 类型可以是{int32, int64, float32, float64}
//          输出一个 T 类型 tensor sum,sum是在所有的MPI进程中求和
REGISTER_OP("HorovodAllreduce")
    .Attr("T: {int32, int64, float32, float64}")
    .Input("tensor: T")
    .Output("sum: T")
    .SetShapeFn([](shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    })
    .Doc(R"doc(
Perform an MPI Allreduce on a tensor. All other processes that do a reduction
on a tensor with the same name must have the same dimension for that tensor.
Tensors are reduced with other tensors that have the same node name for the
allreduce.

Arguments
    tensor:     A tensor to reduce.

Output
    sum:    A tensor with the same shape as `tensor`, summed across all MPI processes.
)doc");

2. 为 Op 实现 kernel

// 2. 为 Op 实现 kernel。
//    在定义接口之后, 每一个实现称之为一个 "kernel",提供一个或多个 Op 的实现,即可以存在多个 kernel。
//    为这些 kernel 的每一个创建一个对应的类, 继承 AsyncOpKernel, 覆盖 ComputeAsync 方法。
//    ComputeAsync 方法提供一个类型为 OpKernelContext* 的参数 context, 用于访问一些有用的信息, 例如输入和输出的 tensor。
class HorovodAllreduceOp : public AsyncOpKernel {
public:
  // 防止类构造函数的隐式自动转换,只能显示调用该构造函数
  explicit HorovodAllreduceOp(OpKernelConstruction* context)
      : AsyncOpKernel(context) {}

  // 重写ComputeAsync()方法
  void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
    OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()),
                         done);

    auto node_name = name();
    auto device = GetDeviceID(context);
    auto tensor = context->input(0);
    Tensor* output;
    OP_REQUIRES_OK_ASYNC(
        context, context->allocate_output(0, tensor.shape(), &output), done);
    // ReadyEvent makes sure input tensor is ready, and output is allocated.
    // shared_ptr 是一个标准的共享所有权的智能指针, 允许多个指针指向同一个对象
    auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context));
    // 模板函数 std::make_shared 可以返回一个指定类型的 std::shared_ptr
    auto hvd_context = std::make_shared<TFOpContext>(context);
    auto hvd_tensor = std::make_shared<TFTensor>(tensor);
    auto hvd_output = std::make_shared<TFTensor>(*output);
    // 将张量的Allreduce操作OP加入队列,加入谁的队列??
    auto enqueue_result = EnqueueTensorAllreduce(
        hvd_context, hvd_tensor, hvd_output, ready_event, node_name, device,
        [context, done](const common::Status& status) {
          context->SetStatus(ConvertStatus(status));
          done();
        });
    OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done);
  }
};

  

3. 注册OP到 TensorFlow 系统

// 3. 注册OP到 TensorFlow 系统
//    注册时可以指定该 kernel 运行时的多个约束条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_CPU),
                        HorovodAllreduceOp);
// 如果执行了GPU
#if HOROVOD_GPU_ALLREDUCE
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU),
                        HorovodAllreduceOp);
#endif

  

猜你喜欢

转载自www.cnblogs.com/lixiaolun/p/9163431.html