CNN检测模型统计检出率

        X, y = get_feature_charseq()
        #max_document_length=64
        volcab_file = "volcab.pkl"
        assert os.path.exists(volcab_file)
        pkl_file = open(volcab_file, 'rb')
        data = pickle.load(pkl_file)
        valid_chars, max_document_length, max_features = data["valid_chars"], data["max_len"], data["volcab_size"]

        print "max_features:", max_features
        print "max_len:", max_document_length

        model = get_cnn_model(max_document_length, max_features, is_training=False)
        model.load(filename)

        print "X:", X[:3]
        print "Y:", y[:3]
        print "len X:", len(X)
        X = pad_sequences(X, maxlen=max_document_length, value=0.)

        result = model.predict(X[1:4])
        print result

        cnt = 0
        page = 10000
        total_page = int((len(X)+page-1)/page)
        print "total_page:", total_page
        labels = y
        labels_pred = []
        for ii in range(0, total_page):
                predictions = model.predict(X[ii*page:ii*page+page])

                labels_pred += list(np.argmax(predictions, axis=1))

                for i,p in enumerate(predictions):
                    i = i+ii*page

                    if p[1] >= 0.95:
                        cnt += 1
        print cnt, " total black(0.95)"
        print(confusion_matrix(labels, labels_pred))
        print(classification_report(labels, labels_pred))

猜你喜欢

转载自www.cnblogs.com/bonelee/p/9075374.html
今日推荐