| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679 |
- # 请求的入口
- import os
- os.environ["CUDA_VISIBLE_DEVICES"] = "4"
- # os.environ["REQUEST_TIMEOUT"] = "300"
- # import tracemalloc
- # tracemalloc.start()
- import faulthandler
- faulthandler.enable()
- from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body, HTTPException
- from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
- from fastapi.staticfiles import StaticFiles
- from fastapi.middleware.cors import CORSMiddleware
- from sse_starlette import EventSourceResponse
- import uvicorn
- from utils.get_logger import setup_logger
- from rag.vector_db.milvus_vector import HybridRetriever
- from response_info import generate_message, generate_response
- from rag.db import MilvusOperate, MysqlOperate
- from rag.file_process import ParseFile
- from rag.documents_process import ProcessDocuments
- from rag.task_registry import task_registry
- from rag.chat_message import ChatRetrieverRag
- # from rag_evaluation import RAGEvaluator, run_evaluation
- from rag.rag_evaluation_service import run_rag_evaluation
- from rag.routes import (
- copy_docs_to_collection as do_copy_docs,
- copy_multi_docs_to_collection as do_copy_multi_docs,
- process_single_slice_metadata,
- generate_slice_metadata as do_generate_slice_metadata,
- update_slice,
- insert_slice,
- )
- from datasets import Dataset
- # from utils.wav2text import Audio2Text
- import json
- import uuid
- from lightRAG import AsyncLightRAGManager
- from rag.llm import VllmApi
- from rag.upload_queue_manager import get_queue_manager
- logger = setup_logger(__name__)
- app = FastAPI()
- # 初始化全局队列管理器
- queue_manager = get_queue_manager(max_concurrent=10)
- # 挂载静态文件
- static_dir = os.path.join(os.path.dirname(__file__), "static")
- if os.path.isdir(static_dir):
- app.mount("/static", StaticFiles(directory=static_dir), name="static")
- # 设置跨域
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.get("/health")
- async def health_check():
- return {"status": "healthy"}
- @app.post("/upload_knowledge")
- async def upload_file_to_db(file_json: dict):
- """
- 文档上传接口(带队列并发控制)
-
- 功能:
- 1. 使用队列管理器控制并发,最多同时处理5个上传任务
- 2. 任务提交到队列后立即返回,不阻塞
- 3. 前一个任务完成后,下一个任务自动开始
- """
- logger.info(f"上传文件请求参数:{file_json}")
-
- # 提取任务ID
- task_id = file_json.get("docs", [{}])[0].get("document_id")
-
- if not task_id:
- logger.error("缺少 document_id 参数")
- return JSONResponse({"code": 400, "message": "缺少 document_id 参数"})
-
- # 检查队列管理器是否运行
- if not queue_manager.is_running:
- await queue_manager.start()
- logger.info("队列管理器已启动")
-
- # 定义任务处理函数
- async def process_task():
- parse_file = ProcessDocuments(file_json)
- return await parse_file.process_documents(file_json, task_id)
-
- # 提交任务到队列
- await queue_manager.submit_task(task_id, process_task)
-
- # 获取队列状态
- queue_status = queue_manager.get_queue_status()
-
- logger.info(f"任务已提交到队列: {task_id}, 队列状态: {queue_status}")
-
- return JSONResponse({
- "code": 200,
- "message": "任务已提交到队列,正在处理中",
- "task_id": task_id,
- "queue_status": {
- "queue_size": queue_status["queue_size"],
- "active_tasks_count": queue_status["active_tasks_count"],
- "position_in_queue": queue_status["queue_size"] + queue_status["active_tasks_count"]
- }
- })
- @app.get("/upload_queue/status")
- async def get_queue_status():
- """
- 获取上传队列的整体状态
-
- 返回:
- - is_running: 队列管理器是否运行中
- - max_concurrent: 最大并发数
- - queue_size: 等待队列中的任务数
- - active_tasks_count: 正在执行的任务数
- - completed_tasks_count: 已完成的任务数
- - failed_tasks_count: 失败的任务数
- - active_tasks: 正在执行的任务ID列表
- """
- status = queue_manager.get_queue_status()
- logger.info(f"查询队列状态: {status}")
- return JSONResponse({
- "code": 200,
- "data": status
- })
- @app.get("/upload_queue/task/{task_id}")
- async def get_task_status(task_id: str):
- """
- 获取指定任务的状态
-
- 参数:
- task_id: 任务ID(document_id)
-
- 返回:
- - status: 任务状态(running/completed/failed/not_found)
- - 其他任务详细信息
- """
- task_status = queue_manager.get_task_status(task_id)
- logger.info(f"查询任务状态 [{task_id}]: {task_status}")
-
- if task_status["status"] == "not_found":
- return JSONResponse({
- "code": 404,
- "message": f"任务不存在: {task_id}",
- "data": task_status
- })
-
- return JSONResponse({
- "code": 200,
- "data": task_status
- })
- @app.post("/upload_queue/clear_history")
- async def clear_queue_history():
- """
- 清空队列的历史记录(已完成和失败的任务)
-
- 注意:不会影响正在执行的任务
- """
- queue_manager.clear_history()
- logger.info("队列历史记录已清空")
- return JSONResponse({
- "code": 200,
- "message": "队列历史记录已清空"
- })
- # @app.post("/network/search")
- # async def chat_with_rag(request: Request, chat_json:dict):
- # retriever = ChatRetrieverRag(chat_json)
- # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
- @app.post("/rag/chat")
- async def chat_with_rag(request: Request, chat_json:dict):
- retriever = ChatRetrieverRag(chat_json)
- return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
- @app.post("/rag/token_speed")
- async def get_token_speed(request_json: dict):
- """
- 调用非流式LLM并返回token生成速度
-
- 参数:
- - query: 用户查询
- - model: 模型名称
-
- 返回:
- - token_speed: token生成速度(tokens/秒)
- """
- import time
- from config import model_name_vllm_url_dict
-
- query = request_json.get("query", "你好")
- model = request_json.get("model")
- if model == "Qwen3-30B":
- model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
-
- if not query:
- return JSONResponse({"code": 400, "message": "query不能为空"})
- if not model:
- return JSONResponse({"code": 400, "message": "model不能为空"})
-
- # 验证模型是否存在
- if model not in model_name_vllm_url_dict:
- return JSONResponse({"code": 400, "message": f"模型{model}不存在"})
-
- logger.info(f"token速度测试请求 - query: {query}, model: {model}")
-
- try:
- # 创建VllmApi实例
- vllm_client = VllmApi(request_json)
-
- # 记录开始时间
- start_time = time.time()
-
- # 调用非流式接口,获取完整响应
- messages = [{"role": "user", "content": query}]
- completion = await vllm_client.generate_non_stream_async(
- prompt=messages,
- model=model,
- return_full_response=True
- )
-
- # 记录结束时间
- end_time = time.time()
- elapsed_time = end_time - start_time
-
- # 从LLM返回的usage字段获取token信息
- usage = completion.usage
- completion_tokens = usage.completion_tokens
- prompt_tokens = usage.prompt_tokens
- total_tokens = usage.total_tokens
-
- # 计算token生成速度
- token_speed = completion_tokens / elapsed_time if elapsed_time > 0 else 0
-
- logger.info(f"token速度测试完成 - 耗时: {elapsed_time:.2f}秒, 输入tokens: {prompt_tokens}, 输出tokens: {completion_tokens}, 速度: {token_speed:.2f} tokens/秒")
-
- return JSONResponse({
- "code": 200,
- "data": {
- "token_speed": round(token_speed, 2),
- # "prompt_tokens": prompt_tokens,
- # "completion_tokens": completion_tokens,
- # "total_tokens": total_tokens,
- "elapsed_time": round(elapsed_time, 2)
- }
- })
-
- except Exception as e:
- logger.error(f"token速度测试失败: {str(e)}")
- return JSONResponse({"code": 500, "message": f"请求失败: {str(e)}"})
- @app.post("/rag/query")
- async def generate_query(request: Request, query_json:dict):
- logger.info(f"请求参数:{query_json}")
- relevant_query = ChatRetrieverRag(query_json)
- relevant_json = await relevant_query.generate_relevant_query(query_json)
- return JSONResponse(relevant_json)
- @app.get("/rag/slice/search/{chat_id}")
- async def generate_query(request: Request, chat_id:str = None):
- logger.info(f"根据chat_id查询切片信息:{chat_id}")
- chat = ChatRetrieverRag(chat_id=chat_id)
- chunk_json = await chat.search_slice()
- return JSONResponse(chunk_json)
- @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
- async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
- logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
- logger.info(f"删除切片信息的结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/batch_delete_slice")
- async def batch_delete_by_chunk_id(delete_json: dict):
- """
- 批量删除切片接口
-
- 参数:
- delete_json: {
- "knowledge_id": str, # 知识库ID
- "document_id": str, # 文档ID
- "slice_id": list # 切片ID列表
- }
-
- 返回:
- {
- "code": 200/500,
- "message": str
- }
- """
- knowledge_id = delete_json.get("knowledge_id")
- document_id = delete_json.get("document_id")
- slice_ids = delete_json.get("slice_id", [])
-
- # 参数验证
- if not knowledge_id:
- logger.error("批量删除切片:缺少 knowledge_id 参数")
- return JSONResponse({"code": 400, "message": "缺少 knowledge_id 参数"})
-
- if not document_id:
- logger.error("批量删除切片:缺少 document_id 参数")
- return JSONResponse({"code": 400, "message": "缺少 document_id 参数"})
-
- if not slice_ids or not isinstance(slice_ids, list):
- logger.error("批量删除切片:slice_id 必须是非空列表")
- return JSONResponse({"code": 400, "message": "slice_id 必须是非空列表"})
-
- logger.info(f"批量删除切片接口,知识库:{knowledge_id}, 文档ID:{document_id}, 切片数量:{len(slice_ids)}")
-
- # 创建 MilvusOperate 实例并执行批量删除
- milvus_client = MilvusOperate(collection_name=knowledge_id)
- resp = milvus_client._batch_delete_by_chunk_ids(slice_ids, knowledge_id, document_id)
-
- logger.info(f"批量删除切片结果:{resp}")
- return JSONResponse(resp)
- @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
- async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
- logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
- rows = MysqlOperate().query_knowledge_by_ids([knowledge_id])
- if rows:
- rows_gp = [row["knowledge_graph"] for row in rows]
- if rows_gp:
- rag_mgr = AsyncLightRAGManager()
- rag_mgr = await rag_mgr.init_workspace(knowledge_id)
- result = await rag_mgr.adelete_by_doc_id(knowledge_id, delete_llm_cache=True)
- # 删除成功
- if result.status == "success":
- logger.info("知识图谱删除成功")
- # 文档不存在
- elif result.status == "not_found":
- logger.info("知识图谱文档不存在")
-
- # 不允许删除(管道忙碌)
- elif result.status == "not_allowed":
- logger.info("知识图谱当前不允许删除")
-
- # 删除失败
- elif result.status == "failure":
- logger.info(f"知识图谱删除失败: {result.message}")
- logger.info(f"删除文档的结果:{resp}")
- return JSONResponse(resp)
- @app.put("/rag/update_slice")
- async def put_by_id(slice_json:dict):
- resp = await update_slice(slice_json)
- return JSONResponse(resp)
- @app.post("/rag/insert_slice")
- async def insert_slice_text(slice_json:dict):
- resp = await insert_slice(slice_json)
- return JSONResponse(resp)
- @app.get("/rag/search/{knowledge_id}/{slice_id}")
- async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
- # 根据切片id查询切片信息
- # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
- logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
- collection_name = knowledge_id # 根据传过来的id处理对应知识库
- resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
-
- logger.info(f"根据切片id查询结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/search_word")
- async def search_by_key_word(search_json:dict):
- # 根据doc_id 查询切片列表信息
- collection_name = search_json.get("knowledge_id")
- logger.info(f"根据关键字请求的参数:{search_json}")
- resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
- logger.info(f"根据关键字查询的结果:{resp}")
- return JSONResponse(resp)
- @app.delete("/rag/delete_knowledge/{knowledge_id}")
- async def delete_collection(knowledge_id: str = None):
- logger.info(f"删除数据库请求的参数:{knowledge_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
- rows = MysqlOperate().query_knowledge_by_ids([knowledge_id])
- if rows:
- rows_gp = [row["knowledge_graph"] for row in rows]
- if rows_gp:
- rag = AsyncLightRAGManager()
- rag = await rag.init_workspace(knowledge_id)
- storages = [
- rag.text_chunks,
- rag.full_docs,
- rag.full_entities,
- rag.full_relations,
- rag.entity_chunks,
- rag.relation_chunks,
- rag.entities_vdb,
- rag.relationships_vdb,
- rag.chunks_vdb,
- rag.chunk_entity_relation_graph,
- rag.doc_status,
- ]
-
- # 并行执行所有drop操作
- import asyncio
- drop_tasks = []
- for storage in storages:
- if storage is not None:
- drop_tasks.append(storage.drop())
-
- results = await asyncio.gather(*drop_tasks, return_exceptions=True)
-
- # 检查结果
- for i, result in enumerate(results):
- storage_name = storages[i].__class__.__name__
- if isinstance(result, Exception):
- logger.info(f"删除失败 {storage_name}: {result}")
- else:
- logger.info(f"成功删除 {storage_name}")
- logger.info(f"删除向量库结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/create_collection")
- async def create_collection(collection: dict):
- collection_name = collection.get("knowledge_id")
- embedding_name = collection.get("embedding_id")
- logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
- resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
- logger.info(f"创建向量库结果:{resp}")
-
- return JSONResponse(resp)
- @app.post("/rag/copy_docs_to_collection")
- async def copy_docs_to_collection_route(copy_request: dict):
- """复制指定文档到新集合(索引)"""
- resp = await do_copy_docs(copy_request)
- return JSONResponse(resp)
- @app.post("/rag/copy_multi_docs_to_collection")
- async def copy_multi_docs_to_collection_route(copy_request: dict):
- """从多个源集合复制文档到目标集合(包含向量库和MySQL元数据)"""
- resp = await do_copy_multi_docs(copy_request)
- return JSONResponse(resp)
- @app.post("/rag/generate_slice_metadata")
- async def generate_slice_metadata_route(request_json: dict):
- """生成切片的 qa、question、summary 字段"""
- resp = await do_generate_slice_metadata(request_json)
- return JSONResponse(resp)
- # ==================== 任务删除接口 ====================
- @app.post("/rag/cancel_task/{task_id}")
- async def cancel_task(task_id: str):
- """
- 取消正在执行的文档处理任务
-
- 参数:
- - task_id: 任务ID(即 document_id)
-
- 返回:
- - code: 200 成功, 404 任务不存在, 400 任务已取消
- """
- mysql_client = MysqlOperate()
- # 首先删除 bm_document 表中的记录
- del_success, del_info = mysql_client.delete_document(task_id)
- # 尝试取消已注册的任务
- success, message = task_registry.cancel(task_id)
-
- if success:
- if del_success:
- logger.info(f"任务 {task_id} 已被取消,bm_document 记录已删除")
- else:
- logger.warning(f"任务 {task_id} 取消成功,但删除 bm_document 记录失败: {del_info}")
-
- return JSONResponse({"code": 200, "message": "任务取消成功"})
-
- # 如果任务未注册,尝试从队列中删除
- if "不存在" in message:
- removed = await queue_manager.remove_from_queue(task_id)
- if removed:
- if del_success:
- logger.info(f"任务 {task_id} 从队列中删除成功,bm_document 记录已删除")
- else:
- logger.warning(f"任务 {task_id} 从队列中删除成功,但删除 bm_document 记录失败: {del_info}")
- return JSONResponse({"code": 200, "message": "任务已从队列中删除"})
-
- logger.info(f"任务 {task_id} 未找到")
- return JSONResponse({"code": 404, "message": "任务不存在"})
-
- logger.info(f"message: {message}")
- return JSONResponse({"code": 400, "message": message})
-
- # @app.get("/rag/tasks")
- # async def get_active_tasks():
- # """
- # 获取所有活跃任务列表(调试用)
- # """
- # tasks = task_registry.get_all_tasks()
- # return JSONResponse({
- # "code": 200,
- # "data": tasks,
- # "count": len(tasks)
- # })
- # ==== RAG评估 ====
- @app.get("/rag/eval")
- async def rag_eval_page():
- """RAG评估页面"""
- html_path = os.path.join(static_dir, "rag_eval.html")
- if os.path.isfile(html_path):
- return FileResponse(html_path, media_type="text/html")
- return JSONResponse({"code": 404, "message": "页面不存在"})
- @app.post("/rag/evaluate")
- async def rag_evaluate(eval_request: dict):
- """
- RAG评估接口
-
- 参数:
- - file_url: 网络文件路径(JSON格式,包含question和ground_truth)
- - knowledge_ids: 知识库ID列表
- - embedding_id: 嵌入模型ID(默认 e5)
- - temperature: 温度参数(默认 0.6)
- - top_p: top_p参数(默认 0.7)
- - max_tokens: 最大token数(默认 4096)
- - model: 模型名称(默认 Qwen3-Coder-30B-loft)
- - slice_count: 检索切片数量(默认 5)
-
- 返回:
- - knowledge_ids: 知识库ID列表
- - metrics: 评估指标及值
- """
- logger.info(f"RAG评估的请求参数:{eval_request}")
- file_url = eval_request.get("file_url")
- knowledge_ids = eval_request.get("knowledge_ids", [])
-
- if not file_url:
- return JSONResponse({"code": 400, "message": "file_url 不能为空"})
- if not knowledge_ids:
- return JSONResponse({"code": 400, "message": "knowledge_ids 不能为空"})
-
- logger.info(f"RAG评估请求 - 文件: {file_url}, 知识库: {knowledge_ids}")
-
- result = await run_rag_evaluation(
- file_url=file_url,
- knowledge_ids=knowledge_ids,
- embedding_id=eval_request.get("embedding_id", "e5"),
- temperature=eval_request.get("temperature", 0.6),
- top_p=eval_request.get("top_p", 0.7),
- max_tokens=eval_request.get("max_tokens", 4096),
- model=eval_request.get("model", "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"),
- slice_count=int(eval_request.get("slice_count", 5)),
- task_id=eval_request.get("task_id"),
- tenant_id = eval_request.get("tenant_id"),
- prompt = eval_request.get("prompt"),
- rerank_model_name = eval_request.get("rerank_model_name", "Qwen3-Reranker-0.6B"),
- )
-
- logger.info(f"RAG评估结果: code={result.get('code')}")
- return JSONResponse(result)
- # @app.post("/ai/translate/audio")
- # async def audio_text(request: Request, file: UploadFile = File(...)):
- # uuid_str = uuid.uuid4()
- # file_name = f"{uuid_str}_{file.filename}"
- # file_path = os.path.join("./tmp_file/audio/", file_name)
- # with open(file_path, "wb") as f:
- # f.write(await file.read())
-
- # asr_text = Audio2Text(file_path).audo_to_text()
- # logger.info(f"asr result: {asr_text}, audio file name: {file_name}")
- # return asr_text
- from pathlib import Path
- import shutil
- # ===== 文件目录配置 =====
- BASE_DIR = Path(__file__).parent
- FILE_DIR = BASE_DIR / "rag" / "sensitive_words"
- FILE_DIR.mkdir(parents=True, exist_ok=True)
- ALLOWED_SUFFIX = ".txt"
- # 停用词库文件列表
- @app.get("/files")
- def list_files():
- files = [
- f.name
- for f in FILE_DIR.iterdir()
- if f.is_file() and f.suffix == ALLOWED_SUFFIX
- ]
- return {
- "count": len(files),
- "files": files,
- }
- # 停用词库文件下载(txt)
- @app.get("/files/{filename}")
- def download_file(filename: str):
- # 防止路径穿越
- if "/" in filename or "\\" in filename:
- raise HTTPException(status_code=400, detail="非法文件名")
- file_path = FILE_DIR / filename
- if not file_path.exists():
- raise HTTPException(status_code=404, detail="文件不存在")
- if file_path.suffix != ALLOWED_SUFFIX:
- raise HTTPException(status_code=400, detail="只允许下载 txt 文件")
- return FileResponse(
- path=file_path,
- filename=filename,
- media_type="text/plain",
- )
- # 停用词库文件上传(txt)
- @app.post("/files/upload")
- def upload_file(file: UploadFile = File(...)):
- filename = file.filename
- if not filename:
- raise HTTPException(status_code=400, detail="文件名不能为空")
- if not filename.endswith(ALLOWED_SUFFIX):
- raise HTTPException(status_code=400, detail="只允许上传 txt 文件")
- target_path = FILE_DIR / filename
- # 覆盖写入
- with target_path.open("wb") as f:
- shutil.copyfileobj(file.file, f)
- return {
- "message": "上传成功",
- "filename": filename,
- "path": str(target_path),
- }
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=6000, loop="asyncio")
|