rag_server.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # 请求的入口
  2. import os
  3. os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
  4. from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body, HTTPException
  5. from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
  6. from fastapi.middleware.cors import CORSMiddleware
  7. from sse_starlette import EventSourceResponse
  8. import uvicorn
  9. from utils.get_logger import setup_logger
  10. from rag.vector_db.milvus_vector import HybridRetriever
  11. from response_info import generate_message, generate_response
  12. from rag.db import MilvusOperate, MysqlOperate
  13. from rag.file_process import ParseFile
  14. from rag.documents_process import ProcessDocuments
  15. from rag.task_registry import task_registry
  16. from rag.chat_message import ChatRetrieverRag
  17. logger = setup_logger(__name__)
  18. app = FastAPI()
  19. # 设置跨域
  20. app.add_middleware(
  21. CORSMiddleware,
  22. allow_origins=["*"],
  23. allow_credentials=True,
  24. allow_methods=["*"],
  25. allow_headers=["*"],
  26. )
  27. @app.get("/health")
  28. async def health_check():
  29. return {"status": "healthy"}
  30. @app.post("/upload_knowledge")
  31. async def upload_file_to_db(file_json: dict):
  32. logger.info(f"上传文件请求参数:{file_json}")
  33. parse_file = ProcessDocuments(file_json)
  34. resp = await parse_file.process_documents(file_json)
  35. # parse_file = ParseFile(file_json)
  36. # resp = parse_file.save_file_to_db()
  37. logger.info(f"上传文件响应结果:{resp}")
  38. return JSONResponse(resp)
  39. # @app.post("/network/search")
  40. # async def chat_with_rag(request: Request, chat_json:dict):
  41. # retriever = ChatRetrieverRag(chat_json)
  42. # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  43. @app.post("/rag/chat")
  44. async def chat_with_rag(request: Request, chat_json:dict):
  45. retriever = ChatRetrieverRag(chat_json)
  46. return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  47. @app.post("/rag/chat/sync")
  48. async def chat_with_rag_sync(chat_json: dict):
  49. """非流式RAG聊天接口"""
  50. logger.info(f"非流式RAG请求参数:{chat_json}")
  51. retriever = ChatRetrieverRag(chat_json)
  52. result = await retriever.generate_sync_response(chat_json)
  53. logger.info(f"非流式RAG响应结果:{result.get('code')}")
  54. return JSONResponse(result)
  55. @app.post("/rag/query")
  56. async def generate_query(request: Request, query_json:dict):
  57. logger.info(f"请求参数:{query_json}")
  58. relevant_query = ChatRetrieverRag(query_json)
  59. relevant_json = await relevant_query.generate_relevant_query(query_json)
  60. return JSONResponse(relevant_json)
  61. @app.get("/rag/slice/search/{chat_id}")
  62. async def generate_query(request: Request, chat_id:str = None):
  63. chat = ChatRetrieverRag(chat_id=chat_id)
  64. chunk_json = await chat.search_slice()
  65. return JSONResponse(chunk_json)
  66. @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
  67. async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
  68. logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
  69. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
  70. logger.info(f"删除切片信息的结果:{resp}")
  71. return JSONResponse(resp)
  72. @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
  73. async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
  74. logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
  75. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
  76. logger.info(f"删除文档的结果:{resp}")
  77. return JSONResponse(resp)
  78. @app.put("/rag/update_slice")
  79. async def put_by_id(slice_json:dict):
  80. logger.info(f"更新切片信息的请求参数:{slice_json}")
  81. collection_name = slice_json.get("knowledge_id")
  82. resp = MilvusOperate(collection_name=collection_name)._put_by_id(slice_json)
  83. logger.info(f"更新切片信息的结果:{resp}")
  84. return JSONResponse(resp)
  85. @app.post("/rag/insert_slice")
  86. async def insert_slice_text(slice_json:dict):
  87. logger.info(f"新增切片信息的请求参数:{slice_json}")
  88. collection_name = slice_json.get("knowledge_id")
  89. resp = MilvusOperate(collection_name=collection_name)._insert_slice(slice_json)
  90. logger.info(f"新增切片信息的结果:{resp}")
  91. return JSONResponse(resp)
  92. @app.get("/rag/search/{knowledge_id}/{slice_id}")
  93. async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
  94. # 根据切片id查询切片信息
  95. # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
  96. logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
  97. collection_name = knowledge_id # 根据传过来的id处理对应知识库
  98. resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
  99. logger.info(f"根据切片id查询结果:{resp}")
  100. return JSONResponse(resp)
  101. @app.post("/rag/search_word")
  102. async def search_by_key_word(search_json:dict):
  103. # 根据doc_id 查询切片列表信息
  104. collection_name = search_json.get("knowledge_id")
  105. logger.info(f"根据关键字请求的参数:{search_json}")
  106. resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
  107. logger.info(f"根据关键字查询的结果:{resp}")
  108. return JSONResponse(resp)
  109. @app.delete("/rag/delete_knowledge/{knowledge_id}")
  110. async def delete_collection(knowledge_id: str = None):
  111. logger.info(f"删除数据库请求的参数:{knowledge_id}")
  112. resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
  113. logger.info(f"删除向量库结果:{resp}")
  114. return JSONResponse(resp)
  115. @app.post("/rag/create_collection")
  116. async def create_collection(collection: dict):
  117. collection_name = collection.get("knowledge_id")
  118. embedding_name = collection.get("embedding_id")
  119. logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
  120. resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
  121. logger.info(f"创建向量库结果:{resp}")
  122. return JSONResponse(resp)
  123. # ==================== 任务取消接口 ====================
  124. @app.post("/rag/cancel_task/{task_id}")
  125. async def cancel_task(task_id: str):
  126. """
  127. 取消正在执行的文档处理任务
  128. 参数:
  129. - task_id: 任务ID(即 document_id)
  130. 返回:
  131. - code: 200 成功, 404 任务不存在, 400 任务已取消
  132. """
  133. mysql_client = MysqlOperate()
  134. # 首先删除 bm_document 表中的记录
  135. del_success, del_info = mysql_client.delete_document(task_id)
  136. success, message = task_registry.cancel(task_id)
  137. if success:
  138. if del_success:
  139. logger.info(f"任务 {task_id} 已被取消,bm_document 记录已删除")
  140. else:
  141. logger.warning(f"任务 {task_id} 取消成功,但删除 bm_document 记录失败: {del_info}")
  142. return JSONResponse({"code": 200, "message": "任务取消成功"})
  143. if "不存在" in message:
  144. return JSONResponse({"code": 404, "message": message})
  145. return JSONResponse({"code": 400, "message": message})
  146. if __name__ == "__main__":
  147. uvicorn.run(app, host="0.0.0.0", port=6666)