How to make FasterTransformer support dynamic batch and dynamic sequence length

FasterTransformer 算子

Since operator nvidia defined in compiled code FasterTransformer open source, and providing tensorrt tensorflow py calls and examples, see FasterTransformer.py . But if you use a custom operator tensorflow very inconvenient, its batch size and sequence length are fixed. Now provided a method allowed to become dynamic, as follows:

  1. Modification bert_transformer_op.cc , the batch_size, from_seq_len, to_seq_len attr attribute removed, renamed input parameters, as follows:
   .Input("output_bias: T")
   .Input("output_layernorm_beta: T")
   .Input("output_layernorm_gamma: T")
+  .Input("batch_size: int32")
+  .Input("from_seq_len: int32")
   .Output("output: T")
   .Attr("T: {float, half}")
-  .Attr("batch_size: int >= 1")
-  .Attr("from_seq_len: int >= 1")
-  .Attr("to_seq_len: int >= 1")
+  //.Attr("batch_size: int >= 1")
+  //.Attr("from_seq_len: int >= 1")
+  //.Attr("to_seq_len: int >= 1")
   .Attr("head_num: int >= 1")
   .Attr("size_per_head: int >= 1")
   .SetShapeFn([](shape_inference::InferenceContext *c) {
       int batch_size, from_seq_len, to_seq_len, head_num, size_per_head;
-      c->GetAttr("batch_size", &batch_size);
-      c->GetAttr("from_seq_len", &from_seq_len);
-      c->GetAttr("to_seq_len", &to_seq_len);
+      //c->GetAttr("batch_size", &batch_size);
+      //c->GetAttr("from_seq_len", &from_seq_len);
+      //c->GetAttr("to_seq_len", &to_seq_len);
       c->GetAttr("head_num", &head_num);
       c->GetAttr("size_per_head", &size_per_head);
-      c->set_output(0, c->MakeShape({batch_size * from_seq_len, head_num * size_per_head}));
+      //c->set_output(0, c->MakeShape({batch_size * from_seq_len, head_num * size_per_head}));
+      c->set_output(0, c->input(0));
       return Status::OK();
       });
 template <typename Device, typename T>
@@ -70,14 +71,15 @@ class BertTransformerOp : public OpKernel
   public:
     explicit BertTransformerOp(OpKernelConstruction *context) : OpKernel(context)
     {
-      OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
-      OP_REQUIRES_OK(context, context->GetAttr("from_seq_len", &from_seq_len_));
-      OP_REQUIRES_OK(context, context->GetAttr("to_seq_len", &to_seq_len_));
+      //OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
+      //OP_REQUIRES_OK(context, context->GetAttr("from_seq_len", &from_seq_len_));
+      //OP_REQUIRES_OK(context, context->GetAttr("to_seq_len", &to_seq_len_));
       OP_REQUIRES_OK(context, context->GetAttr("head_num", &head_num_));
       OP_REQUIRES_OK(context, context->GetAttr("size_per_head", &size_per_head_));
 
-      OP_REQUIRES(context, (from_seq_len_ == to_seq_len_),
-          errors::InvalidArgument("Only support from_seq_len == to_seq_len"));
+      //printf("++++++++ %d =%d \n", from_seq_len_, to_seq_len_)
+      //OP_REQUIRES(context, (from_seq_len_ == to_seq_len_),
+      ///    errors::InvalidArgument("Only support from_seq_len == to_seq_len"));
 
       try
       {
@@ -95,6 +97,11 @@ class BertTransformerOp : public OpKernel
       BertEncoderTransformer<EncoderTraits_> *encoder_transformer_;
       try
       {
+     
+        batch_size_ = context->input(19).flat<int32>().size()/3;
+        from_seq_len_ = context->input(20).flat<int32>().size()/3;
+        to_seq_len_ = from_seq_len_;
+        //printf("==>%d %d\n", batch_size_, from_seq_len_);
         fastertransformer::Allocator<AllocatorType::TF> allocator_(context);
         encoder_transformer_ = new BertEncoderTransformer<EncoderTraits_>(allocator_, 
           batch_size_, from_seq_len_, to_seq_len_, head_num_, size_per_head_);
@@ -104,7 +111,7 @@ class BertTransformerOp : public OpKernel
         OP_REQUIRES(context, false, errors::Internal(error.what()));
       }
       
-      OP_REQUIRES(context, context->num_inputs() == 19, errors::InvalidArgument("Less input arguments"));
+      OP_REQUIRES(context, context->num_inputs() == 21, errors::InvalidArgument("Less input arguments"));
 
       EncoderInitParam<DataType_> param; //init param here

Since the memory cuda input, the input value read directly is not possible (the value from the memory copy memory, time-consuming), but we can directly read the size of a shape in memory, we form a fake size, to obtain batch_size and seq_len by this size.

  1. FasterTransformer.py amended as follows:
    ...
    fast_list_tensor = tf.shape(input_tensor)
    ...
    layer_output = transformer_op_module.bert_transformer(
        layer_input,
        layer_input,
        trainable_vars[0], trainable_vars[2], trainable_vars[4], trainable_vars[1], trainable_vars[3], trainable_vars[5],
        attention_mask,
        trainable_vars[6], trainable_vars[7], trainable_vars[8], trainable_vars[9], trainable_vars[10], trainable_vars[11],
        trainable_vars[12], trainable_vars[13], trainable_vars[14], trainable_vars[15], tf.tile([[1],[2],[3]], [1,fast_list_tensor[0]]),
        tf.tile([[1],[2],[3]], [1,fast_list_tensor[1]]),
        #batch_size=batch_size, 
        #from_seq_len=seq_length, 
        #to_seq_len=seq_length, 
        head_num=num_attention_heads, size_per_head=attention_head_size)
  1. Through the above changes, when we use transformer_op_module, you do not need to enforce the batch size and seq length, and representation generation model when so similar configuration:
input_ids = tf.placeholder(tf.int32,(None, None), 'input_ids')
input_mask = tf.placeholder(tf.float32,(None, None), 'input_mask')
input_type_ids = tf.placeholder(tf.int32,(None, None), 'input_type_ids')

It will be able to generate support tensorflow dynamic batch and model of the dynamic seq len.

Guess you like

Origin www.cnblogs.com/th3Bear/p/11502641.html