import time import numpy as np from pymilvus import ( MilvusClient, DataType, Function, FunctionType, AnnSearchRequest, RRFRanker, ) # from pymilvus.model.hybrid import BGEM3EmbeddingFunction from pymilvus import model from rag.load_model import sentence_transformer_ef from utils.get_logger import setup_logger import torch device = "cpu" if torch.cuda.is_available() else "cuda" logger = setup_logger(__name__) # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/" # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device) embedding_mapping = { "e5": sentence_transformer_ef, "multilingual-e5-large-instruct": sentence_transformer_ef, } class HybridRetriever: def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"): self.uri = uri self.collection_name = collection_name # self.embedding_function = sentence_transformer_ef self.embedding_function = embedding_mapping.get(embedding_name, "e5") self.use_reranker = True self.use_sparse = True self.client = MilvusClient(uri=uri) def has_collection(self): try: collection_flag = self.client.has_collection(self.collection_name) logger.info(f"查询向量库的结果:{collection_flag}") except Exception as e: logger.info(f"查询向量库是否存在出错:{e}") collection_flag = False return collection_flag def build_collection(self): if isinstance(self.embedding_function.dim, dict): dense_dim = self.embedding_function.dim["dense"] else: dense_dim = self.embedding_function.dim logger.info(f"创建数据库的向量维度:{dense_dim}") analyzer_params={ "type": "chinese" } schema = MilvusClient.create_schema() schema.add_field( field_name="pk", datatype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100, ) schema.add_field( field_name="content", datatype=DataType.VARCHAR, max_length=65535, analyzer_params=analyzer_params, enable_match=True, enable_analyzer=True, ) schema.add_field( field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR ) schema.add_field( field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim ) schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64) schema.add_field( field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64 ) schema.add_field(field_name="metadata", datatype=DataType.JSON) functions = Function( name="bm25", function_type=FunctionType.BM25, input_field_names=["content"], output_field_names="sparse_vector", ) schema.add_function(functions) index_params = MilvusClient.prepare_index_params() index_params.add_index( field_name="sparse_vector", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", ) index_params.add_index( field_name="dense_vector", index_type="FLAT", metric_type="IP" ) try: self.client.create_collection( collection_name=self.collection_name, schema=schema, index_params=index_params, ) return "create_collection_success" except Exception as e: logger.error(f"创建{self.collection_name}数据库失败:{e}") return "create_collection_error" def insert_data(self, chunk, metadata): logger.info("准备插入数据") with torch.no_grad(): embedding = self.embedding_function([chunk]) logger.info("获取文本的向量信息。") if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 dense_vec = embedding["dense"][0] else: dense_vec = embedding[0] try: self.client.insert( self.collection_name, {"dense_vector": dense_vec, **metadata} ) logger.info("插入一条数据成功。") return True, "success" except Exception as e: doc_id = metadata.get("doc_id") logger.error(f"处理文档:{doc_id},插入数据出错:{e}") self.delete_by_doc_id(doc_id=doc_id) return False, str(e) def batch_insert_data(self, chunks, metadatas): logger.info("准备插入数据") embedding_lists = self.embedding_function.encode_documents(chunks) logger.info("获取文本的向量信息。") record_lists = [] for embedding, metadata in zip(embedding_lists, metadatas): if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 dense_vec = embedding["dense"][0] else: dense_vec = embedding.tolist() # if hasattr(dense_vec, 'tolist'): # dense_vec = dense_vec.tolist() # logger.info(f"向量维度:{dense_vec}") # if isinstance(dense_vec, (float, int)): # dense_vec = [dense_vec] # if isinstance(dense_vec, np.float32): # dense_vec = [float(dense_vec)] record = {"dense_vector": dense_vec} record.update(metadata) record_lists.append(record) try: self.client.insert( self.collection_name, record_lists ) logger.info("插入数据成功。") return True, "success" except Exception as e: doc_id = metadata.get("doc_id") logger.error(f"处理文档:{doc_id},插入数据出错:{e}") self.delete_by_doc_id(doc_id=doc_id) return False, str(e) def search(self, query: str, k: int = 20, mode="hybrid"): output_fields = [ "content", "doc_id", "chunk_id", "metadata", ] if mode in ["dense", "hybrid"]: with torch.no_grad(): embedding = self.embedding_function([query]) if isinstance(embedding, dict) and "dense" in embedding: dense_vec = embedding["dense"][0] else: dense_vec = embedding[0] if mode == "sparse": results = self.client.search( collection_name=self.collection_name, data=[query], anns_field="sparse_vector", limit=k, output_fields=output_fields, ) elif mode == "dense": results = self.client.search( collection_name=self.collection_name, data=[dense_vec], anns_field="dense_vector", limit=k, output_fields=output_fields, ) elif mode == "hybrid": full_text_search_params = {"metric_type": "BM25"} full_text_search_req = AnnSearchRequest( [query], "sparse_vector", full_text_search_params, limit=k ) dense_search_params = {"metric_type": "IP"} dense_req = AnnSearchRequest( [dense_vec], "dense_vector", dense_search_params, limit=k ) results = self.client.hybrid_search( self.collection_name, [full_text_search_req, dense_req], ranker=RRFRanker(), limit=k, output_fields=output_fields, ) else: raise ValueError("Invalid mode") return [ { "doc_id": doc["entity"]["doc_id"], "chunk_id": doc["entity"]["chunk_id"], "content": doc["entity"]["content"], "metadata": doc["entity"]["metadata"], "score": doc["distance"], } for doc in results[0] ] def query_filter(self, doc_id, filter_field): # doc id 文档id,content中包含 filter_field 字段的 query_output_field = [ "content", "chunk_id", "doc_id", "metadata" ] # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'" # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询 if filter_field: query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'" else: query_expr = f"doc_id == '{doc_id}'" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.error(f"根据关键词查询数据失败:{e}") query_filter_results = [{"code": 500}] return query_filter_results # for result in query_filter_results: # print(f"根据doc id 和 field 过滤结果: {result}\n\n") def query_chunk_id(self, chunk_id): # chunk id,查询切片 query_output_field = [ "content", "doc_id", "chunk_id", # "metadata" ] query_expr = f"chunk_id == '{chunk_id}'" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.info(f"根据chunk id 查询出错:{e}") query_filter_results = [{"code": 500}] return query_filter_results def query_chunk_id_list(self, chunk_id_list): # chunk id,查询切片 query_output_field = [ "content", "doc_id", "chunk_id", # "metadata" ] query_expr = f"chunk_id in {chunk_id_list}" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.info(f"根据chunk id 查询出错:{e}") query_filter_results = [{"code": 500}] return query_filter_results def update_data(self, chunk_id, chunk): # 根据chunk id查询对应的信息, chunk_expr = f"chunk_id == '{chunk_id}'" chunk_output_fields = [ "pk", "doc_id", "chunk_id", "metadata" ] try: chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields) # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}") except Exception as e: logger.error(f"更新切片数据时查询失败:{e}") return "update_query_error", "" if not chunk_results: logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据") return "update_query_no_result", "" with torch.no_grad(): embedding = self.embedding_function([chunk]) if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 dense_vec = embedding["dense"][0] else: dense_vec = embedding[0] chunk_dict = chunk_results[0] metadata = chunk_dict.get("metadata") old_chunk_len = metadata.get("chunk_len") chunk_len = len(chunk) metadata["chunk_len"] = chunk_len chunk_dict["content"] = chunk chunk_dict["dense_vector"] = dense_vec chunk_dict["metadata"] = metadata try: update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict]) logger.info(f"更新返回的数据:{update_res}") return "update_success", chunk_len - old_chunk_len except Exception as e: logger.error(f"更新数据时出错:{e}") return "update_error", "" def update_data(self, chunk_id, chunk): # 根据chunk id查询对应的信息, chunk_expr = f"chunk_id == '{chunk_id}'" chunk_output_fields = [ "pk", "doc_id", "chunk_id", "metadata" ] try: chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields) # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}") except Exception as e: logger.error(f"更新切片数据时查询失败:{e}") return "update_query_error", "" if not chunk_results: logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据") return "update_query_no_result", "" with torch.no_grad(): embedding = self.embedding_function([chunk]) if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 dense_vec = embedding["dense"][0] else: dense_vec = embedding[0] chunk_dict = chunk_results[0] metadata = chunk_dict.get("metadata") old_chunk_len = metadata.get("chunk_len") chunk_len = len(chunk) metadata["chunk_len"] = chunk_len chunk_dict["content"] = chunk chunk_dict["dense_vector"] = dense_vec chunk_dict["metadata"] = metadata try: update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict]) logger.info(f"更新返回的数据:{update_res}") return "update_success", chunk_len - old_chunk_len except Exception as e: logger.error(f"更新数据时出错:{e}") return "update_error", "" def delete_collection(self, collection): try: self.client.drop_collection(collection_name=collection) return "delete_collection_success" except Exception as e: logger.error(f"删除{collection}失败,出错原因:{e}") return "delete_collection_error" def delete_by_chunk_id(self, chunk_id:str = None): # 根据文档id查询主键值 milvus只支持主键删除 expr = f"chunk_id == '{chunk_id}'" try: results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id logger.info(f"根据切片id:{chunk_id},查询的数据:{results}") except Exception as e: logger.error(f"根据切片id查询主键失败:{e}") return "delete_query_error", [] if not results: print(f"No data found for chunk id: {chunk_id}") # return "delete_no_result", [] return "delete_success", [0] # 提取主键值 primary_keys = [result["pk"] for result in results] chunk_len = [result["metadata"]["chunk_len"] for result in results] logger.info(f"获取到的主键信息:{primary_keys}") # 执行删除操作 expr_delete = f"pk in {primary_keys}" # 构造删除表达式 try: delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete) self.client.flush(collection_name=self.collection_name) self.client.compact(collection_name=self.collection_name) logger.info(f"Deleted data with chunk_id: {delete_res}") return "delete_success", chunk_len except Exception as e: logger.error(f"删除数据失败:{e}") return "delete_error", [] def delete_by_doc_id(self, doc_id:str =None): # 根据文档id查询主键值 milvus只支持主键删除 expr = f"doc_id == '{doc_id}'" try: results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id except Exception as e: logger.error(f"根据切片id查询主键失败:{e}") return "delete_query_error" if not results: print(f"No data found for doc_id: {doc_id}") return "delete_no_result" # 提取主键值 primary_keys = [result["pk"] for result in results] logger.info(f"获取到的主键信息:{primary_keys}") # 执行删除操作 try: delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys) self.client.flush(collection_name=self.collection_name) self.client.compact(collection_name=self.collection_name) logger.info(f"Deleted data with doc_id: {delete_res}") return "delete_success" except Exception as e: logger.error(f"删除数据失败:{e}") return "delete_error" # # 执行删除操作 # expr_delete = f"pk in {primary_keys}" # 构造删除表达式 # try: # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete) # logger.info(f"Deleted data with doc_id: {delete_res}") # return "delete_success" # except Exception as e: # logger.error(f"删除数据失败:{e}") # return "delete_error" # 测试 # def parse_json_to_schema_data_format(): # sql_text_list = load_sql_query_ddl_info() # docs = [] # for sql_info in sql_text_list: # sql_text = sql_info.get("sql_text") # source = ",".join(sql_info.get("table_list", [])) # source = sql_info.get("source") if not source else source # sql = sql_info.get("sql") # ddl = sql_info.get("table_ddl") # metadata = {"source": source, "sql": sql, "ddl": ddl} # text_list = sql_info.get("sim_sql_text_list", []) # text_list.append(sql_text) # doc_id = str(uuid4()) # for text in text_list: # chunk_id = str(uuid4()) # insert_dict = { # "content": text, # "doc_id": doc_id, # "chunk_id": chunk_id, # "metadata": metadata # } # docs.append(insert_dict) # return docs # def insert_data_to_milvus(standard_retriever): # sql_dataset = parse_json_to_schema_data_format() # standard_retriever.build_collection() # for sql_dict in sql_dataset: # text = sql_dict["content"] # standard_retriever.insert_data(text, sql_dict) def main(): # dense_ef = BGEM3EmbeddingFunction() embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/" sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device) standard_retriever = HybridRetriever( uri="http://localhost:19530", collection_name="milvus_hybrid", dense_embedding_function=sentence_transformer_ef, ) # 插入数据 # insert_data_to_milvus(standard_retriever) # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3) # model=sparse 稀疏检索 # print(f"稀疏检索结果:{results}") # model=dense 向量检索 # print(f"向量检索结果:{results}") # model=hybrid # print(f"向量检索结果:{results}") # 根据doc id删除数据 # delete_start_time = time.time() # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064" # standard_retriever.delete_data(doc_id=doc_id) # delete_end_time = time.time() # print(f"删除耗时:{delete_end_time-delete_start_time}") # 根据chunk_id 更新数据 # update_start_time = time.time() # chunk = "查询一下员工加班的情况" # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418" # standard_retriever.update_data(chunk_id, chunk) # update_end_time = time.time() # print(f"更新数据的时间:{update_end_time-update_start_time}") # 根据doc id 和关键字查询 query_start_time = time.time() filter_field = "加班" doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"] standard_retriever.query_filter(doc_id, filter_field) query_end_time = time.time() print(f"关键字搜索数据的时间:{query_end_time-query_start_time}") if __name__=="__main__": main()