build() takes 1 positional argument but 2 were given tensorflow2自建模型错误

这是调试transformer中遇到的错误。

class AddPosEmbed(layers.Layer):
    def __init__(self, embed_dim=768, num_patches=64, name=None):
        super(AddPosEmbed, self).__init__(name=name)
        self.embed_dim = embed_dim
        self.num_patches = num_patches

    def build(self, input_shape):

        self.pos_embed = self.add_weight(name="pos_embed",
                                         shape=[1, self.num_patches, self.embed_dim],
                                         initializer=initializers.RandomNormal(stddev=0.02),
                                         trainable=True,
                                         dtype=tf.float32)

    def call(self, inputs, **kwargs):

        x = inputs + self.pos_embed

        return x

删去上面代码中input_shape,就会报这个错

猜你喜欢

转载自blog.csdn.net/qq_44065334/article/details/120619320
今日推荐