lb_gateway.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import asyncio
  2. import itertools
  3. from typing import Dict
  4. import httpx
  5. from fastapi import FastAPI, Request, WebSocket
  6. from fastapi.responses import StreamingResponse, Response
  7. import uvicorn
  8. # ===================== 配置区 =====================
  9. WORKERS = [
  10. "http://127.0.0.1:6001",
  11. "http://127.0.0.1:6002",
  12. "http://127.0.0.1:6003",
  13. "http://127.0.0.1:6004",
  14. ]
  15. REQUEST_TIMEOUT = None # Streaming / SSE 必须是 None
  16. # ===================== 核心状态 =====================
  17. worker_cycle = itertools.cycle(WORKERS)
  18. def next_worker() -> str:
  19. """轮询获取下一个 worker"""
  20. return next(worker_cycle)
  21. # ===================== FastAPI =====================
  22. app = FastAPI()
  23. client = httpx.AsyncClient(
  24. timeout=REQUEST_TIMEOUT,
  25. limits=httpx.Limits(max_connections=1000)
  26. )
  27. # ===================== HTTP 转发 =====================
  28. @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
  29. async def proxy_http(request: Request, path: str):
  30. """
  31. 普通 HTTP / SSE 请求转发
  32. """
  33. backend = next_worker()
  34. url = f"{backend}/{path}"
  35. headers = dict(request.headers)
  36. headers.pop("host", None)
  37. body = await request.body()
  38. # Streaming / SSE
  39. if request.headers.get("accept") == "text/event-stream":
  40. async def stream():
  41. async with client.stream(
  42. method=request.method,
  43. url=url,
  44. headers=headers,
  45. content=body,
  46. params=request.query_params,
  47. ) as resp:
  48. async for chunk in resp.aiter_raw():
  49. yield chunk
  50. return StreamingResponse(
  51. stream(),
  52. media_type="text/event-stream",
  53. )
  54. # 普通请求
  55. resp = await client.request(
  56. method=request.method,
  57. url=url,
  58. headers=headers,
  59. content=body,
  60. params=request.query_params,
  61. )
  62. return Response(
  63. content=resp.content,
  64. status_code=resp.status_code,
  65. headers=dict(resp.headers),
  66. media_type=resp.headers.get("content-type"),
  67. )
  68. # ===================== WebSocket 转发 =====================
  69. @app.websocket("/{path:path}")
  70. async def proxy_ws(ws: WebSocket, path: str):
  71. """
  72. WebSocket 直通转发
  73. """
  74. await ws.accept()
  75. backend = next_worker()
  76. backend_ws = backend.replace("http", "ws") + f"/{path}"
  77. async with httpx.AsyncClient() as session:
  78. async with session.ws_connect(backend_ws) as backend_socket:
  79. async def client_to_backend():
  80. while True:
  81. msg = await ws.receive_text()
  82. await backend_socket.send_text(msg)
  83. async def backend_to_client():
  84. while True:
  85. msg = await backend_socket.receive_text()
  86. await ws.send_text(msg)
  87. await asyncio.gather(
  88. client_to_backend(),
  89. backend_to_client(),
  90. )
  91. # ===================== 健康检查 =====================
  92. @app.get("/health")
  93. async def health():
  94. return {
  95. "status": "ok",
  96. "workers": WORKERS,
  97. }
  98. # ===================== 启动 =====================
  99. if __name__ == "__main__":
  100. uvicorn.run(
  101. app,
  102. host="0.0.0.0",
  103. port=6000,
  104. log_level="info",
  105. )