import time import numpy as np import threading from pymilvus import ( MilvusClient, DataType, Function, FunctionType, AnnSearchRequest, RRFRanker, ) # from pymilvus.model.hybrid import BGEM3EmbeddingFunction from pymilvus import model # from rag.load_model import sentence_transformer_ef, sentence_transformer_qwen from rag.load_model import qwen_ed_ef, bge_m3_ef from utils.get_logger import setup_logger import torch #device = "cpu" if torch.cuda.is_available() else "cuda" device = "cuda" if torch.cuda.is_available() else "cpu" logger = setup_logger(__name__) # 全局锁:保护 PyTorch/CUDA 推理操作,防止多线程并发导致核心转储 _embedding_lock = threading.Lock() # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/" # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device) embedding_mapping = { "bge-m3": bge_m3_ef, # "multilingual-e5-large-instruct": sentence_transformer_ef, "Qwen3-Embedding-0.6B": qwen_ed_ef } class HybridRetriever: def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"): self.uri = uri self.collection_name = collection_name self.embedding_name = embedding_name # self.embedding_function = sentence_transformer_ef logger.info(f"使用的embedding模型是:{embedding_name}") self.embedding_function = embedding_mapping.get(embedding_name) self.use_reranker = True self.use_sparse = True self.client = MilvusClient(uri=uri) def has_collection(self): try: collection_flag = self.client.has_collection(self.collection_name) logger.info(f"查询向量库的结果:{collection_flag}") except Exception as e: logger.info(f"查询向量库是否存在出错:{e}") collection_flag = False return collection_flag def build_collection(self): if isinstance(self.embedding_function.dim, dict): dense_dim = self.embedding_function.dim["dense"] else: dense_dim = self.embedding_function.dim # dense_dim = 1024 logger.info(f"创建数据库的向量维度:{dense_dim}") analyzer_params={ "type": "chinese" } schema = MilvusClient.create_schema() schema.add_field( field_name="pk", datatype=DataType.VARCHAR, is_primary=True, auto_id=True, max_length=100, ) schema.add_field( field_name="content", datatype=DataType.VARCHAR, max_length=65535, analyzer_params=analyzer_params, enable_match=True, enable_analyzer=True, ) schema.add_field( field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR ) schema.add_field( field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim ) schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64) schema.add_field( field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64 ) schema.add_field(field_name="metadata", datatype=DataType.JSON) schema.add_field(field_name="Chapter", datatype=DataType.VARCHAR, max_length=1024) schema.add_field(field_name="Father_Chapter", datatype=DataType.VARCHAR, max_length=1024) functions = Function( name="bm25", function_type=FunctionType.BM25, input_field_names=["content"], output_field_names="sparse_vector", ) schema.add_function(functions) index_params = MilvusClient.prepare_index_params() index_params.add_index( field_name="sparse_vector", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", ) index_params.add_index( field_name="dense_vector", index_type="FLAT", metric_type="IP" ) try: self.client.create_collection( collection_name=self.collection_name, schema=schema, index_params=index_params, ) return "create_collection_success" except Exception as e: logger.error(f"创建{self.collection_name}数据库失败:{e}") return "create_collection_error" def insert_data(self, chunk, metadata): logger.info("准备插入数据") with _embedding_lock: with torch.no_grad(): embedding = self.embedding_function([chunk]) # logger.info(f"1111111111111111111{embedding}") # if embedding.get("dense")in not None: # embedding = embedding["dense"][0] # else: # embedding = embedding[0] # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 # embedding = embedding.astype(np.float32) logger.info("获取文本的向量信息。") if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 # dense_vec = embedding["dense"][0] embedding = embedding["dense"][0] else: embedding = embedding[0] if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 embedding = embedding.astype(np.float32) # else: # dense_vec = embedding[0] dense_vec = embedding # 过滤掉向量库不存储的字段 filtered_metadata = {k: v for k, v in metadata.items() if k not in ("bbox", "page")} try: self.client.insert( self.collection_name, {"dense_vector": dense_vec, **filtered_metadata} ) logger.info("插入一条数据成功。") return True, "success" except Exception as e: doc_id = metadata.get("doc_id") logger.error(f"处理文档:{doc_id},插入数据出错:{e}") self.delete_by_doc_id(doc_id=doc_id) return False, str(e) def batch_insert_data(self, chunks, metadatas): logger.info("准备插入数据") with _embedding_lock: embedding_lists = self.embedding_function.encode_documents(chunks) logger.info("获取文本的向量信息。") record_lists = [] for embedding, metadata in zip(embedding_lists, metadatas): if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 dense_vec = embedding["dense"][0] else: dense_vec = embedding.tolist() # if hasattr(dense_vec, 'tolist'): # dense_vec = dense_vec.tolist() # logger.info(f"向量维度:{dense_vec}") # if isinstance(dense_vec, (float, int)): # dense_vec = [dense_vec] # if isinstance(dense_vec, np.float32): # dense_vec = [float(dense_vec)] # 过滤掉向量库不存储的字段 filtered_metadata = {k: v for k, v in metadata.items() if k not in ("bbox", "page")} record = {"dense_vector": dense_vec} record.update(filtered_metadata) record_lists.append(record) try: self.client.insert( self.collection_name, record_lists ) logger.info("插入数据成功。") return True, "success" except Exception as e: doc_id = metadata.get("doc_id") logger.error(f"处理文档:{doc_id},插入数据出错:{e}") self.delete_by_doc_id(doc_id=doc_id) return False, str(e) def search(self, query: str, k: int = 20, mode="hybrid"): output_fields = [ "content", "doc_id", "chunk_id", "metadata", "Father_Chapter", "Chapter", ] if mode in ["dense", "hybrid"]: with _embedding_lock: with torch.no_grad(): embedding = self.embedding_function([query]) # embedding = embedding[0] # if embedding.get("dense")in not None: # embedding = embedding["dense"][0] # else: # embedding = embedding[0] # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 # embedding = embedding.astype(np.float32) # if isinstance(embedding, dict) and "dense" in embedding: # # dense_vec = embedding["dense"][0] # pass # else: # dense_vec = embedding logger.info("获取文本的向量信息。") if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 # dense_vec = embedding["dense"][0] embedding = embedding["dense"][0] else: embedding = embedding[0] if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 embedding = embedding.astype(np.float32) dense_vec = embedding if mode == "sparse": results = self.client.search( collection_name=self.collection_name, data=[query], anns_field="sparse_vector", limit=k, output_fields=output_fields, ) elif mode == "dense": results = self.client.search( collection_name=self.collection_name, data=[dense_vec], anns_field="dense_vector", limit=k, output_fields=output_fields, ) elif mode == "hybrid": full_text_search_params = {"metric_type": "BM25"} full_text_search_req = AnnSearchRequest( [query], "sparse_vector", full_text_search_params, limit=k ) dense_search_params = {"metric_type": "IP"} dense_req = AnnSearchRequest( [dense_vec], "dense_vector", dense_search_params, limit=k ) results = self.client.hybrid_search( self.collection_name, [full_text_search_req, dense_req], ranker=RRFRanker(), limit=k, output_fields=output_fields, ) else: raise ValueError("Invalid mode") return [ { "doc_id": doc["entity"]["doc_id"], "chunk_id": doc["entity"]["chunk_id"], "content": doc["entity"]["content"], "Father_Chapter": doc["entity"]["Father_Chapter"], "Chapter": doc["entity"]["Chapter"], "metadata": doc["entity"]["metadata"], "score": doc["distance"], } for doc in results[0] ] def query_filter(self, doc_id, filter_field): # doc id 文档id,content中包含 filter_field 字段的 query_output_field = [ "content", "chunk_id", "doc_id", "metadata" ] # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'" # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询 if filter_field: query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'" else: query_expr = f"doc_id == '{doc_id}'" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.error(f"根据关键词查询数据失败:{e}") query_filter_results = [{"code": 500}] return query_filter_results # for result in query_filter_results: # print(f"根据doc id 和 field 过滤结果: {result}\n\n") def query_chunk_id(self, chunk_id): # chunk id,查询切片 query_output_field = [ "content", "doc_id", "chunk_id", # "metadata" ] query_expr = f"chunk_id == '{chunk_id}'" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.info(f"根据chunk id 查询出错:{e}") query_filter_results = [{"code": 500}] return query_filter_results def query_chunk_id_list(self, chunk_id_list): # chunk id,查询切片 query_output_field = [ "content", "doc_id", "chunk_id", # "metadata" ] query_expr = f"chunk_id in {chunk_id_list}" try: query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field) except Exception as e: logger.info(f"根据chunk id 查询出错:{e}") query_filter_results = [{"code": 500}] return query_filter_results def update_data(self, chunk_id, chunk): # 根据chunk id查询对应的信息 chunk_expr = f"chunk_id == '{chunk_id}'" chunk_output_fields = [ "pk", "doc_id", "chunk_id", "metadata", "Father_Chapter", "Chapter", ] try: chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields) except Exception as e: logger.error(f"更新切片数据时查询失败:{e}") return "update_query_error", "" if not chunk_results: logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据") return "update_query_no_result", "" # 使用锁保护 embedding 操作 with _embedding_lock: with torch.no_grad(): embedding = self.embedding_function([chunk]) # # embedding = embedding[0] # if embedding.get("dense")in not None: # embedding = embedding["dense"][0] # else: # embedding = embedding[0] # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 # embedding = embedding.astype(np.float32) # if isinstance(embedding, dict) and "dense" in embedding: # # dense_vec = embedding["dense"][0] # pass # else: # # dense_vec = embedding[0] # dense_vec = embedding logger.info("获取文本的向量信息。") if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 # dense_vec = embedding["dense"][0] embedding = embedding["dense"][0] else: embedding = embedding[0] if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 embedding = embedding.astype(np.float32) dense_vec = embedding chunk_dict = chunk_results[0] logger.info(f"update_data 查询到的 pk: {chunk_dict.get('pk')}") metadata = chunk_dict.get("metadata") old_chunk_len = metadata.get("chunk_len") chunk_len = len(chunk) metadata["chunk_len"] = chunk_len chunk_dict["content"] = chunk chunk_dict["dense_vector"] = dense_vec chunk_dict["metadata"] = metadata try: update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict]) logger.info(f"更新返回的数据:{update_res}") return "update_success", chunk_len - old_chunk_len except Exception as e: logger.error(f"更新数据时出错:{e}") return "update_error", "" def delete_collection(self, collection): try: self.client.drop_collection(collection_name=collection) return "delete_collection_success" except Exception as e: logger.error(f"删除{collection}失败,出错原因:{e}") return "delete_collection_error" def delete_by_chunk_id(self, chunk_id:str = None): # 根据文档id查询主键值 milvus只支持主键删除 expr = f"chunk_id == '{chunk_id}'" try: results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id logger.info(f"根据切片id:{chunk_id},查询的数据:{results}") except Exception as e: logger.error(f"根据切片id查询主键失败:{e}") return "delete_query_error", [] if not results: print(f"No data found for chunk id: {chunk_id}") # return "delete_no_result", [] return "delete_success", [0] # 提取主键值 primary_keys = [result["pk"] for result in results] chunk_len = [result["metadata"]["chunk_len"] for result in results] logger.info(f"获取到的主键信息:{primary_keys}") # 执行删除操作 expr_delete = f"pk in {primary_keys}" # 构造删除表达式 try: delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete) self.client.flush(collection_name=self.collection_name) self.client.compact(collection_name=self.collection_name) logger.info(f"Deleted data with chunk_id: {delete_res}") return "delete_success", chunk_len except Exception as e: logger.error(f"删除数据失败:{e}") return "delete_error", [] def batch_delete_by_chunk_ids(self, chunk_ids: list): """ 批量删除多个切片 参数: chunk_ids: 切片ID列表 返回: (status, chunk_lens_dict): 状态和切片长度字典 {chunk_id: chunk_len} """ if not chunk_ids: return "delete_success", {} # 构造批量查询表达式 chunk_ids_str = ", ".join([f"'{cid}'" for cid in chunk_ids]) expr = f"chunk_id in [{chunk_ids_str}]" try: # 批量查询所有切片的主键和元数据 results = self.client.query( collection_name=self.collection_name, filter=expr, output_fields=["pk", "metadata", "chunk_id"] ) logger.info(f"批量查询到 {len(results)} 个切片") except Exception as e: logger.error(f"批量查询切片失败:{e}") return "delete_query_error", {} if not results: logger.warning(f"未找到任何匹配的切片") return "delete_success", {cid: 0 for cid in chunk_ids} # 提取所有主键和切片长度 primary_keys = [result["pk"] for result in results] chunk_lens_dict = {result["chunk_id"]: result["metadata"]["chunk_len"] for result in results} logger.info(f"获取到的主键信息:{primary_keys}") # 一次性批量删除(关键:使用 IN 表达式) expr_delete = f"pk in {primary_keys}" try: delete_res = self.client.delete( collection_name=self.collection_name, filter=expr_delete ) self.client.flush(collection_name=self.collection_name) self.client.compact(collection_name=self.collection_name) logger.info(f"批量删除成功: {len(primary_keys)} 个切片, 结果: {delete_res}") return "delete_success", chunk_lens_dict except Exception as e: logger.error(f"批量删除失败:{e}") return "delete_error", {} def query_by_doc_ids(self, doc_ids): """ 根据多个 doc_id 查询所有数据 参数: doc_ids: 文档ID列表 返回: 查询结果列表,包含所有字段 """ if not doc_ids: logger.warning("doc_ids 为空") return [] # 构建查询表达式 doc_ids_str = ", ".join([f"'{doc_id}'" for doc_id in doc_ids]) query_expr = f"doc_id in [{doc_ids_str}]" query_output_fields = [ "pk", "content", "dense_vector", "doc_id", "chunk_id", "metadata", "Father_Chapter", "Chapter", ] try: query_results = self.client.query( collection_name=self.collection_name, filter=query_expr, output_fields=query_output_fields ) logger.info(f"根据 doc_ids 查询到 {len(query_results)} 条数据") return query_results except Exception as e: logger.error(f"根据 doc_ids 查询数据失败:{e}") return [] def delete_by_doc_id(self, doc_id:str =None): # 根据文档id查询主键值 milvus只支持主键删除 expr = f"doc_id == '{doc_id}'" try: results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id except Exception as e: logger.error(f"根据切片id查询主键失败:{e}") return "delete_query_error" if not results: print(f"No data found for doc_id: {doc_id}") return "delete_no_result" # 提取主键值 primary_keys = [result["pk"] for result in results] logger.info(f"获取到的主键信息:{primary_keys}") # 执行删除操作 try: delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys) # 非必须 self.client.flush(collection_name=self.collection_name) # 持久化存储 self.client.compact(collection_name=self.collection_name) # 释放磁盘空间 logger.info(f"Deleted data with doc_id: {delete_res}") return "delete_success" except Exception as e: logger.error(f"删除数据失败:{e}") return "delete_error" def query_by_scalar_field(self, doc_id: str, field_name: str, field_value: str): """ 根据标量字段查询数据(如 Father_Chapter) 参数: doc_id: 文档ID field_name: 字段名(如 Father_Chapter) field_value: 字段值 返回: 查询结果列表,格式同 search 方法 """ output_fields = [ "content", "doc_id", "chunk_id", "metadata", "Father_Chapter", "Chapter", ] query_expr = f"doc_id == '{doc_id}' && {field_name} == '{field_value}'" try: query_results = self.client.query( collection_name=self.collection_name, filter=query_expr, output_fields=output_fields ) return [ { "doc_id": doc["doc_id"], "chunk_id": doc["chunk_id"], "content": doc["content"], "Father_Chapter": doc["Father_Chapter"], "Chapter": doc["Chapter"], "metadata": doc["metadata"], "score": 0, } for doc in query_results ] except Exception as e: logger.error(f"标量查询失败:{e}") return [] def update_dense_vector_by_chunk_id(self, chunk_id: str, enhanced_text: str): """ 根据 chunk_id 更新 dense_vector(将增强文本转为向量后替换原向量) 参数: chunk_id: 切片ID enhanced_text: 增强后的文本(正文 + qa + question + summary) 返回: (success, message) """ # 1. 根据 chunk_id 查询原数据 chunk_expr = f"chunk_id == '{chunk_id}'" chunk_output_fields = ["pk", "content", "doc_id", "chunk_id", "metadata", "Chapter", "Father_Chapter"] try: chunk_results = self.client.query( collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields ) except Exception as e: logger.error(f"更新向量时查询失败:{e}") return False, f"查询失败: {e}" if not chunk_results: logger.warning(f"根据 {chunk_id} 未查询到对应数据") return False, "未找到对应切片" # 2. 生成新的向量 with _embedding_lock: with torch.no_grad(): embedding = self.embedding_function([enhanced_text]) # if isinstance(embedding, dict) and "dense" in embedding: # dense_vec = embedding["dense"][0] # else: # dense_vec = embedding[0] logger.info("获取文本的向量信息。") if isinstance(embedding, dict) and "dense" in embedding: # bge embedding 获取embedding的方式 # dense_vec = embedding["dense"][0] embedding = embedding["dense"][0] else: embedding = embedding[0] if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32 embedding = embedding.astype(np.float32) dense_vec = embedding # 3. 更新数据 chunk_dict = chunk_results[0] logger.info(f"查询到的 pk: {chunk_dict.get('pk')}") chunk_dict["dense_vector"] = dense_vec try: self.client.upsert(collection_name=self.collection_name, data=[chunk_dict]) logger.info(f"更新切片 {chunk_id} 的向量成功") return True, "success" except Exception as e: logger.error(f"更新向量失败:{e}") return False, str(e) # # 执行删除操作 # expr_delete = f"pk in {primary_keys}" # 构造删除表达式 # try: # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete) # logger.info(f"Deleted data with doc_id: {delete_res}") # return "delete_success" # except Exception as e: # logger.error(f"删除数据失败:{e}") # return "delete_error" # 测试 # def parse_json_to_schema_data_format(): # sql_text_list = load_sql_query_ddl_info() # docs = [] # for sql_info in sql_text_list: # sql_text = sql_info.get("sql_text") # source = ",".join(sql_info.get("table_list", [])) # source = sql_info.get("source") if not source else source # sql = sql_info.get("sql") # ddl = sql_info.get("table_ddl") # metadata = {"source": source, "sql": sql, "ddl": ddl} # text_list = sql_info.get("sim_sql_text_list", []) # text_list.append(sql_text) # doc_id = str(uuid4()) # for text in text_list: # chunk_id = str(uuid4()) # insert_dict = { # "content": text, # "doc_id": doc_id, # "chunk_id": chunk_id, # "metadata": metadata # } # docs.append(insert_dict) # return docs # def insert_data_to_milvus(standard_retriever): # sql_dataset = parse_json_to_schema_data_format() # standard_retriever.build_collection() # for sql_dict in sql_dataset: # text = sql_dict["content"] # standard_retriever.insert_data(text, sql_dict) def main(): # dense_ef = BGEM3EmbeddingFunction() embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/" sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device) standard_retriever = HybridRetriever( uri="http://10.1.14.18:19530", collection_name="milvus_hybrid", dense_embedding_function=sentence_transformer_ef, ) # 插入数据 # insert_data_to_milvus(standard_retriever) # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3) # model=sparse 稀疏检索 # print(f"稀疏检索结果:{results}") # model=dense 向量检索 # print(f"向量检索结果:{results}") # model=hybrid # print(f"向量检索结果:{results}") # 根据doc id删除数据 # delete_start_time = time.time() # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064" # standard_retriever.delete_data(doc_id=doc_id) # delete_end_time = time.time() # print(f"删除耗时:{delete_end_time-delete_start_time}") # 根据chunk_id 更新数据 # update_start_time = time.time() # chunk = "查询一下员工加班的情况" # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418" # standard_retriever.update_data(chunk_id, chunk) # update_end_time = time.time() # print(f"更新数据的时间:{update_end_time-update_start_time}") # 根据doc id 和关键字查询 query_start_time = time.time() filter_field = "加班" doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"] standard_retriever.query_filter(doc_id, filter_field) query_end_time = time.time() print(f"关键字搜索数据的时间:{query_end_time-query_start_time}") if __name__=="__main__": main()