这是调试transformer中遇到的错误。
class AddPosEmbed(layers.Layer):
def __init__(self, embed_dim=768, num_patches=64, name=None):
super(AddPosEmbed, self).__init__(name=name)
self.embed_dim = embed_dim
self.num_patches = num_patches
def build(self, input_shape):
self.pos_embed = self.add_weight(name="pos_embed",
shape=[1, self.num_patches, self.embed_dim],
initializer=initializers.RandomNormal(stddev=0.02),
trainable=True,
dtype=tf.float32)
def call(self, inputs, **kwargs):
x = inputs + self.pos_embed
return x
删去上面代码中input_shape,就会报这个错