| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- import asyncio
- import itertools
- from typing import Dict
- import httpx
- from fastapi import FastAPI, Request, WebSocket
- from fastapi.responses import StreamingResponse, Response
- import uvicorn
- # ===================== 配置区 =====================
- WORKERS = [
- "http://127.0.0.1:6001",
- "http://127.0.0.1:6002",
- "http://127.0.0.1:6003",
- "http://127.0.0.1:6004",
- ]
- REQUEST_TIMEOUT = None # Streaming / SSE 必须是 None
- # ===================== 核心状态 =====================
- worker_cycle = itertools.cycle(WORKERS)
- def next_worker() -> str:
- """轮询获取下一个 worker"""
- return next(worker_cycle)
- # ===================== FastAPI =====================
- app = FastAPI()
- client = httpx.AsyncClient(
- timeout=REQUEST_TIMEOUT,
- limits=httpx.Limits(max_connections=1000)
- )
- # ===================== HTTP 转发 =====================
- @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
- async def proxy_http(request: Request, path: str):
- """
- 普通 HTTP / SSE 请求转发
- """
- backend = next_worker()
- url = f"{backend}/{path}"
- headers = dict(request.headers)
- headers.pop("host", None)
- body = await request.body()
- # Streaming / SSE
- if request.headers.get("accept") == "text/event-stream":
- async def stream():
- async with client.stream(
- method=request.method,
- url=url,
- headers=headers,
- content=body,
- params=request.query_params,
- ) as resp:
- async for chunk in resp.aiter_raw():
- yield chunk
- return StreamingResponse(
- stream(),
- media_type="text/event-stream",
- )
- # 普通请求
- resp = await client.request(
- method=request.method,
- url=url,
- headers=headers,
- content=body,
- params=request.query_params,
- )
- return Response(
- content=resp.content,
- status_code=resp.status_code,
- headers=dict(resp.headers),
- media_type=resp.headers.get("content-type"),
- )
- # ===================== WebSocket 转发 =====================
- @app.websocket("/{path:path}")
- async def proxy_ws(ws: WebSocket, path: str):
- """
- WebSocket 直通转发
- """
- await ws.accept()
- backend = next_worker()
- backend_ws = backend.replace("http", "ws") + f"/{path}"
- async with httpx.AsyncClient() as session:
- async with session.ws_connect(backend_ws) as backend_socket:
- async def client_to_backend():
- while True:
- msg = await ws.receive_text()
- await backend_socket.send_text(msg)
- async def backend_to_client():
- while True:
- msg = await backend_socket.receive_text()
- await ws.send_text(msg)
- await asyncio.gather(
- client_to_backend(),
- backend_to_client(),
- )
- # ===================== 健康检查 =====================
- @app.get("/health")
- async def health():
- return {
- "status": "ok",
- "workers": WORKERS,
- }
- # ===================== 启动 =====================
- if __name__ == "__main__":
- uvicorn.run(
- app,
- host="0.0.0.0",
- port=6000,
- log_level="info",
- )
|