milvus_vector.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792
  1. import time
  2. import numpy as np
  3. import threading
  4. from pymilvus import (
  5. MilvusClient,
  6. DataType,
  7. Function,
  8. FunctionType,
  9. AnnSearchRequest,
  10. RRFRanker,
  11. )
  12. # from pymilvus.model.hybrid import BGEM3EmbeddingFunction
  13. from pymilvus import model
  14. # from rag.load_model import sentence_transformer_ef, sentence_transformer_qwen
  15. from rag.load_model import qwen_ed_ef, bge_m3_ef
  16. from utils.get_logger import setup_logger
  17. import torch
  18. #device = "cpu" if torch.cuda.is_available() else "cuda"
  19. device = "cuda" if torch.cuda.is_available() else "cpu"
  20. logger = setup_logger(__name__)
  21. # 全局锁:保护 PyTorch/CUDA 推理操作,防止多线程并发导致核心转储
  22. _embedding_lock = threading.Lock()
  23. # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  24. # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  25. embedding_mapping = {
  26. "bge-m3": bge_m3_ef,
  27. # "multilingual-e5-large-instruct": sentence_transformer_ef,
  28. "Qwen3-Embedding-0.6B": qwen_ed_ef
  29. }
  30. class HybridRetriever:
  31. def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
  32. self.uri = uri
  33. self.collection_name = collection_name
  34. self.embedding_name = embedding_name
  35. # self.embedding_function = sentence_transformer_ef
  36. logger.info(f"使用的embedding模型是:{embedding_name}")
  37. self.embedding_function = embedding_mapping.get(embedding_name)
  38. self.use_reranker = True
  39. self.use_sparse = True
  40. self.client = MilvusClient(uri=uri)
  41. def has_collection(self):
  42. try:
  43. collection_flag = self.client.has_collection(self.collection_name)
  44. logger.info(f"查询向量库的结果:{collection_flag}")
  45. except Exception as e:
  46. logger.info(f"查询向量库是否存在出错:{e}")
  47. collection_flag = False
  48. return collection_flag
  49. def build_collection(self):
  50. if isinstance(self.embedding_function.dim, dict):
  51. dense_dim = self.embedding_function.dim["dense"]
  52. else:
  53. dense_dim = self.embedding_function.dim
  54. # dense_dim = 1024
  55. logger.info(f"创建数据库的向量维度:{dense_dim}")
  56. analyzer_params={
  57. "type": "chinese"
  58. }
  59. schema = MilvusClient.create_schema()
  60. schema.add_field(
  61. field_name="pk",
  62. datatype=DataType.VARCHAR,
  63. is_primary=True,
  64. auto_id=True,
  65. max_length=100,
  66. )
  67. schema.add_field(
  68. field_name="content",
  69. datatype=DataType.VARCHAR,
  70. max_length=65535,
  71. analyzer_params=analyzer_params,
  72. enable_match=True,
  73. enable_analyzer=True,
  74. )
  75. schema.add_field(
  76. field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
  77. )
  78. schema.add_field(
  79. field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
  80. )
  81. schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
  82. schema.add_field(
  83. field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
  84. )
  85. schema.add_field(field_name="metadata", datatype=DataType.JSON)
  86. schema.add_field(field_name="Chapter", datatype=DataType.VARCHAR, max_length=1024)
  87. schema.add_field(field_name="Father_Chapter", datatype=DataType.VARCHAR, max_length=1024)
  88. functions = Function(
  89. name="bm25",
  90. function_type=FunctionType.BM25,
  91. input_field_names=["content"],
  92. output_field_names="sparse_vector",
  93. )
  94. schema.add_function(functions)
  95. index_params = MilvusClient.prepare_index_params()
  96. index_params.add_index(
  97. field_name="sparse_vector",
  98. index_type="SPARSE_INVERTED_INDEX",
  99. metric_type="BM25",
  100. )
  101. index_params.add_index(
  102. field_name="dense_vector", index_type="FLAT", metric_type="IP"
  103. )
  104. try:
  105. self.client.create_collection(
  106. collection_name=self.collection_name,
  107. schema=schema,
  108. index_params=index_params,
  109. )
  110. return "create_collection_success"
  111. except Exception as e:
  112. logger.error(f"创建{self.collection_name}数据库失败:{e}")
  113. return "create_collection_error"
  114. def insert_data(self, chunk, metadata):
  115. logger.info("准备插入数据")
  116. with _embedding_lock:
  117. with torch.no_grad():
  118. embedding = self.embedding_function([chunk])
  119. # logger.info(f"1111111111111111111{embedding}")
  120. # if embedding.get("dense")in not None:
  121. # embedding = embedding["dense"][0]
  122. # else:
  123. # embedding = embedding[0]
  124. # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  125. # embedding = embedding.astype(np.float32)
  126. logger.info("获取文本的向量信息。")
  127. if isinstance(embedding, dict) and "dense" in embedding:
  128. # bge embedding 获取embedding的方式
  129. # dense_vec = embedding["dense"][0]
  130. embedding = embedding["dense"][0]
  131. else:
  132. embedding = embedding[0]
  133. if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  134. embedding = embedding.astype(np.float32)
  135. # else:
  136. # dense_vec = embedding[0]
  137. dense_vec = embedding
  138. # 过滤掉向量库不存储的字段
  139. filtered_metadata = {k: v for k, v in metadata.items() if k not in ("bbox", "page")}
  140. try:
  141. self.client.insert(
  142. self.collection_name, {"dense_vector": dense_vec, **filtered_metadata}
  143. )
  144. logger.info("插入一条数据成功。")
  145. return True, "success"
  146. except Exception as e:
  147. doc_id = metadata.get("doc_id")
  148. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  149. self.delete_by_doc_id(doc_id=doc_id)
  150. return False, str(e)
  151. def batch_insert_data(self, chunks, metadatas):
  152. logger.info("准备插入数据")
  153. with _embedding_lock:
  154. embedding_lists = self.embedding_function.encode_documents(chunks)
  155. logger.info("获取文本的向量信息。")
  156. record_lists = []
  157. for embedding, metadata in zip(embedding_lists, metadatas):
  158. if isinstance(embedding, dict) and "dense" in embedding:
  159. # bge embedding 获取embedding的方式
  160. dense_vec = embedding["dense"][0]
  161. else:
  162. dense_vec = embedding.tolist()
  163. # if hasattr(dense_vec, 'tolist'):
  164. # dense_vec = dense_vec.tolist()
  165. # logger.info(f"向量维度:{dense_vec}")
  166. # if isinstance(dense_vec, (float, int)):
  167. # dense_vec = [dense_vec]
  168. # if isinstance(dense_vec, np.float32):
  169. # dense_vec = [float(dense_vec)]
  170. # 过滤掉向量库不存储的字段
  171. filtered_metadata = {k: v for k, v in metadata.items() if k not in ("bbox", "page")}
  172. record = {"dense_vector": dense_vec}
  173. record.update(filtered_metadata)
  174. record_lists.append(record)
  175. try:
  176. self.client.insert(
  177. self.collection_name, record_lists
  178. )
  179. logger.info("插入数据成功。")
  180. return True, "success"
  181. except Exception as e:
  182. doc_id = metadata.get("doc_id")
  183. logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
  184. self.delete_by_doc_id(doc_id=doc_id)
  185. return False, str(e)
  186. def search(self, query: str, k: int = 20, mode="hybrid"):
  187. output_fields = [
  188. "content",
  189. "doc_id",
  190. "chunk_id",
  191. "metadata",
  192. "Father_Chapter",
  193. "Chapter",
  194. ]
  195. if mode in ["dense", "hybrid"]:
  196. with _embedding_lock:
  197. with torch.no_grad():
  198. embedding = self.embedding_function([query])
  199. # embedding = embedding[0]
  200. # if embedding.get("dense")in not None:
  201. # embedding = embedding["dense"][0]
  202. # else:
  203. # embedding = embedding[0]
  204. # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  205. # embedding = embedding.astype(np.float32)
  206. # if isinstance(embedding, dict) and "dense" in embedding:
  207. # # dense_vec = embedding["dense"][0]
  208. # pass
  209. # else:
  210. # dense_vec = embedding
  211. logger.info("获取文本的向量信息。")
  212. if isinstance(embedding, dict) and "dense" in embedding:
  213. # bge embedding 获取embedding的方式
  214. # dense_vec = embedding["dense"][0]
  215. embedding = embedding["dense"][0]
  216. else:
  217. embedding = embedding[0]
  218. if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  219. embedding = embedding.astype(np.float32)
  220. dense_vec = embedding
  221. if mode == "sparse":
  222. results = self.client.search(
  223. collection_name=self.collection_name,
  224. data=[query],
  225. anns_field="sparse_vector",
  226. limit=k,
  227. output_fields=output_fields,
  228. )
  229. elif mode == "dense":
  230. results = self.client.search(
  231. collection_name=self.collection_name,
  232. data=[dense_vec],
  233. anns_field="dense_vector",
  234. limit=k,
  235. output_fields=output_fields,
  236. )
  237. elif mode == "hybrid":
  238. full_text_search_params = {"metric_type": "BM25"}
  239. full_text_search_req = AnnSearchRequest(
  240. [query], "sparse_vector", full_text_search_params, limit=k
  241. )
  242. dense_search_params = {"metric_type": "IP"}
  243. dense_req = AnnSearchRequest(
  244. [dense_vec], "dense_vector", dense_search_params, limit=k
  245. )
  246. results = self.client.hybrid_search(
  247. self.collection_name,
  248. [full_text_search_req, dense_req],
  249. ranker=RRFRanker(),
  250. limit=k,
  251. output_fields=output_fields,
  252. )
  253. else:
  254. raise ValueError("Invalid mode")
  255. return [
  256. {
  257. "doc_id": doc["entity"]["doc_id"],
  258. "chunk_id": doc["entity"]["chunk_id"],
  259. "content": doc["entity"]["content"],
  260. "Father_Chapter": doc["entity"]["Father_Chapter"],
  261. "Chapter": doc["entity"]["Chapter"],
  262. "metadata": doc["entity"]["metadata"],
  263. "score": doc["distance"],
  264. }
  265. for doc in results[0]
  266. ]
  267. def query_filter(self, doc_id, filter_field):
  268. # doc id 文档id,content中包含 filter_field 字段的
  269. query_output_field = [
  270. "content",
  271. "chunk_id",
  272. "doc_id",
  273. "metadata"
  274. ]
  275. # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
  276. # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
  277. if filter_field:
  278. query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
  279. else:
  280. query_expr = f"doc_id == '{doc_id}'"
  281. try:
  282. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  283. except Exception as e:
  284. logger.error(f"根据关键词查询数据失败:{e}")
  285. query_filter_results = [{"code": 500}]
  286. return query_filter_results
  287. # for result in query_filter_results:
  288. # print(f"根据doc id 和 field 过滤结果: {result}\n\n")
  289. def query_chunk_id(self, chunk_id):
  290. # chunk id,查询切片
  291. query_output_field = [
  292. "content",
  293. "doc_id",
  294. "chunk_id",
  295. # "metadata"
  296. ]
  297. query_expr = f"chunk_id == '{chunk_id}'"
  298. try:
  299. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  300. except Exception as e:
  301. logger.info(f"根据chunk id 查询出错:{e}")
  302. query_filter_results = [{"code": 500}]
  303. return query_filter_results
  304. def query_chunk_id_list(self, chunk_id_list):
  305. # chunk id,查询切片
  306. query_output_field = [
  307. "content",
  308. "doc_id",
  309. "chunk_id",
  310. # "metadata"
  311. ]
  312. query_expr = f"chunk_id in {chunk_id_list}"
  313. try:
  314. query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
  315. except Exception as e:
  316. logger.info(f"根据chunk id 查询出错:{e}")
  317. query_filter_results = [{"code": 500}]
  318. return query_filter_results
  319. def update_data(self, chunk_id, chunk):
  320. # 根据chunk id查询对应的信息
  321. chunk_expr = f"chunk_id == '{chunk_id}'"
  322. chunk_output_fields = [
  323. "pk",
  324. "doc_id",
  325. "chunk_id",
  326. "metadata",
  327. "Father_Chapter",
  328. "Chapter",
  329. ]
  330. try:
  331. chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
  332. except Exception as e:
  333. logger.error(f"更新切片数据时查询失败:{e}")
  334. return "update_query_error", ""
  335. if not chunk_results:
  336. logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
  337. return "update_query_no_result", ""
  338. # 使用锁保护 embedding 操作
  339. with _embedding_lock:
  340. with torch.no_grad():
  341. embedding = self.embedding_function([chunk])
  342. # # embedding = embedding[0]
  343. # if embedding.get("dense")in not None:
  344. # embedding = embedding["dense"][0]
  345. # else:
  346. # embedding = embedding[0]
  347. # if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  348. # embedding = embedding.astype(np.float32)
  349. # if isinstance(embedding, dict) and "dense" in embedding:
  350. # # dense_vec = embedding["dense"][0]
  351. # pass
  352. # else:
  353. # # dense_vec = embedding[0]
  354. # dense_vec = embedding
  355. logger.info("获取文本的向量信息。")
  356. if isinstance(embedding, dict) and "dense" in embedding:
  357. # bge embedding 获取embedding的方式
  358. # dense_vec = embedding["dense"][0]
  359. embedding = embedding["dense"][0]
  360. else:
  361. embedding = embedding[0]
  362. if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  363. embedding = embedding.astype(np.float32)
  364. dense_vec = embedding
  365. chunk_dict = chunk_results[0]
  366. logger.info(f"update_data 查询到的 pk: {chunk_dict.get('pk')}")
  367. metadata = chunk_dict.get("metadata")
  368. old_chunk_len = metadata.get("chunk_len")
  369. chunk_len = len(chunk)
  370. metadata["chunk_len"] = chunk_len
  371. chunk_dict["content"] = chunk
  372. chunk_dict["dense_vector"] = dense_vec
  373. chunk_dict["metadata"] = metadata
  374. try:
  375. update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  376. logger.info(f"更新返回的数据:{update_res}")
  377. return "update_success", chunk_len - old_chunk_len
  378. except Exception as e:
  379. logger.error(f"更新数据时出错:{e}")
  380. return "update_error", ""
  381. def delete_collection(self, collection):
  382. try:
  383. self.client.drop_collection(collection_name=collection)
  384. return "delete_collection_success"
  385. except Exception as e:
  386. logger.error(f"删除{collection}失败,出错原因:{e}")
  387. return "delete_collection_error"
  388. def delete_by_chunk_id(self, chunk_id:str = None):
  389. # 根据文档id查询主键值 milvus只支持主键删除
  390. expr = f"chunk_id == '{chunk_id}'"
  391. try:
  392. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"]) # 获取主键 id
  393. logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
  394. except Exception as e:
  395. logger.error(f"根据切片id查询主键失败:{e}")
  396. return "delete_query_error", []
  397. if not results:
  398. print(f"No data found for chunk id: {chunk_id}")
  399. # return "delete_no_result", []
  400. return "delete_success", [0]
  401. # 提取主键值
  402. primary_keys = [result["pk"] for result in results]
  403. chunk_len = [result["metadata"]["chunk_len"] for result in results]
  404. logger.info(f"获取到的主键信息:{primary_keys}")
  405. # 执行删除操作
  406. expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  407. try:
  408. delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  409. self.client.flush(collection_name=self.collection_name)
  410. self.client.compact(collection_name=self.collection_name)
  411. logger.info(f"Deleted data with chunk_id: {delete_res}")
  412. return "delete_success", chunk_len
  413. except Exception as e:
  414. logger.error(f"删除数据失败:{e}")
  415. return "delete_error", []
  416. def batch_delete_by_chunk_ids(self, chunk_ids: list):
  417. """
  418. 批量删除多个切片
  419. 参数:
  420. chunk_ids: 切片ID列表
  421. 返回:
  422. (status, chunk_lens_dict): 状态和切片长度字典 {chunk_id: chunk_len}
  423. """
  424. if not chunk_ids:
  425. return "delete_success", {}
  426. # 构造批量查询表达式
  427. chunk_ids_str = ", ".join([f"'{cid}'" for cid in chunk_ids])
  428. expr = f"chunk_id in [{chunk_ids_str}]"
  429. try:
  430. # 批量查询所有切片的主键和元数据
  431. results = self.client.query(
  432. collection_name=self.collection_name,
  433. filter=expr,
  434. output_fields=["pk", "metadata", "chunk_id"]
  435. )
  436. logger.info(f"批量查询到 {len(results)} 个切片")
  437. except Exception as e:
  438. logger.error(f"批量查询切片失败:{e}")
  439. return "delete_query_error", {}
  440. if not results:
  441. logger.warning(f"未找到任何匹配的切片")
  442. return "delete_success", {cid: 0 for cid in chunk_ids}
  443. # 提取所有主键和切片长度
  444. primary_keys = [result["pk"] for result in results]
  445. chunk_lens_dict = {result["chunk_id"]: result["metadata"]["chunk_len"] for result in results}
  446. logger.info(f"获取到的主键信息:{primary_keys}")
  447. # 一次性批量删除(关键:使用 IN 表达式)
  448. expr_delete = f"pk in {primary_keys}"
  449. try:
  450. delete_res = self.client.delete(
  451. collection_name=self.collection_name,
  452. filter=expr_delete
  453. )
  454. self.client.flush(collection_name=self.collection_name)
  455. self.client.compact(collection_name=self.collection_name)
  456. logger.info(f"批量删除成功: {len(primary_keys)} 个切片, 结果: {delete_res}")
  457. return "delete_success", chunk_lens_dict
  458. except Exception as e:
  459. logger.error(f"批量删除失败:{e}")
  460. return "delete_error", {}
  461. def query_by_doc_ids(self, doc_ids):
  462. """
  463. 根据多个 doc_id 查询所有数据
  464. 参数:
  465. doc_ids: 文档ID列表
  466. 返回:
  467. 查询结果列表,包含所有字段
  468. """
  469. if not doc_ids:
  470. logger.warning("doc_ids 为空")
  471. return []
  472. # 构建查询表达式
  473. doc_ids_str = ", ".join([f"'{doc_id}'" for doc_id in doc_ids])
  474. query_expr = f"doc_id in [{doc_ids_str}]"
  475. query_output_fields = [
  476. "pk",
  477. "content",
  478. "dense_vector",
  479. "doc_id",
  480. "chunk_id",
  481. "metadata",
  482. "Father_Chapter",
  483. "Chapter",
  484. ]
  485. try:
  486. query_results = self.client.query(
  487. collection_name=self.collection_name,
  488. filter=query_expr,
  489. output_fields=query_output_fields
  490. )
  491. logger.info(f"根据 doc_ids 查询到 {len(query_results)} 条数据")
  492. return query_results
  493. except Exception as e:
  494. logger.error(f"根据 doc_ids 查询数据失败:{e}")
  495. return []
  496. def delete_by_doc_id(self, doc_id:str =None):
  497. # 根据文档id查询主键值 milvus只支持主键删除
  498. expr = f"doc_id == '{doc_id}'"
  499. try:
  500. results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"]) # 获取主键 id
  501. except Exception as e:
  502. logger.error(f"根据切片id查询主键失败:{e}")
  503. return "delete_query_error"
  504. if not results:
  505. print(f"No data found for doc_id: {doc_id}")
  506. return "delete_no_result"
  507. # 提取主键值
  508. primary_keys = [result["pk"] for result in results]
  509. logger.info(f"获取到的主键信息:{primary_keys}")
  510. # 执行删除操作
  511. try:
  512. delete_res = self.client.delete(collection_name=self.collection_name, ids=primary_keys)
  513. # 非必须
  514. self.client.flush(collection_name=self.collection_name) # 持久化存储
  515. self.client.compact(collection_name=self.collection_name) # 释放磁盘空间
  516. logger.info(f"Deleted data with doc_id: {delete_res}")
  517. return "delete_success"
  518. except Exception as e:
  519. logger.error(f"删除数据失败:{e}")
  520. return "delete_error"
  521. def query_by_scalar_field(self, doc_id: str, field_name: str, field_value: str):
  522. """
  523. 根据标量字段查询数据(如 Father_Chapter)
  524. 参数:
  525. doc_id: 文档ID
  526. field_name: 字段名(如 Father_Chapter)
  527. field_value: 字段值
  528. 返回:
  529. 查询结果列表,格式同 search 方法
  530. """
  531. output_fields = [
  532. "content",
  533. "doc_id",
  534. "chunk_id",
  535. "metadata",
  536. "Father_Chapter",
  537. "Chapter",
  538. ]
  539. query_expr = f"doc_id == '{doc_id}' && {field_name} == '{field_value}'"
  540. try:
  541. query_results = self.client.query(
  542. collection_name=self.collection_name,
  543. filter=query_expr,
  544. output_fields=output_fields
  545. )
  546. return [
  547. {
  548. "doc_id": doc["doc_id"],
  549. "chunk_id": doc["chunk_id"],
  550. "content": doc["content"],
  551. "Father_Chapter": doc["Father_Chapter"],
  552. "Chapter": doc["Chapter"],
  553. "metadata": doc["metadata"],
  554. "score": 0,
  555. }
  556. for doc in query_results
  557. ]
  558. except Exception as e:
  559. logger.error(f"标量查询失败:{e}")
  560. return []
  561. def update_dense_vector_by_chunk_id(self, chunk_id: str, enhanced_text: str):
  562. """
  563. 根据 chunk_id 更新 dense_vector(将增强文本转为向量后替换原向量)
  564. 参数:
  565. chunk_id: 切片ID
  566. enhanced_text: 增强后的文本(正文 + qa + question + summary)
  567. 返回:
  568. (success, message)
  569. """
  570. # 1. 根据 chunk_id 查询原数据
  571. chunk_expr = f"chunk_id == '{chunk_id}'"
  572. chunk_output_fields = ["pk", "content", "doc_id", "chunk_id", "metadata", "Chapter", "Father_Chapter"]
  573. try:
  574. chunk_results = self.client.query(
  575. collection_name=self.collection_name,
  576. filter=chunk_expr,
  577. output_fields=chunk_output_fields
  578. )
  579. except Exception as e:
  580. logger.error(f"更新向量时查询失败:{e}")
  581. return False, f"查询失败: {e}"
  582. if not chunk_results:
  583. logger.warning(f"根据 {chunk_id} 未查询到对应数据")
  584. return False, "未找到对应切片"
  585. # 2. 生成新的向量
  586. with _embedding_lock:
  587. with torch.no_grad():
  588. embedding = self.embedding_function([enhanced_text])
  589. # if isinstance(embedding, dict) and "dense" in embedding:
  590. # dense_vec = embedding["dense"][0]
  591. # else:
  592. # dense_vec = embedding[0]
  593. logger.info("获取文本的向量信息。")
  594. if isinstance(embedding, dict) and "dense" in embedding:
  595. # bge embedding 获取embedding的方式
  596. # dense_vec = embedding["dense"][0]
  597. embedding = embedding["dense"][0]
  598. else:
  599. embedding = embedding[0]
  600. if embedding.dtype != np.float32: # qwen等模型在加速下只支持fp16,且milvus只支持fp32
  601. embedding = embedding.astype(np.float32)
  602. dense_vec = embedding
  603. # 3. 更新数据
  604. chunk_dict = chunk_results[0]
  605. logger.info(f"查询到的 pk: {chunk_dict.get('pk')}")
  606. chunk_dict["dense_vector"] = dense_vec
  607. try:
  608. self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
  609. logger.info(f"更新切片 {chunk_id} 的向量成功")
  610. return True, "success"
  611. except Exception as e:
  612. logger.error(f"更新向量失败:{e}")
  613. return False, str(e)
  614. # # 执行删除操作
  615. # expr_delete = f"pk in {primary_keys}" # 构造删除表达式
  616. # try:
  617. # delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
  618. # logger.info(f"Deleted data with doc_id: {delete_res}")
  619. # return "delete_success"
  620. # except Exception as e:
  621. # logger.error(f"删除数据失败:{e}")
  622. # return "delete_error"
  623. # 测试
  624. # def parse_json_to_schema_data_format():
  625. # sql_text_list = load_sql_query_ddl_info()
  626. # docs = []
  627. # for sql_info in sql_text_list:
  628. # sql_text = sql_info.get("sql_text")
  629. # source = ",".join(sql_info.get("table_list", []))
  630. # source = sql_info.get("source") if not source else source
  631. # sql = sql_info.get("sql")
  632. # ddl = sql_info.get("table_ddl")
  633. # metadata = {"source": source, "sql": sql, "ddl": ddl}
  634. # text_list = sql_info.get("sim_sql_text_list", [])
  635. # text_list.append(sql_text)
  636. # doc_id = str(uuid4())
  637. # for text in text_list:
  638. # chunk_id = str(uuid4())
  639. # insert_dict = {
  640. # "content": text,
  641. # "doc_id": doc_id,
  642. # "chunk_id": chunk_id,
  643. # "metadata": metadata
  644. # }
  645. # docs.append(insert_dict)
  646. # return docs
  647. # def insert_data_to_milvus(standard_retriever):
  648. # sql_dataset = parse_json_to_schema_data_format()
  649. # standard_retriever.build_collection()
  650. # for sql_dict in sql_dataset:
  651. # text = sql_dict["content"]
  652. # standard_retriever.insert_data(text, sql_dict)
  653. def main():
  654. # dense_ef = BGEM3EmbeddingFunction()
  655. embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
  656. sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  657. standard_retriever = HybridRetriever(
  658. uri="http://10.1.14.18:19530",
  659. collection_name="milvus_hybrid",
  660. dense_embedding_function=sentence_transformer_ef,
  661. )
  662. # 插入数据
  663. # insert_data_to_milvus(standard_retriever)
  664. # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
  665. # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
  666. # model=sparse 稀疏检索
  667. # print(f"稀疏检索结果:{results}")
  668. # model=dense 向量检索
  669. # print(f"向量检索结果:{results}")
  670. # model=hybrid
  671. # print(f"向量检索结果:{results}")
  672. # 根据doc id删除数据
  673. # delete_start_time = time.time()
  674. # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
  675. # standard_retriever.delete_data(doc_id=doc_id)
  676. # delete_end_time = time.time()
  677. # print(f"删除耗时:{delete_end_time-delete_start_time}")
  678. # 根据chunk_id 更新数据
  679. # update_start_time = time.time()
  680. # chunk = "查询一下员工加班的情况"
  681. # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
  682. # standard_retriever.update_data(chunk_id, chunk)
  683. # update_end_time = time.time()
  684. # print(f"更新数据的时间:{update_end_time-update_start_time}")
  685. # 根据doc id 和关键字查询
  686. query_start_time = time.time()
  687. filter_field = "加班"
  688. doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
  689. standard_retriever.query_filter(doc_id, filter_field)
  690. query_end_time = time.time()
  691. print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
  692. if __name__=="__main__":
  693. main()