keras 模型可视化显示卷积conv和池化pooling kernel_size

keras 模型可视化可参考:keras 模型可视化和遇到的坑

但是默认的方法plot_model(model, to_file='model_ResNet.png',show_shapes=True)并不能显示conv和pooling的kernel_size,所以需要修改一下plot_model的方法plot_model->model_to_dot。

def plot_model(model,
               to_file='model.png',
               show_shapes=False,
               show_layer_names=True,
               rankdir='TB'):
    dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)#需要修改的方法,再进入
    _, extension = os.path.splitext(to_file)
    if not extension:
        extension = 'png'
    else:
        extension = extension[1:]
    dot.write(to_file, format=extension)

def model_to_dot(model,
                 show_shapes=False,
                 show_layer_names=True,
                 rankdir='TB'):
    from ..layers.wrappers import Wrapper
    from ..models import Sequential

    _check_pydot()
    dot = pydot.Dot()
    dot.set('rankdir', rankdir)
    dot.set('concentrate', True)
    dot.set_node_defaults(shape='record')

    if isinstance(model, Sequential):
        if not model.built:
            model.build()
        model = model.model
    layers = model.layers

    # Create graph nodes.
    for layer in layers:
        layer_id = str(id(layer))

        # Append a wrapped layer's label to node's label, if it exists.
        layer_name = layer.name
        class_name = layer.__class__.__name__
        if isinstance(layer, Wrapper):
            layer_name = '{}({})'.format(layer_name, layer.layer.name)
            child_class_name = layer.layer.__class__.__name__
            class_name = '{}({})'.format(class_name, child_class_name)

        # Create node's label.
        #修改的位置,判断一下模型的名称,如果是conv或者pooling,就显示
        if show_layer_names:
            if class_name in ["Conv2D"]:
                label = '{}: {}'.format(layer_name, class_name+str(layer.kernel_size))
            elif class_name in ["AveragePooling2D","MaxPooling2D"]:
                label = '{}: {}'.format(layer_name, class_name+str(layer.pool_size))
            else:
                label = '{}: {}'.format(layer_name, class_name)
        else:
            label = class_name

        # Rebuild the label as a table including input/output shapes.
        if show_shapes:
            try:
                outputlabels = str(layer.output_shape)
            except AttributeError:
                outputlabels = 'multiple'
            if hasattr(layer, 'input_shape'):
                inputlabels = str(layer.input_shape)
            elif hasattr(layer, 'input_shapes'):
                inputlabels = ', '.join(
                    [str(ishape) for ishape in layer.input_shapes])
            else:
                inputlabels = 'multiple'
            label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
                                                           inputlabels,
                                                           outputlabels)
        node = pydot.Node(layer_id, label=label)
        dot.add_node(node)

    # Connect nodes with edges.
    for layer in layers:
        layer_id = str(id(layer))
        for i, node in enumerate(layer._inbound_nodes):
            node_key = layer.name + '_ib-' + str(i)
            if node_key in model._container_nodes:
                for inbound_layer in node.inbound_layers:
                    inbound_layer_id = str(id(inbound_layer))
                    layer_id = str(id(layer))
                    dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
    return dot

修改后运行代码,将ResNet模型可视化:

# coding=utf-8
'''

@author: admin
'''
from keras.applications.resnet50 import ResNet50
from keras.utils import plot_model

model = ResNet50(weights='imagenet')

plot_model(model, to_file='model_ResNet1.png',show_shapes=True)

输出ResNet模型图:
这里写图片描述

猜你喜欢

转载自blog.csdn.net/u011311291/article/details/80375664