import asyncio import itertools from typing import Dict import httpx from fastapi import FastAPI, Request, WebSocket from fastapi.responses import StreamingResponse, Response import uvicorn # ===================== 配置区 ===================== WORKERS = [ "http://127.0.0.1:6001", "http://127.0.0.1:6002", "http://127.0.0.1:6003", "http://127.0.0.1:6004", ] REQUEST_TIMEOUT = None # Streaming / SSE 必须是 None # ===================== 核心状态 ===================== worker_cycle = itertools.cycle(WORKERS) def next_worker() -> str: """轮询获取下一个 worker""" return next(worker_cycle) # ===================== FastAPI ===================== app = FastAPI() client = httpx.AsyncClient( timeout=REQUEST_TIMEOUT, limits=httpx.Limits(max_connections=1000) ) # ===================== HTTP 转发 ===================== @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) async def proxy_http(request: Request, path: str): """ 普通 HTTP / SSE 请求转发 """ backend = next_worker() url = f"{backend}/{path}" headers = dict(request.headers) headers.pop("host", None) body = await request.body() # Streaming / SSE if request.headers.get("accept") == "text/event-stream": async def stream(): async with client.stream( method=request.method, url=url, headers=headers, content=body, params=request.query_params, ) as resp: async for chunk in resp.aiter_raw(): yield chunk return StreamingResponse( stream(), media_type="text/event-stream", ) # 普通请求 resp = await client.request( method=request.method, url=url, headers=headers, content=body, params=request.query_params, ) return Response( content=resp.content, status_code=resp.status_code, headers=dict(resp.headers), media_type=resp.headers.get("content-type"), ) # ===================== WebSocket 转发 ===================== @app.websocket("/{path:path}") async def proxy_ws(ws: WebSocket, path: str): """ WebSocket 直通转发 """ await ws.accept() backend = next_worker() backend_ws = backend.replace("http", "ws") + f"/{path}" async with httpx.AsyncClient() as session: async with session.ws_connect(backend_ws) as backend_socket: async def client_to_backend(): while True: msg = await ws.receive_text() await backend_socket.send_text(msg) async def backend_to_client(): while True: msg = await backend_socket.receive_text() await ws.send_text(msg) await asyncio.gather( client_to_backend(), backend_to_client(), ) # ===================== 健康检查 ===================== @app.get("/health") async def health(): return { "status": "ok", "workers": WORKERS, } # ===================== 启动 ===================== if __name__ == "__main__": uvicorn.run( app, host="0.0.0.0", port=6000, log_level="info", )