| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691 |
- from rag.db import MilvusOperate, MysqlOperate
- from rag.load_model import *
- from rag.llm import VllmApi
- from config import redis_config_dict
- import json
- import re
- import gc
- import redis
- from utils.get_logger import setup_logger
- import random
- import concurrent.futures
- from concurrent.futures import ThreadPoolExecutor, as_completed
- import time
- import httpx
- # 停用词
- from rag.sensitive_words_detection import TxtSensitiveWordStore, SecurityTokenizer, build_warning
- # LightRAG
- from lightRAG import AsyncLightRAGManager
- store = TxtSensitiveWordStore("./rag/sensitive_words/")
- tokenizer_sensitive = SecurityTokenizer(store)
- # from rag.load_model import tokenizer
- logger = setup_logger(__name__)
- # rerank_model_mapping = {
- # "bge-reranker-v2-m3": (bce_rerank_tokenizer, bce_rerank_base_model),
- # "rerank": (bce_rerank_tokenizer, bce_rerank_base_model),
- # "Qwen3-Reranker-0.6B": (qwen_rerank_tokenizer, qwen_rerank_base_model)
- # }
- rerank_model_mapping_vllm = {
- "bge-reranker-v2-m3": (rerank_bge_url, rerank_bge_model),
- "Qwen3-Reranker-0.6B": (rerank_qwen_url, rerank_qwen_model)
- }
- redis_host = redis_config_dict.get("host")
- redis_port = redis_config_dict.get("port")
- redis_db = redis_config_dict.get("db")
- class ChatRetrieverRag:
- def __init__(self, chat_json: dict = None, chat_id: str=None):
- self.chat_id = chat_id
- self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
- if not chat_id:
- self.vllm_client = VllmApi(chat_json)
-
- def count_tokens(self,text):
- tokens = tokenizer.tokenize(text)
- return len(tokens)
- # 手动rerank
- # def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
- # rerank_list = []
- # for result in hybrid_search_result:
- # rerank_list.append([query, result["content"]])
- # tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
- # # 重排序
- # rerank_model.eval()
- # inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
- # inputs = {k: v.to(device) for k, v in inputs.items()}
- # with torch.no_grad():
- # logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
- # logits = logits.detach().cpu()
- # scores = torch.sigmoid(logits).tolist()
- # # scores_list = scores.tolist()
- # logger.info(f"重排序的得分:{scores}")
- # # sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
- # threshold = 0.3
- # sorted_pairs = sorted( ((s, r) for s, r in zip(scores, hybrid_search_result) if s >= threshold), key=lambda x: x[0], reverse=True )
- # if not sorted_pairs:
- # sorted_scores = []
- # sorted_search = []
- # else:
- # sorted_scores, sorted_search = zip(*sorted_pairs)
- # logger.info(f"过滤后的{sorted_scores}")
- # # sorted_scores, sorted_search = zip(*sorted_pairs)
- # sorted_scores_list = list(sorted_scores)
- # sorted_search_list = list(sorted_search)
- # for score, search in zip(sorted_scores_list, sorted_search_list):
- # search["rerank_score"] = score
- # del inputs, logits
- # gc.collect()
- # torch.cuda.empty_cache()
- # search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list)))
- # return search_result
- # rerank(url)
- async def rerank_result_async(self, query, top_k, hybrid_search_result, rerank_embedding_name):
- """
- Args:
- query (str): 查询文本
- top_k (int): 最终返回结果数量
- hybrid_search_result (list[dict]): 原始搜索结果列表
- rerank_embedding_name (str): rerank 模型
-
- Returns:
- list[dict]: 添加了 "rerank_score" 字段的结果
- """
- # 查模型信息
- rerank_url, rerank_model = rerank_model_mapping_vllm.get(rerank_embedding_name)
-
- if not hybrid_search_result:
- return []
- # 构建候选文档列表
- documents = [r["content"] for r in hybrid_search_result]
- payload = {
- "model": rerank_model,
- "query": query,
- "documents": documents
- }
- service_url = f"{rerank_url}/v1/rerank"
- # 使用异步 HTTP 请求
- try:
- async with httpx.AsyncClient(timeout=60.0) as client:
- resp = await client.post(service_url, json=payload)
- resp.raise_for_status()
- result = resp.json()
- except Exception as e:
- logger.error(f"调用 rerank 服务失败: {e}")
- # 出错 fallback
- scores = [0.0] * len(documents)
- else:
- # 解析分数
- if "results" in result:
- scores = [item.get("relevance_score", 0.0) for item in result["results"]]
- elif "data" in result:
- scores = [item.get("score", 0.0) for item in result["data"]]
- else:
- logger.warning(f"未知 rerank 返回格式: {result}")
- scores = [0.0] * len(documents)
- logger.info(f"重排序的得分:{scores}")
- # 构建分数与搜索结果列表
- sorted_pairs = list(zip(scores, hybrid_search_result))
- # # 阈值过滤
- # threshold = 0.2
- # filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold]
- # if not filtered_pairs:
- # filtered_pairs = sorted_pairs
- # logger.info(f"过滤后的{filtered_pairs}")
- # # 按分数降序排序
- # filtered_pairs.sort(key=lambda x: x[0], reverse=True)
- threshold = 0.3
- # sorted_pairs = sorted(
- # ((s, r) for s, r in sorted_pairs if s >= threshold) or sorted_pairs,
- # key=lambda x: x[0],
- # reverse=True
- # )
- filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold]
- if not filtered_pairs:
- filtered_pairs = sorted_pairs
- sorted_pairs = sorted(
- filtered_pairs,
- key=lambda x: x[0],
- reverse=True
- )
- if not sorted_pairs:
- sorted_scores_list = []
- sorted_search_list = []
- else:
- sorted_scores_list, sorted_search_list = zip(*sorted_pairs)
- logger.info(f"过滤后的{sorted_scores_list}")
- # 写入 rerank_score
- # sorted_scores_list, sorted_search_list = zip(*filtered_pairs)
- sorted_scores_list, sorted_search_list = list(sorted_scores_list), list(sorted_search_list)
- for score, search in zip(sorted_scores_list, sorted_search_list):
- search["rerank_score"] = score
- # 清理 GPU 内存
- gc.collect()
- torch.cuda.empty_cache()
- # 随机采样 top_k
- search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list)))
- # logger.info(f"query='{query}' rerank结果: {[r['rerank_score'] for r in search_result]}")
- return search_result
-
- async def segment_query_with_llm(self, query, max_retries=3):
- """
- 使用LLM对长问题进行切分,提取多个子问题
- """
- segment_prompt = f"""
- 你是一个专业的问题分析助手。请将下面的复杂问题分解为3-5个具体的子问题,每个子问题都应该是独立的、可检索的。
- 原问题:{query}
- 请按照以下JSON格式返回:
- {{
- "sub_questions": [
- "子问题1",
- "子问题2",
- "子问题3"
- ]
- }}
- 要求:
- 1. 每个子问题都要具体明确
- 2. 子问题之间相互独立
- 3. 覆盖原问题的主要方面
- 4. 适合向量检索
- """
-
- for attempt in range(max_retries):
- try:
- logger.info(f"第{attempt + 1}次尝试问题切分")
-
- # 构造请求参数
- messages = [{"role": "user", "content": segment_prompt}]
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
-
- # 调用LLM
- response_text = ""
- async for chunk in self.vllm_client.chat(
- prompt=segment_prompt,
- model=model,
- temperature=0.1,
- top_p=0.9,
- max_tokens=1024,
- stream=True,
- history=[]
- ):
- # print(chunk)
- if chunk.get("event") == "add":
- response_text += chunk.get("data", "")
-
- logger.info(f"LLM返回的切分结果:{response_text}")
-
- # 解析JSON响应
- import json
- if "```json" in response_text:
- json_pattern = r'```json\s*(.*?)```'
- matches = re.findall(json_pattern, response_text, re.DOTALL)
- if matches:
- response_text = matches[0]
-
- result = json.loads(response_text.strip())
- sub_questions = result.get("sub_questions", [])
-
- if sub_questions and len(sub_questions) > 0:
- logger.info(f"成功切分为{len(sub_questions)}个子问题:{sub_questions}")
- return sub_questions
- else:
- logger.warning("LLM返回的子问题列表为空")
-
- except Exception as e:
- logger.error(f"第{attempt + 1}次问题切分失败:{str(e)}")
- if attempt == max_retries - 1:
- logger.error("问题切分达到最大重试次数,使用原问题")
- return [query]
-
- return [query]
- def request_milvus_multi(self, collection_name_list, querys, rag_embedding_name, top, mode):
- """
- 支持多query的向量检索方法
- """
- results = []
- doc_id_to_collection = {}
- def query_collection(collection_name, query):
- return collection_name, query, MilvusOperate(
- collection_name=collection_name,
- embedding_name=rag_embedding_name
- )._search(query, k=top, mode=mode)
- tasks = [(c, q) for c in collection_name_list for q in querys]
- max_workers = min(20, len(tasks))
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_task = {
- executor.submit(query_collection, c, q): (c, q) for c, q in tasks
- }
- for future in as_completed(future_to_task):
- collection_name, query = future_to_task[future]
- try:
- _, _, search_result = future.result()
- for doc in search_result:
- doc_id = doc["doc_id"]
- if doc_id not in doc_id_to_collection:
- doc_id_to_collection[doc_id] = collection_name
- results.extend(search_result)
- except Exception as e:
- logger.error(f"查询集合 {collection_name} 对 query {query} 出错: {e}")
- return results, doc_id_to_collection
- def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode):
- search_multi_collection_list = []
- search_doc_id_to_knowledge_id_dict = {}
- # 多线程查询
- def query_collection(collection_name):
- return collection_name, MilvusOperate(
- collection_name=collection_name,
- embedding_name=rag_embedding_name
- )._search(query, k=top, mode=mode)
-
- max_workers = min(10, len(collection_name_list))
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_name = {
- executor.submit(query_collection, name): name
- for name in collection_name_list
- }
-
- for future in concurrent.futures.as_completed(future_to_name):
- collection_name = future_to_name[future]
- try:
- _, hybrid_search_result = future.result()
-
- for result in hybrid_search_result:
- doc_id = result.get("doc_id")
-
- if doc_id not in search_doc_id_to_knowledge_id_dict:
- search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
-
- search_multi_collection_list.extend(hybrid_search_result)
-
- except Exception as e:
- logger.error(f"查询集合 {collection_name} 时出错: {str(e)}")
- continue
- return search_multi_collection_list, search_doc_id_to_knowledge_id_dict
- async def retriever_result(self, chat_json, rewrite = None):
- mode_search_mode = {
- "embedding": "dense",
- "keyword": "sparse",
- "mixed": "hybrid",
- }
- collection_name_list = chat_json.get("knowledgeIds")
- top = chat_json.get("sliceCount", 15)
- retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
- rag_embedding_name = chat_json.get("embeddingId")
- mode_rag = retriever_info.get("recall_method")
- mode = mode_search_mode.get(mode_rag, "hybrid")
- rerank_embedding_name = retriever_info.get("rerank_model_name")
- rerank_status = retriever_info.get("rerank_status")
- if rewrite:
- chat_json["query"] = rewrite
- query = chat_json.get("query")
- query_list = [query]
-
- # search_multi_collection_list = []
- # search_doc_id_to_knowledge_id_dict = {}
- # for collection_name in collection_name_list:
- # hybrid_search_result = MilvusOperate(
- # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
- # for result in hybrid_search_result:
- # doc_id = result.get("doc_id")
- # if doc_id not in search_doc_id_to_knowledge_id_dict:
- # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
- # search_multi_collection_list.extend(hybrid_search_result)
-
- """
- search_multi_collection_list信息格式:
- {
- "doc_id": "07917ce4-ce63-11f0-ad4f-ecd68ae97d92",
- "chunk_id": "7a4aba3e-ce63-11f0-8a8d-ecd68ae97d92",
- "content": "## 3.13 高处作业 24\n",
- "Father_Chapter": "3->3.6",
- "Chapter": "3->3.6->3.6.2",
- "metadata": {
- "source": "超长ceshi.pdf",
- "chunk_index": 14,
- "chunk_len": 16
- },
- "score": 0.016393441706895828,
- "rerank_score": 1
- }
- """
- start_time = time.time()
- two_time = -1
- # if rewrite:
- # query = rewrite
- # query_list = [rewrite]
- # 1. 检查是否需要问题切分
- problem_segmentation = chat_json.get("Problem_segmentation", False)
-
- if problem_segmentation:
- logger.info("启用问题切分模式")
- # 使用LLM切分问题
- query_list = await self.segment_query_with_llm(query)
- logger.info(f"切分后的问题列表:{query_list}")
- # [query + "\n" + summary if summary else query for query in query_list]
-
- # 多轮检索
- # if not query_list:
- # query_list = [query + "\n" + summary] if summary else [query]
-
- logger.info(f"query_list:{query_list}")
- search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus_multi(
- collection_name_list, query_list, rag_embedding_name, top, mode
- )
- # # 单次检索
- # query_list = query
- # search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus(
- # collection_name_list, rag_embedding_name, query, top, mode
- # )
- logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}")
- if len(search_multi_collection_list) <= 0:
- return [], search_doc_id_to_knowledge_id_dict
- # 2. 重排序
- if rerank_status:
- # rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name)
- rerank_result_list = await self.rerank_result_async(query, top, search_multi_collection_list, rerank_embedding_name)
- else:
- for result in search_multi_collection_list:
- result["rerank_score"] = 1
- rerank_result_list = random.sample(search_multi_collection_list, min(top, len(search_multi_collection_list)))
- # 3. 收集文档ID,查询MySQL判断是否需要章节扩展
- doc_ids = list({r["doc_id"] for r in rerank_result_list})
- # 收集知识库ID,用于知识图谱检索
- knowledge_ids = list({r["metadata"].get("knowledge_id", "") for r in rerank_result_list})
- logger.info(f"检索到的知识库id{knowledge_ids}")
- mysql_client = MysqlOperate()
- enabled_doc_ids = mysql_client.query_parent_generation_enabled(doc_ids)
- enabled_kn_rs = mysql_client.query_knowledge_by_ids(knowledge_ids)
- enabled_kn_gp_ids = set()
- for r in enabled_kn_rs:
- if r["knowledge_graph"]:
- enabled_kn_gp_ids.add(r["knowledge_id"])
- if enabled_kn_gp_ids:
- gp_add_list = []
- enabled_kn_gp_ids = list(enabled_kn_gp_ids)
- gp_time = time.time()
- logger.info(f"需要查询知识图谱的知识库:{enabled_kn_gp_ids}")
- # if chat_json.get("lightrag"):
- for kn_id in enabled_kn_gp_ids:
- result = await AsyncLightRAGManager().retrieve(label=kn_id, query=query)
- chunks_add = result["data"].get("chunks", [])
- if chunks_add:
- # chunk_dict = {}
- for chunk_add in chunks_add:
- # chunk_dict["content"] = chunk_add
- gp_add_list.append({
- "doc_id": None,
- "chunk_id": None,
- "content": chunk_add["content"],
- "metadata": {},
- "rerank_score": 1,
- "revision_status": None,
- "section": None,
- "parent_section": None,
- })
- logger.info(f"知识图谱扩展的内容:{gp_add_list}")
- rerank_result_list.extend(gp_add_list)
-
- if enabled_doc_ids:
- tmp_time = time.time()
- logger.info(f"需要章节扩展的文档: {enabled_doc_ids}")
- existing_chunk_ids = {r["chunk_id"] for r in rerank_result_list}
- # 收集需要扩展的 (collection_name, doc_id, Father_Chapter) 组合-集合名称-文档id-父级标题
- father_chapter_set = set()
- for result in rerank_result_list:
- doc_id = result.get("doc_id")
- if doc_id not in enabled_doc_ids:
- continue
- father_chapter = result.get("Father_Chapter")
- collection_name = search_doc_id_to_knowledge_id_dict.get(doc_id)
- if father_chapter and collection_name:
- father_chapter_set.add((collection_name, doc_id, father_chapter))
-
- # 4. 二次检索并追加结果
- for collection_name, doc_id, father_chapter in father_chapter_set:
- chapter_results = MilvusOperate(
- collection_name=collection_name, embedding_name=rag_embedding_name
- )._query_by_scalar_field(doc_id, "Father_Chapter", father_chapter)
- for r in chapter_results:
- if r["chunk_id"] not in existing_chunk_ids:
- r["rerank_score"] = 1
- rerank_result_list.append(r)
- existing_chunk_ids.add(r["chunk_id"])
- two_time = time.time() - tmp_time
- logger.info(f"父子章节扩展后的结果数: {len(rerank_result_list)}")
- up_time = time.time() - start_time
- two_status = "未启动父子召回" if two_time < 0 else f"启动父子召回耗时:{two_time}"
- logger.info(f"检索耗时:{up_time}({two_status})")
- return rerank_result_list, search_doc_id_to_knowledge_id_dict
- async def generate_rag_response(self, chat_json, chunk_content):
- # logger.info(f"rag聊天的请求参数:{chat_json}")
- # retriever_result_list = self.retriever_result(chat_json)
- # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- prompt = chat_json.get("prompt")
- query = chat_json.get("query")
- stream = chat_json.get("stream", True)
- # chunk_content = ""
- # for retriever in retriever_result_list:
- # chunk_content += retriever["content"]
- prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
- logger.info(f"请求的提示词:{prompt}")
- temperature = float(chat_json.get("temperature", 0.6))
- top_p = float(chat_json.get("topP", 0.7))
- max_token = chat_json.get("maxToken", 4096)
- enable_think = chat_json.get("enable_think", False)
-
- history = chat_json.get("messages", [])
- if history:
- history[-1]["content"] = prompt
- # 调用模型获取返回的结果
- # model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- chat_resp = ""
- async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=stream, history=history, enable_think=enable_think):
- logger.info(f"chat_message中接收到的LLM信息{chunk}")
- if stream:
- if chunk.get("event") == "add":
- chat_resp += chunk.get("data")
- chunk["data"] = chat_resp
- yield chunk
- else:
- yield chunk
- else:
- chunk
-
- def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict):
- doc_info = {}
- chunk_content = ""
- # 组织成每个doc对应的json格式
- for retriever in retriever_result_list:
- chunk_text = retriever["content"]
- chunk_content += chunk_text
- doc_id = retriever["doc_id"]
- chunk_id = retriever["chunk_id"]
- rerank_score = retriever["rerank_score"]
- doc_name = retriever.get("metadata").get("source","")
- if doc_id in doc_info:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text),
- "document_name": retriever.get("document_name"),
- "revision_status": retriever.get("revision_status"),
- "ref_slice_id": retriever.get("ref_slice_id"),
- "revision_group_id": retriever.get("revision_group_id"),
- "section": retriever.get("section"),
- "parent_section": retriever.get("parent_section"),
- }
- doc_info[doc_id]["chunk_list"].append(c_d)
- # doc_info[doc_id]["rerank_score"].append(rerank_score)
- else:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text),
- "document_name": retriever.get("document_name"),
- "revision_status": retriever.get("revision_status"),
- "ref_slice_id": retriever.get("ref_slice_id"),
- "revision_group_id": retriever.get("revision_group_id"),
- "section": retriever.get("section"),
- "parent_section": retriever.get("parent_section"),
- }
- doc_info[doc_id] = {
- "doc_name": doc_name,
- "chunk_list": [c_d],
- }
- # doc_list = []
- knowledge_info_dict = {}
- """
- {
- "knowledge_id": [{}, {}],
- "knowledge_id"
- }
- """
- for k, v in doc_info.items():
- d = {}
- knowledge_id = search_doc_id_to_knowledge_id_dict.get(k)
- d["doc_id"] = k
- d["doc_name"] = v.get("doc_name")
- d["chunk_nums"] = len(v.get("chunk_list"))
- d["chunk_info_list"] = v.get("chunk_list")
- # d["chunk_len"] = len(v.get("chunk_id_list"))
- if knowledge_id in knowledge_info_dict:
- knowledge_info_dict[knowledge_id].append(d)
- else:
- knowledge_info_dict[knowledge_id] = [d]
- # doc_list.append(d)
- return chunk_content, knowledge_info_dict
- def _inject_revision_rules_to_prompt(self, prompt: str, rules_text: str):
- if not rules_text:
- return prompt
- if not prompt:
- return rules_text
- # insert_anchor = "用户输入问题"
- # if insert_anchor in prompt:
- # idx = prompt.find(insert_anchor)
- # return prompt[:idx] + rules_text + "\n\n" + prompt[idx:]
- # if "{用户}" in prompt:
- # return prompt.replace("{用户}", rules_text + "\n\n{用户}", 1)
- return prompt + "\n\n" + rules_text
- def _normalize_revision_status(self, value):
- return value
- # if value is None:
- # return None
- # if isinstance(value, bool):
- # return int(value)
- # if isinstance(value, int):
- # return value
- # try:
- # s = str(value).strip()
- # if s == "":
- # return None
- # return int(s)
- # except Exception:
- # return None
- def _normalize_slice_id(self, value):
- # if value is None:
- # return None
- # s = str(value).strip()
- # if s == "":
- # return None
- # if s.isdigit() and len(s) < 20:
- # return s.zfill(20)
- return value
- def _format_slice_block(self, doc_name: str, version_tag: str, section: str, parent_section: str, text: str, extra_title_suffix: str = ""):
- title = (doc_name or "未知文档")
- if version_tag:
- title += f"({version_tag})"
- if extra_title_suffix:
- title += extra_title_suffix
- lines = [f"> {title}"]
- if parent_section:
- lines.append(f"{parent_section}")
- if section and section != parent_section:
- lines.append(f"{section}" if parent_section else f"{section}")
- if text:
- lines.append(str(text).strip())
- return "\n".join(lines).strip() + "\n\n"
- # def _build_llm_chunk_content(self, retriever_result_list):
- # slice_map = {}
- # for r in retriever_result_list:
- # sid = self._normalize_slice_id(r.get("chunk_id"))
- # if sid:
- # # def _build_llm_chunk_content(self, retriever_result_list):
- # # slice_map = {}
- # # for r in retriever_result_list:
- # # sid = self._normalize_slice_id(r.get("chunk_id"))
- # # if sid:
- # # slice_map[sid] = r
- # original_to_revision_ids = {}
- # revision_ids = set()
- # for r in retriever_result_list:
- # # if self._normalize_revision_status(r.get("revision_status")) == 1 and r.get("ref_slice_id"):
- # if r.get("revision_status") == "1" and r.get("ref_slice_id"):
- # original_to_revision_ids[r.get("chunk_id")] = r.get("ref_slice_id")
- # revision_ids.add(r.get("ref_slice_id"))
- # if sid:
- # slice_map[sid] = r
- # visited = set()
- # blocks = []
- # for r in retriever_result_list:
- # sid = self._normalize_slice_id(r.get("chunk_id"))
- # if not sid or sid in visited:
- # continue
-
- # revision_status = r.get("revision_status")
- # ref_slice_id = self._normalize_slice_id(r.get("ref_slice_id"))
- # if revision_status == "1" and ref_slice_id:
- # original = r
- # revision = slice_map.get(ref_slice_id)
- # blocks.append(
- # self._format_slice_block(
- # original.get("document_name"),
- # "原始版",
- # original.get("section"),
- # original.get("parent_section"),
- # original.get("content"),
- # extra_title_suffix=" [已修订]",
- # )
- # )
- # visited.add(sid)
- # revision_sid = self._normalize_slice_id(revision.get("chunk_id")) if revision else None
- # if revision and revision_sid and revision_sid not in visited:
- # blocks.append(
- # self._format_slice_block(
- # revision.get("document_name"),
- # "修订版",
- # revision.get("section"),
- # revision.get("parent_section"),
- # revision.get("content"),
- # )
- # )
- # visited.add(revision_sid)
- # continue
- # if revision_status == "0":
- # deprecated_text = (r.get("content") or "")
- # deprecated_text = deprecated_text + "\n\n【已废弃】此切片已弃用,请优先参考修订版或最新文档。"
- # blocks.append(
- # self._format_slice_block(
- # r.get("document_name"),
- # "已废弃",
- # r.get("section"),
- # r.get("parent_section"),
- # deprecated_text,
- # )
- # )
- # visited.add(sid)
- # continue
- # if sid in revision_ids:
- # continue
- # blocks.append(
- # self._format_slice_block(
- # r.get("document_name"),
- # "",
- # r.get("section"),
- # r.get("parent_section"),
- # r.get("content"),
- # )
- # )
- # visited.add(sid)
- # return "".join(blocks).strip()
- def _build_llm_chunk_content(self, retriever_result_list):
- """
- 构建 LLM 用的切片内容,仅发送最初检索到的向量数据库内容块,每个块添加 chunk_id。
- """
- logger.info(f"_build_llm_chunk_content 接收到的切片数量: {len(retriever_result_list)}")
-
- blocks = []
- for r in retriever_result_list:
- chunk_id = r.get("chunk_id")
- content = r.get("content", "")
- doc_name = r.get("document_name") or r.get("metadata", {}).get("source", "未知文档")
-
- # 格式:[chunk_id: xxx] 文档名\n内容
- block = f"[chunk_id: {chunk_id}]\n> {doc_name}\n{content}\n\n"
- blocks.append(block)
-
- return "".join(blocks).strip()
- # def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict):
- # mysql_client = MysqlOperate()
- # slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")]
- # slice_ids = [s for s in slice_ids if s]
- # slice_rows = []
- # success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids)
- # if success:
- # slice_rows = rows_or_error
- # slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")}
- # logger.info(slice_rows)
- # has_revision_or_deprecated = False
- # revision_slice_ids = set()
- # original_to_revision_id = {}
- # for r in retriever_result_list:
- # sid = self._normalize_slice_id(r.get("chunk_id"))
- # info = slice_info_map.get(sid)
- # if info:
- # normalized_revision_status = self._normalize_revision_status(info.get("revision_status"))
- # r["revision_status"] = normalized_revision_status
- # r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id"))
- # r["section"] = info.get("section")
- # r["parent_section"] = info.get("parent_section")
- # if normalized_revision_status in ("0", "1"):
- # has_revision_or_deprecated = True
- # logger.info("修订切片")
- # else:
- # r.setdefault("revision_status", None)
- # r.setdefault("ref_slice_id", None)
- # r.setdefault("section", None)
- # r.setdefault("parent_section", None)
- # ref_id = self._normalize_slice_id(r.get("ref_slice_id"))
- # sid = self._normalize_slice_id(r.get("chunk_id"))
- # if r.get("revision_status") == "1" and ref_id and sid:
- # original_to_revision_id[sid] = ref_id
- # revision_slice_ids.add(ref_id)
- # r["revision_group_id"] = sid
- # else:
- # r.setdefault("revision_group_id", None)
- # ref_slice_rows = []
- # if revision_slice_ids:
- # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
- # if success:
- # ref_slice_rows = ref_rows_or_error
- # revision_id_to_original_id = {}
- # for original_id, revision_id in original_to_revision_id.items():
- # if revision_id:
- # revision_id_to_original_id[revision_id] = original_id
- # for r in retriever_result_list:
- # original_id = revision_id_to_original_id.get(r.get("chunk_id"))
- # if original_id:
- # r["revision_group_id"] = original_id
- # r["revision_of_slice_id"] = original_id
- # else:
- # r.setdefault("revision_of_slice_id", None)
- # for row in (slice_rows + ref_slice_rows):
- # doc_id = row.get("document_id")
- # knowledge_id = row.get("knowledge_id")
- # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict:
- # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id
- # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
- # if success:
- # ref_slice_rows = ref_rows_or_error
- # revision_id_to_original_id = {}
- # for original_id, revision_id in original_to_revision_id.items():
- # if revision_id:
- # revision_id_to_original_id[revision_id] = original_id
- # for r in retriever_result_list:
- # original_id = revision_id_to_original_id.get(r.get("chunk_id"))
- # if original_id:
- # r["revision_group_id"] = original_id
- # r["revision_of_slice_id"] = original_id
- # else:
- # r.setdefault("revision_of_slice_id", None)
- # for row in (slice_rows + ref_slice_rows):
- # doc_id = row.get("document_id")
- # knowledge_id = row.get("knowledge_id")
- # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict:
- # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id
- # existing_slice_ids = {r.get("chunk_id") for r in retriever_result_list if r.get("chunk_id")}
- # existing_slice_ids = {self._normalize_slice_id(s) for s in existing_slice_ids}
- # existing_slice_ids = {s for s in existing_slice_ids if s}
- # ref_map = {r.get("slice_id"): r for r in ref_slice_rows if r.get("slice_id")}
- # for revision_id in revision_slice_ids:
- # if revision_id in existing_slice_ids:
- # continue
- # info = ref_map.get(revision_id)
- # if not info:
- # continue
- # doc_id = info.get("document_id")
- # text = info.get("slice_text") or ""
- # metadata = {"source": "", "chunk_index": None, "chunk_len": len(text)}
- # original_id = revision_id_to_original_id.get(revision_id)
- # retriever_result_list.append(
- # {
- # "doc_id": doc_id,
- # "chunk_id": revision_id,
- # "content": text,
- # "Father_Chapter": "",
- # "Chapter": "",
- # "metadata": metadata,
- # "score": 0,
- # "rerank_score": 1,
- # "revision_status": None,
- # "ref_slice_id": None,
- # "revision_group_id": original_id,
- # "revision_of_slice_id": original_id,
- # "section": info.get("section"),
- # "parent_section": info.get("parent_section"),
- # }
- # )
- # existing_slice_ids.add(revision_id)
- # doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")})
- # success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids)
- # doc_name_map = doc_name_map_or_error if success else {}
- # for r in retriever_result_list:
- # doc_id = r.get("doc_id")
- # doc_name = doc_name_map.get(doc_id)
- # r["document_name"] = doc_name
- # if r.get("metadata") is None:
- # r["metadata"] = {}
- # if doc_name:
- # r["metadata"]["source"] = doc_name
- # rules_text = ""
- # if has_revision_or_deprecated:
- # rules_text = (
- # "切片中可能含有废弃或修订切片。\n"
- # "- 若切片标题包含“已废弃”,表示该内容已不再适用;如需引用请明确说明其已废弃,并优先给出最新/修订信息(若有)。\n"
- # "- 若切片标题包含“原始版/修订版”,表示同一主题存在版本差异;回答时优先采用修订版,必要时对比说明差异。\n"
- # "- 输出时可参考如下结构组织:\n"
- # "> 文档A(原始版)[已修订]\n"
- # "> 文档B(修订版)\n"
- # "> 文档C(已废弃)\n"
- # )
- # chunk_content_for_llm = self._build_llm_chunk_content(retriever_result_list)
- # return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, rules_text
- def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict):
- """
- 处理切片的修订状态、原始/修订关系、废弃状态,并补充文档名称。
- 返回更新后的 retriever_result_list、search_doc_id_to_knowledge_id_dict、
- LLM 拼接内容以及修订映射信息。
-
- 新逻辑:
- 1. 查询修订切片的完整内容
- 2. 将修订后的切片给LLM使用
- 3. 保留原始切片信息用于Redis存储
- """
- mysql_client = MysqlOperate()
- # 收集所有 chunk_id
- slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")]
- slice_ids = [s for s in slice_ids if s]
- # 查询切片修订信息
- slice_rows = []
- if slice_ids:
- success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids)
- if success:
- slice_rows = rows_or_error
- # 构建 slice_id -> row 映射
- slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")}
- # 构建原始->修订映射和修订映射信息
- original_to_revision_id = {}
- revision_slice_ids = set() # 收集所有修订切片ID
- chunk_revision_map = {} # 用于存储修订关系
- for r in retriever_result_list:
- # 当前切片id
- sid = self._normalize_slice_id(r.get("chunk_id"))
- # 切片id对应的row
- info = slice_info_map.get(sid)
- if info:
- # 切片修订状态
- revision_status = self._normalize_revision_status(info.get("revision_status"))
- # 记录状态-修订切片id-父子标题
- r["revision_status"] = revision_status
- r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id"))
- r["section"] = info.get("section")
- r["parent_section"] = info.get("parent_section")
- # 原始->修订映射
- ref_id = r.get("ref_slice_id")
- # 如果是被修订的切片
- if revision_status == "1" and ref_id and sid:
- # 记录原始切片-修订切片的id映射
- original_to_revision_id[sid] = ref_id
- # 记录修订切片id
- revision_slice_ids.add(ref_id)
- # 记录后续使用
- chunk_revision_map[sid] = {"revised_to": ref_id}
- elif revision_status == "0":
- # 废弃切片
- chunk_revision_map[sid] = {"deprecated": True}
- else:
- r.setdefault("revision_status", None)
- r.setdefault("ref_slice_id", None)
- r.setdefault("section", None)
- r.setdefault("parent_section", None)
- # 查询修订切片的完整信息
- revision_slice_info_map = {}
- # 如果有被修订的切片
- if revision_slice_ids:
- # 查询修订切片信息
- success, revision_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
- if success:
- # 构建修订id->row的映射
- revision_slice_info_map = {
- self._normalize_slice_id(r.get("slice_id")): r
- for r in revision_rows_or_error if r.get("slice_id")
- }
- logger.info(f"查询到 {len(revision_slice_info_map)} 个修订切片的完整信息")
- # 补充文档名称
- doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")})
- success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids)
- doc_name_map = doc_name_map_or_error if success else {}
- for r in retriever_result_list:
- doc_id = r.get("doc_id")
- doc_name = doc_name_map.get(doc_id)
- r["document_name"] = doc_name
- if r.get("metadata") is None:
- r["metadata"] = {}
- if doc_name:
- r["metadata"]["source"] = doc_name
- # 创建用于LLM的结果列表(使用修订后的切片,过滤废弃切片)
- llm_result_list = []
- deprecated_chunks = [] # 记录废弃切片,用于后续存储到Redis
-
- for r in retriever_result_list:
- sid = self._normalize_slice_id(r.get("chunk_id"))
- revision_status = r.get("revision_status")
-
- # 过滤废弃切片,不传给LLM
- if revision_status == "0":
- deprecated_chunks.append(r)
- logger.info(f"过滤废弃切片 {sid},不传给LLM")
- continue
-
- # 处理 revision_status="1" 但 ref_slice_id 为空的情况
- if revision_status == "1" and not r.get("ref_slice_id"):
- info = slice_info_map.get(sid)
- revision_text = info.get("revision_slice_text", "") if info else ""
- if not revision_text:
- revision_text = r.get("content", "")
- r["content"] = revision_text
- llm_result_list.append(r)
- logger.info(f"使用切片 {sid} 的 revision_slice_text(无映射关系)")
- continue
-
- # 如果这是一个有修订的原始切片,使用修订后的内容
- if sid in original_to_revision_id:
- revision_id = original_to_revision_id[sid]
- revision_info = revision_slice_info_map.get(revision_id)
-
- # 获取当前切片的 revision_slice_text
- info = slice_info_map.get(sid)
- revision_text = info.get("revision_slice_text", "") if info else ""
- if not revision_text:
- revision_text = r.get("content", "")
-
- if revision_info:
- # 创建修订切片的副本用于LLM,使用当前切片的 revision_slice_text
- revised_chunk = {
- "chunk_id": revision_id,
- "content": revision_text,
- "doc_id": revision_info.get("document_id", r.get("doc_id")),
- "document_name": doc_name_map.get(revision_info.get("document_id"), r.get("document_name")),
- "section": revision_info.get("section"),
- "parent_section": revision_info.get("parent_section"),
- "revision_status": None,
- "ref_slice_id": None,
- "metadata": r.get("metadata", {}),
- "rerank_score": r.get("rerank_score", 1),
- "is_revised_version": True,
- "original_chunk_id": sid
- }
- llm_result_list.append(revised_chunk)
- logger.info(f"使用原始切片 {sid} 的 revision_slice_text + 映射切片 {revision_id} 的元数据")
- else:
- # 如果没有查询到修订切片信息,使用当前切片的 revision_slice_text
- r["content"] = revision_text
- llm_result_list.append(r)
- logger.warning(f"未找到映射切片 {revision_id},使用原始切片 {sid} 的 revision_slice_text")
- else:
- # 没有修订的切片直接使用
- llm_result_list.append(r)
-
- logger.info(f"过滤后传给LLM的切片数: {len(llm_result_list)}, 废弃切片数: {len(deprecated_chunks)}")
- # 生成 LLM 拼接内容(使用修订后的切片)
- chunk_content_for_llm = self._build_llm_chunk_content(llm_result_list)
- return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list
- async def generate_event(self, chat_json, request):
- logger.info(f"rag聊天的请求参数:{chat_json}")
- start_time = time.time()
- query = chat_json.get("messages")[-1].get("content")
- text = query
- hits = tokenizer_sensitive.detect(text)
- if hits:
- # print(build_warning(hits))
- yield {"id": "", "event": "add", "data": build_warning(hits)}
- return
- app_name = chat_json.get("name", "")
- label_prompt = """
- 你是一个 RAG 系统的输入判定器,只做判断,不生成答案。
- 请严格判断用户输入是否属于以下三类之一:
- 1. 语义模糊(NEED_CLARIFY):输入没有明确的主语、对象或范围,或者类似“是什么”“怎么样”“基本规范”“规定是什么”等,无法确定查询目标。
- 2. 明显拼写或术语错误(NEED_CORRECT):输入包含错别字、无意义组合、逻辑混乱或非标准术语,例如“天气怎么阳”“明天什么暗拍”“低功耗贾张氏设计标准”。
- 3. 输入清晰(OK):输入语义明确、包含实体、对象、拼写和术语正确,可直接用于检索和回答。
- 要求严格输出标准 JSON,只输出 JSON,不生成任何解释性文字。
- 输出 JSON 格式:
- {{
- "decision": "OK" 或 "NEED_CLARIFY" 或 "NEED_CORRECT" 或 "SENSITIVE",
- "reason": "一句话说明原因(可为空字符串)",
- "correct": "如果是NEED_CORRECT给我出纠正的内容,否则为空",
- "suggested_question": "如果需要反问或纠正,给出建议文本,否则为空字符串,如有此回答,尽量生动形象,增强用户体验",
- "summary": "如果有上下文,给出上下文摘要",
- "rewrite": "结合上下文重写用户问题,去除指代和模糊表达,使其可独立理解并适合知识检索"
- }}
- 注意事项:
- - 如果输入模糊或有错别字,必须返回对应类型并提供简短建议(suggested_question)。
- - 所有字段必须存在。
- - 输出必须严格遵守 JSON 格式。
- - 不允许有多余文字、换行或 Markdown。
- - 上下文重要性:
- - 如果用户问题较模糊但可依赖之前对话(如“之前问过xxx,接着问为什么?怎么做?”),判定为OK。
- - 根据最近对话内容,将当前问题改写为简短、独立、明确、适合检索的查询,输出到 rewrite 字段。
- - 如果有上下文,summary 字段应简明概括最近相关对话内容。
- - 示例:
- 输入:“是什么?” → NEED_CLARIFY
- 输入:“都会各有各的、是什么?” → NEED_CLARIFY
- 输入:“天气怎么阳” → NEED_CORRECT
- 输入:“明天的温度是多少?” → OK
- 输入上下文:
- Q: 明天的温度是多少?
- A: 最高温度 22°C,最低温度 14°C
- 输入:“为什么?” → OK,rewrite: “为什么明天温度变化幅度大?”
- """
- # - 通过上下文进行判断,假如上一个问题已经回答过,用户接着问“为什么?怎么做?”之类的同样判定为OK,并根据最近对话内容,将当前用户问题改写为一个不含指代词、对象明确、可独立理解、适合用于知识检索的查询,输出到rewrite字段。
- # 用户输入:
- # "{user_input}"
- chat_json["query"] = query
- import copy
- label_messages = copy.deepcopy(chat_json.get("messages"))
- # label_prompt = label_prompt.format(user_input=query)
- # label_messages = chat_json.get("messages").copy()
- # label_messages[-1]["content"] = label_prompt
- label_messages.insert(0, {"role": "system","content": f"{label_prompt}"})
- logger.info(f"追问识别的提示词:{label_messages}")
- completion = await self.vllm_client.generate_non_stream_async(
- prompt=label_messages,
- model="/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- )
- raw_text = completion.strip()
- logger.info(f"问题判定的结果:{raw_text}")
- # 尝试解析 JSON
- try:
- data = json.loads(raw_text)
- except json.JSONDecodeError:
- logger.error("判定器返回的文本不是有效 JSON")
- data = {"decision": "OK", "reason": "", "suggested_question": ""}
- decision = data.get("decision", "OK")
- summary = data.get("summary", "")
- rewrite = data.get("rewrite", "")
- if decision == "SENSITIVE":
- yield {"id": "", "event": "add", "data": data.get("sensitive")}
- return
- if decision == "NEED_CORRECT":
- new_query = data.get("correct", query)
- chat_json["query"] = new_query
- decision = "OK"
- if decision != "OK":
- # 使用app_name进行向量检索,生成反问
- try:
- collection_name_list = chat_json.get("knowledgeIds", [])
- rag_embedding_name = chat_json.get("embeddingId")
-
- if collection_name_list and rag_embedding_name:
- # 使用原始query进行检索,获取相关上下文
- search_results, _ = self.request_milvus(
- collection_name_list=collection_name_list,
- rag_embedding_name=rag_embedding_name,
- query=query,
- top=3, # 只检索3条相关内容
- mode="hybrid"
- )
-
- # 提取检索到的内容
- context_content = "\n".join([r.get("content", "") for r in search_results[:3]])
-
- if context_content:
- # 构建智能反问提示词
- clarify_prompt = f"""你是一个RAG应用助手。用户的输入存在问题:{data.get('reason', '输入不够清晰')}
- 用户输入:{query}
- 知识库中检索到以下相关内容:
- {context_content}
- 请基于检索到的内容,用友好、引导性的语气向用户提出1-2个具体的问题,帮助用户明确他们的需求。要求:
- 1. 问题要具体、有针对性
- 2. 语气要友好、鼓励性
- 3. 可以提供一些选项供用户选择
- 4. 直接输出问题,不要有多余的解释
- 示例格式:
- "您可能想了解XXX方面的内容。请问您是想了解:
- 1. XXX的具体定义和标准?
- 2. XXX的应用场景和案例?
- 还是有其他具体问题呢?"
- """
-
- rq_tokens_str = ""
- # 使用流式chat接口进行反问
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- full_response = ""
- logger.info(f"反问的提示词是:{clarify_prompt}")
- async for chunk in self.vllm_client.chat(
- prompt=clarify_prompt,
- model=model,
- temperature=0.7,
- top_p=0.9,
- max_tokens=512,
- stream=True,
- history=[]
- ):
- if chunk.get("event") == "add":
- # 累积完整响应
- full_response += chunk.get("data", "")
- rq_tokens_str = full_response
- chunk["data"] = full_response
- # 直接yield当前chunk,保持流式输出
- yield {"id": chunk.get("id", ""), "event": "add", "data": chunk.get("data")}
- elif chunk.get("event") == "finish":
- yield {"id": chunk.get("id", ""), "event": "finish", "data": f'{{"inputToken": {self.count_tokens(query)}, "outputToken": {self.count_tokens(rq_tokens_str)}, "timeConsuming": {time.time() - start_time}}}'}
-
- logger.info(f"反问生成完成,完整响应: {full_response}")
- return
-
- except Exception as e:
- logger.error(f"反问生成失败,使用默认回复: {e}")
-
- # 如果检索失败或没有配置知识库,使用原始的suggested_question
- reply = data.get("suggested_question", "抱歉,我没有理解您的问题。能否请您更详细地描述一下?")
- yield {"id": "", "event": "add", "data": reply}
- yield {"id": "", "event": "finish", "data": ""}
- return
- if decision == "OK":
-
- chat_id = ""
- SENTINEL = "⧗⧗LLM_STREAM_SENTINEL⧗⧗"
-
- try:
- output_token_str = ""
- tokens = self.count_tokens(chat_json.get("query") + chat_json.get("prompt"))
- logger.info(f"用户输入的token数:{tokens}")
-
- retriever_result_list, search_doc_id_to_knowledge_id_dict = await self.retriever_result(chat_json, rewrite)
- logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list = self._apply_revision_and_doc_name(
- chat_json,
- retriever_result_list,
- search_doc_id_to_knowledge_id_dict,
- )
- # 添加引用指令到提示词
- citation_instruction = f"\n\n《额外指令》\n\n1、chunk_id仅供正文结束后使用,禁止在正文中以任何形式输出任何chunk_id\n\n2、在回答完成后,请按照以下格式列出你引用的所有chunk_id(必须输出,无引用cited_chunks可返回空字符串),格式为:{SENTINEL}\n[cited_chunks: id1, id2, id3]"
- chat_json["prompt"] = chat_json.get("prompt", "") + citation_instruction
- first = True
- sentinel_detected = False
- buffering = False
- buffer_content = ""
- clean_output = ""
- end_time = None
-
- async for event in self.generate_rag_response(chat_json, chunk_content_for_llm):
- chat_id = event.get("id")
-
- if first:
- end_time = time.time()
- first = False
-
- if await request.is_disconnected():
- logger.info(f"chat id:{chat_id}连接中断")
- yield {"id": chat_id, "event": "interrupted", "data": ""}
- return
-
- if event.get("event") == "add":
- full_data = event.get("data", "")
- output_token_str = full_data
-
- # 如果已检测到完整哨兵线,收集后续内容用于解析chunk_ids
- if sentinel_detected:
- # re.IGNORECASE不区分大小写
- cited_match = re.search(r'\[cited_chunks[::]\s*([^\]]+)\]', full_data, re.IGNORECASE)
- # if not cited_match:
- # sentinel_detected = False
- # continue
- if cited_match:
- cited_ids_str = cited_match.group(1)
- cited_chunk_ids = [cid.strip().upper() for cid in cited_ids_str.split(',')]
- logger.info(f"解析到LLM引用的chunk_ids: {cited_chunk_ids}")
-
- # 构建包含引用块的knowledge_info_dict,同时独立保存修订切片和原始切片
- knowledge_info_dict = {}
-
- # 创建chunk_id映射,用于快速查找(使用小写作为key以避免大小写问题)
- llm_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in llm_result_list}
- original_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in retriever_result_list}
-
- # 记录日志,调试
- logger.info(f"llm_result_list中的chunk_ids: {[r.get('chunk_id') for r in llm_result_list]}")
- logger.info(f"retriever_result_list中的chunk_ids: {[r.get('chunk_id') for r in retriever_result_list]}")
-
- for cited_id in cited_chunk_ids:
- cited_id_normalized = cited_id.strip().upper()
-
- # 在LLM结果列表中查找(可能是修订版本)- 使用小写匹配
- llm_chunk = llm_chunk_map.get(cited_id_normalized.lower())
-
- if not llm_chunk:
- logger.warning(f"未找到LLM引用的chunk_id: {cited_id_normalized}")
- sentinel_detected = False
- continue
-
- doc_id = llm_chunk.get("doc_id")
- knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id)
-
- # 判断这个切片是否是修订版本
- is_revised_version = llm_chunk.get("is_revised_version", False)
-
- # 构建当前切片的chunk_info
- current_chunk_info = {
- "chunk_id": llm_chunk.get("chunk_id"),
- "content": llm_chunk.get("content", ""),
- "rerank_score": llm_chunk.get("rerank_score"),
- "chunk_len": len(llm_chunk.get("content", "")),
- "document_name": llm_chunk.get("document_name"),
- "revision_status": llm_chunk.get("revision_status"),
- "ref_slice_id": llm_chunk.get("ref_slice_id"),
- "section": llm_chunk.get("section"),
- "parent_section": llm_chunk.get("parent_section"),
- }
-
- # 初始化chunk_info_list
- chunk_info_list = []
-
- # 如果这是修订版本,需要同时保存修订切片和原始切片
- if is_revised_version:
- original_chunk_id = llm_chunk.get("original_chunk_id")
- # 使用小写查找原始切片
- original_chunk = original_chunk_map.get(original_chunk_id.lower() if original_chunk_id else None)
-
- logger.info(f"查找原始切片: original_chunk_id={original_chunk_id}, found={original_chunk is not None}")
-
- if original_chunk:
- # 1. 先添加修订切片信息
- current_chunk_info["slice_type"] = "revised"
- chunk_info_list.append(current_chunk_info)
-
- # 2. 再添加原始切片信息
- original_chunk_info = {
- "chunk_id": original_chunk.get("chunk_id"),
- "content": original_chunk.get("content", ""),
- "rerank_score": original_chunk.get("rerank_score"),
- "chunk_len": len(original_chunk.get("content", "")),
- "document_name": original_chunk.get("document_name"),
- "revision_status": original_chunk.get("revision_status"),
- "ref_slice_id": original_chunk.get("ref_slice_id"), # 保留数据库原始字段,指向修订切片
- "section": original_chunk.get("section"),
- "parent_section": original_chunk.get("parent_section"),
- "slice_type": "original" # 标记为原始切片
- }
- # 添加 revised_to_chunk_id 和 chunk_revision_map 的内容(废弃),因为 ref_slice_id 已经包含了指向修订切片的信息
-
- chunk_info_list.append(original_chunk_info)
- logger.info(f"独立保存修订切片 {llm_chunk.get('chunk_id')} 和原始切片 {original_chunk_id} 信息")
- else:
- # 没有找到原始切片,只保存修订切片
- current_chunk_info["slice_type"] = "revised"
- chunk_info_list.append(current_chunk_info)
- logger.warning(f"未找到修订切片 {llm_chunk.get('chunk_id')} 对应的原始切片 {original_chunk_id}")
- else:
- # 非修订切片,检查是否有废弃切片需要一起存储
- # 在retriever_result_list中查找是否有revision_status="0"的切片
- chunk_in_retriever = original_chunk_map.get(cited_id_normalized.lower())
- if chunk_in_retriever and chunk_in_retriever.get("revision_status") == "0":
- # 这是一个废弃切片,标记为deprecated
- current_chunk_info["slice_type"] = "deprecated"
- logger.info(f"存储废弃切片 {cited_id_normalized} 到Redis")
- else:
- # 普通切片
- current_chunk_info["slice_type"] = "normal"
- chunk_info_list.append(current_chunk_info)
-
- doc_info = {
- "doc_id": doc_id,
- "doc_name": llm_chunk.get("document_name"),
- "chunk_nums": len(chunk_info_list),
- "chunk_info_list": chunk_info_list
- }
-
- if knowledge_id in knowledge_info_dict:
- existing_doc = None
- for doc in knowledge_info_dict[knowledge_id]:
- if doc["doc_id"] == doc_id:
- existing_doc = doc
- break
- if existing_doc:
- existing_doc["chunk_info_list"].extend(chunk_info_list)
- existing_doc["chunk_nums"] += len(chunk_info_list)
- else:
- knowledge_info_dict[knowledge_id].append(doc_info)
- else:
- knowledge_info_dict[knowledge_id] = [doc_info]
-
- # 额外处理:将所有废弃切片也存储到Redis(即使LLM没有引用)
- for r in retriever_result_list:
- if r.get("revision_status") == "0":
- deprecated_chunk_id = r.get("chunk_id")
- doc_id = r.get("doc_id")
- knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id)
-
- # 检查是否已经存储过这个废弃切片
- already_stored = False
- if knowledge_id in knowledge_info_dict:
- for doc in knowledge_info_dict[knowledge_id]:
- if doc["doc_id"] == doc_id:
- for chunk_info in doc["chunk_info_list"]:
- if chunk_info["chunk_id"] == deprecated_chunk_id:
- already_stored = True
- break
- if already_stored:
- break
-
- if not already_stored:
- # 构建废弃切片信息
- deprecated_chunk_info = {
- "chunk_id": deprecated_chunk_id,
- "content": r.get("content", ""),
- "rerank_score": r.get("rerank_score"),
- "chunk_len": len(r.get("content", "")),
- "document_name": r.get("document_name"),
- "revision_status": "0", # 明确标记为废弃
- "ref_slice_id": r.get("ref_slice_id"),
- "section": r.get("section"),
- "parent_section": r.get("parent_section"),
- "slice_type": "deprecated"
- }
-
- doc_info = {
- "doc_id": doc_id,
- "doc_name": r.get("document_name"),
- "chunk_nums": 1,
- "chunk_info_list": [deprecated_chunk_info]
- }
-
- if knowledge_id in knowledge_info_dict:
- existing_doc = None
- for doc in knowledge_info_dict[knowledge_id]:
- if doc["doc_id"] == doc_id:
- existing_doc = doc
- break
- if existing_doc:
- existing_doc["chunk_info_list"].append(deprecated_chunk_info)
- existing_doc["chunk_nums"] += 1
- else:
- knowledge_info_dict[knowledge_id].append(doc_info)
- else:
- knowledge_info_dict[knowledge_id] = [doc_info]
-
- logger.info(f"额外存储废弃切片 {deprecated_chunk_id} 到Redis")
-
- # 存储到Redis
- self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
- logger.info(f"存储引用的块到Redis: {len(cited_chunk_ids)}个(包含修订、原始和废弃切片信息)")
-
- # 返回finish事件
- output_token = self.count_tokens(clean_output)
- time_out = end_time - start_time if end_time else 0
- finish_data = {
- "inputToken": tokens,
- "outputToken": output_token,
- "timeConsuming": time_out
- }
- # yield {"id": chat_id, "event": "finish", "data": '你猜!'}
- logger.info(f"111111111111111111111111111:{event}")
- event["event"] = "finish"
- event["data"] = finish_data
- # yield event
- send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'}
- yield send
- # yield {"id": chat_id, "event": "finish", "inputToken": tokens, "outputToken": output_token, "timeConsuming": time_out}
- logger.info(f"{send}")
- return # 直接返回,不再继续处理后续事件
- continue
-
- # 检查是否包含完整哨兵线
- if SENTINEL in full_data:
- sentinel_pos = full_data.find(SENTINEL)
- sentinel_detected = True
- clean_output = full_data[:sentinel_pos]
- logger.info(f"检测到完整哨兵线")
- continue
-
- # 检查是否包含哨兵线的第一个字符 '⧗'
- if not buffering and '⧗' in full_data:
- first_char_pos = full_data.find('⧗')
- # 发送 '⧗' 之前的内容
- if first_char_pos > 0:
- event["data"] = full_data[:first_char_pos]
- yield event
- buffering = True
- buffer_content = full_data[first_char_pos:]
- clean_output = full_data[:first_char_pos]
- logger.info(f"检测到哨兵线第一个字符,开始缓冲")
- continue
-
- # 如果正在缓冲
- if buffering:
- buffer_content = full_data[len(clean_output):]
- if SENTINEL.startswith(buffer_content):
- # 可能是哨兵线的一部分,继续缓冲
- continue
- elif event.get("event") != "finish":
- # 不是哨兵线,发送缓冲的内容
- buffering = False
- yield event
- elif event.get("event") != "finish":
- # 正常返回
- yield event
-
- elif event.get("event") == "finish":
- # 如果已经通过哨兵线发送过finish事件,跳过这个finish事件
- if sentinel_detected:
- logger.info("哨兵线已处理,跳过LLM的finish事件")
- continue
-
- # 如果LLM提前结束但没有哨兵线,使用所有块
- _chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict)
- self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
- logger.warning("LLM未输出哨兵线,存储所有检索块")
-
- # 发送finish事件(LLM未输出哨兵线的情况)
- output_token = self.count_tokens(output_token_str)
- time_out = end_time - start_time if end_time else 0
- finish_data = {
- "inputToken": tokens,
- "outputToken": output_token,
- "timeConsuming": time_out
- }
- event["data"] = finish_data
-
- send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'}
- yield send
- logger.info(f"LLM未输出哨兵线,发送finish事件: {send}")
- elif event.get("event") != "finish":
- yield event
-
- except Exception as e:
- logger.error(f"执行出错:{e}")
- yield {"id": chat_id, "event": "finish", "data": ""}
- return
- async def query_redis_chunk_by_chat_id(self, chat_id):
- chunk_redis_str = self.redis_client.get(chat_id)
- chunk_json = json.loads(chunk_redis_str)
- logger.info(f"redis中存储的chunk 信息:{chunk_json}")
- #knowledge_id = chunk_json.get("knowledge_id")
- knowledge_id = list(chunk_json.keys())[0] if chunk_json else None
- chunk_list = []
- #for chunk in chunk_json["doc"]:
- # doc_chunk_list = chunk.get("chunk_info_list")
- # chunk_list.extend(doc_chunk_list)
- for knowledge_id_key, doc_list in chunk_json.items():
- for doc in doc_list:
- doc_chunk_list = doc.get("chunk_info_list")
- if doc_chunk_list:
- chunk_list.extend(doc_chunk_list)
- if len(chunk_list) > 3:
- random_chunk_list = random.sample(chunk_list, 3)
- else:
- random_chunk_list = chunk_list
- chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
- return knowledge_id, chunk_id_list
-
- async def generate_relevant_query(self, query_json):
- chat_id = query_json.get("chat_id")
- if chat_id:
- knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
- chunk_content_list = MilvusOperate(collection_name=knowledge_id, embedding_name=query_json.get("embedding_id"))._search_by_chunk_id_list(chunk_id_list)
- chunk_content = "\n".join(chunk_content_list)
- # generate_query_prompt = "请根据下面提供的信息,帮我生成"
- query_messages = query_json.get("messages")
- user_dict = query_messages.pop()
- messages = [{"role": "assistant", "content": chunk_content}, user_dict]
- else:
- messages = query_json.get("messages")
- #model = query_json.get("model")
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
- query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
-
- async for result in query_result:
- # result_json = json.loads(result)
- # logger.info(f"生成的问题:{result_json}")
- # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
- result = result.get("data", "").strip()
- logger.info(f"模型生成的问题:{result}")
- try:
- if "```json" in result:
- json_pattern = r'```json\s(.*?)```'
- matches = re.findall(json_pattern, result, re.DOTALL)
- result = matches[0]
- query_json = json.loads(result)
- except Exception as e:
- query_json = eval(result)
- query_list = query_json.get("问题")
- return {"code": 200, "data": query_list}
- async def search_slice(self):
- try:
- chunk_redis_str = self.redis_client.get(self.chat_id)
- chunk_json = json.loads(chunk_redis_str)
- chunk_json["code"] = 200
- except Exception as e:
- logger.error(f"查询redis报错:{e}")
- chunk_json = {
- "code": 500,
- "message": str(e)
- }
- return chunk_json
|