机器学习中遇到的问题

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011734144/article/details/84066630

1、ImportError: libcublas.so.9.0: cannot open shared object file: No such file or directory

找不到so库文件

用自己的账户运行包会报上面的错误,但是用root账户运行却没有问题

原因:root账户里面的LD_LIBRARY_PATH路径能找到这个so文件,但是自己的账户的LD_LIBRARY_PATH路径下找不到

解决办法:在~/.bashrc文件中添加:  export LD_LIBRARY_PATH=/usr/local/cuda-9.0/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

   

2、对训练好的模型部署tornado服务

第一次http请求的时候正常,但是第二次请求的时候会报如下错误:

ValueError: Variable rnn/multi_rnn_cell/cell_0/gru_cell/gates/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

解决办法:

     在代码的开始清除计算graph:

    

tf.reset_default_graph()

3、训练模型时候的输入格式的问题

训练的模型是为了预测,训练好的模型可能会部署成服务,那么为了更好的兼容调用方的传参方式,模型的输入类型最好尽可能的接近调用方的传参方式。

比如: 我的模型一开始设计的传参方式是摘要,基本信息,正文等分词后的文档的id表示,这种情况下,调用方调用的时候肯定无法传这样的参数

办法: 应该将调用方可能传参的格式作为模型的输入,所有的参数的处理方到模型内部去。

4、Data is not binary and pos_label is not specified

相关代码:   

all_data = handle_all_data()

    all_data = np.array(all_data, dtype=np.float64)

    train_data = all_data[: 1000000]

    test_data = all_data[1000000:]

    model_svm = svm.SVC(kernel='rbf', C=1.0)

    train_x = train_data[:, :-1]

    train_label = train_data[:, -1]

    model_svm.fit(train_x, train_label)



    print cross_val_score(model_svm, train_x, train_label, cv=10, scoring='roc_auc')

用svm模型来训练,错误原因是train_x和train_label 中的特征值是字符串

错误位置:np.array在将数组特征转换的时候,没有指定dtype,所以这个方法默认会把特征转换成字符串类型的,需要加上dtype=np.float64来指明特征值的类型

5、模型fit的时候报错

错误:ValueError: Unknown label type: 'continuous'

原因: 模型是分类问题,对y做了标准化

x_scaled = preprocessing.StandardScaler().fit_transform(x)

#y因为是分类,所以这里不要标准化,标准化后会变成浮点数

y_scaled = preprocessing.StandardScaler().fit_transform(y.reshape(-1, 1))

说明: 在y是预测一个值的时候进行标准化是可以的,但是如果y是预测的分类,则不要用标准化

6、执行fit时报错

代码: preprocessing.StandardScaler().fit(norm_x)

错误:ValueError: Input contains NaN, infinity or a value too large for dtype('float64')

基于错误说明,说明数据中可能存在空值,无限大的值等,我的代码中是某些出现了空值,通过查看数据的分布即可知道是否有空值

(pdb)   all_data.describe() , 可以得到如下结果

count行的值是每个字段不为空的值的总数,如果某个字段的count值少于样本总数, 那说明这个字段存在空值

7、用matplotlib时报错

错误:ImportError: No module named Tkinter

解决办法: 添加如下黑体部分

import matplotlib

matplotlib.use('agg')

import matplotlib.pyplot as plt

猜你喜欢

转载自blog.csdn.net/u011734144/article/details/84066630