paddleocr_server.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import json
  5. import re
  6. from pathlib import Path
  7. from fastapi import FastAPI, HTTPException
  8. from pydantic import BaseModel
  9. from typing import Optional
  10. import uvicorn
  11. from paddleocr import PaddleOCRVL
  12. app = FastAPI(title="PaddleOCR-VL Service")
  13. # 全局 pipeline 实例
  14. pipeline = None
  15. class ParseRequest(BaseModel):
  16. pdf_path: str
  17. output_dir: Optional[str] = "./tmp_file/paddleocr_parsed"
  18. vl_rec_server_url: Optional[str] = "http://127.0.0.1:8118/v1"
  19. class ParseResponse(BaseModel):
  20. code: int
  21. message: str
  22. data: Optional[dict] = None
  23. import time
  24. import traceback
  25. MAX_RETRY = 3
  26. RETRY_INTERVAL = 5 # 秒
  27. def safe_predict_with_retry(
  28. pdf_path: str,
  29. vl_rec_server_url: str,
  30. **predict_kwargs
  31. ):
  32. """
  33. 重试
  34. - 失败后重建 pipeline
  35. - 等待 vllm-server 重启
  36. """
  37. global pipeline
  38. last_exception = None
  39. for attempt in range(1, MAX_RETRY + 1):
  40. try:
  41. ocr_pipeline = get_pipeline(vl_rec_server_url)
  42. return ocr_pipeline.predict(
  43. input=pdf_path,
  44. **predict_kwargs
  45. )
  46. except Exception as e:
  47. last_exception = e
  48. print(f"PaddleOCR predict 失败,第 {attempt}/{MAX_RETRY} 次")
  49. traceback.print_exc()
  50. pipeline = None
  51. if attempt < MAX_RETRY:
  52. print(f"[INFO] 等待 {RETRY_INTERVAL}s 后重试...")
  53. time.sleep(RETRY_INTERVAL)
  54. # 所有重试失败
  55. raise RuntimeError(
  56. f"PaddleOCR predict 连续失败 {MAX_RETRY} 次,最后错误: {last_exception}"
  57. )
  58. def get_pipeline(vl_rec_server_url: str):
  59. """获取或创建 PaddleOCRVL pipeline"""
  60. global pipeline
  61. if pipeline is None:
  62. pipeline = PaddleOCRVL(
  63. vl_rec_backend="vllm-server",
  64. vl_rec_server_url=vl_rec_server_url
  65. )
  66. return pipeline
  67. @app.post("/parse", response_model=ParseResponse)
  68. async def parse_pdf(request: ParseRequest):
  69. """
  70. 解析 PDF 文件
  71. 返回:
  72. - json_data: 合并后的 JSON 数据
  73. - md_path: MD 文件保存路径
  74. - pdf_file_name: PDF 文件名(不含扩展名)
  75. """
  76. try:
  77. pdf_path = request.pdf_path
  78. output_dir = Path(request.output_dir)
  79. if not os.path.exists(pdf_path):
  80. raise HTTPException(status_code=400, detail=f"PDF 文件不存在: {pdf_path}")
  81. pdf_file_name = Path(pdf_path).stem
  82. # 创建输出目录
  83. output_path = output_dir / pdf_file_name
  84. output_path.mkdir(parents=True, exist_ok=True)
  85. # 获取 pipeline
  86. ocr_pipeline = get_pipeline(request.vl_rec_server_url)
  87. # 执行解析
  88. # output = ocr_pipeline.predict(
  89. # input=pdf_path,
  90. # format_block_content=True,
  91. # use_chart_recognition=True
  92. # )
  93. output = safe_predict_with_retry(
  94. pdf_path=pdf_path,
  95. vl_rec_server_url=request.vl_rec_server_url,
  96. format_block_content=True,
  97. use_chart_recognition=True
  98. )
  99. # 收集所有页面数据
  100. json_list = []
  101. markdown_list = []
  102. markdown_images = {}
  103. for res in output:
  104. json_data = res.json
  105. json_list.append({"res": json_data})
  106. md_info = res.markdown
  107. markdown_list.append(md_info)
  108. page_images = md_info.get("markdown_images", {})
  109. if page_images:
  110. markdown_images.update(page_images)
  111. # 合并 MD 数据
  112. markdown_texts = ocr_pipeline.concatenate_markdown_pages(markdown_list)
  113. # 保存 MD 文件
  114. md_file_path = output_path / f"{pdf_file_name}.md"
  115. with open(md_file_path, "w", encoding="utf-8") as f:
  116. f.write(markdown_texts)
  117. # 保存图片
  118. images_dir = output_path / "imgs"
  119. for img_path, image in markdown_images.items():
  120. file_path = output_path / img_path
  121. file_path.parent.mkdir(parents=True, exist_ok=True)
  122. image.save(file_path)
  123. # 合并 JSON 数据
  124. merged_json = {
  125. "input_file": str(pdf_path),
  126. "total_pages": len(json_list),
  127. "pages": json_list
  128. }
  129. # 保存 JSON 文件
  130. json_file_path = output_path / f"{pdf_file_name}.json"
  131. with open(json_file_path, "w", encoding="utf-8") as f:
  132. json.dump(merged_json, f, ensure_ascii=False, indent=2)
  133. return ParseResponse(
  134. code=200,
  135. message="解析成功",
  136. data={
  137. "json_data": merged_json,
  138. "md_path": str(md_file_path),
  139. "json_path": str(json_file_path),
  140. "pdf_file_name": pdf_file_name,
  141. "output_dir": str(output_path)
  142. }
  143. )
  144. except Exception as e:
  145. return ParseResponse(
  146. code=500,
  147. message=f"解析失败: {str(e)}",
  148. data=None
  149. )
  150. @app.get("/health")
  151. async def health_check():
  152. """健康检查"""
  153. return {"status": "ok"}
  154. if __name__ == "__main__":
  155. import argparse
  156. parser = argparse.ArgumentParser()
  157. parser.add_argument("--host", default="0.0.0.0", help="服务地址")
  158. parser.add_argument("--port", type=int, default=8119, help="服务端口")
  159. args = parser.parse_args()
  160. uvicorn.run(app, host=args.host, port=args.port)