slice_metadata.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # 切片元数据生成相关业务逻辑
  2. import json
  3. import re
  4. from utils.get_logger import setup_logger
  5. from rag.db import MilvusOperate, MysqlOperate
  6. logger = setup_logger(__name__)
  7. async def process_single_slice_metadata(
  8. knowledge_id: str,
  9. slice_id: str,
  10. slice_text: str,
  11. enable_qa: bool,
  12. enable_question: bool,
  13. enable_summary: bool,
  14. embedding_id: str = "e5",
  15. separator_num: str = "-1"
  16. ) -> tuple:
  17. """
  18. 处理单个切片的元数据生成(供 update_slice 和 insert_slice 调用)
  19. 返回: (success: bool, message: str)
  20. qa、question、summary
  21. """
  22. mysql_client = MysqlOperate()
  23. milvus_client = MilvusOperate(collection_name=knowledge_id, embedding_name=embedding_id)
  24. # 若三者都不启用,直接将 slice_text 写入向量库
  25. if not any([enable_qa, enable_question, enable_summary]):
  26. vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
  27. slice_id, slice_text
  28. )
  29. return vector_success, vector_msg
  30. # 动态构建提示词
  31. from rag.llm import VllmApi
  32. vllm_client = VllmApi({})
  33. json_fields = []
  34. requirements = []
  35. if enable_qa and separator_num != "3":
  36. json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}, {"question": "问题2", "answer": "答案2"}, {"question": "问题3", "answer": "答案3"}]')
  37. requirements.append("qa问答对必须能从正文中直接找到答案")
  38. elif enable_qa:
  39. json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}]')
  40. requirements.append("qa问答对必须能从正文中直接找到答案")
  41. if enable_question:
  42. json_fields.append(' "question": "根据正文生成的一个核心问题"')
  43. requirements.append("question必须是正文能够回答的问题")
  44. if enable_summary:
  45. json_fields.append(' "summary": "正文的简洁摘要(100字以内)"')
  46. requirements.append("summary必须是对正文的准确概括")
  47. json_template = "{\n" + ",\n".join(json_fields) + "\n}"
  48. requirements_text = "\n".join([f"{i+1}. {r}" for i, r in enumerate(requirements)])
  49. requirements_text += f"\n{len(requirements)+1}. 只输出JSON,不要有其他内容"
  50. prompt = fr"""请根据以下正文内容,严格基于正文信息生成指定字段,不要编造任何正文中不存在的内容,无论输入内容的长短都必须按照下面要求输出。
  51. 【重要JSON格式要求】
  52. - 必须输出标准JSON格式,使用英文双引号(")作为字符串分隔符
  53. - qa问答尽量简短精炼
  54. - 正文中涉及的所有转义内容必须严格遵循 JSON 标准转义规则(如 \"、\\、\n、\t、\uXXXX),禁止使用任何非 JSON 规范的转义表达(如 \circ、\alpha 等)
  55. 【正文内容】
  56. {slice_text}
  57. 请按照以下JSON格式输出:
  58. {json_template}
  59. 要求:
  60. {requirements_text}"""
  61. try:
  62. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  63. llm_result = ""
  64. async for chunk in vllm_client.chat(prompt, model, temperature=0.3, top_p=0.7, max_tokens=2048, stream=False):
  65. llm_result = chunk.get("data", "")
  66. break
  67. logger.info(f"LLM生成结果:{llm_result}")
  68. json_match = re.search(r'\{[\s\S]*\}', llm_result)
  69. if not json_match:
  70. return False, "LLM返回格式错误"
  71. result_json = json.loads(json_match.group())
  72. qa_str = json.dumps(result_json.get("qa", []), ensure_ascii=False) if enable_qa else None
  73. question_str = result_json.get("question", "") if enable_question else None
  74. summary_str = result_json.get("summary", "") if enable_summary else None
  75. # 更新 MySQL
  76. update_success, update_msg = mysql_client.update_slice_llm_fields(
  77. knowledge_id, slice_id, qa_str, question_str, summary_str
  78. )
  79. if not update_success:
  80. return False, f"更新MySQL失败: {update_msg}"
  81. # 更新向量库
  82. enhanced_parts = [slice_text]
  83. if enable_question and question_str:
  84. enhanced_parts.append(f"问题:{question_str}")
  85. if enable_summary and summary_str:
  86. enhanced_parts.append(f"摘要:{summary_str}")
  87. if enable_qa and qa_str:
  88. enhanced_parts.append(f"QA问答:{qa_str}")
  89. enhanced_text = "\n\n".join(enhanced_parts)
  90. vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
  91. slice_id, enhanced_text
  92. )
  93. return vector_success, vector_msg if not vector_success else "success"
  94. except Exception as e:
  95. logger.error(f"处理切片 {slice_id} 元数据出错: {e}")
  96. return False, str(e)
  97. async def generate_slice_metadata(request_json: dict) -> dict:
  98. """
  99. 根据 knowledge_id 和 document_id 生成切片的 qa、question、summary 字段,
  100. 并更新 MySQL 和向量库。若三者都不启用,则直接将 slice_text 写入向量库。
  101. 参数:
  102. - knowledge_id: 知识库ID
  103. - document_id: 文档ID(若有slice_id则可选)
  104. - slice_id: 切片ID(可选,若提供则只处理该切片)
  105. - embedding_id: 向量模型ID(可选,默认 e5)
  106. - qa: 是否生成QA问答对(默认False)
  107. - question: 是否生成核心问题(默认False)
  108. - summary: 是否生成摘要(默认False)
  109. 返回:
  110. - dict: 操作结果
  111. """
  112. knowledge_id = request_json.get("knowledge_id")
  113. document_id = request_json.get("document_id")
  114. slice_id = request_json.get("slice_id")
  115. embedding_id = request_json.get("embedding_id", "e5")
  116. enable_qa = request_json.get("qa", False)
  117. enable_question = request_json.get("question", False)
  118. enable_summary = request_json.get("summary", False)
  119. if not knowledge_id:
  120. return {"code": 400, "message": "knowledge_id 不能为空"}
  121. if not slice_id and not document_id:
  122. return {"code": 400, "message": "document_id 和 slice_id 至少需要提供一个"}
  123. logger.info(f"生成切片元数据 - 知识库: {knowledge_id}, 文档: {document_id}, 切片: {slice_id}, qa={enable_qa}, question={enable_question}, summary={enable_summary}")
  124. mysql_client = MysqlOperate()
  125. milvus_client = MilvusOperate(collection_name=knowledge_id, embedding_name=embedding_id)
  126. # 1. 查询切片数据
  127. if slice_id:
  128. success, slice_data = mysql_client.query_slice_by_id(knowledge_id, slice_id)
  129. else:
  130. success, slice_data = mysql_client.query_slice_by_knowledge_and_doc(knowledge_id, document_id)
  131. if not success:
  132. return {"code": 500, "message": f"查询切片数据失败: {slice_data}"}
  133. if not slice_data:
  134. return {"code": 404, "message": "未找到切片数据"}
  135. logger.info(f"查询到 {len(slice_data)} 条切片数据")
  136. # 2. 为每个切片生成 qa、question、summary
  137. from rag.llm import VllmApi
  138. vllm_client = VllmApi({})
  139. success_count = 0
  140. fail_count = 0
  141. for slice_item in slice_data:
  142. current_slice_id = slice_item["slice_id"]
  143. slice_text = slice_item["slice_text"]
  144. # 若三者都不启用,直接将 slice_text 写入向量库并跳过
  145. if not any([enable_qa, enable_question, enable_summary]):
  146. vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
  147. current_slice_id, slice_text
  148. )
  149. if vector_success:
  150. success_count += 1
  151. logger.info(f"切片 {current_slice_id} 直接写入向量库成功")
  152. else:
  153. fail_count += 1
  154. logger.warning(f"切片 {current_slice_id} 写入向量库失败: {vector_msg}")
  155. continue
  156. # 动态构建提示词(三指标,有则添加无则略过)
  157. json_fields = []
  158. requirements = []
  159. if enable_qa:
  160. json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}, {"question": "问题2", "answer": "答案2"}, {"question": "问题3", "answer": "答案3"}]')
  161. requirements.append("qa问答对必须能从正文中直接找到答案")
  162. if enable_question:
  163. json_fields.append(' "question": "根据正文生成的一个核心问题"')
  164. requirements.append("question必须是正文能够回答的问题")
  165. if enable_summary:
  166. json_fields.append(' "summary": "正文的简洁摘要(100字以内)"')
  167. requirements.append("summary必须是对正文的准确概括")
  168. # 组织Json模板
  169. json_template = "{\n" + ",\n".join(json_fields) + "\n}"
  170. requirements_text = "\n".join([f"{i+1}. {r}" for i, r in enumerate(requirements)])
  171. requirements_text += f"\n{len(requirements)+1}. 只输出JSON,不要有其他内容"
  172. prompt = f"""请根据以下正文内容,严格基于正文信息生成指定字段,不要编造任何正文中不存在的内容。
  173. 【重要JSON格式要求】
  174. - 必须输出标准JSON格式,使用英文双引号(")作为字符串分隔符
  175. - 如果正文中包含引号内容(如"信息"),在JSON值中应写成转义形式(如\"信息\")或用其他方式表达
  176. 【正文内容】
  177. {slice_text}
  178. 请按照以下JSON格式输出:
  179. {json_template}
  180. 要求:
  181. {requirements_text}"""
  182. try:
  183. # 调用 LLM
  184. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  185. llm_result = ""
  186. async for chunk in vllm_client.chat(prompt, model, temperature=0.3, top_p=0.7, max_tokens=2048, stream=False):
  187. llm_result = chunk
  188. break
  189. logger.info(f"LLM生成的:{llm_result}")
  190. # 解析 LLM 结果(提取JSON结果)
  191. json_match = re.search(r'\{[\s\S]*\}', llm_result)
  192. if json_match:
  193. result_json = json.loads(json_match.group())
  194. qa_str = json.dumps(result_json.get("qa", []), ensure_ascii=False) if enable_qa else None
  195. question_str = result_json.get("question", "") if enable_question else None
  196. summary_str = result_json.get("summary", "") if enable_summary else None
  197. else:
  198. logger.warning(f"切片 {current_slice_id} LLM返回格式错误")
  199. fail_count += 1
  200. continue
  201. # 3. 更新 MySQL - 只更新启用的字段
  202. update_success, update_msg = mysql_client.update_slice_llm_fields(
  203. knowledge_id, current_slice_id, qa_str, question_str, summary_str
  204. )
  205. if not update_success:
  206. logger.error(f"更新切片 {current_slice_id} MySQL失败: {update_msg}")
  207. fail_count += 1
  208. continue
  209. # 4. 更新向量库 - 只拼接启用的字段
  210. enhanced_parts = [slice_text]
  211. if enable_question and question_str:
  212. enhanced_parts.append(f"问题:{question_str}")
  213. if enable_summary and summary_str:
  214. enhanced_parts.append(f"摘要:{summary_str}")
  215. if enable_qa and qa_str:
  216. enhanced_parts.append(f"QA问答:{qa_str}")
  217. enhanced_text = "\n\n".join(enhanced_parts)
  218. vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
  219. current_slice_id, enhanced_text
  220. )
  221. if not vector_success:
  222. logger.warning(f"更新切片 {current_slice_id} 向量失败: {vector_msg}")
  223. success_count += 1
  224. logger.info(f"切片 {current_slice_id} 处理成功")
  225. except Exception as e:
  226. logger.error(f"处理切片 {current_slice_id} 出错: {e}")
  227. fail_count += 1
  228. continue
  229. result = {
  230. "code": 200,
  231. "message": "处理完成",
  232. "data": {
  233. "total": len(slice_data),
  234. "success": success_count,
  235. "failed": fail_count
  236. }
  237. }
  238. logger.info(f"生成切片元数据结果: {result}")
  239. return result