利用tf.gather_nd等一系列tf函数取出qebrain中的mis_matching feature

            shift_proj_inputs = self.emb_proj_layer(shift_inputs)
            _pre_qefv = tf.concat([shift_outputs, shift_proj_inputs], axis=-1)
            # _pre_qefv = shift_outputs + shift_proj_inputs
            # Notice, currently <s> is not to predict, but actually in our QE model, we can predict it.
            logits = self.output_layer(_pre_qefv)
            sample_id = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
            # Extract logits feature for mismatching
            shape = tf.shape(fw_target_input)
            idx0 = tf.expand_dims(tf.range(shape[0]), -1)
            idx0 = tf.tile(idx0, [1, shape[1]])
            idx0 = tf.cast(idx0, fw_target_input.dtype)
            idx1 = tf.expand_dims(tf.range(shape[1]), 0)
            idx1 = tf.tile(idx1, [shape[0], 1])
            idx1 = tf.cast(idx1, fw_target_input.dtype)
            indices_real = tf.stack([idx0, idx1, fw_target_input], axis=-1)
            logits_mt = tf.gather_nd(logits, indices_real)
            logits_max = tf.reduce_max(logits, axis=-1)
            logits_diff = tf.subtract(logits_max, logits_mt)
            logits_same = tf.cast(tf.equal(sample_id, fw_target_input), tf.float32)
            logits_fea = tf.stack([logits_mt, logits_max, logits_diff, logits_same], axis=-1)

如题,在上面一段代码中,关键要理解的是tf.gather_nd函数。这个函数接受两个参数,然后会根据第二个参数从第一个参数中取出来对应位置的值。

logits的大小是[batch_size, sequence_len, vocab_size], indices_real的大小是[batch_size, sequence_len, 3],这里的3又对应的是[batch_size, sequence_len, vocab_size], 最后会取出[batch_size, sequence_len]个值,组成一个[batch_size, sequence_len]大小的向量,这就是所谓的mt feature。

那么indices_real是如何得到的呢?关键在于idx0和idx1。我们希望idx0的每个元素对应的是batch_size,idx1的每个元素对应的是seq_len,所以希望idx0是[[0,0,0,...],[1,1,1,...],...],idx1是[[0,1,2,...],[0,1,2,...],...],这样最后得到的元素就是[0,0,x],[0,1,x],...,[1,0,x],[1,1,x],...。所以对于idx0,需要在第1维上重复,而对于idx1,需要在第0维上重复。

猜你喜欢

转载自blog.csdn.net/bonjourdeutsch/article/details/99843482
今日推荐