ncnn之五:参数和模型文件结构

net.param

7767517 # magic number
158 183 # [layer count] [blob count]
# [layer type] [layer name] [input count] [output count] [input blobs] [output blobs] [layer specific params]
Input            data                             0 1 data
Convolution      conv1                            1 1 data conv1 0=64 1=3 11=3 3=2 13=2 4=0 14=0 5=1 6=1728
ReLU             relu_conv1                       1 1 conv1 relu_conv1
Convolution      conv2                            1 1 relu_conv1 conv2 0=64 1=3 11=3 3=2 13=2 4=0 14=0 5=1 6=36864

load_param代码:

int Net::load_param(FILE* fp)
{
    int magic = 0;
    fscanf(fp, "%d", &magic);
    if (magic != 7767517)
    {
        fprintf(stderr, "param is too old, please regenerate\n");
        return -1;
    }

    // parse
    int layer_count = 0;
    int blob_count = 0;
    fscanf(fp, "%d %d", &layer_count, &blob_count);

    layers.resize(layer_count);
    blobs.resize(blob_count);

    ParamDict pd;

    int layer_index = 0;
    int blob_index = 0;
    while (!feof(fp))
    {
        int nscan = 0;
        
        char layer_type[257];
        char layer_name[257];
        int bottom_count = 0;
        int top_count = 0;
        nscan = fscanf(fp, "%256s %256s %d %d", layer_type, layer_name, &bottom_count, &top_count);
        if (nscan != 4)
        {
            continue;
        }

        Layer* layer = create_layer(layer_type);
        if (!layer)
        {
            layer = create_custom_layer(layer_type);
        }
        if (!layer)
        {
            fprintf(stderr, "layer %s not exists or registered\n", layer_type);
            clear();
            return -1;
        }

        layer->type = std::string(layer_type);
        layer->name = std::string(layer_name);
//         fprintf(stderr, "new layer %d %s\n", layer_index, layer_name);

        layer->bottoms.resize(bottom_count);
        for (int i=0; i<bottom_count; i++)
        {
            char bottom_name[257];
            nscan = fscanf(fp, "%256s", bottom_name);
            if (nscan != 1)
            {
                continue;
            }

            int bottom_blob_index = find_blob_index_by_name(bottom_name);
            if (bottom_blob_index == -1)
            {
                Blob& blob = blobs[blob_index];
                bottom_blob_index = blob_index;
                blob.name = std::string(bottom_name);
//                 fprintf(stderr, "new blob %s\n", bottom_name);
                blob_index++;
            }

            Blob& blob = blobs[bottom_blob_index];
            blob.consumers.push_back(layer_index);
            layer->bottoms[i] = bottom_blob_index;
        }

        layer->tops.resize(top_count);
        for (int i=0; i<top_count; i++)
        {
            Blob& blob = blobs[blob_index];

            char blob_name[257];
            nscan = fscanf(fp, "%256s", blob_name);
            if (nscan != 1)
            {
                continue;
            }

            blob.name = std::string(blob_name);
//             fprintf(stderr, "new blob %s\n", blob_name);
            blob.producer = layer_index;
            layer->tops[i] = blob_index;
            blob_index++;
        }

        // layer specific params
        int pdlr = pd.load_param(fp);
        if (pdlr != 0)
        {
            fprintf(stderr, "ParamDict load_param failed\n");
            continue;
        }

        int lr = layer->load_param(pd);
        if (lr != 0)
        {
            fprintf(stderr, "layer load_param failed\n");
            continue;
        }

        layers[layer_index] = layer;
        layer_index++;
    }
    return 0;
}

int Net::load_param(const char* protopath)
{
    FILE* fp = fopen(protopath, "rb");
    if (!fp)
    {
        fprintf(stderr, "fopen %s failed\n", protopath);
        return -1;
    }

    int ret = load_param(fp);
    fclose(fp);
    return ret;
}

参数字典 ParamDict

index:0~19 对应整形或浮点型数据,
查阅 operation param weight table
在这里插入图片描述
index:-23000 减去0~19 对应整形或浮点型数组 , 属于特殊参数(可能没有): 一种是k=v的类型;另一种是k=len,v1,v2,v3….(数组类型)。该层在ncnn中是存放到paramDict结构中. 根据读取的index判断,如果小于-23300,表示为数组,那么等号右边第一个参数就是数组长度,后面顺序就是数组内容,[array size],int,int,…,int或[array size],float,float,…,float,例如:

0=1 1=2.5 -23303=2,2.0,3.0

index为-23303,表明当前参数为数组,等号右边第一个参数为2,表明数组长度为2,后面2.0,3.0就是数组的内容

int ParamDict::load_param(const DataReader& dr)
{
    clear();

//     0=100 1=1.250000 -23303=5,0.1,0.2,0.4,0.8,1.0

    // parse each key=value pair
    int id = 0;
    while (dr.scan("%d=", &id) == 1)
    {
        bool is_array = id <= -23300;
        if (is_array)
        {
            id = -id - 23300;
        }

        if (is_array)
        {
            int len = 0;
            int nscan = dr.scan("%d", &len);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict read array length failed\n");
                return -1;
            }

            params[id].v.create(len);

            for (int j = 0; j < len; j++)
            {
                char vstr[16];
                nscan = dr.scan(",%15[^,\n ]", vstr);
                if (nscan != 1)
                {
                    fprintf(stderr, "ParamDict read array element failed\n");
                    return -1;
                }

                bool is_float = vstr_is_float(vstr);

                if (is_float)
                {
                    float* ptr = params[id].v;
                    nscan = sscanf(vstr, "%f", &ptr[j]);
                }
                else
                {
                    int* ptr = params[id].v;
                    nscan = sscanf(vstr, "%d", &ptr[j]);
                }
                if (nscan != 1)
                {
                    fprintf(stderr, "ParamDict parse array element failed\n");
                    return -1;
                }

                params[id].type = is_float ? 6 : 5;
            }
        }
        else
        {
            char vstr[16];
            int nscan = dr.scan("%15s", vstr);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict read value failed\n");
                return -1;
            }

            bool is_float = vstr_is_float(vstr);

            if (is_float)
                nscan = sscanf(vstr, "%f", &params[id].f);
            else
                nscan = sscanf(vstr, "%d", &params[id].i);
            if (nscan != 1)
            {
                fprintf(stderr, "ParamDict parse value failed\n");
                return -1;
            }

            params[id].type = is_float ? 3 : 2;
        }
    }

    return 0;
}

其中 vstr_is_float函数,原理很简单,就是判断数字对应字符串中是否存在小数点’.‘或字母’e’,对应小数的两种写法,一种正常的小数点表示法,一种是科学计数法.

static bool vstr_is_float(const char vstr[16])
{
    // look ahead for determine isfloat
    for (int j=0; j<16; j++)
    {
        if (vstr[j] == '\0')
            break;

        if (vstr[j] == '.' || tolower(vstr[j]) == 'e')
            return true;
    }

    return false;
}

net.bin

  +---------+---------+---------+---------+---------+---------+
  | weight1 | weight2 | weight3 | weight4 | ....... | weightN |
  +---------+---------+---------+---------+---------+---------+
  ^         ^         ^         ^
  0x0      0x80      0x140     0x1C0

其中 weight buffer

[flag] (optional)
[raw data]
[padding] (optional)

flag : unsigned int, little-endian, indicating the weight storage type, 0 => float32, 0x01306B47 => float16, otherwise => quantized int8, may be omitted if the layer implementation forced the storage type explicitly
raw data : raw weight data, little-endian, float32 data or float16 data or quantized table and indexes depending on the storage type flag
padding : padding space for 32bit alignment, may be omitted if already aligned

load_model代码:

int Net::load_model(FILE* fp)
{
    if (layers.empty())
    {
        fprintf(stderr, "network graph not ready\n");
        return -1;
    }

    // load file
    int ret = 0;
    ModelBinFromStdio mb(fp);
    for (size_t i=0; i<layers.size(); i++)
    {
        Layer* layer = layers[i];
        int lret = layer->load_model(mb);
        if (lret != 0)
        {
            fprintf(stderr, "layer load_model %d failed\n", (int)i);
            ret = -1;
            break;
        }
    }
    return ret;
}

int Net::load_model(const char* modelpath)
{
    FILE* fp = fopen(modelpath, "rb");
    if (!fp)
    {
        fprintf(stderr, "fopen %s failed\n", modelpath);
        return -1;
    }
    int ret = load_model(fp);
    fclose(fp);
    return ret;
}

参考:
1 param-and-model-file-structure
2 operation param weight table

发布了270 篇原创文章 · 获赞 344 · 访问量 65万+

猜你喜欢

转载自blog.csdn.net/shanglianlm/article/details/103247724