from rag.db import MilvusOperate, MysqlOperate from rag.load_model import * from rag.llm import VllmApi from config import redis_config_dict import json import re import gc import redis from utils.get_logger import setup_logger import random import concurrent.futures from concurrent.futures import ThreadPoolExecutor, as_completed import time import httpx # 停用词 from rag.sensitive_words_detection import TxtSensitiveWordStore, SecurityTokenizer, build_warning # LightRAG from lightRAG import AsyncLightRAGManager store = TxtSensitiveWordStore("./rag/sensitive_words/") tokenizer_sensitive = SecurityTokenizer(store) # from rag.load_model import tokenizer logger = setup_logger(__name__) # rerank_model_mapping = { # "bge-reranker-v2-m3": (bce_rerank_tokenizer, bce_rerank_base_model), # "rerank": (bce_rerank_tokenizer, bce_rerank_base_model), # "Qwen3-Reranker-0.6B": (qwen_rerank_tokenizer, qwen_rerank_base_model) # } rerank_model_mapping_vllm = { "bge-reranker-v2-m3": (rerank_bge_url, rerank_bge_model), "Qwen3-Reranker-0.6B": (rerank_qwen_url, rerank_qwen_model) } redis_host = redis_config_dict.get("host") redis_port = redis_config_dict.get("port") redis_db = redis_config_dict.get("db") class ChatRetrieverRag: def __init__(self, chat_json: dict = None, chat_id: str=None): self.chat_id = chat_id self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db) if not chat_id: self.vllm_client = VllmApi(chat_json) def count_tokens(self,text): tokens = tokenizer.tokenize(text) return len(tokens) # 手动rerank # def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name): # rerank_list = [] # for result in hybrid_search_result: # rerank_list.append([query, result["content"]]) # tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name) # # 重排序 # rerank_model.eval() # inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt") # inputs = {k: v.to(device) for k, v in inputs.items()} # with torch.no_grad(): # logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float() # logits = logits.detach().cpu() # scores = torch.sigmoid(logits).tolist() # # scores_list = scores.tolist() # logger.info(f"重排序的得分:{scores}") # # sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True) # threshold = 0.3 # sorted_pairs = sorted( ((s, r) for s, r in zip(scores, hybrid_search_result) if s >= threshold), key=lambda x: x[0], reverse=True ) # if not sorted_pairs: # sorted_scores = [] # sorted_search = [] # else: # sorted_scores, sorted_search = zip(*sorted_pairs) # logger.info(f"过滤后的{sorted_scores}") # # sorted_scores, sorted_search = zip(*sorted_pairs) # sorted_scores_list = list(sorted_scores) # sorted_search_list = list(sorted_search) # for score, search in zip(sorted_scores_list, sorted_search_list): # search["rerank_score"] = score # del inputs, logits # gc.collect() # torch.cuda.empty_cache() # search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list))) # return search_result # rerank(url) async def rerank_result_async(self, query, top_k, hybrid_search_result, rerank_embedding_name): """ Args: query (str): 查询文本 top_k (int): 最终返回结果数量 hybrid_search_result (list[dict]): 原始搜索结果列表 rerank_embedding_name (str): rerank 模型 Returns: list[dict]: 添加了 "rerank_score" 字段的结果 """ # 查模型信息 rerank_url, rerank_model = rerank_model_mapping_vllm.get(rerank_embedding_name) if not hybrid_search_result: return [] # 构建候选文档列表 documents = [r["content"] for r in hybrid_search_result] payload = { "model": rerank_model, "query": query, "documents": documents } service_url = f"{rerank_url}/v1/rerank" # 使用异步 HTTP 请求 try: async with httpx.AsyncClient(timeout=60.0) as client: resp = await client.post(service_url, json=payload) resp.raise_for_status() result = resp.json() except Exception as e: logger.error(f"调用 rerank 服务失败: {e}") # 出错 fallback scores = [0.0] * len(documents) else: # 解析分数 if "results" in result: scores = [item.get("relevance_score", 0.0) for item in result["results"]] elif "data" in result: scores = [item.get("score", 0.0) for item in result["data"]] else: logger.warning(f"未知 rerank 返回格式: {result}") scores = [0.0] * len(documents) logger.info(f"重排序的得分:{scores}") # 构建分数与搜索结果列表 sorted_pairs = list(zip(scores, hybrid_search_result)) # # 阈值过滤 # threshold = 0.2 # filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold] # if not filtered_pairs: # filtered_pairs = sorted_pairs # logger.info(f"过滤后的{filtered_pairs}") # # 按分数降序排序 # filtered_pairs.sort(key=lambda x: x[0], reverse=True) threshold = 0.3 # sorted_pairs = sorted( # ((s, r) for s, r in sorted_pairs if s >= threshold) or sorted_pairs, # key=lambda x: x[0], # reverse=True # ) filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold] if not filtered_pairs: filtered_pairs = sorted_pairs sorted_pairs = sorted( filtered_pairs, key=lambda x: x[0], reverse=True ) if not sorted_pairs: sorted_scores_list = [] sorted_search_list = [] else: sorted_scores_list, sorted_search_list = zip(*sorted_pairs) logger.info(f"过滤后的{sorted_scores_list}") # 写入 rerank_score # sorted_scores_list, sorted_search_list = zip(*filtered_pairs) sorted_scores_list, sorted_search_list = list(sorted_scores_list), list(sorted_search_list) for score, search in zip(sorted_scores_list, sorted_search_list): search["rerank_score"] = score # 清理 GPU 内存 gc.collect() torch.cuda.empty_cache() # 随机采样 top_k search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list))) # logger.info(f"query='{query}' rerank结果: {[r['rerank_score'] for r in search_result]}") return search_result async def segment_query_with_llm(self, query, max_retries=3): """ 使用LLM对长问题进行切分,提取多个子问题 """ segment_prompt = f""" 你是一个专业的问题分析助手。请将下面的复杂问题分解为3-5个具体的子问题,每个子问题都应该是独立的、可检索的。 原问题:{query} 请按照以下JSON格式返回: {{ "sub_questions": [ "子问题1", "子问题2", "子问题3" ] }} 要求: 1. 每个子问题都要具体明确 2. 子问题之间相互独立 3. 覆盖原问题的主要方面 4. 适合向量检索 """ for attempt in range(max_retries): try: logger.info(f"第{attempt + 1}次尝试问题切分") # 构造请求参数 messages = [{"role": "user", "content": segment_prompt}] model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507" # 调用LLM response_text = "" async for chunk in self.vllm_client.chat( prompt=segment_prompt, model=model, temperature=0.1, top_p=0.9, max_tokens=1024, stream=True, history=[] ): # print(chunk) if chunk.get("event") == "add": response_text += chunk.get("data", "") logger.info(f"LLM返回的切分结果:{response_text}") # 解析JSON响应 import json if "```json" in response_text: json_pattern = r'```json\s*(.*?)```' matches = re.findall(json_pattern, response_text, re.DOTALL) if matches: response_text = matches[0] result = json.loads(response_text.strip()) sub_questions = result.get("sub_questions", []) if sub_questions and len(sub_questions) > 0: logger.info(f"成功切分为{len(sub_questions)}个子问题:{sub_questions}") return sub_questions else: logger.warning("LLM返回的子问题列表为空") except Exception as e: logger.error(f"第{attempt + 1}次问题切分失败:{str(e)}") if attempt == max_retries - 1: logger.error("问题切分达到最大重试次数,使用原问题") return [query] return [query] def request_milvus_multi(self, collection_name_list, querys, rag_embedding_name, top, mode): """ 支持多query的向量检索方法 """ results = [] doc_id_to_collection = {} def query_collection(collection_name, query): return collection_name, query, MilvusOperate( collection_name=collection_name, embedding_name=rag_embedding_name )._search(query, k=top, mode=mode) tasks = [(c, q) for c in collection_name_list for q in querys] max_workers = min(20, len(tasks)) with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_task = { executor.submit(query_collection, c, q): (c, q) for c, q in tasks } for future in as_completed(future_to_task): collection_name, query = future_to_task[future] try: _, _, search_result = future.result() for doc in search_result: doc_id = doc["doc_id"] if doc_id not in doc_id_to_collection: doc_id_to_collection[doc_id] = collection_name results.extend(search_result) except Exception as e: logger.error(f"查询集合 {collection_name} 对 query {query} 出错: {e}") return results, doc_id_to_collection def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode): search_multi_collection_list = [] search_doc_id_to_knowledge_id_dict = {} # 多线程查询 def query_collection(collection_name): return collection_name, MilvusOperate( collection_name=collection_name, embedding_name=rag_embedding_name )._search(query, k=top, mode=mode) max_workers = min(10, len(collection_name_list)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_name = { executor.submit(query_collection, name): name for name in collection_name_list } for future in concurrent.futures.as_completed(future_to_name): collection_name = future_to_name[future] try: _, hybrid_search_result = future.result() for result in hybrid_search_result: doc_id = result.get("doc_id") if doc_id not in search_doc_id_to_knowledge_id_dict: search_doc_id_to_knowledge_id_dict[doc_id] = collection_name search_multi_collection_list.extend(hybrid_search_result) except Exception as e: logger.error(f"查询集合 {collection_name} 时出错: {str(e)}") continue return search_multi_collection_list, search_doc_id_to_knowledge_id_dict async def retriever_result(self, chat_json, rewrite = None): mode_search_mode = { "embedding": "dense", "keyword": "sparse", "mixed": "hybrid", } collection_name_list = chat_json.get("knowledgeIds") top = chat_json.get("sliceCount", 15) retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}")) rag_embedding_name = chat_json.get("embeddingId") mode_rag = retriever_info.get("recall_method") mode = mode_search_mode.get(mode_rag, "hybrid") rerank_embedding_name = retriever_info.get("rerank_model_name") rerank_status = retriever_info.get("rerank_status") if rewrite: chat_json["query"] = rewrite query = chat_json.get("query") query_list = [query] # search_multi_collection_list = [] # search_doc_id_to_knowledge_id_dict = {} # for collection_name in collection_name_list: # hybrid_search_result = MilvusOperate( # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode) # for result in hybrid_search_result: # doc_id = result.get("doc_id") # if doc_id not in search_doc_id_to_knowledge_id_dict: # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name # search_multi_collection_list.extend(hybrid_search_result) """ search_multi_collection_list信息格式: { "doc_id": "07917ce4-ce63-11f0-ad4f-ecd68ae97d92", "chunk_id": "7a4aba3e-ce63-11f0-8a8d-ecd68ae97d92", "content": "## 3.13 高处作业 24\n", "Father_Chapter": "3->3.6", "Chapter": "3->3.6->3.6.2", "metadata": { "source": "超长ceshi.pdf", "chunk_index": 14, "chunk_len": 16 }, "score": 0.016393441706895828, "rerank_score": 1 } """ start_time = time.time() two_time = -1 # if rewrite: # query = rewrite # query_list = [rewrite] # 1. 检查是否需要问题切分 problem_segmentation = chat_json.get("Problem_segmentation", False) if problem_segmentation: logger.info("启用问题切分模式") # 使用LLM切分问题 query_list = await self.segment_query_with_llm(query) logger.info(f"切分后的问题列表:{query_list}") # [query + "\n" + summary if summary else query for query in query_list] # 多轮检索 # if not query_list: # query_list = [query + "\n" + summary] if summary else [query] logger.info(f"query_list:{query_list}") search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus_multi( collection_name_list, query_list, rag_embedding_name, top, mode ) # # 单次检索 # query_list = query # search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus( # collection_name_list, rag_embedding_name, query, top, mode # ) logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}") if len(search_multi_collection_list) <= 0: return [], search_doc_id_to_knowledge_id_dict # 2. 重排序 if rerank_status: # rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name) rerank_result_list = await self.rerank_result_async(query, top, search_multi_collection_list, rerank_embedding_name) else: for result in search_multi_collection_list: result["rerank_score"] = 1 rerank_result_list = random.sample(search_multi_collection_list, min(top, len(search_multi_collection_list))) # 3. 收集文档ID,查询MySQL判断是否需要章节扩展 doc_ids = list({r["doc_id"] for r in rerank_result_list}) # 收集知识库ID,用于知识图谱检索 knowledge_ids = list({r["metadata"].get("knowledge_id", "") for r in rerank_result_list}) logger.info(f"检索到的知识库id{knowledge_ids}") mysql_client = MysqlOperate() enabled_doc_ids = mysql_client.query_parent_generation_enabled(doc_ids) enabled_kn_rs = mysql_client.query_knowledge_by_ids(knowledge_ids) enabled_kn_gp_ids = set() for r in enabled_kn_rs: if r["knowledge_graph"]: enabled_kn_gp_ids.add(r["knowledge_id"]) if enabled_kn_gp_ids: gp_add_list = [] enabled_kn_gp_ids = list(enabled_kn_gp_ids) gp_time = time.time() logger.info(f"需要查询知识图谱的知识库:{enabled_kn_gp_ids}") # if chat_json.get("lightrag"): for kn_id in enabled_kn_gp_ids: result = await AsyncLightRAGManager().retrieve(label=kn_id, query=query) chunks_add = result["data"].get("chunks", []) if chunks_add: # chunk_dict = {} for chunk_add in chunks_add: # chunk_dict["content"] = chunk_add gp_add_list.append({ "doc_id": None, "chunk_id": None, "content": chunk_add["content"], "metadata": {}, "rerank_score": 1, "revision_status": None, "section": None, "parent_section": None, }) logger.info(f"知识图谱扩展的内容:{gp_add_list}") rerank_result_list.extend(gp_add_list) if enabled_doc_ids: tmp_time = time.time() logger.info(f"需要章节扩展的文档: {enabled_doc_ids}") existing_chunk_ids = {r["chunk_id"] for r in rerank_result_list} # 收集需要扩展的 (collection_name, doc_id, Father_Chapter) 组合-集合名称-文档id-父级标题 father_chapter_set = set() for result in rerank_result_list: doc_id = result.get("doc_id") if doc_id not in enabled_doc_ids: continue father_chapter = result.get("Father_Chapter") collection_name = search_doc_id_to_knowledge_id_dict.get(doc_id) if father_chapter and collection_name: father_chapter_set.add((collection_name, doc_id, father_chapter)) # 4. 二次检索并追加结果 for collection_name, doc_id, father_chapter in father_chapter_set: chapter_results = MilvusOperate( collection_name=collection_name, embedding_name=rag_embedding_name )._query_by_scalar_field(doc_id, "Father_Chapter", father_chapter) for r in chapter_results: if r["chunk_id"] not in existing_chunk_ids: r["rerank_score"] = 1 rerank_result_list.append(r) existing_chunk_ids.add(r["chunk_id"]) two_time = time.time() - tmp_time logger.info(f"父子章节扩展后的结果数: {len(rerank_result_list)}") up_time = time.time() - start_time two_status = "未启动父子召回" if two_time < 0 else f"启动父子召回耗时:{two_time}" logger.info(f"检索耗时:{up_time}({two_status})") return rerank_result_list, search_doc_id_to_knowledge_id_dict async def generate_rag_response(self, chat_json, chunk_content): # logger.info(f"rag聊天的请求参数:{chat_json}") # retriever_result_list = self.retriever_result(chat_json) # logger.info(f"向量库中获取的最终结果:{retriever_result_list}") prompt = chat_json.get("prompt") query = chat_json.get("query") stream = chat_json.get("stream", True) # chunk_content = "" # for retriever in retriever_result_list: # chunk_content += retriever["content"] prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query) logger.info(f"请求的提示词:{prompt}") temperature = float(chat_json.get("temperature", 0.6)) top_p = float(chat_json.get("topP", 0.7)) max_token = chat_json.get("maxToken", 4096) enable_think = chat_json.get("enable_think", False) history = chat_json.get("messages", []) if history: history[-1]["content"] = prompt # 调用模型获取返回的结果 # model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B") model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507" chat_resp = "" async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=stream, history=history, enable_think=enable_think): logger.info(f"chat_message中接收到的LLM信息{chunk}") if stream: if chunk.get("event") == "add": chat_resp += chunk.get("data") chunk["data"] = chat_resp yield chunk else: yield chunk else: chunk def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict): doc_info = {} chunk_content = "" # 组织成每个doc对应的json格式 for retriever in retriever_result_list: chunk_text = retriever["content"] chunk_content += chunk_text doc_id = retriever["doc_id"] chunk_id = retriever["chunk_id"] rerank_score = retriever["rerank_score"] doc_name = retriever.get("metadata").get("source","") if doc_id in doc_info: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text), "document_name": retriever.get("document_name"), "revision_status": retriever.get("revision_status"), "ref_slice_id": retriever.get("ref_slice_id"), "revision_group_id": retriever.get("revision_group_id"), "section": retriever.get("section"), "parent_section": retriever.get("parent_section"), } doc_info[doc_id]["chunk_list"].append(c_d) # doc_info[doc_id]["rerank_score"].append(rerank_score) else: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text), "document_name": retriever.get("document_name"), "revision_status": retriever.get("revision_status"), "ref_slice_id": retriever.get("ref_slice_id"), "revision_group_id": retriever.get("revision_group_id"), "section": retriever.get("section"), "parent_section": retriever.get("parent_section"), } doc_info[doc_id] = { "doc_name": doc_name, "chunk_list": [c_d], } # doc_list = [] knowledge_info_dict = {} """ { "knowledge_id": [{}, {}], "knowledge_id" } """ for k, v in doc_info.items(): d = {} knowledge_id = search_doc_id_to_knowledge_id_dict.get(k) d["doc_id"] = k d["doc_name"] = v.get("doc_name") d["chunk_nums"] = len(v.get("chunk_list")) d["chunk_info_list"] = v.get("chunk_list") # d["chunk_len"] = len(v.get("chunk_id_list")) if knowledge_id in knowledge_info_dict: knowledge_info_dict[knowledge_id].append(d) else: knowledge_info_dict[knowledge_id] = [d] # doc_list.append(d) return chunk_content, knowledge_info_dict def _inject_revision_rules_to_prompt(self, prompt: str, rules_text: str): if not rules_text: return prompt if not prompt: return rules_text # insert_anchor = "用户输入问题" # if insert_anchor in prompt: # idx = prompt.find(insert_anchor) # return prompt[:idx] + rules_text + "\n\n" + prompt[idx:] # if "{用户}" in prompt: # return prompt.replace("{用户}", rules_text + "\n\n{用户}", 1) return prompt + "\n\n" + rules_text def _normalize_revision_status(self, value): return value # if value is None: # return None # if isinstance(value, bool): # return int(value) # if isinstance(value, int): # return value # try: # s = str(value).strip() # if s == "": # return None # return int(s) # except Exception: # return None def _normalize_slice_id(self, value): # if value is None: # return None # s = str(value).strip() # if s == "": # return None # if s.isdigit() and len(s) < 20: # return s.zfill(20) return value def _format_slice_block(self, doc_name: str, version_tag: str, section: str, parent_section: str, text: str, extra_title_suffix: str = ""): title = (doc_name or "未知文档") if version_tag: title += f"({version_tag})" if extra_title_suffix: title += extra_title_suffix lines = [f"> {title}"] if parent_section: lines.append(f"{parent_section}") if section and section != parent_section: lines.append(f"{section}" if parent_section else f"{section}") if text: lines.append(str(text).strip()) return "\n".join(lines).strip() + "\n\n" # def _build_llm_chunk_content(self, retriever_result_list): # slice_map = {} # for r in retriever_result_list: # sid = self._normalize_slice_id(r.get("chunk_id")) # if sid: # # def _build_llm_chunk_content(self, retriever_result_list): # # slice_map = {} # # for r in retriever_result_list: # # sid = self._normalize_slice_id(r.get("chunk_id")) # # if sid: # # slice_map[sid] = r # original_to_revision_ids = {} # revision_ids = set() # for r in retriever_result_list: # # if self._normalize_revision_status(r.get("revision_status")) == 1 and r.get("ref_slice_id"): # if r.get("revision_status") == "1" and r.get("ref_slice_id"): # original_to_revision_ids[r.get("chunk_id")] = r.get("ref_slice_id") # revision_ids.add(r.get("ref_slice_id")) # if sid: # slice_map[sid] = r # visited = set() # blocks = [] # for r in retriever_result_list: # sid = self._normalize_slice_id(r.get("chunk_id")) # if not sid or sid in visited: # continue # revision_status = r.get("revision_status") # ref_slice_id = self._normalize_slice_id(r.get("ref_slice_id")) # if revision_status == "1" and ref_slice_id: # original = r # revision = slice_map.get(ref_slice_id) # blocks.append( # self._format_slice_block( # original.get("document_name"), # "原始版", # original.get("section"), # original.get("parent_section"), # original.get("content"), # extra_title_suffix=" [已修订]", # ) # ) # visited.add(sid) # revision_sid = self._normalize_slice_id(revision.get("chunk_id")) if revision else None # if revision and revision_sid and revision_sid not in visited: # blocks.append( # self._format_slice_block( # revision.get("document_name"), # "修订版", # revision.get("section"), # revision.get("parent_section"), # revision.get("content"), # ) # ) # visited.add(revision_sid) # continue # if revision_status == "0": # deprecated_text = (r.get("content") or "") # deprecated_text = deprecated_text + "\n\n【已废弃】此切片已弃用,请优先参考修订版或最新文档。" # blocks.append( # self._format_slice_block( # r.get("document_name"), # "已废弃", # r.get("section"), # r.get("parent_section"), # deprecated_text, # ) # ) # visited.add(sid) # continue # if sid in revision_ids: # continue # blocks.append( # self._format_slice_block( # r.get("document_name"), # "", # r.get("section"), # r.get("parent_section"), # r.get("content"), # ) # ) # visited.add(sid) # return "".join(blocks).strip() def _build_llm_chunk_content(self, retriever_result_list): """ 构建 LLM 用的切片内容,仅发送最初检索到的向量数据库内容块,每个块添加 chunk_id。 """ logger.info(f"_build_llm_chunk_content 接收到的切片数量: {len(retriever_result_list)}") blocks = [] for r in retriever_result_list: chunk_id = r.get("chunk_id") content = r.get("content", "") doc_name = r.get("document_name") or r.get("metadata", {}).get("source", "未知文档") # 格式:[chunk_id: xxx] 文档名\n内容 block = f"[chunk_id: {chunk_id}]\n> {doc_name}\n{content}\n\n" blocks.append(block) return "".join(blocks).strip() # def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict): # mysql_client = MysqlOperate() # slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")] # slice_ids = [s for s in slice_ids if s] # slice_rows = [] # success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids) # if success: # slice_rows = rows_or_error # slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")} # logger.info(slice_rows) # has_revision_or_deprecated = False # revision_slice_ids = set() # original_to_revision_id = {} # for r in retriever_result_list: # sid = self._normalize_slice_id(r.get("chunk_id")) # info = slice_info_map.get(sid) # if info: # normalized_revision_status = self._normalize_revision_status(info.get("revision_status")) # r["revision_status"] = normalized_revision_status # r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id")) # r["section"] = info.get("section") # r["parent_section"] = info.get("parent_section") # if normalized_revision_status in ("0", "1"): # has_revision_or_deprecated = True # logger.info("修订切片") # else: # r.setdefault("revision_status", None) # r.setdefault("ref_slice_id", None) # r.setdefault("section", None) # r.setdefault("parent_section", None) # ref_id = self._normalize_slice_id(r.get("ref_slice_id")) # sid = self._normalize_slice_id(r.get("chunk_id")) # if r.get("revision_status") == "1" and ref_id and sid: # original_to_revision_id[sid] = ref_id # revision_slice_ids.add(ref_id) # r["revision_group_id"] = sid # else: # r.setdefault("revision_group_id", None) # ref_slice_rows = [] # if revision_slice_ids: # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids)) # if success: # ref_slice_rows = ref_rows_or_error # revision_id_to_original_id = {} # for original_id, revision_id in original_to_revision_id.items(): # if revision_id: # revision_id_to_original_id[revision_id] = original_id # for r in retriever_result_list: # original_id = revision_id_to_original_id.get(r.get("chunk_id")) # if original_id: # r["revision_group_id"] = original_id # r["revision_of_slice_id"] = original_id # else: # r.setdefault("revision_of_slice_id", None) # for row in (slice_rows + ref_slice_rows): # doc_id = row.get("document_id") # knowledge_id = row.get("knowledge_id") # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict: # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids)) # if success: # ref_slice_rows = ref_rows_or_error # revision_id_to_original_id = {} # for original_id, revision_id in original_to_revision_id.items(): # if revision_id: # revision_id_to_original_id[revision_id] = original_id # for r in retriever_result_list: # original_id = revision_id_to_original_id.get(r.get("chunk_id")) # if original_id: # r["revision_group_id"] = original_id # r["revision_of_slice_id"] = original_id # else: # r.setdefault("revision_of_slice_id", None) # for row in (slice_rows + ref_slice_rows): # doc_id = row.get("document_id") # knowledge_id = row.get("knowledge_id") # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict: # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id # existing_slice_ids = {r.get("chunk_id") for r in retriever_result_list if r.get("chunk_id")} # existing_slice_ids = {self._normalize_slice_id(s) for s in existing_slice_ids} # existing_slice_ids = {s for s in existing_slice_ids if s} # ref_map = {r.get("slice_id"): r for r in ref_slice_rows if r.get("slice_id")} # for revision_id in revision_slice_ids: # if revision_id in existing_slice_ids: # continue # info = ref_map.get(revision_id) # if not info: # continue # doc_id = info.get("document_id") # text = info.get("slice_text") or "" # metadata = {"source": "", "chunk_index": None, "chunk_len": len(text)} # original_id = revision_id_to_original_id.get(revision_id) # retriever_result_list.append( # { # "doc_id": doc_id, # "chunk_id": revision_id, # "content": text, # "Father_Chapter": "", # "Chapter": "", # "metadata": metadata, # "score": 0, # "rerank_score": 1, # "revision_status": None, # "ref_slice_id": None, # "revision_group_id": original_id, # "revision_of_slice_id": original_id, # "section": info.get("section"), # "parent_section": info.get("parent_section"), # } # ) # existing_slice_ids.add(revision_id) # doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")}) # success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids) # doc_name_map = doc_name_map_or_error if success else {} # for r in retriever_result_list: # doc_id = r.get("doc_id") # doc_name = doc_name_map.get(doc_id) # r["document_name"] = doc_name # if r.get("metadata") is None: # r["metadata"] = {} # if doc_name: # r["metadata"]["source"] = doc_name # rules_text = "" # if has_revision_or_deprecated: # rules_text = ( # "切片中可能含有废弃或修订切片。\n" # "- 若切片标题包含“已废弃”,表示该内容已不再适用;如需引用请明确说明其已废弃,并优先给出最新/修订信息(若有)。\n" # "- 若切片标题包含“原始版/修订版”,表示同一主题存在版本差异;回答时优先采用修订版,必要时对比说明差异。\n" # "- 输出时可参考如下结构组织:\n" # "> 文档A(原始版)[已修订]\n" # "> 文档B(修订版)\n" # "> 文档C(已废弃)\n" # ) # chunk_content_for_llm = self._build_llm_chunk_content(retriever_result_list) # return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, rules_text def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict): """ 处理切片的修订状态、原始/修订关系、废弃状态,并补充文档名称。 返回更新后的 retriever_result_list、search_doc_id_to_knowledge_id_dict、 LLM 拼接内容以及修订映射信息。 新逻辑: 1. 查询修订切片的完整内容 2. 将修订后的切片给LLM使用 3. 保留原始切片信息用于Redis存储 """ mysql_client = MysqlOperate() # 收集所有 chunk_id slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")] slice_ids = [s for s in slice_ids if s] # 查询切片修订信息 slice_rows = [] if slice_ids: success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids) if success: slice_rows = rows_or_error # 构建 slice_id -> row 映射 slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")} # 构建原始->修订映射和修订映射信息 original_to_revision_id = {} revision_slice_ids = set() # 收集所有修订切片ID chunk_revision_map = {} # 用于存储修订关系 for r in retriever_result_list: # 当前切片id sid = self._normalize_slice_id(r.get("chunk_id")) # 切片id对应的row info = slice_info_map.get(sid) if info: # 切片修订状态 revision_status = self._normalize_revision_status(info.get("revision_status")) # 记录状态-修订切片id-父子标题 r["revision_status"] = revision_status r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id")) r["section"] = info.get("section") r["parent_section"] = info.get("parent_section") # 原始->修订映射 ref_id = r.get("ref_slice_id") # 如果是被修订的切片 if revision_status == "1" and ref_id and sid: # 记录原始切片-修订切片的id映射 original_to_revision_id[sid] = ref_id # 记录修订切片id revision_slice_ids.add(ref_id) # 记录后续使用 chunk_revision_map[sid] = {"revised_to": ref_id} elif revision_status == "0": # 废弃切片 chunk_revision_map[sid] = {"deprecated": True} else: r.setdefault("revision_status", None) r.setdefault("ref_slice_id", None) r.setdefault("section", None) r.setdefault("parent_section", None) # 查询修订切片的完整信息 revision_slice_info_map = {} # 如果有被修订的切片 if revision_slice_ids: # 查询修订切片信息 success, revision_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids)) if success: # 构建修订id->row的映射 revision_slice_info_map = { self._normalize_slice_id(r.get("slice_id")): r for r in revision_rows_or_error if r.get("slice_id") } logger.info(f"查询到 {len(revision_slice_info_map)} 个修订切片的完整信息") # 补充文档名称 doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")}) success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids) doc_name_map = doc_name_map_or_error if success else {} for r in retriever_result_list: doc_id = r.get("doc_id") doc_name = doc_name_map.get(doc_id) r["document_name"] = doc_name if r.get("metadata") is None: r["metadata"] = {} if doc_name: r["metadata"]["source"] = doc_name # 创建用于LLM的结果列表(使用修订后的切片,过滤废弃切片) llm_result_list = [] deprecated_chunks = [] # 记录废弃切片,用于后续存储到Redis for r in retriever_result_list: sid = self._normalize_slice_id(r.get("chunk_id")) revision_status = r.get("revision_status") # 过滤废弃切片,不传给LLM if revision_status == "0": deprecated_chunks.append(r) logger.info(f"过滤废弃切片 {sid},不传给LLM") continue # 处理 revision_status="1" 但 ref_slice_id 为空的情况 if revision_status == "1" and not r.get("ref_slice_id"): info = slice_info_map.get(sid) revision_text = info.get("revision_slice_text", "") if info else "" if not revision_text: revision_text = r.get("content", "") r["content"] = revision_text llm_result_list.append(r) logger.info(f"使用切片 {sid} 的 revision_slice_text(无映射关系)") continue # 如果这是一个有修订的原始切片,使用修订后的内容 if sid in original_to_revision_id: revision_id = original_to_revision_id[sid] revision_info = revision_slice_info_map.get(revision_id) # 获取当前切片的 revision_slice_text info = slice_info_map.get(sid) revision_text = info.get("revision_slice_text", "") if info else "" if not revision_text: revision_text = r.get("content", "") if revision_info: # 创建修订切片的副本用于LLM,使用当前切片的 revision_slice_text revised_chunk = { "chunk_id": revision_id, "content": revision_text, "doc_id": revision_info.get("document_id", r.get("doc_id")), "document_name": doc_name_map.get(revision_info.get("document_id"), r.get("document_name")), "section": revision_info.get("section"), "parent_section": revision_info.get("parent_section"), "revision_status": None, "ref_slice_id": None, "metadata": r.get("metadata", {}), "rerank_score": r.get("rerank_score", 1), "is_revised_version": True, "original_chunk_id": sid } llm_result_list.append(revised_chunk) logger.info(f"使用原始切片 {sid} 的 revision_slice_text + 映射切片 {revision_id} 的元数据") else: # 如果没有查询到修订切片信息,使用当前切片的 revision_slice_text r["content"] = revision_text llm_result_list.append(r) logger.warning(f"未找到映射切片 {revision_id},使用原始切片 {sid} 的 revision_slice_text") else: # 没有修订的切片直接使用 llm_result_list.append(r) logger.info(f"过滤后传给LLM的切片数: {len(llm_result_list)}, 废弃切片数: {len(deprecated_chunks)}") # 生成 LLM 拼接内容(使用修订后的切片) chunk_content_for_llm = self._build_llm_chunk_content(llm_result_list) return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list async def generate_event(self, chat_json, request): logger.info(f"rag聊天的请求参数:{chat_json}") start_time = time.time() query = chat_json.get("messages")[-1].get("content") text = query hits = tokenizer_sensitive.detect(text) if hits: # print(build_warning(hits)) yield {"id": "", "event": "add", "data": build_warning(hits)} return app_name = chat_json.get("name", "") label_prompt = """ 你是一个 RAG 系统的输入判定器,只做判断,不生成答案。 请严格判断用户输入是否属于以下三类之一: 1. 语义模糊(NEED_CLARIFY):输入没有明确的主语、对象或范围,或者类似“是什么”“怎么样”“基本规范”“规定是什么”等,无法确定查询目标。 2. 明显拼写或术语错误(NEED_CORRECT):输入包含错别字、无意义组合、逻辑混乱或非标准术语,例如“天气怎么阳”“明天什么暗拍”“低功耗贾张氏设计标准”。 3. 输入清晰(OK):输入语义明确、包含实体、对象、拼写和术语正确,可直接用于检索和回答。 要求严格输出标准 JSON,只输出 JSON,不生成任何解释性文字。 输出 JSON 格式: {{ "decision": "OK" 或 "NEED_CLARIFY" 或 "NEED_CORRECT" 或 "SENSITIVE", "reason": "一句话说明原因(可为空字符串)", "correct": "如果是NEED_CORRECT给我出纠正的内容,否则为空", "suggested_question": "如果需要反问或纠正,给出建议文本,否则为空字符串,如有此回答,尽量生动形象,增强用户体验", "summary": "如果有上下文,给出上下文摘要", "rewrite": "结合上下文重写用户问题,去除指代和模糊表达,使其可独立理解并适合知识检索" }} 注意事项: - 如果输入模糊或有错别字,必须返回对应类型并提供简短建议(suggested_question)。 - 所有字段必须存在。 - 输出必须严格遵守 JSON 格式。 - 不允许有多余文字、换行或 Markdown。 - 上下文重要性: - 如果用户问题较模糊但可依赖之前对话(如“之前问过xxx,接着问为什么?怎么做?”),判定为OK。 - 根据最近对话内容,将当前问题改写为简短、独立、明确、适合检索的查询,输出到 rewrite 字段。 - 如果有上下文,summary 字段应简明概括最近相关对话内容。 - 示例: 输入:“是什么?” → NEED_CLARIFY 输入:“都会各有各的、是什么?” → NEED_CLARIFY 输入:“天气怎么阳” → NEED_CORRECT 输入:“明天的温度是多少?” → OK 输入上下文: Q: 明天的温度是多少? A: 最高温度 22°C,最低温度 14°C 输入:“为什么?” → OK,rewrite: “为什么明天温度变化幅度大?” """ # - 通过上下文进行判断,假如上一个问题已经回答过,用户接着问“为什么?怎么做?”之类的同样判定为OK,并根据最近对话内容,将当前用户问题改写为一个不含指代词、对象明确、可独立理解、适合用于知识检索的查询,输出到rewrite字段。 # 用户输入: # "{user_input}" chat_json["query"] = query import copy label_messages = copy.deepcopy(chat_json.get("messages")) # label_prompt = label_prompt.format(user_input=query) # label_messages = chat_json.get("messages").copy() # label_messages[-1]["content"] = label_prompt label_messages.insert(0, {"role": "system","content": f"{label_prompt}"}) logger.info(f"追问识别的提示词:{label_messages}") completion = await self.vllm_client.generate_non_stream_async( prompt=label_messages, model="/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507" ) raw_text = completion.strip() logger.info(f"问题判定的结果:{raw_text}") # 尝试解析 JSON try: data = json.loads(raw_text) except json.JSONDecodeError: logger.error("判定器返回的文本不是有效 JSON") data = {"decision": "OK", "reason": "", "suggested_question": ""} decision = data.get("decision", "OK") summary = data.get("summary", "") rewrite = data.get("rewrite", "") if decision == "SENSITIVE": yield {"id": "", "event": "add", "data": data.get("sensitive")} return if decision == "NEED_CORRECT": new_query = data.get("correct", query) chat_json["query"] = new_query decision = "OK" if decision != "OK": # 使用app_name进行向量检索,生成反问 try: collection_name_list = chat_json.get("knowledgeIds", []) rag_embedding_name = chat_json.get("embeddingId") if collection_name_list and rag_embedding_name: # 使用原始query进行检索,获取相关上下文 search_results, _ = self.request_milvus( collection_name_list=collection_name_list, rag_embedding_name=rag_embedding_name, query=query, top=3, # 只检索3条相关内容 mode="hybrid" ) # 提取检索到的内容 context_content = "\n".join([r.get("content", "") for r in search_results[:3]]) if context_content: # 构建智能反问提示词 clarify_prompt = f"""你是一个RAG应用助手。用户的输入存在问题:{data.get('reason', '输入不够清晰')} 用户输入:{query} 知识库中检索到以下相关内容: {context_content} 请基于检索到的内容,用友好、引导性的语气向用户提出1-2个具体的问题,帮助用户明确他们的需求。要求: 1. 问题要具体、有针对性 2. 语气要友好、鼓励性 3. 可以提供一些选项供用户选择 4. 直接输出问题,不要有多余的解释 示例格式: "您可能想了解XXX方面的内容。请问您是想了解: 1. XXX的具体定义和标准? 2. XXX的应用场景和案例? 还是有其他具体问题呢?" """ rq_tokens_str = "" # 使用流式chat接口进行反问 model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507" full_response = "" logger.info(f"反问的提示词是:{clarify_prompt}") async for chunk in self.vllm_client.chat( prompt=clarify_prompt, model=model, temperature=0.7, top_p=0.9, max_tokens=512, stream=True, history=[] ): if chunk.get("event") == "add": # 累积完整响应 full_response += chunk.get("data", "") rq_tokens_str = full_response chunk["data"] = full_response # 直接yield当前chunk,保持流式输出 yield {"id": chunk.get("id", ""), "event": "add", "data": chunk.get("data")} elif chunk.get("event") == "finish": yield {"id": chunk.get("id", ""), "event": "finish", "data": f'{{"inputToken": {self.count_tokens(query)}, "outputToken": {self.count_tokens(rq_tokens_str)}, "timeConsuming": {time.time() - start_time}}}'} logger.info(f"反问生成完成,完整响应: {full_response}") return except Exception as e: logger.error(f"反问生成失败,使用默认回复: {e}") # 如果检索失败或没有配置知识库,使用原始的suggested_question reply = data.get("suggested_question", "抱歉,我没有理解您的问题。能否请您更详细地描述一下?") yield {"id": "", "event": "add", "data": reply} yield {"id": "", "event": "finish", "data": ""} return if decision == "OK": chat_id = "" SENTINEL = "⧗⧗LLM_STREAM_SENTINEL⧗⧗" try: output_token_str = "" tokens = self.count_tokens(chat_json.get("query") + chat_json.get("prompt")) logger.info(f"用户输入的token数:{tokens}") retriever_result_list, search_doc_id_to_knowledge_id_dict = await self.retriever_result(chat_json, rewrite) logger.info(f"向量库中获取的最终结果:{retriever_result_list}") retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list = self._apply_revision_and_doc_name( chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict, ) # 添加引用指令到提示词 citation_instruction = f"\n\n《额外指令》\n\n1、chunk_id仅供正文结束后使用,禁止在正文中以任何形式输出任何chunk_id\n\n2、在回答完成后,请按照以下格式列出你引用的所有chunk_id(必须输出,无引用cited_chunks可返回空字符串),格式为:{SENTINEL}\n[cited_chunks: id1, id2, id3]" chat_json["prompt"] = chat_json.get("prompt", "") + citation_instruction first = True sentinel_detected = False buffering = False buffer_content = "" clean_output = "" end_time = None async for event in self.generate_rag_response(chat_json, chunk_content_for_llm): chat_id = event.get("id") if first: end_time = time.time() first = False if await request.is_disconnected(): logger.info(f"chat id:{chat_id}连接中断") yield {"id": chat_id, "event": "interrupted", "data": ""} return if event.get("event") == "add": full_data = event.get("data", "") output_token_str = full_data # 如果已检测到完整哨兵线,收集后续内容用于解析chunk_ids if sentinel_detected: # re.IGNORECASE不区分大小写 cited_match = re.search(r'\[cited_chunks[::]\s*([^\]]+)\]', full_data, re.IGNORECASE) # if not cited_match: # sentinel_detected = False # continue if cited_match: cited_ids_str = cited_match.group(1) cited_chunk_ids = [cid.strip().upper() for cid in cited_ids_str.split(',')] logger.info(f"解析到LLM引用的chunk_ids: {cited_chunk_ids}") # 构建包含引用块的knowledge_info_dict,同时独立保存修订切片和原始切片 knowledge_info_dict = {} # 创建chunk_id映射,用于快速查找(使用小写作为key以避免大小写问题) llm_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in llm_result_list} original_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in retriever_result_list} # 记录日志,调试 logger.info(f"llm_result_list中的chunk_ids: {[r.get('chunk_id') for r in llm_result_list]}") logger.info(f"retriever_result_list中的chunk_ids: {[r.get('chunk_id') for r in retriever_result_list]}") for cited_id in cited_chunk_ids: cited_id_normalized = cited_id.strip().upper() # 在LLM结果列表中查找(可能是修订版本)- 使用小写匹配 llm_chunk = llm_chunk_map.get(cited_id_normalized.lower()) if not llm_chunk: logger.warning(f"未找到LLM引用的chunk_id: {cited_id_normalized}") sentinel_detected = False continue doc_id = llm_chunk.get("doc_id") knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id) # 判断这个切片是否是修订版本 is_revised_version = llm_chunk.get("is_revised_version", False) # 构建当前切片的chunk_info current_chunk_info = { "chunk_id": llm_chunk.get("chunk_id"), "content": llm_chunk.get("content", ""), "rerank_score": llm_chunk.get("rerank_score"), "chunk_len": len(llm_chunk.get("content", "")), "document_name": llm_chunk.get("document_name"), "revision_status": llm_chunk.get("revision_status"), "ref_slice_id": llm_chunk.get("ref_slice_id"), "section": llm_chunk.get("section"), "parent_section": llm_chunk.get("parent_section"), } # 初始化chunk_info_list chunk_info_list = [] # 如果这是修订版本,需要同时保存修订切片和原始切片 if is_revised_version: original_chunk_id = llm_chunk.get("original_chunk_id") # 使用小写查找原始切片 original_chunk = original_chunk_map.get(original_chunk_id.lower() if original_chunk_id else None) logger.info(f"查找原始切片: original_chunk_id={original_chunk_id}, found={original_chunk is not None}") if original_chunk: # 1. 先添加修订切片信息 current_chunk_info["slice_type"] = "revised" chunk_info_list.append(current_chunk_info) # 2. 再添加原始切片信息 original_chunk_info = { "chunk_id": original_chunk.get("chunk_id"), "content": original_chunk.get("content", ""), "rerank_score": original_chunk.get("rerank_score"), "chunk_len": len(original_chunk.get("content", "")), "document_name": original_chunk.get("document_name"), "revision_status": original_chunk.get("revision_status"), "ref_slice_id": original_chunk.get("ref_slice_id"), # 保留数据库原始字段,指向修订切片 "section": original_chunk.get("section"), "parent_section": original_chunk.get("parent_section"), "slice_type": "original" # 标记为原始切片 } # 添加 revised_to_chunk_id 和 chunk_revision_map 的内容(废弃),因为 ref_slice_id 已经包含了指向修订切片的信息 chunk_info_list.append(original_chunk_info) logger.info(f"独立保存修订切片 {llm_chunk.get('chunk_id')} 和原始切片 {original_chunk_id} 信息") else: # 没有找到原始切片,只保存修订切片 current_chunk_info["slice_type"] = "revised" chunk_info_list.append(current_chunk_info) logger.warning(f"未找到修订切片 {llm_chunk.get('chunk_id')} 对应的原始切片 {original_chunk_id}") else: # 非修订切片,检查是否有废弃切片需要一起存储 # 在retriever_result_list中查找是否有revision_status="0"的切片 chunk_in_retriever = original_chunk_map.get(cited_id_normalized.lower()) if chunk_in_retriever and chunk_in_retriever.get("revision_status") == "0": # 这是一个废弃切片,标记为deprecated current_chunk_info["slice_type"] = "deprecated" logger.info(f"存储废弃切片 {cited_id_normalized} 到Redis") else: # 普通切片 current_chunk_info["slice_type"] = "normal" chunk_info_list.append(current_chunk_info) doc_info = { "doc_id": doc_id, "doc_name": llm_chunk.get("document_name"), "chunk_nums": len(chunk_info_list), "chunk_info_list": chunk_info_list } if knowledge_id in knowledge_info_dict: existing_doc = None for doc in knowledge_info_dict[knowledge_id]: if doc["doc_id"] == doc_id: existing_doc = doc break if existing_doc: existing_doc["chunk_info_list"].extend(chunk_info_list) existing_doc["chunk_nums"] += len(chunk_info_list) else: knowledge_info_dict[knowledge_id].append(doc_info) else: knowledge_info_dict[knowledge_id] = [doc_info] # 额外处理:将所有废弃切片也存储到Redis(即使LLM没有引用) for r in retriever_result_list: if r.get("revision_status") == "0": deprecated_chunk_id = r.get("chunk_id") doc_id = r.get("doc_id") knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id) # 检查是否已经存储过这个废弃切片 already_stored = False if knowledge_id in knowledge_info_dict: for doc in knowledge_info_dict[knowledge_id]: if doc["doc_id"] == doc_id: for chunk_info in doc["chunk_info_list"]: if chunk_info["chunk_id"] == deprecated_chunk_id: already_stored = True break if already_stored: break if not already_stored: # 构建废弃切片信息 deprecated_chunk_info = { "chunk_id": deprecated_chunk_id, "content": r.get("content", ""), "rerank_score": r.get("rerank_score"), "chunk_len": len(r.get("content", "")), "document_name": r.get("document_name"), "revision_status": "0", # 明确标记为废弃 "ref_slice_id": r.get("ref_slice_id"), "section": r.get("section"), "parent_section": r.get("parent_section"), "slice_type": "deprecated" } doc_info = { "doc_id": doc_id, "doc_name": r.get("document_name"), "chunk_nums": 1, "chunk_info_list": [deprecated_chunk_info] } if knowledge_id in knowledge_info_dict: existing_doc = None for doc in knowledge_info_dict[knowledge_id]: if doc["doc_id"] == doc_id: existing_doc = doc break if existing_doc: existing_doc["chunk_info_list"].append(deprecated_chunk_info) existing_doc["chunk_nums"] += 1 else: knowledge_info_dict[knowledge_id].append(doc_info) else: knowledge_info_dict[knowledge_id] = [doc_info] logger.info(f"额外存储废弃切片 {deprecated_chunk_id} 到Redis") # 存储到Redis self.redis_client.set(chat_id, json.dumps(knowledge_info_dict)) logger.info(f"存储引用的块到Redis: {len(cited_chunk_ids)}个(包含修订、原始和废弃切片信息)") # 返回finish事件 output_token = self.count_tokens(clean_output) time_out = end_time - start_time if end_time else 0 finish_data = { "inputToken": tokens, "outputToken": output_token, "timeConsuming": time_out } # yield {"id": chat_id, "event": "finish", "data": '你猜!'} logger.info(f"111111111111111111111111111:{event}") event["event"] = "finish" event["data"] = finish_data # yield event send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'} yield send # yield {"id": chat_id, "event": "finish", "inputToken": tokens, "outputToken": output_token, "timeConsuming": time_out} logger.info(f"{send}") return # 直接返回,不再继续处理后续事件 continue # 检查是否包含完整哨兵线 if SENTINEL in full_data: sentinel_pos = full_data.find(SENTINEL) sentinel_detected = True clean_output = full_data[:sentinel_pos] logger.info(f"检测到完整哨兵线") continue # 检查是否包含哨兵线的第一个字符 '⧗' if not buffering and '⧗' in full_data: first_char_pos = full_data.find('⧗') # 发送 '⧗' 之前的内容 if first_char_pos > 0: event["data"] = full_data[:first_char_pos] yield event buffering = True buffer_content = full_data[first_char_pos:] clean_output = full_data[:first_char_pos] logger.info(f"检测到哨兵线第一个字符,开始缓冲") continue # 如果正在缓冲 if buffering: buffer_content = full_data[len(clean_output):] if SENTINEL.startswith(buffer_content): # 可能是哨兵线的一部分,继续缓冲 continue elif event.get("event") != "finish": # 不是哨兵线,发送缓冲的内容 buffering = False yield event elif event.get("event") != "finish": # 正常返回 yield event elif event.get("event") == "finish": # 如果已经通过哨兵线发送过finish事件,跳过这个finish事件 if sentinel_detected: logger.info("哨兵线已处理,跳过LLM的finish事件") continue # 如果LLM提前结束但没有哨兵线,使用所有块 _chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict) self.redis_client.set(chat_id, json.dumps(knowledge_info_dict)) logger.warning("LLM未输出哨兵线,存储所有检索块") # 发送finish事件(LLM未输出哨兵线的情况) output_token = self.count_tokens(output_token_str) time_out = end_time - start_time if end_time else 0 finish_data = { "inputToken": tokens, "outputToken": output_token, "timeConsuming": time_out } event["data"] = finish_data send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'} yield send logger.info(f"LLM未输出哨兵线,发送finish事件: {send}") elif event.get("event") != "finish": yield event except Exception as e: logger.error(f"执行出错:{e}") yield {"id": chat_id, "event": "finish", "data": ""} return async def query_redis_chunk_by_chat_id(self, chat_id): chunk_redis_str = self.redis_client.get(chat_id) chunk_json = json.loads(chunk_redis_str) logger.info(f"redis中存储的chunk 信息:{chunk_json}") #knowledge_id = chunk_json.get("knowledge_id") knowledge_id = list(chunk_json.keys())[0] if chunk_json else None chunk_list = [] #for chunk in chunk_json["doc"]: # doc_chunk_list = chunk.get("chunk_info_list") # chunk_list.extend(doc_chunk_list) for knowledge_id_key, doc_list in chunk_json.items(): for doc in doc_list: doc_chunk_list = doc.get("chunk_info_list") if doc_chunk_list: chunk_list.extend(doc_chunk_list) if len(chunk_list) > 3: random_chunk_list = random.sample(chunk_list, 3) else: random_chunk_list = chunk_list chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list] return knowledge_id, chunk_id_list async def generate_relevant_query(self, query_json): chat_id = query_json.get("chat_id") if chat_id: knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id) chunk_content_list = MilvusOperate(collection_name=knowledge_id, embedding_name=query_json.get("embedding_id"))._search_by_chunk_id_list(chunk_id_list) chunk_content = "\n".join(chunk_content_list) # generate_query_prompt = "请根据下面提供的信息,帮我生成" query_messages = query_json.get("messages") user_dict = query_messages.pop() messages = [{"role": "assistant", "content": chunk_content}, user_dict] else: messages = query_json.get("messages") #model = query_json.get("model") model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507" query_result = self.vllm_client.chat(model=model, stream=False, history=messages) async for result in query_result: # result_json = json.loads(result) # logger.info(f"生成的问题:{result_json}") # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip() result = result.get("data", "").strip() logger.info(f"模型生成的问题:{result}") try: if "```json" in result: json_pattern = r'```json\s(.*?)```' matches = re.findall(json_pattern, result, re.DOTALL) result = matches[0] query_json = json.loads(result) except Exception as e: query_json = eval(result) query_list = query_json.get("问题") return {"code": 200, "data": query_list} async def search_slice(self): try: chunk_redis_str = self.redis_client.get(self.chat_id) chunk_json = json.loads(chunk_redis_str) chunk_json["code"] = 200 except Exception as e: logger.error(f"查询redis报错:{e}") chunk_json = { "code": 500, "message": str(e) } return chunk_json