milvus_vector.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. import time
  2. import numpy as np
  3. from pymilvus import (
  4. MilvusClient,
  5. DataType,
  6. Function,
  7. FunctionType,
  8. AnnSearchRequest,
  9. RRFRanker,
  10. )
  11. # from pymilvus.model.hybrid import BGEM3EmbeddingFunction
  12. from pymilvus import model
  13. from rag.load_model import sentence_transformer_ef
  14. from utils.get_logger import setup_logger
  15. import torch
  16. #device = "cpu" if torch.cuda.is_available() else "cuda"
  17. device = "cuda" if torch.cuda.is_available() else "cpu"
  18. logger = setup_logger(__name__)
  19. # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  20. # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  21. embedding_mapping = {
  22. "e5": sentence_transformer_ef,
  23. "multilingual-e5-large-instruct": sentence_transformer_ef,
  24. }
  25. class HybridRetriever:
  26. def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
  27. self.uri = uri
  28. self.collection_name = collection_name
  29. # self.embedding_function = sentence_transformer_ef
  30. self.embedding_function = embedding_mapping.get(embedding_name, "e5")
  31. self.use_reranker = True
  32. self.use_sparse = True
  33. self.client = MilvusClient(uri=uri)
  34. def has_collection(self):
  35. try:
  36. collection_flag = self.client.has_collection(self.collection_name)
  37. logger.info(f"查询向量库的结果:{collection_flag}")
  38. except Exception as e:
  39. logger.info(f"查询向量库是否存在出错:{e}")
  40. collection_flag = False
  41. return collection_flag
  42. def build_collection(self):
  43. if isinstance(self.embedding_function.dim, dict):
  44. dense_dim = self.embedding_function.dim["dense"]
  45. else:
  46. dense_dim = self.embedding_function.dim
  47. logger.info(f"创建数据库的向量维度:{dense_dim}")
  48. analyzer_params={
  49. "type": "chinese"
  50. }
  51. schema = MilvusClient.create_schema()
  52. schema.add_field(
  53. field_name="pk",
  54. datatype=DataType.VARCHAR,
  55. is_primary=True,
  56. auto_id=True,
  57. max_length=100,
  58. )
  59. schema.add_field(
  60. field_name="content",
  61. datatype=DataType.VARCHAR,
  62. max_length=65535,
  63. analyzer_params=analyzer_params,
  64. enable_match=True,
  65. enable_analyzer=True,
  66. )
  67. schema.add_field(
  68. field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
  69. )
  70. schema.add_field(
  71. field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
  72. )
  73. schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
  74. schema.add_field(
  75. field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
  76. )
  77. schema.add_field(field_name="metadata", datatype=DataType.JSON)
  78. functions = Function(
  79. name="bm25",
  80. function_type=FunctionType.BM25,
  81. input_field_names=["content"],
  82. output_field_names="sparse_vector",
  83. )
  84. schema.add_function(functions)
  85. index_params = MilvusClient.prepare_index_params()
  86. index_params.add_index(
  87. field_name="sparse_vector",
  88. index_type="SPARSE_INVERTED_INDEX",
  89. metric_type="BM25",
  90. )
  91. index_params.add_index(
  92. field_name="dense_vector", index_type="FLAT", metric_type="IP"
  93. )
  94. try:
  95. self.client.create_collection(
  96. collection_name=self.collection_name,
  97. schema=schema,
  98. index_params=index_params,
  99. )
  100. return "create_collection_success"
  101. except Exception as e:
  102. logger.error(f"创建{self.collection_name}数据库失败:{e}")
  103. return "create_collection_error"
  104. def insert_data(self, chunk, metadata):
  105. logger.info("准备插入数据")
  106. with torch.no_grad():
  107. embedding = self.embedding_function([chunk])
  108. logger.info("获取文本的向量信息。")
  109. if isinstance(embedding, dict) and "dense" in embedding:
  110. # bge embedding 获取embedding的方式
  111. dense_vec = embedding["dense"][0]
  112. else:
  113. dense_vec = embedding[0]
  114. try:
  115. self.client.insert(
  116. self.collection_name, {"dense_vector": dense_vec, **metadata}
  117. )
  118. logger.info("插入一条数据成功。")
  119. return True, "success"
  120. except Exception as e:
  121. doc_id = metadata.get("doc_id")
  122. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  123. self.delete_by_doc_id(doc_id=doc_id)
  124. return False, str(e)
  125. def batch_insert_data(self, chunks, metadatas):
  126. logger.info("准备插入数据")
  127. embedding_lists = self.embedding_function.encode_documents(chunks)
  128. logger.info("获取文本的向量信息。")
  129. record_lists = []
  130. for embedding, metadata in zip(embedding_lists, metadatas):
  131. if isinstance(embedding, dict) and "dense" in embedding:
  132. # bge embedding 获取embedding的方式
  133. dense_vec = embedding["dense"][0]
  134. else:
  135. dense_vec = embedding.tolist()
  136. # if hasattr(dense_vec, 'tolist'):
  137. # dense_vec = dense_vec.tolist()
  138. # logger.info(f"向量维度:{dense_vec}")
  139. # if isinstance(dense_vec, (float, int)):
  140. # dense_vec = [dense_vec]
  141. # if isinstance(dense_vec, np.float32):
  142. # dense_vec = [float(dense_vec)]
  143. record = {"dense_vector": dense_vec}
  144. record.update(metadata)
  145. record_lists.append(record)
  146. try:
  147. self.client.insert(
  148. self.collection_name, record_lists
  149. )
  150. logger.info("插入数据成功。")
  151. return True, "success"
  152. except Exception as e:
  153. doc_id = metadata.get("doc_id")
  154. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  155. self.delete_by_doc_id(doc_id=doc_id)
  156. return False, str(e)
  157. def search(self, query: str, k: int = 20, mode="hybrid"):
  158. output_fields = [
  159. "content",
  160. "doc_id",
  161. "chunk_id",
  162. "metadata",
  163. ]
  164. if mode in ["dense", "hybrid"]:
  165. with torch.no_grad():
  166. embedding = self.embedding_function([query])
  167. if isinstance(embedding, dict) and "dense" in embedding:
  168. dense_vec = embedding["dense"][0]
  169. else:
  170. dense_vec = embedding[0]
  171. if mode == "sparse":
  172. results = self.client.search(
  173. collection_name=self.collection_name,
  174. data=[query],
  175. anns_field="sparse_vector",
  176. limit=k,
  177. output_fields=output_fields,
  178. )
  179. elif mode == "dense":
  180. results = self.client.search(
  181. collection_name=self.collection_name,
  182. data=[dense_vec],
  183. anns_field="dense_vector",
  184. limit=k,
  185. output_fields=output_fields,
  186. )
  187. elif mode == "hybrid":
  188. full_text_search_params = {"metric_type": "BM25"}
  189. full_text_search_req = AnnSearchRequest(
  190. [query], "sparse_vector", full_text_search_params, limit=k
  191. )
  192. dense_search_params = {"metric_type": "IP"}
  193. dense_req = AnnSearchRequest(
  194. [dense_vec], "dense_vector", dense_search_params, limit=k
  195. )
  196. results = self.client.hybrid_search(
  197. self.collection_name,
  198. [full_text_search_req, dense_req],
  199. ranker=RRFRanker(),
  200. limit=k,
  201. output_fields=output_fields,
  202. )
  203. else:
  204. raise ValueError("Invalid mode")
  205. return [
  206. {
  207. "doc_id": doc["entity"]["doc_id"],
  208. "chunk_id": doc["entity"]["chunk_id"],
  209. "content": doc["entity"]["content"],
  210. "metadata": doc["entity"]["metadata"],
  211. "score": doc["distance"],
  212. }
  213. for doc in results[0]
  214. ]
  215. def query_filter(self, doc_id, filter_field):
  216. # doc id 文档id,content中包含 filter_field 字段的
  217. query_output_field = [
  218. "content",
  219. "chunk_id",
  220. "doc_id",
  221. "metadata"
  222. ]
  223. # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
  224. # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
  225. if filter_field:
  226. query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
  227. else:
  228. query_expr = f"doc_id == '{doc_id}'"
  229. try:
  230. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  231. except Exception as e:
  232. logger.error(f"根据关键词查询数据失败:{e}")
  233. query_filter_results = [{"code": 500}]
  234. return query_filter_results
  235. # for result in query_filter_results:
  236. # print(f"根据doc id 和 field 过滤结果: {result}\n\n")
  237. def query_chunk_id(self, chunk_id):
  238. # chunk id,查询切片
  239. query_output_field = [
  240. "content",
  241. "doc_id",
  242. "chunk_id",
  243. # "metadata"
  244. ]
  245. query_expr = f"chunk_id == '{chunk_id}'"
  246. try:
  247. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  248. except Exception as e:
  249. logger.info(f"根据chunk id 查询出错:{e}")
  250. query_filter_results = [{"code": 500}]
  251. return query_filter_results
  252. def query_chunk_id_list(self, chunk_id_list):
  253. # chunk id,查询切片
  254. query_output_field = [
  255. "content",
  256. "doc_id",
  257. "chunk_id",
  258. # "metadata"
  259. ]
  260. query_expr = f"chunk_id in {chunk_id_list}"
  261. try:
  262. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  263. except Exception as e:
  264. logger.info(f"根据chunk id 查询出错:{e}")
  265. query_filter_results = [{"code": 500}]
  266. return query_filter_results
  267. def update_data(self, chunk_id, chunk):
  268. # 根据chunk id查询对应的信息,
  269. chunk_expr = f"chunk_id == '{chunk_id}'"
  270. chunk_output_fields = [
  271. "pk",
  272. "doc_id",
  273. "chunk_id",
  274. "metadata"
  275. ]
  276. try:
  277. chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
  278. # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
  279. except Exception as e:
  280. logger.error(f"更新切片数据时查询失败:{e}")
  281. return "update_query_error", ""
  282. if not chunk_results:
  283. logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据")
  284. return "update_query_no_result", ""
  285. with torch.no_grad():
  286. embedding = self.embedding_function([chunk])
  287. if isinstance(embedding, dict) and "dense" in embedding:
  288. # bge embedding 获取embedding的方式
  289. dense_vec = embedding["dense"][0]
  290. else:
  291. dense_vec = embedding[0]
  292. chunk_dict = chunk_results[0]
  293. metadata = chunk_dict.get("metadata")
  294. old_chunk_len = metadata.get("chunk_len")
  295. chunk_len = len(chunk)
  296. metadata["chunk_len"] = chunk_len
  297. chunk_dict["content"] = chunk
  298. chunk_dict["dense_vector"] = dense_vec
  299. chunk_dict["metadata"] = metadata
  300. try:
  301. update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  302. logger.info(f"更新返回的数据:{update_res}")
  303. return "update_success", chunk_len - old_chunk_len
  304. except Exception as e:
  305. logger.error(f"更新数据时出错:{e}")
  306. return "update_error", ""
  307. def update_data(self, chunk_id, chunk):
  308. # 根据chunk id查询对应的信息,
  309. chunk_expr = f"chunk_id == '{chunk_id}'"
  310. chunk_output_fields = [
  311. "pk",
  312. "doc_id",
  313. "chunk_id",
  314. "metadata"
  315. ]
  316. try:
  317. chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
  318. # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
  319. except Exception as e:
  320. logger.error(f"更新切片数据时查询失败:{e}")
  321. return "update_query_error", ""
  322. if not chunk_results:
  323. logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
  324. return "update_query_no_result", ""
  325. with torch.no_grad():
  326. embedding = self.embedding_function([chunk])
  327. if isinstance(embedding, dict) and "dense" in embedding:
  328. # bge embedding 获取embedding的方式
  329. dense_vec = embedding["dense"][0]
  330. else:
  331. dense_vec = embedding[0]
  332. chunk_dict = chunk_results[0]
  333. metadata = chunk_dict.get("metadata")
  334. old_chunk_len = metadata.get("chunk_len")
  335. chunk_len = len(chunk)
  336. metadata["chunk_len"] = chunk_len
  337. chunk_dict["content"] = chunk
  338. chunk_dict["dense_vector"] = dense_vec
  339. chunk_dict["metadata"] = metadata
  340. try:
  341. update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  342. logger.info(f"更新返回的数据:{update_res}")
  343. return "update_success", chunk_len - old_chunk_len
  344. except Exception as e:
  345. logger.error(f"更新数据时出错:{e}")
  346. return "update_error", ""
  347. def delete_collection(self, collection):
  348. try:
  349. self.client.drop_collection(collection_name=collection)
  350. return "delete_collection_success"
  351. except Exception as e:
  352. logger.error(f"删除{collection}失败,出错原因:{e}")
  353. return "delete_collection_error"
  354. def delete_by_chunk_id(self, chunk_id:str = None):
  355. # 根据文档id查询主键值 milvus只支持主键删除
  356. expr = f"chunk_id == '{chunk_id}'"
  357. try:
  358. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id
  359. logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
  360. except Exception as e:
  361. logger.error(f"根据切片id查询主键失败:{e}")
  362. return "delete_query_error", []
  363. if not results:
  364. print(f"No data found for chunk id: {chunk_id}")
  365. # return "delete_no_result", []
  366. return "delete_success", [0]
  367. # 提取主键值
  368. primary_keys = [result["pk"] for result in results]
  369. chunk_len = [result["metadata"]["chunk_len"] for result in results]
  370. logger.info(f"获取到的主键信息:{primary_keys}")
  371. # 执行删除操作
  372. expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  373. try:
  374. delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  375. self.client.flush(collection_name=self.collection_name)
  376. self.client.compact(collection_name=self.collection_name)
  377. logger.info(f"Deleted data with chunk_id: {delete_res}")
  378. return "delete_success", chunk_len
  379. except Exception as e:
  380. logger.error(f"删除数据失败:{e}")
  381. return "delete_error", []
  382. def delete_by_doc_id(self, doc_id:str =None):
  383. # 根据文档id查询主键值 milvus只支持主键删除
  384. expr = f"doc_id == '{doc_id}'"
  385. try:
  386. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id
  387. except Exception as e:
  388. logger.error(f"根据切片id查询主键失败:{e}")
  389. return "delete_query_error"
  390. if not results:
  391. print(f"No data found for doc_id: {doc_id}")
  392. return "delete_no_result"
  393. # 提取主键值
  394. primary_keys = [result["pk"] for result in results]
  395. logger.info(f"获取到的主键信息:{primary_keys}")
  396. # 执行删除操作
  397. try:
  398. delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys)
  399. self.client.flush(collection_name=self.collection_name)
  400. self.client.compact(collection_name=self.collection_name)
  401. logger.info(f"Deleted data with doc_id: {delete_res}")
  402. return "delete_success"
  403. except Exception as e:
  404. logger.error(f"删除数据失败:{e}")
  405. return "delete_error"
  406. # # 执行删除操作
  407. # expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  408. # try:
  409. # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  410. # logger.info(f"Deleted data with doc_id: {delete_res}")
  411. # return "delete_success"
  412. # except Exception as e:
  413. # logger.error(f"删除数据失败:{e}")
  414. # return "delete_error"
  415. # 测试
  416. # def parse_json_to_schema_data_format():
  417. # sql_text_list = load_sql_query_ddl_info()
  418. # docs = []
  419. # for sql_info in sql_text_list:
  420. # sql_text = sql_info.get("sql_text")
  421. # source = ",".join(sql_info.get("table_list", []))
  422. # source = sql_info.get("source") if not source else source
  423. # sql = sql_info.get("sql")
  424. # ddl = sql_info.get("table_ddl")
  425. # metadata = {"source": source, "sql": sql, "ddl": ddl}
  426. # text_list = sql_info.get("sim_sql_text_list", [])
  427. # text_list.append(sql_text)
  428. # doc_id = str(uuid4())
  429. # for text in text_list:
  430. # chunk_id = str(uuid4())
  431. # insert_dict = {
  432. # "content": text,
  433. # "doc_id": doc_id,
  434. # "chunk_id": chunk_id,
  435. # "metadata": metadata
  436. # }
  437. # docs.append(insert_dict)
  438. # return docs
  439. # def insert_data_to_milvus(standard_retriever):
  440. # sql_dataset = parse_json_to_schema_data_format()
  441. # standard_retriever.build_collection()
  442. # for sql_dict in sql_dataset:
  443. # text = sql_dict["content"]
  444. # standard_retriever.insert_data(text, sql_dict)
  445. def main():
  446. # dense_ef = BGEM3EmbeddingFunction()
  447. embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  448. sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  449. standard_retriever = HybridRetriever(
  450. uri="http://localhost:19530",
  451. collection_name="milvus_hybrid",
  452. dense_embedding_function=sentence_transformer_ef,
  453. )
  454. # 插入数据
  455. # insert_data_to_milvus(standard_retriever)
  456. # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
  457. # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
  458. # model=sparse 稀疏检索
  459. # print(f"稀疏检索结果:{results}")
  460. # model=dense 向量检索
  461. # print(f"向量检索结果:{results}")
  462. # model=hybrid
  463. # print(f"向量检索结果:{results}")
  464. # 根据doc id删除数据
  465. # delete_start_time = time.time()
  466. # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
  467. # standard_retriever.delete_data(doc_id=doc_id)
  468. # delete_end_time = time.time()
  469. # print(f"删除耗时:{delete_end_time-delete_start_time}")
  470. # 根据chunk_id 更新数据
  471. # update_start_time = time.time()
  472. # chunk = "查询一下员工加班的情况"
  473. # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
  474. # standard_retriever.update_data(chunk_id, chunk)
  475. # update_end_time = time.time()
  476. # print(f"更新数据的时间:{update_end_time-update_start_time}")
  477. # 根据doc id 和关键字查询
  478. query_start_time = time.time()
  479. filter_field = "加班"
  480. doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
  481. standard_retriever.query_filter(doc_id, filter_field)
  482. query_end_time = time.time()
  483. print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
  484. if __name__=="__main__":
  485. main()