概要
- トーチサマリは、各層のネットワーク構造とパラメータを記述するために深層学習で一般的に使用されるツールです。ネットワークモデルを構築する際に、ネットワークモデル内のパラメータが正しいか、さらに正しい次元の出力情報が与えられるかどうかを確認するために使用できます。
- ネットワークモデルの入力情報には、従来の単一入力に加えて、複数入力のネットワークモデルもあります。複数の入力が使用される場合、torchsummary の使用時にエラー メッセージが表示されます: TypeError: can't multiply sequence by non-int of type 'tuple'
- この場合に使用されるバージョン Torchsummary=1.5.1
- 参考:
- https://github.com/sksq96/pytorch-summary/issues/90
- https://blog.csdn.net/qq_43733107/article/details/126508616
質問
エラー メッセージによると、torchsummary/torchsummary.py の 100 行目にエラー コードが見つかります。
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
その中には、input_size
一般的な画像情報などの入力情報の次元があります: (3, 64, 64)。また、入力が複数ある場合、np.prod
パラメータの乗算を直接実現することはできません。
解決
方法 1
total_input_size = abs(np.prod(sum((input_size),())) * batch_size * 4. / (1024 ** 2.))
方法 2
total_input_size = abs(np.sum([np.prod(in_tuple) for in_tuple in input_size]) * batch_size * 4. / (1024 ** 2.))
元のコードにアノテーションを付けた後、上記の 2 つの方法のいずれかを選択して機能を実現できます。また、コードから、1 つは加算後の乗算であり、もう 1 つはトラバーサル乗算の後の累積であることがわかります。