| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- import time
- import numpy as np
- from pymilvus import (
- MilvusClient,
- DataType,
- Function,
- FunctionType,
- AnnSearchRequest,
- RRFRanker,
- )
- # from pymilvus.model.hybrid import BGEM3EmbeddingFunction
- from pymilvus import model
- from rag.load_model import sentence_transformer_ef
- from utils.get_logger import setup_logger
- import torch
- device = "cpu" if torch.cuda.is_available() else "cuda"
- logger = setup_logger(__name__)
- # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
- # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
- embedding_mapping = {
- "e5": sentence_transformer_ef,
- "multilingual-e5-large-instruct": sentence_transformer_ef,
- }
- class HybridRetriever:
- def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
- self.uri = uri
- self.collection_name = collection_name
- # self.embedding_function = sentence_transformer_ef
- self.embedding_function = embedding_mapping.get(embedding_name, "e5")
- self.use_reranker = True
- self.use_sparse = True
- self.client = MilvusClient(uri=uri)
-
- def has_collection(self):
- try:
- collection_flag = self.client.has_collection(self.collection_name)
- logger.info(f"查询向量库的结果:{collection_flag}")
- except Exception as e:
- logger.info(f"查询向量库是否存在出错:{e}")
- collection_flag = False
- return collection_flag
- def build_collection(self):
- if isinstance(self.embedding_function.dim, dict):
- dense_dim = self.embedding_function.dim["dense"]
- else:
- dense_dim = self.embedding_function.dim
- logger.info(f"创建数据库的向量维度:{dense_dim}")
- analyzer_params={
- "type": "chinese"
- }
- schema = MilvusClient.create_schema()
- schema.add_field(
- field_name="pk",
- datatype=DataType.VARCHAR,
- is_primary=True,
- auto_id=True,
- max_length=100,
- )
- schema.add_field(
- field_name="content",
- datatype=DataType.VARCHAR,
- max_length=65535,
- analyzer_params=analyzer_params,
- enable_match=True,
- enable_analyzer=True,
- )
- schema.add_field(
- field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
- )
- schema.add_field(
- field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
- )
- schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
- schema.add_field(
- field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
- )
- schema.add_field(field_name="metadata", datatype=DataType.JSON)
- functions = Function(
- name="bm25",
- function_type=FunctionType.BM25,
- input_field_names=["content"],
- output_field_names="sparse_vector",
- )
- schema.add_function(functions)
- index_params = MilvusClient.prepare_index_params()
- index_params.add_index(
- field_name="sparse_vector",
- index_type="SPARSE_INVERTED_INDEX",
- metric_type="BM25",
- )
- index_params.add_index(
- field_name="dense_vector", index_type="FLAT", metric_type="IP"
- )
- try:
- self.client.create_collection(
- collection_name=self.collection_name,
- schema=schema,
- index_params=index_params,
- )
- return "create_collection_success"
- except Exception as e:
- logger.error(f"创建{self.collection_name}数据库失败:{e}")
- return "create_collection_error"
- def insert_data(self, chunk, metadata):
- logger.info("准备插入数据")
- with torch.no_grad():
- embedding = self.embedding_function([chunk])
- logger.info("获取文本的向量信息。")
- if isinstance(embedding, dict) and "dense" in embedding:
- # bge embedding 获取embedding的方式
- dense_vec = embedding["dense"][0]
- else:
- dense_vec = embedding[0]
-
- try:
- self.client.insert(
- self.collection_name, {"dense_vector": dense_vec, **metadata}
- )
- logger.info("插入一条数据成功。")
- return True, "success"
- except Exception as e:
- doc_id = metadata.get("doc_id")
- logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
- self.delete_by_doc_id(doc_id=doc_id)
- return False, str(e)
-
- def batch_insert_data(self, chunks, metadatas):
- logger.info("准备插入数据")
- embedding_lists = self.embedding_function.encode_documents(chunks)
- logger.info("获取文本的向量信息。")
- record_lists = []
- for embedding, metadata in zip(embedding_lists, metadatas):
- if isinstance(embedding, dict) and "dense" in embedding:
- # bge embedding 获取embedding的方式
- dense_vec = embedding["dense"][0]
- else:
- dense_vec = embedding.tolist()
- # if hasattr(dense_vec, 'tolist'):
- # dense_vec = dense_vec.tolist()
- # logger.info(f"向量维度:{dense_vec}")
-
- # if isinstance(dense_vec, (float, int)):
- # dense_vec = [dense_vec]
-
- # if isinstance(dense_vec, np.float32):
- # dense_vec = [float(dense_vec)]
-
- record = {"dense_vector": dense_vec}
- record.update(metadata)
- record_lists.append(record)
-
- try:
- self.client.insert(
- self.collection_name, record_lists
- )
- logger.info("插入数据成功。")
- return True, "success"
- except Exception as e:
- doc_id = metadata.get("doc_id")
- logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
- self.delete_by_doc_id(doc_id=doc_id)
- return False, str(e)
-
- def search(self, query: str, k: int = 20, mode="hybrid"):
- output_fields = [
- "content",
- "doc_id",
- "chunk_id",
- "metadata",
- ]
- if mode in ["dense", "hybrid"]:
- with torch.no_grad():
- embedding = self.embedding_function([query])
- if isinstance(embedding, dict) and "dense" in embedding:
- dense_vec = embedding["dense"][0]
- else:
- dense_vec = embedding[0]
- if mode == "sparse":
- results = self.client.search(
- collection_name=self.collection_name,
- data=[query],
- anns_field="sparse_vector",
- limit=k,
- output_fields=output_fields,
- )
- elif mode == "dense":
- results = self.client.search(
- collection_name=self.collection_name,
- data=[dense_vec],
- anns_field="dense_vector",
- limit=k,
- output_fields=output_fields,
- )
- elif mode == "hybrid":
- full_text_search_params = {"metric_type": "BM25"}
- full_text_search_req = AnnSearchRequest(
- [query], "sparse_vector", full_text_search_params, limit=k
- )
- dense_search_params = {"metric_type": "IP"}
- dense_req = AnnSearchRequest(
- [dense_vec], "dense_vector", dense_search_params, limit=k
- )
- results = self.client.hybrid_search(
- self.collection_name,
- [full_text_search_req, dense_req],
- ranker=RRFRanker(),
- limit=k,
- output_fields=output_fields,
- )
- else:
- raise ValueError("Invalid mode")
- return [
- {
- "doc_id": doc["entity"]["doc_id"],
- "chunk_id": doc["entity"]["chunk_id"],
- "content": doc["entity"]["content"],
- "metadata": doc["entity"]["metadata"],
- "score": doc["distance"],
- }
- for doc in results[0]
- ]
-
- def query_filter(self, doc_id, filter_field):
- # doc id 文档id,content中包含 filter_field 字段的
- query_output_field = [
- "content",
- "chunk_id",
- "doc_id",
- "metadata"
- ]
- # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
- # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
- if filter_field:
- query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
- else:
- query_expr = f"doc_id == '{doc_id}'"
- try:
- query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
- except Exception as e:
- logger.error(f"根据关键词查询数据失败:{e}")
- query_filter_results = [{"code": 500}]
- return query_filter_results
- # for result in query_filter_results:
- # print(f"根据doc id 和 field 过滤结果: {result}\n\n")
-
- def query_chunk_id(self, chunk_id):
- # chunk id,查询切片
- query_output_field = [
- "content",
- "doc_id",
- "chunk_id",
- # "metadata"
- ]
- query_expr = f"chunk_id == '{chunk_id}'"
- try:
- query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
- except Exception as e:
- logger.info(f"根据chunk id 查询出错:{e}")
- query_filter_results = [{"code": 500}]
- return query_filter_results
-
- def query_chunk_id_list(self, chunk_id_list):
- # chunk id,查询切片
- query_output_field = [
- "content",
- "doc_id",
- "chunk_id",
- # "metadata"
- ]
- query_expr = f"chunk_id in {chunk_id_list}"
- try:
- query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
- except Exception as e:
- logger.info(f"根据chunk id 查询出错:{e}")
- query_filter_results = [{"code": 500}]
- return query_filter_results
- def update_data(self, chunk_id, chunk):
- # 根据chunk id查询对应的信息,
- chunk_expr = f"chunk_id == '{chunk_id}'"
- chunk_output_fields = [
- "pk",
- "doc_id",
- "chunk_id",
- "metadata"
- ]
- try:
- chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
- # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
- except Exception as e:
- logger.error(f"更新切片数据时查询失败:{e}")
- return "update_query_error", ""
- if not chunk_results:
- logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据")
- return "update_query_no_result", ""
- with torch.no_grad():
- embedding = self.embedding_function([chunk])
- if isinstance(embedding, dict) and "dense" in embedding:
- # bge embedding 获取embedding的方式
- dense_vec = embedding["dense"][0]
- else:
- dense_vec = embedding[0]
- chunk_dict = chunk_results[0]
- metadata = chunk_dict.get("metadata")
- old_chunk_len = metadata.get("chunk_len")
- chunk_len = len(chunk)
- metadata["chunk_len"] = chunk_len
- chunk_dict["content"] = chunk
- chunk_dict["dense_vector"] = dense_vec
- chunk_dict["metadata"] = metadata
- try:
- update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
- logger.info(f"更新返回的数据:{update_res}")
- return "update_success", chunk_len - old_chunk_len
- except Exception as e:
- logger.error(f"更新数据时出错:{e}")
- return "update_error", ""
-
- def update_data(self, chunk_id, chunk):
- # 根据chunk id查询对应的信息,
- chunk_expr = f"chunk_id == '{chunk_id}'"
- chunk_output_fields = [
- "pk",
- "doc_id",
- "chunk_id",
- "metadata"
- ]
- try:
- chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
- # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
- except Exception as e:
- logger.error(f"更新切片数据时查询失败:{e}")
- return "update_query_error", ""
- if not chunk_results:
- logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
- return "update_query_no_result", ""
- with torch.no_grad():
- embedding = self.embedding_function([chunk])
- if isinstance(embedding, dict) and "dense" in embedding:
- # bge embedding 获取embedding的方式
- dense_vec = embedding["dense"][0]
- else:
- dense_vec = embedding[0]
- chunk_dict = chunk_results[0]
- metadata = chunk_dict.get("metadata")
- old_chunk_len = metadata.get("chunk_len")
- chunk_len = len(chunk)
- metadata["chunk_len"] = chunk_len
- chunk_dict["content"] = chunk
- chunk_dict["dense_vector"] = dense_vec
- chunk_dict["metadata"] = metadata
- try:
- update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
- logger.info(f"更新返回的数据:{update_res}")
- return "update_success", chunk_len - old_chunk_len
- except Exception as e:
- logger.error(f"更新数据时出错:{e}")
- return "update_error", ""
-
- def delete_collection(self, collection):
- try:
- self.client.drop_collection(collection_name=collection)
- return "delete_collection_success"
- except Exception as e:
- logger.error(f"删除{collection}失败,出错原因:{e}")
- return "delete_collection_error"
- def delete_by_chunk_id(self, chunk_id:str = None):
- # 根据文档id查询主键值 milvus只支持主键删除
- expr = f"chunk_id == '{chunk_id}'"
- try:
- results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id
- logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
- except Exception as e:
- logger.error(f"根据切片id查询主键失败:{e}")
- return "delete_query_error", []
- if not results:
- print(f"No data found for chunk id: {chunk_id}")
- # return "delete_no_result", []
- return "delete_success", [0]
-
- # 提取主键值
- primary_keys = [result["pk"] for result in results]
- chunk_len = [result["metadata"]["chunk_len"] for result in results]
- logger.info(f"获取到的主键信息:{primary_keys}")
-
- # 执行删除操作
- expr_delete = f"pk in {primary_keys}" # 构造删除表达式
- try:
- delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
- self.client.flush(collection_name=self.collection_name)
- self.client.compact(collection_name=self.collection_name)
- logger.info(f"Deleted data with chunk_id: {delete_res}")
- return "delete_success", chunk_len
- except Exception as e:
- logger.error(f"删除数据失败:{e}")
- return "delete_error", []
-
- def delete_by_doc_id(self, doc_id:str =None):
- # 根据文档id查询主键值 milvus只支持主键删除
- expr = f"doc_id == '{doc_id}'"
- try:
- results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id
- except Exception as e:
- logger.error(f"根据切片id查询主键失败:{e}")
- return "delete_query_error"
- if not results:
- print(f"No data found for doc_id: {doc_id}")
- return "delete_no_result"
-
- # 提取主键值
- primary_keys = [result["pk"] for result in results]
- logger.info(f"获取到的主键信息:{primary_keys}")
-
- # 执行删除操作
- try:
- delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys)
- self.client.flush(collection_name=self.collection_name)
- self.client.compact(collection_name=self.collection_name)
- logger.info(f"Deleted data with doc_id: {delete_res}")
- return "delete_success"
- except Exception as e:
- logger.error(f"删除数据失败:{e}")
- return "delete_error"
- # # 执行删除操作
- # expr_delete = f"pk in {primary_keys}" # 构造删除表达式
- # try:
- # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
- # logger.info(f"Deleted data with doc_id: {delete_res}")
- # return "delete_success"
- # except Exception as e:
- # logger.error(f"删除数据失败:{e}")
- # return "delete_error"
-
-
- # 测试
- # def parse_json_to_schema_data_format():
- # sql_text_list = load_sql_query_ddl_info()
- # docs = []
- # for sql_info in sql_text_list:
- # sql_text = sql_info.get("sql_text")
- # source = ",".join(sql_info.get("table_list", []))
- # source = sql_info.get("source") if not source else source
- # sql = sql_info.get("sql")
- # ddl = sql_info.get("table_ddl")
- # metadata = {"source": source, "sql": sql, "ddl": ddl}
- # text_list = sql_info.get("sim_sql_text_list", [])
- # text_list.append(sql_text)
- # doc_id = str(uuid4())
- # for text in text_list:
- # chunk_id = str(uuid4())
- # insert_dict = {
- # "content": text,
- # "doc_id": doc_id,
- # "chunk_id": chunk_id,
- # "metadata": metadata
- # }
- # docs.append(insert_dict)
-
- # return docs
- # def insert_data_to_milvus(standard_retriever):
- # sql_dataset = parse_json_to_schema_data_format()
- # standard_retriever.build_collection()
- # for sql_dict in sql_dataset:
- # text = sql_dict["content"]
- # standard_retriever.insert_data(text, sql_dict)
- def main():
- # dense_ef = BGEM3EmbeddingFunction()
- embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
- sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
- standard_retriever = HybridRetriever(
- uri="http://localhost:19530",
- collection_name="milvus_hybrid",
- dense_embedding_function=sentence_transformer_ef,
- )
- # 插入数据
- # insert_data_to_milvus(standard_retriever)
- # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
- # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
- # model=sparse 稀疏检索
- # print(f"稀疏检索结果:{results}")
- # model=dense 向量检索
- # print(f"向量检索结果:{results}")
- # model=hybrid
- # print(f"向量检索结果:{results}")
- # 根据doc id删除数据
- # delete_start_time = time.time()
- # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
- # standard_retriever.delete_data(doc_id=doc_id)
- # delete_end_time = time.time()
- # print(f"删除耗时:{delete_end_time-delete_start_time}")
- # 根据chunk_id 更新数据
- # update_start_time = time.time()
- # chunk = "查询一下员工加班的情况"
- # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
- # standard_retriever.update_data(chunk_id, chunk)
- # update_end_time = time.time()
- # print(f"更新数据的时间:{update_end_time-update_start_time}")
- # 根据doc id 和关键字查询
- query_start_time = time.time()
- filter_field = "加班"
- doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
- standard_retriever.query_filter(doc_id, filter_field)
- query_end_time = time.time()
- print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
- if __name__=="__main__":
- main()
|