| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- # 切片元数据生成相关业务逻辑
- import json
- import re
- from utils.get_logger import setup_logger
- from rag.db import MilvusOperate, MysqlOperate
- logger = setup_logger(__name__)
- async def process_single_slice_metadata(
- knowledge_id: str,
- slice_id: str,
- slice_text: str,
- enable_qa: bool,
- enable_question: bool,
- enable_summary: bool,
- embedding_id: str = "e5",
- separator_num: str = "-1"
- ) -> tuple:
- """
- 处理单个切片的元数据生成(供 update_slice 和 insert_slice 调用)
- 返回: (success: bool, message: str)
- qa、question、summary
- """
- mysql_client = MysqlOperate()
- milvus_client = MilvusOperate(collection_name=knowledge_id, embedding_name=embedding_id)
-
- # 若三者都不启用,直接将 slice_text 写入向量库
- if not any([enable_qa, enable_question, enable_summary]):
- vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
- slice_id, slice_text
- )
- return vector_success, vector_msg
-
- # 动态构建提示词
- from rag.llm import VllmApi
- vllm_client = VllmApi({})
-
- json_fields = []
- requirements = []
- if enable_qa and separator_num != "3":
- json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}, {"question": "问题2", "answer": "答案2"}, {"question": "问题3", "answer": "答案3"}]')
- requirements.append("qa问答对必须能从正文中直接找到答案")
- elif enable_qa:
- json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}]')
- requirements.append("qa问答对必须能从正文中直接找到答案")
- if enable_question:
- json_fields.append(' "question": "根据正文生成的一个核心问题"')
- requirements.append("question必须是正文能够回答的问题")
- if enable_summary:
- json_fields.append(' "summary": "正文的简洁摘要(100字以内)"')
- requirements.append("summary必须是对正文的准确概括")
-
- json_template = "{\n" + ",\n".join(json_fields) + "\n}"
- requirements_text = "\n".join([f"{i+1}. {r}" for i, r in enumerate(requirements)])
- requirements_text += f"\n{len(requirements)+1}. 只输出JSON,不要有其他内容"
-
- prompt = fr"""请根据以下正文内容,严格基于正文信息生成指定字段,不要编造任何正文中不存在的内容,无论输入内容的长短都必须按照下面要求输出。
- 【重要JSON格式要求】
- - 必须输出标准JSON格式,使用英文双引号(")作为字符串分隔符
- - qa问答尽量简短精炼
- - 正文中涉及的所有转义内容必须严格遵循 JSON 标准转义规则(如 \"、\\、\n、\t、\uXXXX),禁止使用任何非 JSON 规范的转义表达(如 \circ、\alpha 等)
- 【正文内容】
- {slice_text}
- 请按照以下JSON格式输出:
- {json_template}
- 要求:
- {requirements_text}"""
- try:
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- llm_result = ""
- async for chunk in vllm_client.chat(prompt, model, temperature=0.3, top_p=0.7, max_tokens=2048, stream=False):
- llm_result = chunk.get("data", "")
- break
-
- logger.info(f"LLM生成结果:{llm_result}")
-
- json_match = re.search(r'\{[\s\S]*\}', llm_result)
- if not json_match:
- return False, "LLM返回格式错误"
-
- result_json = json.loads(json_match.group())
- qa_str = json.dumps(result_json.get("qa", []), ensure_ascii=False) if enable_qa else None
- question_str = result_json.get("question", "") if enable_question else None
- summary_str = result_json.get("summary", "") if enable_summary else None
-
- # 更新 MySQL
- update_success, update_msg = mysql_client.update_slice_llm_fields(
- knowledge_id, slice_id, qa_str, question_str, summary_str
- )
- if not update_success:
- return False, f"更新MySQL失败: {update_msg}"
-
- # 更新向量库
- enhanced_parts = [slice_text]
- if enable_question and question_str:
- enhanced_parts.append(f"问题:{question_str}")
- if enable_summary and summary_str:
- enhanced_parts.append(f"摘要:{summary_str}")
- if enable_qa and qa_str:
- enhanced_parts.append(f"QA问答:{qa_str}")
- enhanced_text = "\n\n".join(enhanced_parts)
-
- vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
- slice_id, enhanced_text
- )
- return vector_success, vector_msg if not vector_success else "success"
-
- except Exception as e:
- logger.error(f"处理切片 {slice_id} 元数据出错: {e}")
- return False, str(e)
- async def generate_slice_metadata(request_json: dict) -> dict:
- """
- 根据 knowledge_id 和 document_id 生成切片的 qa、question、summary 字段,
- 并更新 MySQL 和向量库。若三者都不启用,则直接将 slice_text 写入向量库。
-
- 参数:
- - knowledge_id: 知识库ID
- - document_id: 文档ID(若有slice_id则可选)
- - slice_id: 切片ID(可选,若提供则只处理该切片)
- - embedding_id: 向量模型ID(可选,默认 e5)
- - qa: 是否生成QA问答对(默认False)
- - question: 是否生成核心问题(默认False)
- - summary: 是否生成摘要(默认False)
-
- 返回:
- - dict: 操作结果
- """
- knowledge_id = request_json.get("knowledge_id")
- document_id = request_json.get("document_id")
- slice_id = request_json.get("slice_id")
- embedding_id = request_json.get("embedding_id", "e5")
- enable_qa = request_json.get("qa", False)
- enable_question = request_json.get("question", False)
- enable_summary = request_json.get("summary", False)
-
- if not knowledge_id:
- return {"code": 400, "message": "knowledge_id 不能为空"}
-
- if not slice_id and not document_id:
- return {"code": 400, "message": "document_id 和 slice_id 至少需要提供一个"}
-
- logger.info(f"生成切片元数据 - 知识库: {knowledge_id}, 文档: {document_id}, 切片: {slice_id}, qa={enable_qa}, question={enable_question}, summary={enable_summary}")
-
- mysql_client = MysqlOperate()
- milvus_client = MilvusOperate(collection_name=knowledge_id, embedding_name=embedding_id)
-
- # 1. 查询切片数据
- if slice_id:
- success, slice_data = mysql_client.query_slice_by_id(knowledge_id, slice_id)
- else:
- success, slice_data = mysql_client.query_slice_by_knowledge_and_doc(knowledge_id, document_id)
- if not success:
- return {"code": 500, "message": f"查询切片数据失败: {slice_data}"}
-
- if not slice_data:
- return {"code": 404, "message": "未找到切片数据"}
-
- logger.info(f"查询到 {len(slice_data)} 条切片数据")
-
- # 2. 为每个切片生成 qa、question、summary
- from rag.llm import VllmApi
- vllm_client = VllmApi({})
-
- success_count = 0
- fail_count = 0
-
- for slice_item in slice_data:
- current_slice_id = slice_item["slice_id"]
- slice_text = slice_item["slice_text"]
- # 若三者都不启用,直接将 slice_text 写入向量库并跳过
- if not any([enable_qa, enable_question, enable_summary]):
- vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
- current_slice_id, slice_text
- )
- if vector_success:
- success_count += 1
- logger.info(f"切片 {current_slice_id} 直接写入向量库成功")
- else:
- fail_count += 1
- logger.warning(f"切片 {current_slice_id} 写入向量库失败: {vector_msg}")
- continue
-
- # 动态构建提示词(三指标,有则添加无则略过)
- json_fields = []
- requirements = []
- if enable_qa:
- json_fields.append(' "qa": [{"question": "问题1", "answer": "答案1"}, {"question": "问题2", "answer": "答案2"}, {"question": "问题3", "answer": "答案3"}]')
- requirements.append("qa问答对必须能从正文中直接找到答案")
- if enable_question:
- json_fields.append(' "question": "根据正文生成的一个核心问题"')
- requirements.append("question必须是正文能够回答的问题")
- if enable_summary:
- json_fields.append(' "summary": "正文的简洁摘要(100字以内)"')
- requirements.append("summary必须是对正文的准确概括")
-
- # 组织Json模板
- json_template = "{\n" + ",\n".join(json_fields) + "\n}"
- requirements_text = "\n".join([f"{i+1}. {r}" for i, r in enumerate(requirements)])
- requirements_text += f"\n{len(requirements)+1}. 只输出JSON,不要有其他内容"
-
- prompt = f"""请根据以下正文内容,严格基于正文信息生成指定字段,不要编造任何正文中不存在的内容。
- 【重要JSON格式要求】
- - 必须输出标准JSON格式,使用英文双引号(")作为字符串分隔符
- - 如果正文中包含引号内容(如"信息"),在JSON值中应写成转义形式(如\"信息\")或用其他方式表达
- 【正文内容】
- {slice_text}
- 请按照以下JSON格式输出:
- {json_template}
- 要求:
- {requirements_text}"""
- try:
- # 调用 LLM
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- llm_result = ""
- async for chunk in vllm_client.chat(prompt, model, temperature=0.3, top_p=0.7, max_tokens=2048, stream=False):
- llm_result = chunk
- break
-
- logger.info(f"LLM生成的:{llm_result}")
- # 解析 LLM 结果(提取JSON结果)
- json_match = re.search(r'\{[\s\S]*\}', llm_result)
- if json_match:
- result_json = json.loads(json_match.group())
- qa_str = json.dumps(result_json.get("qa", []), ensure_ascii=False) if enable_qa else None
- question_str = result_json.get("question", "") if enable_question else None
- summary_str = result_json.get("summary", "") if enable_summary else None
- else:
- logger.warning(f"切片 {current_slice_id} LLM返回格式错误")
- fail_count += 1
- continue
-
- # 3. 更新 MySQL - 只更新启用的字段
- update_success, update_msg = mysql_client.update_slice_llm_fields(
- knowledge_id, current_slice_id, qa_str, question_str, summary_str
- )
- if not update_success:
- logger.error(f"更新切片 {current_slice_id} MySQL失败: {update_msg}")
- fail_count += 1
- continue
-
- # 4. 更新向量库 - 只拼接启用的字段
- enhanced_parts = [slice_text]
- if enable_question and question_str:
- enhanced_parts.append(f"问题:{question_str}")
- if enable_summary and summary_str:
- enhanced_parts.append(f"摘要:{summary_str}")
- if enable_qa and qa_str:
- enhanced_parts.append(f"QA问答:{qa_str}")
- enhanced_text = "\n\n".join(enhanced_parts)
-
- vector_success, vector_msg = milvus_client.hybrid_retriever.update_dense_vector_by_chunk_id(
- current_slice_id, enhanced_text
- )
- if not vector_success:
- logger.warning(f"更新切片 {current_slice_id} 向量失败: {vector_msg}")
-
- success_count += 1
- logger.info(f"切片 {current_slice_id} 处理成功")
-
- except Exception as e:
- logger.error(f"处理切片 {current_slice_id} 出错: {e}")
- fail_count += 1
- continue
-
- result = {
- "code": 200,
- "message": "处理完成",
- "data": {
- "total": len(slice_data),
- "success": success_count,
- "failed": fail_count
- }
- }
- logger.info(f"生成切片元数据结果: {result}")
- return result
|