# 请求的入口 from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body from fastapi.responses import JSONResponse, FileResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from sse_starlette import EventSourceResponse import uvicorn from utils.get_logger import setup_logger from rag.vector_db.milvus_vector import HybridRetriever from response_info import generate_message, generate_response from rag.db import MilvusOperate from rag.file_process import ParseFile from rag.documents_process import ProcessDocuments from rag.chat_message import ChatRetrieverRag logger = setup_logger(__name__) app = FastAPI() # 设置跨域 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health_check(): return {"status": "healthy"} @app.post("/upload_knowledge") async def upload_file_to_db(file_json: dict): logger.info(f"上传文件请求参数:{file_json}") parse_file = ProcessDocuments(file_json) resp = await parse_file.process_documents(file_json) # parse_file = ParseFile(file_json) # resp = parse_file.save_file_to_db() logger.info(f"上传文件响应结果:{resp}") return JSONResponse(resp) # @app.post("/network/search") # async def chat_with_rag(request: Request, chat_json:dict): # retriever = ChatRetrieverRag(chat_json) # return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300) @app.post("/rag/chat") async def chat_with_rag(request: Request, chat_json:dict): retriever = ChatRetrieverRag(chat_json) return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300) @app.post("/rag/query") async def generate_query(request: Request, query_json:dict): logger.info(f"请求参数:{query_json}") relevant_query = ChatRetrieverRag(query_json) relevant_json = await relevant_query.generate_relevant_query(query_json) return JSONResponse(relevant_json) @app.get("/rag/slice/search/{chat_id}") async def generate_query(request: Request, chat_id:str = None): chat = ChatRetrieverRag(chat_id=chat_id) chunk_json = await chat.search_slice() return JSONResponse(chunk_json) @app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}") async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None): logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}") resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id) logger.info(f"删除切片信息的结果:{resp}") return JSONResponse(resp) @app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}") async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None): logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}") resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id) logger.info(f"删除文档的结果:{resp}") return JSONResponse(resp) @app.put("/rag/update_slice") async def put_by_id(slice_json:dict): logger.info(f"更新切片信息的请求参数:{slice_json}") collection_name = slice_json.get("knowledge_id") resp = MilvusOperate(collection_name=collection_name)._put_by_id(slice_json) logger.info(f"更新切片信息的结果:{resp}") return JSONResponse(resp) @app.post("/rag/insert_slice") async def insert_slice_text(slice_json:dict): logger.info(f"新增切片信息的请求参数:{slice_json}") collection_name = slice_json.get("knowledge_id") resp = MilvusOperate(collection_name=collection_name)._insert_slice(slice_json) logger.info(f"新增切片信息的结果:{resp}") return JSONResponse(resp) @app.get("/rag/search/{knowledge_id}/{slice_id}") async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None): # 根据切片id查询切片信息 # print(f"知识库:{knowledge_id}, 切片:{slice_id}") logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}") collection_name = knowledge_id # 根据传过来的id处理对应知识库 resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id) logger.info(f"根据切片id查询结果:{resp}") return JSONResponse(resp) @app.post("/rag/search_word") async def search_by_key_word(search_json:dict): # 根据doc_id 查询切片列表信息 collection_name = search_json.get("knowledge_id") logger.info(f"根据关键字请求的参数:{search_json}") resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json) logger.info(f"根据关键字查询的结果:{resp}") return JSONResponse(resp) @app.delete("/rag/delete_knowledge/{knowledge_id}") async def delete_collection(knowledge_id: str = None): logger.info(f"删除数据库请求的参数:{knowledge_id}") resp = MilvusOperate(collection_name=knowledge_id)._delete_collection() logger.info(f"删除向量库结果:{resp}") return JSONResponse(resp) @app.post("/rag/create_collection") async def create_collection(collection: dict): collection_name = collection.get("knowledge_id") embedding_name = collection.get("embedding_id") logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}") resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection() logger.info(f"创建向量库结果:{resp}") return JSONResponse(resp) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=18079)