[tensorflow] How to subsample dataset like bootstrapping using tf.data API? Tensorflow怎么对数据集做采样

版权声明:Copyright reserved to Hazekiah Wang ([email protected]) https://blog.csdn.net/u010909964/article/details/83834657

The problem is I have big dataset and for each epoch i want to use only random subset of this dataset, but how can I make it using tf.data API.

There are two approaches.

  1. use filters
def create_filter(proba):
    return tf.less_equal(tf.random_uniform([], dtype=tf.float32), tf.cast(proba, tf.float32))

proba = n_samples*self.batch_size/total_samples
ds = ds.filter(lambda *x: create_filter(proba))
  1. shuffle then skip
ds = ds.shuffle(buffer_size=total_samples//2).skip(9*total_samples//10)

this approach is slower if you choose large buffer_size, but if you take small buffer_size, the randomness will not be sufficient if you skip much of your dataset, which means the bootstrapping datasets actually have a lot that intersect.

猜你喜欢

转载自blog.csdn.net/u010909964/article/details/83834657