rag_server.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. # 请求的入口
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "4"
  4. # os.environ["REQUEST_TIMEOUT"] = "300"
  5. # import tracemalloc
  6. # tracemalloc.start()
  7. import faulthandler
  8. faulthandler.enable()
  9. from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body, HTTPException
  10. from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
  11. from fastapi.staticfiles import StaticFiles
  12. from fastapi.middleware.cors import CORSMiddleware
  13. from sse_starlette import EventSourceResponse
  14. import uvicorn
  15. from utils.get_logger import setup_logger
  16. from rag.vector_db.milvus_vector import HybridRetriever
  17. from response_info import generate_message, generate_response
  18. from rag.db import MilvusOperate, MysqlOperate
  19. from rag.file_process import ParseFile
  20. from rag.documents_process import ProcessDocuments
  21. from rag.task_registry import task_registry
  22. from rag.chat_message import ChatRetrieverRag
  23. # from rag_evaluation import RAGEvaluator, run_evaluation
  24. from rag.rag_evaluation_service import run_rag_evaluation
  25. from rag.routes import (
  26. copy_docs_to_collection as do_copy_docs,
  27. copy_multi_docs_to_collection as do_copy_multi_docs,
  28. process_single_slice_metadata,
  29. generate_slice_metadata as do_generate_slice_metadata,
  30. update_slice,
  31. insert_slice,
  32. )
  33. from datasets import Dataset
  34. # from utils.wav2text import Audio2Text
  35. import json
  36. import uuid
  37. from lightRAG import AsyncLightRAGManager
  38. from rag.llm import VllmApi
  39. from rag.upload_queue_manager import get_queue_manager
  40. logger = setup_logger(__name__)
  41. app = FastAPI()
  42. # 初始化全局队列管理器
  43. queue_manager = get_queue_manager(max_concurrent=10)
  44. # 挂载静态文件
  45. static_dir = os.path.join(os.path.dirname(__file__), "static")
  46. if os.path.isdir(static_dir):
  47. app.mount("/static", StaticFiles(directory=static_dir), name="static")
  48. # 设置跨域
  49. app.add_middleware(
  50. CORSMiddleware,
  51. allow_origins=["*"],
  52. allow_credentials=True,
  53. allow_methods=["*"],
  54. allow_headers=["*"],
  55. )
  56. @app.get("/health")
  57. async def health_check():
  58. return {"status": "healthy"}
  59. @app.post("/upload_knowledge")
  60. async def upload_file_to_db(file_json: dict):
  61. """
  62. 文档上传接口(带队列并发控制)
  63. 功能:
  64. 1. 使用队列管理器控制并发,最多同时处理5个上传任务
  65. 2. 任务提交到队列后立即返回,不阻塞
  66. 3. 前一个任务完成后,下一个任务自动开始
  67. """
  68. logger.info(f"上传文件请求参数:{file_json}")
  69. # 提取任务ID
  70. task_id = file_json.get("docs", [{}])[0].get("document_id")
  71. if not task_id:
  72. logger.error("缺少 document_id 参数")
  73. return JSONResponse({"code": 400, "message": "缺少 document_id 参数"})
  74. # 检查队列管理器是否运行
  75. if not queue_manager.is_running:
  76. await queue_manager.start()
  77. logger.info("队列管理器已启动")
  78. # 定义任务处理函数
  79. async def process_task():
  80. parse_file = ProcessDocuments(file_json)
  81. return await parse_file.process_documents(file_json, task_id)
  82. # 提交任务到队列
  83. await queue_manager.submit_task(task_id, process_task)
  84. # 获取队列状态
  85. queue_status = queue_manager.get_queue_status()
  86. logger.info(f"任务已提交到队列: {task_id}, 队列状态: {queue_status}")
  87. return JSONResponse({
  88. "code": 200,
  89. "message": "任务已提交到队列,正在处理中",
  90. "task_id": task_id,
  91. "queue_status": {
  92. "queue_size": queue_status["queue_size"],
  93. "active_tasks_count": queue_status["active_tasks_count"],
  94. "position_in_queue": queue_status["queue_size"] + queue_status["active_tasks_count"]
  95. }
  96. })
  97. @app.get("/upload_queue/status")
  98. async def get_queue_status():
  99. """
  100. 获取上传队列的整体状态
  101. 返回:
  102. - is_running: 队列管理器是否运行中
  103. - max_concurrent: 最大并发数
  104. - queue_size: 等待队列中的任务数
  105. - active_tasks_count: 正在执行的任务数
  106. - completed_tasks_count: 已完成的任务数
  107. - failed_tasks_count: 失败的任务数
  108. - active_tasks: 正在执行的任务ID列表
  109. """
  110. status = queue_manager.get_queue_status()
  111. logger.info(f"查询队列状态: {status}")
  112. return JSONResponse({
  113. "code": 200,
  114. "data": status
  115. })
  116. @app.get("/upload_queue/task/{task_id}")
  117. async def get_task_status(task_id: str):
  118. """
  119. 获取指定任务的状态
  120. 参数:
  121. task_id: 任务ID(document_id)
  122. 返回:
  123. - status: 任务状态(running/completed/failed/not_found)
  124. - 其他任务详细信息
  125. """
  126. task_status = queue_manager.get_task_status(task_id)
  127. logger.info(f"查询任务状态 [{task_id}]: {task_status}")
  128. if task_status["status"] == "not_found":
  129. return JSONResponse({
  130. "code": 404,
  131. "message": f"任务不存在: {task_id}",
  132. "data": task_status
  133. })
  134. return JSONResponse({
  135. "code": 200,
  136. "data": task_status
  137. })
  138. @app.post("/upload_queue/clear_history")
  139. async def clear_queue_history():
  140. """
  141. 清空队列的历史记录(已完成和失败的任务)
  142. 注意:不会影响正在执行的任务
  143. """
  144. queue_manager.clear_history()
  145. logger.info("队列历史记录已清空")
  146. return JSONResponse({
  147. "code": 200,
  148. "message": "队列历史记录已清空"
  149. })
  150. # @app.post("/network/search")
  151. # async def chat_with_rag(request: Request, chat_json:dict):
  152. # retriever = ChatRetrieverRag(chat_json)
  153. # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  154. @app.post("/rag/chat")
  155. async def chat_with_rag(request: Request, chat_json:dict):
  156. retriever = ChatRetrieverRag(chat_json)
  157. return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  158. @app.post("/rag/token_speed")
  159. async def get_token_speed(request_json: dict):
  160. """
  161. 调用非流式LLM并返回token生成速度
  162. 参数:
  163. - query: 用户查询
  164. - model: 模型名称
  165. 返回:
  166. - token_speed: token生成速度(tokens/秒)
  167. """
  168. import time
  169. from config import model_name_vllm_url_dict
  170. query = request_json.get("query", "你好")
  171. model = request_json.get("model")
  172. if model == "Qwen3-30B":
  173. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  174. if not query:
  175. return JSONResponse({"code": 400, "message": "query不能为空"})
  176. if not model:
  177. return JSONResponse({"code": 400, "message": "model不能为空"})
  178. # 验证模型是否存在
  179. if model not in model_name_vllm_url_dict:
  180. return JSONResponse({"code": 400, "message": f"模型{model}不存在"})
  181. logger.info(f"token速度测试请求 - query: {query}, model: {model}")
  182. try:
  183. # 创建VllmApi实例
  184. vllm_client = VllmApi(request_json)
  185. # 记录开始时间
  186. start_time = time.time()
  187. # 调用非流式接口,获取完整响应
  188. messages = [{"role": "user", "content": query}]
  189. completion = await vllm_client.generate_non_stream_async(
  190. prompt=messages,
  191. model=model,
  192. return_full_response=True
  193. )
  194. # 记录结束时间
  195. end_time = time.time()
  196. elapsed_time = end_time - start_time
  197. # 从LLM返回的usage字段获取token信息
  198. usage = completion.usage
  199. completion_tokens = usage.completion_tokens
  200. prompt_tokens = usage.prompt_tokens
  201. total_tokens = usage.total_tokens
  202. # 计算token生成速度
  203. token_speed = completion_tokens / elapsed_time if elapsed_time > 0 else 0
  204. logger.info(f"token速度测试完成 - 耗时: {elapsed_time:.2f}秒, 输入tokens: {prompt_tokens}, 输出tokens: {completion_tokens}, 速度: {token_speed:.2f} tokens/秒")
  205. return JSONResponse({
  206. "code": 200,
  207. "data": {
  208. "token_speed": round(token_speed, 2),
  209. # "prompt_tokens": prompt_tokens,
  210. # "completion_tokens": completion_tokens,
  211. # "total_tokens": total_tokens,
  212. "elapsed_time": round(elapsed_time, 2)
  213. }
  214. })
  215. except Exception as e:
  216. logger.error(f"token速度测试失败: {str(e)}")
  217. return JSONResponse({"code": 500, "message": f"请求失败: {str(e)}"})
  218. @app.post("/rag/query")
  219. async def generate_query(request: Request, query_json:dict):
  220. logger.info(f"请求参数:{query_json}")
  221. relevant_query = ChatRetrieverRag(query_json)
  222. relevant_json = await relevant_query.generate_relevant_query(query_json)
  223. return JSONResponse(relevant_json)
  224. @app.get("/rag/slice/search/{chat_id}")
  225. async def generate_query(request: Request, chat_id:str = None):
  226. logger.info(f"根据chat_id查询切片信息:{chat_id}")
  227. chat = ChatRetrieverRag(chat_id=chat_id)
  228. chunk_json = await chat.search_slice()
  229. return JSONResponse(chunk_json)
  230. @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
  231. async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
  232. logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
  233. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
  234. logger.info(f"删除切片信息的结果:{resp}")
  235. return JSONResponse(resp)
  236. @app.post("/rag/batch_delete_slice")
  237. async def batch_delete_by_chunk_id(delete_json: dict):
  238. """
  239. 批量删除切片接口
  240. 参数:
  241. delete_json: {
  242. "knowledge_id": str, # 知识库ID
  243. "document_id": str, # 文档ID
  244. "slice_id": list # 切片ID列表
  245. }
  246. 返回:
  247. {
  248. "code": 200/500,
  249. "message": str
  250. }
  251. """
  252. knowledge_id = delete_json.get("knowledge_id")
  253. document_id = delete_json.get("document_id")
  254. slice_ids = delete_json.get("slice_id", [])
  255. # 参数验证
  256. if not knowledge_id:
  257. logger.error("批量删除切片:缺少 knowledge_id 参数")
  258. return JSONResponse({"code": 400, "message": "缺少 knowledge_id 参数"})
  259. if not document_id:
  260. logger.error("批量删除切片:缺少 document_id 参数")
  261. return JSONResponse({"code": 400, "message": "缺少 document_id 参数"})
  262. if not slice_ids or not isinstance(slice_ids, list):
  263. logger.error("批量删除切片:slice_id 必须是非空列表")
  264. return JSONResponse({"code": 400, "message": "slice_id 必须是非空列表"})
  265. logger.info(f"批量删除切片接口,知识库:{knowledge_id}, 文档ID:{document_id}, 切片数量:{len(slice_ids)}")
  266. # 创建 MilvusOperate 实例并执行批量删除
  267. milvus_client = MilvusOperate(collection_name=knowledge_id)
  268. resp = milvus_client._batch_delete_by_chunk_ids(slice_ids, knowledge_id, document_id)
  269. logger.info(f"批量删除切片结果:{resp}")
  270. return JSONResponse(resp)
  271. @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
  272. async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
  273. logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
  274. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
  275. rows = MysqlOperate().query_knowledge_by_ids([knowledge_id])
  276. if rows:
  277. rows_gp = [row["knowledge_graph"] for row in rows]
  278. if rows_gp:
  279. rag_mgr = AsyncLightRAGManager()
  280. rag_mgr = await rag_mgr.init_workspace(knowledge_id)
  281. result = await rag_mgr.adelete_by_doc_id(knowledge_id, delete_llm_cache=True)
  282. # 删除成功
  283. if result.status == "success":
  284. logger.info("知识图谱删除成功")
  285. # 文档不存在
  286. elif result.status == "not_found":
  287. logger.info("知识图谱文档不存在")
  288. # 不允许删除(管道忙碌)
  289. elif result.status == "not_allowed":
  290. logger.info("知识图谱当前不允许删除")
  291. # 删除失败
  292. elif result.status == "failure":
  293. logger.info(f"知识图谱删除失败: {result.message}")
  294. logger.info(f"删除文档的结果:{resp}")
  295. return JSONResponse(resp)
  296. @app.put("/rag/update_slice")
  297. async def put_by_id(slice_json:dict):
  298. resp = await update_slice(slice_json)
  299. return JSONResponse(resp)
  300. @app.post("/rag/insert_slice")
  301. async def insert_slice_text(slice_json:dict):
  302. resp = await insert_slice(slice_json)
  303. return JSONResponse(resp)
  304. @app.get("/rag/search/{knowledge_id}/{slice_id}")
  305. async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
  306. # 根据切片id查询切片信息
  307. # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
  308. logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
  309. collection_name = knowledge_id # 根据传过来的id处理对应知识库
  310. resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
  311. logger.info(f"根据切片id查询结果:{resp}")
  312. return JSONResponse(resp)
  313. @app.post("/rag/search_word")
  314. async def search_by_key_word(search_json:dict):
  315. # 根据doc_id 查询切片列表信息
  316. collection_name = search_json.get("knowledge_id")
  317. logger.info(f"根据关键字请求的参数:{search_json}")
  318. resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
  319. logger.info(f"根据关键字查询的结果:{resp}")
  320. return JSONResponse(resp)
  321. @app.delete("/rag/delete_knowledge/{knowledge_id}")
  322. async def delete_collection(knowledge_id: str = None):
  323. logger.info(f"删除数据库请求的参数:{knowledge_id}")
  324. resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
  325. rows = MysqlOperate().query_knowledge_by_ids([knowledge_id])
  326. if rows:
  327. rows_gp = [row["knowledge_graph"] for row in rows]
  328. if rows_gp:
  329. rag = AsyncLightRAGManager()
  330. rag = await rag.init_workspace(knowledge_id)
  331. storages = [
  332. rag.text_chunks,
  333. rag.full_docs,
  334. rag.full_entities,
  335. rag.full_relations,
  336. rag.entity_chunks,
  337. rag.relation_chunks,
  338. rag.entities_vdb,
  339. rag.relationships_vdb,
  340. rag.chunks_vdb,
  341. rag.chunk_entity_relation_graph,
  342. rag.doc_status,
  343. ]
  344. # 并行执行所有drop操作
  345. import asyncio
  346. drop_tasks = []
  347. for storage in storages:
  348. if storage is not None:
  349. drop_tasks.append(storage.drop())
  350. results = await asyncio.gather(*drop_tasks, return_exceptions=True)
  351. # 检查结果
  352. for i, result in enumerate(results):
  353. storage_name = storages[i].__class__.__name__
  354. if isinstance(result, Exception):
  355. logger.info(f"删除失败 {storage_name}: {result}")
  356. else:
  357. logger.info(f"成功删除 {storage_name}")
  358. logger.info(f"删除向量库结果:{resp}")
  359. return JSONResponse(resp)
  360. @app.post("/rag/create_collection")
  361. async def create_collection(collection: dict):
  362. collection_name = collection.get("knowledge_id")
  363. embedding_name = collection.get("embedding_id")
  364. logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
  365. resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
  366. logger.info(f"创建向量库结果:{resp}")
  367. return JSONResponse(resp)
  368. @app.post("/rag/copy_docs_to_collection")
  369. async def copy_docs_to_collection_route(copy_request: dict):
  370. """复制指定文档到新集合(索引)"""
  371. resp = await do_copy_docs(copy_request)
  372. return JSONResponse(resp)
  373. @app.post("/rag/copy_multi_docs_to_collection")
  374. async def copy_multi_docs_to_collection_route(copy_request: dict):
  375. """从多个源集合复制文档到目标集合(包含向量库和MySQL元数据)"""
  376. resp = await do_copy_multi_docs(copy_request)
  377. return JSONResponse(resp)
  378. @app.post("/rag/generate_slice_metadata")
  379. async def generate_slice_metadata_route(request_json: dict):
  380. """生成切片的 qa、question、summary 字段"""
  381. resp = await do_generate_slice_metadata(request_json)
  382. return JSONResponse(resp)
  383. # ==================== 任务删除接口 ====================
  384. @app.post("/rag/cancel_task/{task_id}")
  385. async def cancel_task(task_id: str):
  386. """
  387. 取消正在执行的文档处理任务
  388. 参数:
  389. - task_id: 任务ID(即 document_id)
  390. 返回:
  391. - code: 200 成功, 404 任务不存在, 400 任务已取消
  392. """
  393. mysql_client = MysqlOperate()
  394. # 首先删除 bm_document 表中的记录
  395. del_success, del_info = mysql_client.delete_document(task_id)
  396. # 尝试取消已注册的任务
  397. success, message = task_registry.cancel(task_id)
  398. if success:
  399. if del_success:
  400. logger.info(f"任务 {task_id} 已被取消,bm_document 记录已删除")
  401. else:
  402. logger.warning(f"任务 {task_id} 取消成功,但删除 bm_document 记录失败: {del_info}")
  403. return JSONResponse({"code": 200, "message": "任务取消成功"})
  404. # 如果任务未注册,尝试从队列中删除
  405. if "不存在" in message:
  406. removed = await queue_manager.remove_from_queue(task_id)
  407. if removed:
  408. if del_success:
  409. logger.info(f"任务 {task_id} 从队列中删除成功,bm_document 记录已删除")
  410. else:
  411. logger.warning(f"任务 {task_id} 从队列中删除成功,但删除 bm_document 记录失败: {del_info}")
  412. return JSONResponse({"code": 200, "message": "任务已从队列中删除"})
  413. logger.info(f"任务 {task_id} 未找到")
  414. return JSONResponse({"code": 404, "message": "任务不存在"})
  415. logger.info(f"message: {message}")
  416. return JSONResponse({"code": 400, "message": message})
  417. # @app.get("/rag/tasks")
  418. # async def get_active_tasks():
  419. # """
  420. # 获取所有活跃任务列表(调试用)
  421. # """
  422. # tasks = task_registry.get_all_tasks()
  423. # return JSONResponse({
  424. # "code": 200,
  425. # "data": tasks,
  426. # "count": len(tasks)
  427. # })
  428. # ==== RAG评估 ====
  429. @app.get("/rag/eval")
  430. async def rag_eval_page():
  431. """RAG评估页面"""
  432. html_path = os.path.join(static_dir, "rag_eval.html")
  433. if os.path.isfile(html_path):
  434. return FileResponse(html_path, media_type="text/html")
  435. return JSONResponse({"code": 404, "message": "页面不存在"})
  436. @app.post("/rag/evaluate")
  437. async def rag_evaluate(eval_request: dict):
  438. """
  439. RAG评估接口
  440. 参数:
  441. - file_url: 网络文件路径(JSON格式,包含question和ground_truth)
  442. - knowledge_ids: 知识库ID列表
  443. - embedding_id: 嵌入模型ID(默认 e5)
  444. - temperature: 温度参数(默认 0.6)
  445. - top_p: top_p参数(默认 0.7)
  446. - max_tokens: 最大token数(默认 4096)
  447. - model: 模型名称(默认 Qwen3-Coder-30B-loft)
  448. - slice_count: 检索切片数量(默认 5)
  449. 返回:
  450. - knowledge_ids: 知识库ID列表
  451. - metrics: 评估指标及值
  452. """
  453. logger.info(f"RAG评估的请求参数:{eval_request}")
  454. file_url = eval_request.get("file_url")
  455. knowledge_ids = eval_request.get("knowledge_ids", [])
  456. if not file_url:
  457. return JSONResponse({"code": 400, "message": "file_url 不能为空"})
  458. if not knowledge_ids:
  459. return JSONResponse({"code": 400, "message": "knowledge_ids 不能为空"})
  460. logger.info(f"RAG评估请求 - 文件: {file_url}, 知识库: {knowledge_ids}")
  461. result = await run_rag_evaluation(
  462. file_url=file_url,
  463. knowledge_ids=knowledge_ids,
  464. embedding_id=eval_request.get("embedding_id", "e5"),
  465. temperature=eval_request.get("temperature", 0.6),
  466. top_p=eval_request.get("top_p", 0.7),
  467. max_tokens=eval_request.get("max_tokens", 4096),
  468. model=eval_request.get("model", "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"),
  469. slice_count=int(eval_request.get("slice_count", 5)),
  470. task_id=eval_request.get("task_id"),
  471. tenant_id = eval_request.get("tenant_id"),
  472. prompt = eval_request.get("prompt"),
  473. rerank_model_name = eval_request.get("rerank_model_name", "Qwen3-Reranker-0.6B"),
  474. )
  475. logger.info(f"RAG评估结果: code={result.get('code')}")
  476. return JSONResponse(result)
  477. # @app.post("/ai/translate/audio")
  478. # async def audio_text(request: Request, file: UploadFile = File(...)):
  479. # uuid_str = uuid.uuid4()
  480. # file_name = f"{uuid_str}_{file.filename}"
  481. # file_path = os.path.join("./tmp_file/audio/", file_name)
  482. # with open(file_path, "wb") as f:
  483. # f.write(await file.read())
  484. # asr_text = Audio2Text(file_path).audo_to_text()
  485. # logger.info(f"asr result: {asr_text}, audio file name: {file_name}")
  486. # return asr_text
  487. from pathlib import Path
  488. import shutil
  489. # ===== 文件目录配置 =====
  490. BASE_DIR = Path(__file__).parent
  491. FILE_DIR = BASE_DIR / "rag" / "sensitive_words"
  492. FILE_DIR.mkdir(parents=True, exist_ok=True)
  493. ALLOWED_SUFFIX = ".txt"
  494. # 停用词库文件列表
  495. @app.get("/files")
  496. def list_files():
  497. files = [
  498. f.name
  499. for f in FILE_DIR.iterdir()
  500. if f.is_file() and f.suffix == ALLOWED_SUFFIX
  501. ]
  502. return {
  503. "count": len(files),
  504. "files": files,
  505. }
  506. # 停用词库文件下载(txt)
  507. @app.get("/files/{filename}")
  508. def download_file(filename: str):
  509. # 防止路径穿越
  510. if "/" in filename or "\\" in filename:
  511. raise HTTPException(status_code=400, detail="非法文件名")
  512. file_path = FILE_DIR / filename
  513. if not file_path.exists():
  514. raise HTTPException(status_code=404, detail="文件不存在")
  515. if file_path.suffix != ALLOWED_SUFFIX:
  516. raise HTTPException(status_code=400, detail="只允许下载 txt 文件")
  517. return FileResponse(
  518. path=file_path,
  519. filename=filename,
  520. media_type="text/plain",
  521. )
  522. # 停用词库文件上传(txt)
  523. @app.post("/files/upload")
  524. def upload_file(file: UploadFile = File(...)):
  525. filename = file.filename
  526. if not filename:
  527. raise HTTPException(status_code=400, detail="文件名不能为空")
  528. if not filename.endswith(ALLOWED_SUFFIX):
  529. raise HTTPException(status_code=400, detail="只允许上传 txt 文件")
  530. target_path = FILE_DIR / filename
  531. # 覆盖写入
  532. with target_path.open("wb") as f:
  533. shutil.copyfileobj(file.file, f)
  534. return {
  535. "message": "上传成功",
  536. "filename": filename,
  537. "path": str(target_path),
  538. }
  539. if __name__ == "__main__":
  540. uvicorn.run(app, host="0.0.0.0", port=6000, loop="asyncio")