milvus_vector.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import time
  2. import numpy as np
  3. from pymilvus import (
  4. MilvusClient,
  5. DataType,
  6. Function,
  7. FunctionType,
  8. AnnSearchRequest,
  9. RRFRanker,
  10. )
  11. # from pymilvus.model.hybrid import BGEM3EmbeddingFunction
  12. from pymilvus import model
  13. from rag.load_model import sentence_transformer_ef
  14. from utils.get_logger import setup_logger
  15. import torch
  16. device = "cpu" if torch.cuda.is_available() else "cuda"
  17. logger = setup_logger(__name__)
  18. # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  19. # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  20. embedding_mapping = {
  21. "e5": sentence_transformer_ef,
  22. "multilingual-e5-large-instruct": sentence_transformer_ef,
  23. }
  24. class HybridRetriever:
  25. def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
  26. self.uri = uri
  27. self.collection_name = collection_name
  28. # self.embedding_function = sentence_transformer_ef
  29. self.embedding_function = embedding_mapping.get(embedding_name, "e5")
  30. self.use_reranker = True
  31. self.use_sparse = True
  32. self.client = MilvusClient(uri=uri)
  33. def has_collection(self):
  34. try:
  35. collection_flag = self.client.has_collection(self.collection_name)
  36. logger.info(f"查询向量库的结果:{collection_flag}")
  37. except Exception as e:
  38. logger.info(f"查询向量库是否存在出错:{e}")
  39. collection_flag = False
  40. return collection_flag
  41. def build_collection(self):
  42. if isinstance(self.embedding_function.dim, dict):
  43. dense_dim = self.embedding_function.dim["dense"]
  44. else:
  45. dense_dim = self.embedding_function.dim
  46. logger.info(f"创建数据库的向量维度:{dense_dim}")
  47. analyzer_params={
  48. "type": "chinese"
  49. }
  50. schema = MilvusClient.create_schema()
  51. schema.add_field(
  52. field_name="pk",
  53. datatype=DataType.VARCHAR,
  54. is_primary=True,
  55. auto_id=True,
  56. max_length=100,
  57. )
  58. schema.add_field(
  59. field_name="content",
  60. datatype=DataType.VARCHAR,
  61. max_length=65535,
  62. analyzer_params=analyzer_params,
  63. enable_match=True,
  64. enable_analyzer=True,
  65. )
  66. schema.add_field(
  67. field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
  68. )
  69. schema.add_field(
  70. field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
  71. )
  72. schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
  73. schema.add_field(
  74. field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
  75. )
  76. schema.add_field(field_name="metadata", datatype=DataType.JSON)
  77. functions = Function(
  78. name="bm25",
  79. function_type=FunctionType.BM25,
  80. input_field_names=["content"],
  81. output_field_names="sparse_vector",
  82. )
  83. schema.add_function(functions)
  84. index_params = MilvusClient.prepare_index_params()
  85. index_params.add_index(
  86. field_name="sparse_vector",
  87. index_type="SPARSE_INVERTED_INDEX",
  88. metric_type="BM25",
  89. )
  90. index_params.add_index(
  91. field_name="dense_vector", index_type="FLAT", metric_type="IP"
  92. )
  93. try:
  94. self.client.create_collection(
  95. collection_name=self.collection_name,
  96. schema=schema,
  97. index_params=index_params,
  98. )
  99. return "create_collection_success"
  100. except Exception as e:
  101. logger.error(f"创建{self.collection_name}数据库失败:{e}")
  102. return "create_collection_error"
  103. def insert_data(self, chunk, metadata):
  104. logger.info("准备插入数据")
  105. with torch.no_grad():
  106. embedding = self.embedding_function([chunk])
  107. logger.info("获取文本的向量信息。")
  108. if isinstance(embedding, dict) and "dense" in embedding:
  109. # bge embedding 获取embedding的方式
  110. dense_vec = embedding["dense"][0]
  111. else:
  112. dense_vec = embedding[0]
  113. try:
  114. self.client.insert(
  115. self.collection_name, {"dense_vector": dense_vec, **metadata}
  116. )
  117. logger.info("插入一条数据成功。")
  118. return True, "success"
  119. except Exception as e:
  120. doc_id = metadata.get("doc_id")
  121. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  122. self.delete_by_doc_id(doc_id=doc_id)
  123. return False, str(e)
  124. def batch_insert_data(self, chunks, metadatas):
  125. logger.info("准备插入数据")
  126. embedding_lists = self.embedding_function.encode_documents(chunks)
  127. logger.info("获取文本的向量信息。")
  128. record_lists = []
  129. for embedding, metadata in zip(embedding_lists, metadatas):
  130. if isinstance(embedding, dict) and "dense" in embedding:
  131. # bge embedding 获取embedding的方式
  132. dense_vec = embedding["dense"][0]
  133. else:
  134. dense_vec = embedding.tolist()
  135. # if hasattr(dense_vec, 'tolist'):
  136. # dense_vec = dense_vec.tolist()
  137. # logger.info(f"向量维度:{dense_vec}")
  138. # if isinstance(dense_vec, (float, int)):
  139. # dense_vec = [dense_vec]
  140. # if isinstance(dense_vec, np.float32):
  141. # dense_vec = [float(dense_vec)]
  142. record = {"dense_vector": dense_vec}
  143. record.update(metadata)
  144. record_lists.append(record)
  145. try:
  146. self.client.insert(
  147. self.collection_name, record_lists
  148. )
  149. logger.info("插入数据成功。")
  150. return True, "success"
  151. except Exception as e:
  152. doc_id = metadata.get("doc_id")
  153. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  154. self.delete_by_doc_id(doc_id=doc_id)
  155. return False, str(e)
  156. def search(self, query: str, k: int = 20, mode="hybrid"):
  157. output_fields = [
  158. "content",
  159. "doc_id",
  160. "chunk_id",
  161. "metadata",
  162. ]
  163. if mode in ["dense", "hybrid"]:
  164. with torch.no_grad():
  165. embedding = self.embedding_function([query])
  166. if isinstance(embedding, dict) and "dense" in embedding:
  167. dense_vec = embedding["dense"][0]
  168. else:
  169. dense_vec = embedding[0]
  170. if mode == "sparse":
  171. results = self.client.search(
  172. collection_name=self.collection_name,
  173. data=[query],
  174. anns_field="sparse_vector",
  175. limit=k,
  176. output_fields=output_fields,
  177. )
  178. elif mode == "dense":
  179. results = self.client.search(
  180. collection_name=self.collection_name,
  181. data=[dense_vec],
  182. anns_field="dense_vector",
  183. limit=k,
  184. output_fields=output_fields,
  185. )
  186. elif mode == "hybrid":
  187. full_text_search_params = {"metric_type": "BM25"}
  188. full_text_search_req = AnnSearchRequest(
  189. [query], "sparse_vector", full_text_search_params, limit=k
  190. )
  191. dense_search_params = {"metric_type": "IP"}
  192. dense_req = AnnSearchRequest(
  193. [dense_vec], "dense_vector", dense_search_params, limit=k
  194. )
  195. results = self.client.hybrid_search(
  196. self.collection_name,
  197. [full_text_search_req, dense_req],
  198. ranker=RRFRanker(),
  199. limit=k,
  200. output_fields=output_fields,
  201. )
  202. else:
  203. raise ValueError("Invalid mode")
  204. return [
  205. {
  206. "doc_id": doc["entity"]["doc_id"],
  207. "chunk_id": doc["entity"]["chunk_id"],
  208. "content": doc["entity"]["content"],
  209. "metadata": doc["entity"]["metadata"],
  210. "score": doc["distance"],
  211. }
  212. for doc in results[0]
  213. ]
  214. def query_filter(self, doc_id, filter_field):
  215. # doc id 文档id,content中包含 filter_field 字段的
  216. query_output_field = [
  217. "content",
  218. "chunk_id",
  219. "doc_id",
  220. "metadata"
  221. ]
  222. # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
  223. # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
  224. if filter_field:
  225. query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
  226. else:
  227. query_expr = f"doc_id == '{doc_id}'"
  228. try:
  229. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  230. except Exception as e:
  231. logger.error(f"根据关键词查询数据失败:{e}")
  232. query_filter_results = [{"code": 500}]
  233. return query_filter_results
  234. # for result in query_filter_results:
  235. # print(f"根据doc id 和 field 过滤结果: {result}\n\n")
  236. def query_chunk_id(self, chunk_id):
  237. # chunk id,查询切片
  238. query_output_field = [
  239. "content",
  240. "doc_id",
  241. "chunk_id",
  242. # "metadata"
  243. ]
  244. query_expr = f"chunk_id == '{chunk_id}'"
  245. try:
  246. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  247. except Exception as e:
  248. logger.info(f"根据chunk id 查询出错:{e}")
  249. query_filter_results = [{"code": 500}]
  250. return query_filter_results
  251. def query_chunk_id_list(self, chunk_id_list):
  252. # chunk id,查询切片
  253. query_output_field = [
  254. "content",
  255. "doc_id",
  256. "chunk_id",
  257. # "metadata"
  258. ]
  259. query_expr = f"chunk_id in {chunk_id_list}"
  260. try:
  261. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  262. except Exception as e:
  263. logger.info(f"根据chunk id 查询出错:{e}")
  264. query_filter_results = [{"code": 500}]
  265. return query_filter_results
  266. def update_data(self, chunk_id, chunk):
  267. # 根据chunk id查询对应的信息,
  268. chunk_expr = f"chunk_id == '{chunk_id}'"
  269. chunk_output_fields = [
  270. "pk",
  271. "doc_id",
  272. "chunk_id",
  273. "metadata"
  274. ]
  275. try:
  276. chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
  277. # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
  278. except Exception as e:
  279. logger.error(f"更新切片数据时查询失败:{e}")
  280. return "update_query_error", ""
  281. if not chunk_results:
  282. logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据")
  283. return "update_query_no_result", ""
  284. with torch.no_grad():
  285. embedding = self.embedding_function([chunk])
  286. if isinstance(embedding, dict) and "dense" in embedding:
  287. # bge embedding 获取embedding的方式
  288. dense_vec = embedding["dense"][0]
  289. else:
  290. dense_vec = embedding[0]
  291. chunk_dict = chunk_results[0]
  292. metadata = chunk_dict.get("metadata")
  293. old_chunk_len = metadata.get("chunk_len")
  294. chunk_len = len(chunk)
  295. metadata["chunk_len"] = chunk_len
  296. chunk_dict["content"] = chunk
  297. chunk_dict["dense_vector"] = dense_vec
  298. chunk_dict["metadata"] = metadata
  299. try:
  300. update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  301. logger.info(f"更新返回的数据:{update_res}")
  302. return "update_success", chunk_len - old_chunk_len
  303. except Exception as e:
  304. logger.error(f"更新数据时出错:{e}")
  305. return "update_error", ""
  306. def update_data(self, chunk_id, chunk):
  307. # 根据chunk id查询对应的信息,
  308. chunk_expr = f"chunk_id == '{chunk_id}'"
  309. chunk_output_fields = [
  310. "pk",
  311. "doc_id",
  312. "chunk_id",
  313. "metadata"
  314. ]
  315. try:
  316. chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
  317. # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
  318. except Exception as e:
  319. logger.error(f"更新切片数据时查询失败:{e}")
  320. return "update_query_error", ""
  321. if not chunk_results:
  322. logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
  323. return "update_query_no_result", ""
  324. with torch.no_grad():
  325. embedding = self.embedding_function([chunk])
  326. if isinstance(embedding, dict) and "dense" in embedding:
  327. # bge embedding 获取embedding的方式
  328. dense_vec = embedding["dense"][0]
  329. else:
  330. dense_vec = embedding[0]
  331. chunk_dict = chunk_results[0]
  332. metadata = chunk_dict.get("metadata")
  333. old_chunk_len = metadata.get("chunk_len")
  334. chunk_len = len(chunk)
  335. metadata["chunk_len"] = chunk_len
  336. chunk_dict["content"] = chunk
  337. chunk_dict["dense_vector"] = dense_vec
  338. chunk_dict["metadata"] = metadata
  339. try:
  340. update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  341. logger.info(f"更新返回的数据:{update_res}")
  342. return "update_success", chunk_len - old_chunk_len
  343. except Exception as e:
  344. logger.error(f"更新数据时出错:{e}")
  345. return "update_error", ""
  346. def delete_collection(self, collection):
  347. try:
  348. self.client.drop_collection(collection_name=collection)
  349. return "delete_collection_success"
  350. except Exception as e:
  351. logger.error(f"删除{collection}失败,出错原因:{e}")
  352. return "delete_collection_error"
  353. def delete_by_chunk_id(self, chunk_id:str = None):
  354. # 根据文档id查询主键值 milvus只支持主键删除
  355. expr = f"chunk_id == '{chunk_id}'"
  356. try:
  357. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id
  358. logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
  359. except Exception as e:
  360. logger.error(f"根据切片id查询主键失败:{e}")
  361. return "delete_query_error", []
  362. if not results:
  363. print(f"No data found for chunk id: {chunk_id}")
  364. # return "delete_no_result", []
  365. return "delete_success", [0]
  366. # 提取主键值
  367. primary_keys = [result["pk"] for result in results]
  368. chunk_len = [result["metadata"]["chunk_len"] for result in results]
  369. logger.info(f"获取到的主键信息:{primary_keys}")
  370. # 执行删除操作
  371. expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  372. try:
  373. delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  374. self.client.flush(collection_name=self.collection_name)
  375. self.client.compact(collection_name=self.collection_name)
  376. logger.info(f"Deleted data with chunk_id: {delete_res}")
  377. return "delete_success", chunk_len
  378. except Exception as e:
  379. logger.error(f"删除数据失败:{e}")
  380. return "delete_error", []
  381. def delete_by_doc_id(self, doc_id:str =None):
  382. # 根据文档id查询主键值 milvus只支持主键删除
  383. expr = f"doc_id == '{doc_id}'"
  384. try:
  385. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id
  386. except Exception as e:
  387. logger.error(f"根据切片id查询主键失败:{e}")
  388. return "delete_query_error"
  389. if not results:
  390. print(f"No data found for doc_id: {doc_id}")
  391. return "delete_no_result"
  392. # 提取主键值
  393. primary_keys = [result["pk"] for result in results]
  394. logger.info(f"获取到的主键信息:{primary_keys}")
  395. # 执行删除操作
  396. try:
  397. delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys)
  398. self.client.flush(collection_name=self.collection_name)
  399. self.client.compact(collection_name=self.collection_name)
  400. logger.info(f"Deleted data with doc_id: {delete_res}")
  401. return "delete_success"
  402. except Exception as e:
  403. logger.error(f"删除数据失败:{e}")
  404. return "delete_error"
  405. # # 执行删除操作
  406. # expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  407. # try:
  408. # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  409. # logger.info(f"Deleted data with doc_id: {delete_res}")
  410. # return "delete_success"
  411. # except Exception as e:
  412. # logger.error(f"删除数据失败:{e}")
  413. # return "delete_error"
  414. # 测试
  415. # def parse_json_to_schema_data_format():
  416. # sql_text_list = load_sql_query_ddl_info()
  417. # docs = []
  418. # for sql_info in sql_text_list:
  419. # sql_text = sql_info.get("sql_text")
  420. # source = ",".join(sql_info.get("table_list", []))
  421. # source = sql_info.get("source") if not source else source
  422. # sql = sql_info.get("sql")
  423. # ddl = sql_info.get("table_ddl")
  424. # metadata = {"source": source, "sql": sql, "ddl": ddl}
  425. # text_list = sql_info.get("sim_sql_text_list", [])
  426. # text_list.append(sql_text)
  427. # doc_id = str(uuid4())
  428. # for text in text_list:
  429. # chunk_id = str(uuid4())
  430. # insert_dict = {
  431. # "content": text,
  432. # "doc_id": doc_id,
  433. # "chunk_id": chunk_id,
  434. # "metadata": metadata
  435. # }
  436. # docs.append(insert_dict)
  437. # return docs
  438. # def insert_data_to_milvus(standard_retriever):
  439. # sql_dataset = parse_json_to_schema_data_format()
  440. # standard_retriever.build_collection()
  441. # for sql_dict in sql_dataset:
  442. # text = sql_dict["content"]
  443. # standard_retriever.insert_data(text, sql_dict)
  444. def main():
  445. # dense_ef = BGEM3EmbeddingFunction()
  446. embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  447. sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  448. standard_retriever = HybridRetriever(
  449. uri="http://localhost:19530",
  450. collection_name="milvus_hybrid",
  451. dense_embedding_function=sentence_transformer_ef,
  452. )
  453. # 插入数据
  454. # insert_data_to_milvus(standard_retriever)
  455. # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
  456. # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
  457. # model=sparse 稀疏检索
  458. # print(f"稀疏检索结果:{results}")
  459. # model=dense 向量检索
  460. # print(f"向量检索结果:{results}")
  461. # model=hybrid
  462. # print(f"向量检索结果:{results}")
  463. # 根据doc id删除数据
  464. # delete_start_time = time.time()
  465. # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
  466. # standard_retriever.delete_data(doc_id=doc_id)
  467. # delete_end_time = time.time()
  468. # print(f"删除耗时:{delete_end_time-delete_start_time}")
  469. # 根据chunk_id 更新数据
  470. # update_start_time = time.time()
  471. # chunk = "查询一下员工加班的情况"
  472. # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
  473. # standard_retriever.update_data(chunk_id, chunk)
  474. # update_end_time = time.time()
  475. # print(f"更新数据的时间:{update_end_time-update_start_time}")
  476. # 根据doc id 和关键字查询
  477. query_start_time = time.time()
  478. filter_field = "加班"
  479. doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
  480. standard_retriever.query_filter(doc_id, filter_field)
  481. query_end_time = time.time()
  482. print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
  483. if __name__=="__main__":
  484. main()