由于Tensorflow版本不一致的问题导致Estimator有很多坑!!!
问题代码如下
# 预定义estimator使用
output_dir = 'baseline_model'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
# tensorflow版本有问题,需要改!!!
baseline_estimator = tf.estimator.BaselineClassifier(model_dir=output_dir,
n_classes=2)
baseline_estimator.train(input_fn=lambda : make_dataset(
train_df,y_train,epochs=100))
解决办法如下:
baseline_estimator = tf.compat.v1.estimator.BaselineClassifier(model_dir=output_dir,
n_classes=2)