TensorFlow实现一个简单的LR分类器

tensorflow 实现一个简单的LR分类器。

(1) 读入数据
训练数据保存在csv格式的文本中,第一行是特征名称,最后一列代表label
tensorflow特证名不支持汉字,在构建特征的时候对汉字特证名进行了转换。

def get_datas(file_name):
    datas = pd.read_csv(file_name)
    columns_map = {}
    count = 0
    for column in datas.columns:
        columns_map[column] = 'f'+str(count)
        count += 1
    # columns_format.append('label')
    return columns_map, datas.columns, datas

(2)构建训练输入函数

def input_fn(datas, columns, columns_map, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.

    feature_map = {}

    for column in columns[:-1]:
        feature_map[columns_map[column]] = datas.get(column)

    labels = datas.get(columns[-1])

    dataset = tf.data.Dataset.from_tensor_slices((feature_map, labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(3000).repeat().batch(batch_size)

    # Return the read end of the pipeline.
    return dataset.make_one_shot_iterator().get_next()

(3)训练模型,验证模型,预测数据

columns_map, columns, datas = get_datas(file_name)

eval_columns_map, eval_columns, eval_datas = get_datas(eval_file_name)

feature_columns = [tf.feature_column.numeric_column(columns_map[name]) for name in columns[:-1]]

classifier = tf.estimator.LinearClassifier(feature_columns=feature_columns)
classifier.train(input_fn=lambda: input_fn(datas, columns, columns_map, 100), steps=5000)
result = classifier.evaluate(input_fn=lambda: input_fn(eval_datas, eval_columns, eval_columns_map, 100), steps=10)

# clear_output()

for key,value in sorted(result.items()):
  print('%s: %s' % (key, value))

predict = classifier.predict(input_fn=lambda: input_fn(eval_datas, eval_columns, eval_columns_map, 100))

count = 0
for pre in predict:
    print('================ %s'%count)
    count += 1
    print(pre['logits'])
    print(pre['logistic'])
    print(pre['probabilities'])
    print(pre['class_ids'])
    print(pre['classes'])

验证效果

accuracy: 0.995
accuracy_baseline: 0.689
auc: 0.9999533
auc_precision_recall: 0.99989796
average_loss: 0.020066695
global_step: 5000
label/mean: 0.311
loss: 2.0066695
precision: 1.0
prediction/mean: 0.3166183
recall: 0.98392284

验证集取样可能存在问题,需要分析,发现问题。

猜你喜欢

转载自blog.csdn.net/WitsMakeMen/article/details/81589454