keras 模型可视化可参考:keras 模型可视化和遇到的坑
但是默认的方法plot_model(model, to_file='model_ResNet.png',show_shapes=True)
并不能显示conv和pooling的kernel_size,所以需要修改一下plot_model的方法plot_model->model_to_dot。
def plot_model(model,
to_file='model.png',
show_shapes=False,
show_layer_names=True,
rankdir='TB'):
dot = model_to_dot(model, show_shapes, show_layer_names, rankdir)#需要修改的方法,再进入
_, extension = os.path.splitext(to_file)
if not extension:
extension = 'png'
else:
extension = extension[1:]
dot.write(to_file, format=extension)
def model_to_dot(model,
show_shapes=False,
show_layer_names=True,
rankdir='TB'):
from ..layers.wrappers import Wrapper
from ..models import Sequential
_check_pydot()
dot = pydot.Dot()
dot.set('rankdir', rankdir)
dot.set('concentrate', True)
dot.set_node_defaults(shape='record')
if isinstance(model, Sequential):
if not model.built:
model.build()
model = model.model
layers = model.layers
# Create graph nodes.
for layer in layers:
layer_id = str(id(layer))
# Append a wrapped layer's label to node's label, if it exists.
layer_name = layer.name
class_name = layer.__class__.__name__
if isinstance(layer, Wrapper):
layer_name = '{}({})'.format(layer_name, layer.layer.name)
child_class_name = layer.layer.__class__.__name__
class_name = '{}({})'.format(class_name, child_class_name)
# Create node's label.
#修改的位置,判断一下模型的名称,如果是conv或者pooling,就显示
if show_layer_names:
if class_name in ["Conv2D"]:
label = '{}: {}'.format(layer_name, class_name+str(layer.kernel_size))
elif class_name in ["AveragePooling2D","MaxPooling2D"]:
label = '{}: {}'.format(layer_name, class_name+str(layer.pool_size))
else:
label = '{}: {}'.format(layer_name, class_name)
else:
label = class_name
# Rebuild the label as a table including input/output shapes.
if show_shapes:
try:
outputlabels = str(layer.output_shape)
except AttributeError:
outputlabels = 'multiple'
if hasattr(layer, 'input_shape'):
inputlabels = str(layer.input_shape)
elif hasattr(layer, 'input_shapes'):
inputlabels = ', '.join(
[str(ishape) for ishape in layer.input_shapes])
else:
inputlabels = 'multiple'
label = '%s\n|{input:|output:}|{{%s}|{%s}}' % (label,
inputlabels,
outputlabels)
node = pydot.Node(layer_id, label=label)
dot.add_node(node)
# Connect nodes with edges.
for layer in layers:
layer_id = str(id(layer))
for i, node in enumerate(layer._inbound_nodes):
node_key = layer.name + '_ib-' + str(i)
if node_key in model._container_nodes:
for inbound_layer in node.inbound_layers:
inbound_layer_id = str(id(inbound_layer))
layer_id = str(id(layer))
dot.add_edge(pydot.Edge(inbound_layer_id, layer_id))
return dot
修改后运行代码,将ResNet模型可视化:
# coding=utf-8
'''
@author: admin
'''
from keras.applications.resnet50 import ResNet50
from keras.utils import plot_model
model = ResNet50(weights='imagenet')
plot_model(model, to_file='model_ResNet1.png',show_shapes=True)
输出ResNet模型图: