c++读取protobuf里的map字段或嵌套message字段



首先需要了解, map字段是个语法糖, 比如 node_def.proto中的如下部分:
...
import "tensorflow/core/framework/attr_value.proto";
...
message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;
  map<string, AttrValue> attr = 5;
};
最后一个map 其实是相当于一个嵌套的 repeated message:
message NodeDef {
  string name = 1;
  string op = 2;
  repeated string input = 3;
  string device = 4;


  message attrEntry {
    string key = 1;
    AttrValue value = 2;
  }


  repeated attrEntry attr = 5;
};




此处应该用到protobuf反射:  部分代码如下:


void getattrvalue_bykey( const tensorflow::NodeDef &node,  string needed_attrkey  ,ReturnStruct &returnstruct  ) {
     if (needed_attrkey.size()> 0 ) {
         // **1** : 从传入目标层 的attr属性中读信息
         const google::protobuf::Reflection* pNodeReflection   = node.GetReflection();
         const google::protobuf::Descriptor* pNodeDescriptor   = node.GetDescriptor();
         const google::protobuf::FieldDescriptor* pAttrField   = pNodeDescriptor->FindFieldByName("attr");


         for (int j = 0; j < node.attr_size(); ++j ) {     
             NodeDef::NodeDef_AttrEntry attrentry ;
             attrentry.MergeFrom( pNodeReflection->GetRepeatedMessage(node,  pAttrField , j  ) );


             auto pAttrentryReflection  = attrentry.GetReflection();
             auto pAttrentryDescriptor  = attrentry.GetDescriptor();
             auto pAttrkey_field        = pAttrentryDescriptor->FindFieldByName("key");
             auto pAttrValue_field      = pAttrentryDescriptor->FindFieldByName("value");


             if ( needed_attrkey == "strides") {
                 //strides  < == > attr的key为 "strides"   值为list 的第二和第三维。
                 if ( "strides" == pAttrentryReflection->GetString(attrentry , pAttrkey_field) ) {
                     ::tensorflow::AttrValue attrvalue;
                     attrvalue.MergeFrom( pAttrentryReflection->GetMessage(attrentry, pAttrValue_field ) );
                     auto pAttrvalueReflection = attrvalue.GetReflection();
                     auto pAttrvalueDescriptor = attrvalue.GetDescriptor();


                     // **2** : 从attr的 attrvalue属性中 读list信息
                     auto pListvalue_field       = pAttrvalueDescriptor->FindFieldByName("list");
                     ::tensorflow::AttrValue_ListValue listvalue ;
                     listvalue.MergeFrom(pAttrvalueReflection->GetMessage(attrvalue, pListvalue_field ) );
                     auto pListvalueReflection   = listvalue.GetReflection();
                     auto pListvalueDescriptor   = listvalue.GetDescriptor();
                     auto pIntegervalue_field    = pListvalueDescriptor->FindFieldByName("i");
                     //TODO: "strides" 值为list 的第二和第三维。 假设现在这两维的值是一样的。如果不一样,将来再处理
                     returnstruct.intvalue       = pListvalueReflection->GetRepeatedInt64(listvalue, pIntegervalue_field, 1 );
                     if (returnstruct.intvalue  != pListvalueReflection->GetRepeatedInt64(listvalue, pIntegervalue_field, 2 ) ) {
                        cout << "!!warning!  the width and hight strides is not equal !! now only use one !!" << endl;
                     }
                 }
             }
        }
    }
}



猜你喜欢

转载自blog.csdn.net/anthea_luo/article/details/80230921
今日推荐