simbert&milvus实现相似句检索

       朋友们,simbert模型是一个较好的相似句检索模型,但是在大规模检索中,需要实现快速检索,这个时候离不开milvus等向量检索库,下面用实际代码来讲一下simbert之milvus应用。 

import numpy as np
from bert4keras.backend import keras, K
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
import tensorflow as tf
from openapi_server.models.sentence_schema import SentenceSchema
from openapi_server.models.QaVecSchema import QaVecSchema
import connexion
from mysql_tool.connection import DBHelper
from config.loadconfig import get_logger
from milvus import Milvus, IndexType, MetricType, Status
import random
from bert4keras.snippets import sequence_padding
from apscheduler.schedulers.background import BackgroundScheduler
import datetime
import os

 

logger = get_logger(__name__)
global graph
graph = tf.get_default_graph()
sess = keras.backend.get_session()
# 获取绝对目录上上级目录
upper2path = os.path.abspath(os.path.join(os.getcwd()))
# bert配置
config_path =   "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_config.json"
checkpoint_path =  "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/bert_model.ckpt"
dict_path =  "/Users/Downloads/data/model/chinese_simbert_L-6_H-384_A-12/vocab.txt"
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)

 

# 建立加载模型
bert = build_transformer_model(
    config_path,
    checkpoint_path,
    with_pool='linear',
    application='unilm',
    return_keras_model=False,
)

# 加载编码器
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])

 

向量入库:

def qa2vecs():
    collection_reconstruct()
    data = qa_query()
    milvus, collection_name = MilvusHelper().connection()
    param = {
        'collection_name': collection_name,
        'dimension': 384,
        'index_file_size': 256,  # optional
        'metric_type': MetricType.IP  # optional
    }
    milvus.create_collection(param)
    vecs = []
    ids = []
    progress_idx = 0
    with sess.as_default():
        with graph.as_default():
            for record in data:
                progress_idx += 1
                token_ids, segment_ids = tokenizer.encode(record["text"])
                vec = encoder.predict([[token_ids], [segment_ids]])[0]
                vecs.append(vec)
                ids.append(record["id"])
                if (len(ids) % 5000 == 0 or progress_idx == len(data)) and len(ids) > 0:
                    logger.info("data sync :{:.2f}%".format(progress_idx * 100.0 / len(data)))
                    milvus.insert(collection_name=collection_name, records=vecs_normalize(vecs), ids=ids, params=param)
                    vecs = []
                    ids = []
    milvus.close()
    return progress_idx

 上面的向量入库的时候,文本的id和text都存了,milvus里面有id->text的向量,所以最终检索的时候,能够同时拿到vector和id,然后id去mysql里面找即可。

猜你喜欢

转载自blog.csdn.net/qq_23953717/article/details/130930167