|
@@ -0,0 +1,507 @@
|
|
|
|
|
+import time
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from pymilvus import (
|
|
|
|
|
+ MilvusClient,
|
|
|
|
|
+ DataType,
|
|
|
|
|
+ Function,
|
|
|
|
|
+ FunctionType,
|
|
|
|
|
+ AnnSearchRequest,
|
|
|
|
|
+ RRFRanker,
|
|
|
|
|
+)
|
|
|
|
|
+# from pymilvus.model.hybrid import BGEM3EmbeddingFunction
|
|
|
|
|
+from pymilvus import model
|
|
|
|
|
+from rag.load_model import sentence_transformer_ef
|
|
|
|
|
+from utils.get_logger import setup_logger
|
|
|
|
|
+import torch
|
|
|
|
|
+device = "cpu" if torch.cuda.is_available() else "cuda"
|
|
|
|
|
+logger = setup_logger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+# embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
|
|
|
|
|
+# sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+embedding_mapping = {
|
|
|
|
|
+ "e5": sentence_transformer_ef,
|
|
|
|
|
+ "multilingual-e5-large-instruct": sentence_transformer_ef,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+class HybridRetriever:
|
|
|
|
|
+ def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
|
|
|
|
|
+ self.uri = uri
|
|
|
|
|
+ self.collection_name = collection_name
|
|
|
|
|
+ # self.embedding_function = sentence_transformer_ef
|
|
|
|
|
+ self.embedding_function = embedding_mapping.get(embedding_name, "e5")
|
|
|
|
|
+ self.use_reranker = True
|
|
|
|
|
+ self.use_sparse = True
|
|
|
|
|
+ self.client = MilvusClient(uri=uri)
|
|
|
|
|
+
|
|
|
|
|
+ def has_collection(self):
|
|
|
|
|
+ try:
|
|
|
|
|
+ collection_flag = self.client.has_collection(self.collection_name)
|
|
|
|
|
+ logger.info(f"查询向量库的结果:{collection_flag}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.info(f"查询向量库是否存在出错:{e}")
|
|
|
|
|
+ collection_flag = False
|
|
|
|
|
+ return collection_flag
|
|
|
|
|
+
|
|
|
|
|
+ def build_collection(self):
|
|
|
|
|
+ if isinstance(self.embedding_function.dim, dict):
|
|
|
|
|
+ dense_dim = self.embedding_function.dim["dense"]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_dim = self.embedding_function.dim
|
|
|
|
|
+ logger.info(f"创建数据库的向量维度:{dense_dim}")
|
|
|
|
|
+ analyzer_params={
|
|
|
|
|
+ "type": "chinese"
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ schema = MilvusClient.create_schema()
|
|
|
|
|
+ schema.add_field(
|
|
|
|
|
+ field_name="pk",
|
|
|
|
|
+ datatype=DataType.VARCHAR,
|
|
|
|
|
+ is_primary=True,
|
|
|
|
|
+ auto_id=True,
|
|
|
|
|
+ max_length=100,
|
|
|
|
|
+ )
|
|
|
|
|
+ schema.add_field(
|
|
|
|
|
+ field_name="content",
|
|
|
|
|
+ datatype=DataType.VARCHAR,
|
|
|
|
|
+ max_length=65535,
|
|
|
|
|
+ analyzer_params=analyzer_params,
|
|
|
|
|
+ enable_match=True,
|
|
|
|
|
+ enable_analyzer=True,
|
|
|
|
|
+ )
|
|
|
|
|
+ schema.add_field(
|
|
|
|
|
+ field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
|
|
|
|
|
+ )
|
|
|
|
|
+ schema.add_field(
|
|
|
|
|
+ field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
|
|
|
|
|
+ )
|
|
|
|
|
+ schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
|
|
|
|
|
+ schema.add_field(
|
|
|
|
|
+ field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
|
|
|
|
|
+ )
|
|
|
|
|
+ schema.add_field(field_name="metadata", datatype=DataType.JSON)
|
|
|
|
|
+
|
|
|
|
|
+ functions = Function(
|
|
|
|
|
+ name="bm25",
|
|
|
|
|
+ function_type=FunctionType.BM25,
|
|
|
|
|
+ input_field_names=["content"],
|
|
|
|
|
+ output_field_names="sparse_vector",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ schema.add_function(functions)
|
|
|
|
|
+
|
|
|
|
|
+ index_params = MilvusClient.prepare_index_params()
|
|
|
|
|
+ index_params.add_index(
|
|
|
|
|
+ field_name="sparse_vector",
|
|
|
|
|
+ index_type="SPARSE_INVERTED_INDEX",
|
|
|
|
|
+ metric_type="BM25",
|
|
|
|
|
+ )
|
|
|
|
|
+ index_params.add_index(
|
|
|
|
|
+ field_name="dense_vector", index_type="FLAT", metric_type="IP"
|
|
|
|
|
+ )
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.client.create_collection(
|
|
|
|
|
+ collection_name=self.collection_name,
|
|
|
|
|
+ schema=schema,
|
|
|
|
|
+ index_params=index_params,
|
|
|
|
|
+ )
|
|
|
|
|
+ return "create_collection_success"
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"创建{self.collection_name}数据库失败:{e}")
|
|
|
|
|
+ return "create_collection_error"
|
|
|
|
|
+
|
|
|
|
|
+ def insert_data(self, chunk, metadata):
|
|
|
|
|
+ logger.info("准备插入数据")
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ embedding = self.embedding_function([chunk])
|
|
|
|
|
+ logger.info("获取文本的向量信息。")
|
|
|
|
|
+ if isinstance(embedding, dict) and "dense" in embedding:
|
|
|
|
|
+ # bge embedding 获取embedding的方式
|
|
|
|
|
+ dense_vec = embedding["dense"][0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_vec = embedding[0]
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.client.insert(
|
|
|
|
|
+ self.collection_name, {"dense_vector": dense_vec, **metadata}
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info("插入一条数据成功。")
|
|
|
|
|
+ return True, "success"
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ doc_id = metadata.get("doc_id")
|
|
|
|
|
+ logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
|
|
|
|
|
+ self.delete_by_doc_id(doc_id=doc_id)
|
|
|
|
|
+ return False, str(e)
|
|
|
|
|
+
|
|
|
|
|
+ def batch_insert_data(self, chunks, metadatas):
|
|
|
|
|
+ logger.info("准备插入数据")
|
|
|
|
|
+ embedding_lists = self.embedding_function.encode_documents(chunks)
|
|
|
|
|
+ logger.info("获取文本的向量信息。")
|
|
|
|
|
+ record_lists = []
|
|
|
|
|
+ for embedding, metadata in zip(embedding_lists, metadatas):
|
|
|
|
|
+ if isinstance(embedding, dict) and "dense" in embedding:
|
|
|
|
|
+ # bge embedding 获取embedding的方式
|
|
|
|
|
+ dense_vec = embedding["dense"][0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_vec = embedding.tolist()
|
|
|
|
|
+
|
|
|
|
|
+ # if hasattr(dense_vec, 'tolist'):
|
|
|
|
|
+ # dense_vec = dense_vec.tolist()
|
|
|
|
|
+
|
|
|
|
|
+ # logger.info(f"向量维度:{dense_vec}")
|
|
|
|
|
+
|
|
|
|
|
+ # if isinstance(dense_vec, (float, int)):
|
|
|
|
|
+ # dense_vec = [dense_vec]
|
|
|
|
|
+
|
|
|
|
|
+ # if isinstance(dense_vec, np.float32):
|
|
|
|
|
+ # dense_vec = [float(dense_vec)]
|
|
|
|
|
+
|
|
|
|
|
+ record = {"dense_vector": dense_vec}
|
|
|
|
|
+ record.update(metadata)
|
|
|
|
|
+ record_lists.append(record)
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.client.insert(
|
|
|
|
|
+ self.collection_name, record_lists
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info("插入数据成功。")
|
|
|
|
|
+ return True, "success"
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ doc_id = metadata.get("doc_id")
|
|
|
|
|
+ logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
|
|
|
|
|
+ self.delete_by_doc_id(doc_id=doc_id)
|
|
|
|
|
+ return False, str(e)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def search(self, query: str, k: int = 20, mode="hybrid"):
|
|
|
|
|
+
|
|
|
|
|
+ output_fields = [
|
|
|
|
|
+ "content",
|
|
|
|
|
+ "doc_id",
|
|
|
|
|
+ "chunk_id",
|
|
|
|
|
+ "metadata",
|
|
|
|
|
+ ]
|
|
|
|
|
+ if mode in ["dense", "hybrid"]:
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ embedding = self.embedding_function([query])
|
|
|
|
|
+ if isinstance(embedding, dict) and "dense" in embedding:
|
|
|
|
|
+ dense_vec = embedding["dense"][0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_vec = embedding[0]
|
|
|
|
|
+
|
|
|
|
|
+ if mode == "sparse":
|
|
|
|
|
+ results = self.client.search(
|
|
|
|
|
+ collection_name=self.collection_name,
|
|
|
|
|
+ data=[query],
|
|
|
|
|
+ anns_field="sparse_vector",
|
|
|
|
|
+ limit=k,
|
|
|
|
|
+ output_fields=output_fields,
|
|
|
|
|
+ )
|
|
|
|
|
+ elif mode == "dense":
|
|
|
|
|
+ results = self.client.search(
|
|
|
|
|
+ collection_name=self.collection_name,
|
|
|
|
|
+ data=[dense_vec],
|
|
|
|
|
+ anns_field="dense_vector",
|
|
|
|
|
+ limit=k,
|
|
|
|
|
+ output_fields=output_fields,
|
|
|
|
|
+ )
|
|
|
|
|
+ elif mode == "hybrid":
|
|
|
|
|
+ full_text_search_params = {"metric_type": "BM25"}
|
|
|
|
|
+ full_text_search_req = AnnSearchRequest(
|
|
|
|
|
+ [query], "sparse_vector", full_text_search_params, limit=k
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ dense_search_params = {"metric_type": "IP"}
|
|
|
|
|
+ dense_req = AnnSearchRequest(
|
|
|
|
|
+ [dense_vec], "dense_vector", dense_search_params, limit=k
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ results = self.client.hybrid_search(
|
|
|
|
|
+ self.collection_name,
|
|
|
|
|
+ [full_text_search_req, dense_req],
|
|
|
|
|
+ ranker=RRFRanker(),
|
|
|
|
|
+ limit=k,
|
|
|
|
|
+ output_fields=output_fields,
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError("Invalid mode")
|
|
|
|
|
+ return [
|
|
|
|
|
+ {
|
|
|
|
|
+ "doc_id": doc["entity"]["doc_id"],
|
|
|
|
|
+ "chunk_id": doc["entity"]["chunk_id"],
|
|
|
|
|
+ "content": doc["entity"]["content"],
|
|
|
|
|
+ "metadata": doc["entity"]["metadata"],
|
|
|
|
|
+ "score": doc["distance"],
|
|
|
|
|
+ }
|
|
|
|
|
+ for doc in results[0]
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ def query_filter(self, doc_id, filter_field):
|
|
|
|
|
+ # doc id 文档id,content中包含 filter_field 字段的
|
|
|
|
|
+ query_output_field = [
|
|
|
|
|
+ "content",
|
|
|
|
|
+ "chunk_id",
|
|
|
|
|
+ "doc_id",
|
|
|
|
|
+ "metadata"
|
|
|
|
|
+ ]
|
|
|
|
|
+ # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
|
|
|
|
|
+ # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
|
|
|
|
|
+ if filter_field:
|
|
|
|
|
+ query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
|
|
|
|
|
+ else:
|
|
|
|
|
+ query_expr = f"doc_id == '{doc_id}'"
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"根据关键词查询数据失败:{e}")
|
|
|
|
|
+ query_filter_results = [{"code": 500}]
|
|
|
|
|
+ return query_filter_results
|
|
|
|
|
+ # for result in query_filter_results:
|
|
|
|
|
+ # print(f"根据doc id 和 field 过滤结果: {result}\n\n")
|
|
|
|
|
+
|
|
|
|
|
+ def query_chunk_id(self, chunk_id):
|
|
|
|
|
+ # chunk id,查询切片
|
|
|
|
|
+ query_output_field = [
|
|
|
|
|
+ "content",
|
|
|
|
|
+ "doc_id",
|
|
|
|
|
+ "chunk_id",
|
|
|
|
|
+ # "metadata"
|
|
|
|
|
+ ]
|
|
|
|
|
+ query_expr = f"chunk_id == '{chunk_id}'"
|
|
|
|
|
+ try:
|
|
|
|
|
+ query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.info(f"根据chunk id 查询出错:{e}")
|
|
|
|
|
+ query_filter_results = [{"code": 500}]
|
|
|
|
|
+ return query_filter_results
|
|
|
|
|
+
|
|
|
|
|
+ def update_data(self, chunk_id, chunk):
|
|
|
|
|
+ # 根据chunk id查询对应的信息,
|
|
|
|
|
+ chunk_expr = f"chunk_id == '{chunk_id}'"
|
|
|
|
|
+ chunk_output_fields = [
|
|
|
|
|
+ "pk",
|
|
|
|
|
+ "doc_id",
|
|
|
|
|
+ "chunk_id",
|
|
|
|
|
+ "metadata"
|
|
|
|
|
+ ]
|
|
|
|
|
+ try:
|
|
|
|
|
+ chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
|
|
|
|
|
+ # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"更新切片数据时查询失败:{e}")
|
|
|
|
|
+ return "update_query_error", ""
|
|
|
|
|
+ if not chunk_results:
|
|
|
|
|
+ logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据")
|
|
|
|
|
+ return "update_query_no_result", ""
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ embedding = self.embedding_function([chunk])
|
|
|
|
|
+ if isinstance(embedding, dict) and "dense" in embedding:
|
|
|
|
|
+ # bge embedding 获取embedding的方式
|
|
|
|
|
+ dense_vec = embedding["dense"][0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_vec = embedding[0]
|
|
|
|
|
+ chunk_dict = chunk_results[0]
|
|
|
|
|
+ metadata = chunk_dict.get("metadata")
|
|
|
|
|
+ old_chunk_len = metadata.get("chunk_len")
|
|
|
|
|
+ chunk_len = len(chunk)
|
|
|
|
|
+ metadata["chunk_len"] = chunk_len
|
|
|
|
|
+ chunk_dict["content"] = chunk
|
|
|
|
|
+ chunk_dict["dense_vector"] = dense_vec
|
|
|
|
|
+ chunk_dict["metadata"] = metadata
|
|
|
|
|
+ try:
|
|
|
|
|
+ update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
|
|
|
|
|
+ logger.info(f"更新返回的数据:{update_res}")
|
|
|
|
|
+ return "update_success", chunk_len - old_chunk_len
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"更新数据时出错:{e}")
|
|
|
|
|
+ return "update_error", ""
|
|
|
|
|
+
|
|
|
|
|
+ def update_data(self, chunk_id, chunk):
|
|
|
|
|
+ # 根据chunk id查询对应的信息,
|
|
|
|
|
+ chunk_expr = f"chunk_id == '{chunk_id}'"
|
|
|
|
|
+ chunk_output_fields = [
|
|
|
|
|
+ "pk",
|
|
|
|
|
+ "doc_id",
|
|
|
|
|
+ "chunk_id",
|
|
|
|
|
+ "metadata"
|
|
|
|
|
+ ]
|
|
|
|
|
+ try:
|
|
|
|
|
+ chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
|
|
|
|
|
+ # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"更新切片数据时查询失败:{e}")
|
|
|
|
|
+ return "update_query_error", ""
|
|
|
|
|
+ if not chunk_results:
|
|
|
|
|
+ logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
|
|
|
|
|
+ return "update_query_no_result", ""
|
|
|
|
|
+ with torch.no_grad():
|
|
|
|
|
+ embedding = self.embedding_function([chunk])
|
|
|
|
|
+ if isinstance(embedding, dict) and "dense" in embedding:
|
|
|
|
|
+ # bge embedding 获取embedding的方式
|
|
|
|
|
+ dense_vec = embedding["dense"][0]
|
|
|
|
|
+ else:
|
|
|
|
|
+ dense_vec = embedding[0]
|
|
|
|
|
+ chunk_dict = chunk_results[0]
|
|
|
|
|
+ metadata = chunk_dict.get("metadata")
|
|
|
|
|
+ old_chunk_len = metadata.get("chunk_len")
|
|
|
|
|
+ chunk_len = len(chunk)
|
|
|
|
|
+ metadata["chunk_len"] = chunk_len
|
|
|
|
|
+ chunk_dict["content"] = chunk
|
|
|
|
|
+ chunk_dict["dense_vector"] = dense_vec
|
|
|
|
|
+ chunk_dict["metadata"] = metadata
|
|
|
|
|
+ try:
|
|
|
|
|
+ update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
|
|
|
|
|
+ logger.info(f"更新返回的数据:{update_res}")
|
|
|
|
|
+ return "update_success", chunk_len - old_chunk_len
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"更新数据时出错:{e}")
|
|
|
|
|
+ return "update_error", ""
|
|
|
|
|
+
|
|
|
|
|
+ def delete_collection(self, collection):
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.client.drop_collection(collection_name=collection)
|
|
|
|
|
+ return "delete_collection_success"
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"删除{collection}失败,出错原因:{e}")
|
|
|
|
|
+ return "delete_collection_error"
|
|
|
|
|
+
|
|
|
|
|
+ def delete_by_chunk_id(self, chunk_id:str = None):
|
|
|
|
|
+ # 根据文档id查询主键值 milvus只支持主键删除
|
|
|
|
|
+ expr = f"chunk_id == '{chunk_id}'"
|
|
|
|
|
+ try:
|
|
|
|
|
+ results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id
|
|
|
|
|
+ logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"根据切片id查询主键失败:{e}")
|
|
|
|
|
+ return "delete_query_error", []
|
|
|
|
|
+ if not results:
|
|
|
|
|
+ print(f"No data found for chunk id: {chunk_id}")
|
|
|
|
|
+ # return "delete_no_result", []
|
|
|
|
|
+ return "delete_success", 0
|
|
|
|
|
+
|
|
|
|
|
+ # 提取主键值
|
|
|
|
|
+ primary_keys = [result["pk"] for result in results]
|
|
|
|
|
+ chunk_len = [result["metadata"]["chunk_len"] for result in results]
|
|
|
|
|
+ logger.info(f"获取到的主键信息:{primary_keys}")
|
|
|
|
|
+
|
|
|
|
|
+ # 执行删除操作
|
|
|
|
|
+ expr_delete = f"pk in {primary_keys}" # 构造删除表达式
|
|
|
|
|
+ try:
|
|
|
|
|
+ delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
|
|
|
|
|
+ logger.info(f"Deleted data with chunk_id: {delete_res}")
|
|
|
|
|
+ return "delete_success", chunk_len
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"删除数据失败:{e}")
|
|
|
|
|
+ return "delete_error", []
|
|
|
|
|
+
|
|
|
|
|
+ def delete_by_doc_id(self, doc_id:str =None):
|
|
|
|
|
+ # 根据文档id查询主键值 milvus只支持主键删除
|
|
|
|
|
+ expr = f"doc_id == '{doc_id}'"
|
|
|
|
|
+ try:
|
|
|
|
|
+ results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"根据切片id查询主键失败:{e}")
|
|
|
|
|
+ return "delete_query_error"
|
|
|
|
|
+ if not results:
|
|
|
|
|
+ print(f"No data found for doc_id: {doc_id}")
|
|
|
|
|
+ return "delete_no_result"
|
|
|
|
|
+
|
|
|
|
|
+ # 提取主键值
|
|
|
|
|
+ primary_keys = [result["pk"] for result in results]
|
|
|
|
|
+ logger.info(f"获取到的主键信息:{primary_keys}")
|
|
|
|
|
+
|
|
|
|
|
+ # 执行删除操作
|
|
|
|
|
+ expr_delete = f"pk in {primary_keys}" # 构造删除表达式
|
|
|
|
|
+ try:
|
|
|
|
|
+ delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
|
|
|
|
|
+ logger.info(f"Deleted data with doc_id: {delete_res}")
|
|
|
|
|
+ return "delete_success"
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"删除数据失败:{e}")
|
|
|
|
|
+ return "delete_error"
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 测试
|
|
|
|
|
+# def parse_json_to_schema_data_format():
|
|
|
|
|
+# sql_text_list = load_sql_query_ddl_info()
|
|
|
|
|
+# docs = []
|
|
|
|
|
+# for sql_info in sql_text_list:
|
|
|
|
|
+# sql_text = sql_info.get("sql_text")
|
|
|
|
|
+# source = ",".join(sql_info.get("table_list", []))
|
|
|
|
|
+# source = sql_info.get("source") if not source else source
|
|
|
|
|
+# sql = sql_info.get("sql")
|
|
|
|
|
+# ddl = sql_info.get("table_ddl")
|
|
|
|
|
+# metadata = {"source": source, "sql": sql, "ddl": ddl}
|
|
|
|
|
+# text_list = sql_info.get("sim_sql_text_list", [])
|
|
|
|
|
+# text_list.append(sql_text)
|
|
|
|
|
+# doc_id = str(uuid4())
|
|
|
|
|
+# for text in text_list:
|
|
|
|
|
+# chunk_id = str(uuid4())
|
|
|
|
|
+# insert_dict = {
|
|
|
|
|
+# "content": text,
|
|
|
|
|
+# "doc_id": doc_id,
|
|
|
|
|
+# "chunk_id": chunk_id,
|
|
|
|
|
+# "metadata": metadata
|
|
|
|
|
+# }
|
|
|
|
|
+# docs.append(insert_dict)
|
|
|
|
|
+
|
|
|
|
|
+# return docs
|
|
|
|
|
+
|
|
|
|
|
+# def insert_data_to_milvus(standard_retriever):
|
|
|
|
|
+# sql_dataset = parse_json_to_schema_data_format()
|
|
|
|
|
+# standard_retriever.build_collection()
|
|
|
|
|
+# for sql_dict in sql_dataset:
|
|
|
|
|
+# text = sql_dict["content"]
|
|
|
|
|
+# standard_retriever.insert_data(text, sql_dict)
|
|
|
|
|
+
|
|
|
|
|
+def main():
|
|
|
|
|
+ # dense_ef = BGEM3EmbeddingFunction()
|
|
|
|
|
+ embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
|
|
|
|
|
+ sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
|
|
|
|
|
+ standard_retriever = HybridRetriever(
|
|
|
|
|
+ uri="http://localhost:19530",
|
|
|
|
|
+ collection_name="milvus_hybrid",
|
|
|
|
|
+ dense_embedding_function=sentence_transformer_ef,
|
|
|
|
|
+ )
|
|
|
|
|
+ # 插入数据
|
|
|
|
|
+ # insert_data_to_milvus(standard_retriever)
|
|
|
|
|
+
|
|
|
|
|
+ # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
|
|
|
|
|
+ # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
|
|
|
|
|
+ # model=sparse 稀疏检索
|
|
|
|
|
+ # print(f"稀疏检索结果:{results}")
|
|
|
|
|
+
|
|
|
|
|
+ # model=dense 向量检索
|
|
|
|
|
+ # print(f"向量检索结果:{results}")
|
|
|
|
|
+
|
|
|
|
|
+ # model=hybrid
|
|
|
|
|
+ # print(f"向量检索结果:{results}")
|
|
|
|
|
+
|
|
|
|
|
+ # 根据doc id删除数据
|
|
|
|
|
+ # delete_start_time = time.time()
|
|
|
|
|
+ # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
|
|
|
|
|
+ # standard_retriever.delete_data(doc_id=doc_id)
|
|
|
|
|
+ # delete_end_time = time.time()
|
|
|
|
|
+ # print(f"删除耗时:{delete_end_time-delete_start_time}")
|
|
|
|
|
+
|
|
|
|
|
+ # 根据chunk_id 更新数据
|
|
|
|
|
+ # update_start_time = time.time()
|
|
|
|
|
+ # chunk = "查询一下员工加班的情况"
|
|
|
|
|
+ # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
|
|
|
|
|
+ # standard_retriever.update_data(chunk_id, chunk)
|
|
|
|
|
+ # update_end_time = time.time()
|
|
|
|
|
+ # print(f"更新数据的时间:{update_end_time-update_start_time}")
|
|
|
|
|
+
|
|
|
|
|
+ # 根据doc id 和关键字查询
|
|
|
|
|
+ query_start_time = time.time()
|
|
|
|
|
+ filter_field = "加班"
|
|
|
|
|
+ doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
|
|
|
|
|
+ standard_retriever.query_filter(doc_id, filter_field)
|
|
|
|
|
+ query_end_time = time.time()
|
|
|
|
|
+ print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__=="__main__":
|
|
|
|
|
+ main()
|