| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- # 请求的入口
- from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body
- from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
- from fastapi.middleware.cors import CORSMiddleware
- from sse_starlette import EventSourceResponse
- import uvicorn
- from utils.get_logger import setup_logger
- from rag.vector_db.milvus_vector import HybridRetriever
- from response_info import generate_message, generate_response
- from rag.db import MilvusOperate
- from rag.file_process import ParseFile
- from rag.documents_process import ProcessDocuments
- from rag.chat_message import ChatRetrieverRag
- logger = setup_logger(__name__)
- app = FastAPI()
- # 设置跨域
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- @app.get("/health")
- async def health_check():
- return {"status": "healthy"}
- @app.post("/upload_knowledge")
- async def upload_file_to_db(file_json: dict):
- logger.info(f"上传文件请求参数:{file_json}")
- parse_file = ProcessDocuments(file_json)
- resp = await parse_file.process_documents(file_json)
- # parse_file = ParseFile(file_json)
- # resp = parse_file.save_file_to_db()
- logger.info(f"上传文件响应结果:{resp}")
- return JSONResponse(resp)
- # @app.post("/network/search")
- # async def chat_with_rag(request: Request, chat_json:dict):
- # retriever = ChatRetrieverRag(chat_json)
- # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
- @app.post("/rag/chat")
- async def chat_with_rag(request: Request, chat_json:dict):
- retriever = ChatRetrieverRag(chat_json)
- return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
- @app.post("/rag/query")
- async def generate_query(request: Request, query_json:dict):
- logger.info(f"请求参数:{query_json}")
- relevant_query = ChatRetrieverRag(query_json)
- relevant_json = await relevant_query.generate_relevant_query(query_json)
- return JSONResponse(relevant_json)
- @app.get("/rag/slice/search/{chat_id}")
- async def generate_query(request: Request, chat_id:str = None):
- chat = ChatRetrieverRag(chat_id=chat_id)
- chunk_json = await chat.search_slice()
- return JSONResponse(chunk_json)
- @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
- async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
- logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
- logger.info(f"删除切片信息的结果:{resp}")
- return JSONResponse(resp)
- @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
- async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
- logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
- logger.info(f"删除文档的结果:{resp}")
- return JSONResponse(resp)
- @app.put("/rag/update_slice")
- async def put_by_id(slice_json:dict):
- logger.info(f"更新切片信息的请求参数:{slice_json}")
- collection_name = slice_json.get("knowledge_id")
- resp = MilvusOperate(collection_name=collection_name)._put_by_id(slice_json)
- logger.info(f"更新切片信息的结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/insert_slice")
- async def insert_slice_text(slice_json:dict):
- logger.info(f"新增切片信息的请求参数:{slice_json}")
- collection_name = slice_json.get("knowledge_id")
- resp = MilvusOperate(collection_name=collection_name)._insert_slice(slice_json)
- logger.info(f"新增切片信息的结果:{resp}")
- return JSONResponse(resp)
- @app.get("/rag/search/{knowledge_id}/{slice_id}")
- async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
- # 根据切片id查询切片信息
- # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
- logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
- collection_name = knowledge_id # 根据传过来的id处理对应知识库
- resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
- logger.info(f"根据切片id查询结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/search_word")
- async def search_by_key_word(search_json:dict):
- # 根据doc_id 查询切片列表信息
- collection_name = search_json.get("knowledge_id")
- logger.info(f"根据关键字请求的参数:{search_json}")
- resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
- logger.info(f"根据关键字查询的结果:{resp}")
- return JSONResponse(resp)
- @app.delete("/rag/delete_knowledge/{knowledge_id}")
- async def delete_collection(knowledge_id: str = None):
- logger.info(f"删除数据库请求的参数:{knowledge_id}")
- resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
- logger.info(f"删除向量库结果:{resp}")
- return JSONResponse(resp)
- @app.post("/rag/create_collection")
- async def create_collection(collection: dict):
- collection_name = collection.get("knowledge_id")
- embedding_name = collection.get("embedding_id")
- logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
- resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
- logger.info(f"创建向量库结果:{resp}")
-
- return JSONResponse(resp)
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=18079)
|