rag_server.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # 请求的入口
  2. from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body
  3. from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from sse_starlette import EventSourceResponse
  6. import uvicorn
  7. from utils.get_logger import setup_logger
  8. from rag.vector_db.milvus_vector import HybridRetriever
  9. from response_info import generate_message, generate_response
  10. from rag.db import MilvusOperate
  11. from rag.file_process import ParseFile
  12. from rag.documents_process import ProcessDocuments
  13. from rag.chat_message import ChatRetrieverRag
  14. logger = setup_logger(__name__)
  15. app = FastAPI()
  16. # 设置跨域
  17. app.add_middleware(
  18. CORSMiddleware,
  19. allow_origins=["*"],
  20. allow_credentials=True,
  21. allow_methods=["*"],
  22. allow_headers=["*"],
  23. )
  24. @app.get("/health")
  25. async def health_check():
  26. return {"status": "healthy"}
  27. @app.post("/upload_knowledge")
  28. async def upload_file_to_db(file_json: dict):
  29. logger.info(f"上传文件请求参数:{file_json}")
  30. parse_file = ProcessDocuments(file_json)
  31. resp = await parse_file.process_documents(file_json)
  32. # parse_file = ParseFile(file_json)
  33. # resp = parse_file.save_file_to_db()
  34. logger.info(f"上传文件响应结果:{resp}")
  35. return JSONResponse(resp)
  36. # @app.post("/network/search")
  37. # async def chat_with_rag(request: Request, chat_json:dict):
  38. # retriever = ChatRetrieverRag(chat_json)
  39. # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  40. @app.post("/rag/chat")
  41. async def chat_with_rag(request: Request, chat_json:dict):
  42. retriever = ChatRetrieverRag(chat_json)
  43. return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
  44. @app.post("/rag/query")
  45. async def generate_query(request: Request, query_json:dict):
  46. logger.info(f"请求参数:{query_json}")
  47. relevant_query = ChatRetrieverRag(query_json)
  48. relevant_json = await relevant_query.generate_relevant_query(query_json)
  49. return JSONResponse(relevant_json)
  50. @app.get("/rag/slice/search/{chat_id}")
  51. async def generate_query(request: Request, chat_id:str = None):
  52. chat = ChatRetrieverRag(chat_id=chat_id)
  53. chunk_json = await chat.search_slice()
  54. return JSONResponse(chunk_json)
  55. @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
  56. async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
  57. logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
  58. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
  59. logger.info(f"删除切片信息的结果:{resp}")
  60. return JSONResponse(resp)
  61. @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
  62. async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
  63. logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
  64. resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
  65. logger.info(f"删除文档的结果:{resp}")
  66. return JSONResponse(resp)
  67. @app.put("/rag/update_slice")
  68. async def put_by_id(slice_json:dict):
  69. logger.info(f"更新切片信息的请求参数:{slice_json}")
  70. collection_name = slice_json.get("knowledge_id")
  71. resp = MilvusOperate(collection_name=collection_name)._put_by_id(slice_json)
  72. logger.info(f"更新切片信息的结果:{resp}")
  73. return JSONResponse(resp)
  74. @app.post("/rag/insert_slice")
  75. async def insert_slice_text(slice_json:dict):
  76. logger.info(f"新增切片信息的请求参数:{slice_json}")
  77. collection_name = slice_json.get("knowledge_id")
  78. resp = MilvusOperate(collection_name=collection_name)._insert_slice(slice_json)
  79. logger.info(f"新增切片信息的结果:{resp}")
  80. return JSONResponse(resp)
  81. @app.get("/rag/search/{knowledge_id}/{slice_id}")
  82. async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
  83. # 根据切片id查询切片信息
  84. # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
  85. logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
  86. collection_name = knowledge_id # 根据传过来的id处理对应知识库
  87. resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
  88. logger.info(f"根据切片id查询结果:{resp}")
  89. return JSONResponse(resp)
  90. @app.post("/rag/search_word")
  91. async def search_by_key_word(search_json:dict):
  92. # 根据doc_id 查询切片列表信息
  93. collection_name = search_json.get("knowledge_id")
  94. logger.info(f"根据关键字请求的参数:{search_json}")
  95. resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
  96. logger.info(f"根据关键字查询的结果:{resp}")
  97. return JSONResponse(resp)
  98. @app.delete("/rag/delete_knowledge/{knowledge_id}")
  99. async def delete_collection(knowledge_id: str = None):
  100. logger.info(f"删除数据库请求的参数:{knowledge_id}")
  101. resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
  102. logger.info(f"删除向量库结果:{resp}")
  103. return JSONResponse(resp)
  104. @app.post("/rag/create_collection")
  105. async def create_collection(collection: dict):
  106. collection_name = collection.get("knowledge_id")
  107. embedding_name = collection.get("embedding_id")
  108. logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
  109. resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
  110. logger.info(f"创建向量库结果:{resp}")
  111. return JSONResponse(resp)
  112. if __name__ == "__main__":
  113. uvicorn.run(app, host="0.0.0.0", port=18079)