| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- import os
- import json
- import re
- from pathlib import Path
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel
- from typing import Optional
- import uvicorn
- from paddleocr import PaddleOCRVL
- app = FastAPI(title="PaddleOCR-VL Service")
- # 全局 pipeline 实例
- pipeline = None
- class ParseRequest(BaseModel):
- pdf_path: str
- output_dir: Optional[str] = "./tmp_file/paddleocr_parsed"
- vl_rec_server_url: Optional[str] = "http://127.0.0.1:8118/v1"
- class ParseResponse(BaseModel):
- code: int
- message: str
- data: Optional[dict] = None
- import time
- import traceback
- MAX_RETRY = 3
- RETRY_INTERVAL = 5 # 秒
- def safe_predict_with_retry(
- pdf_path: str,
- vl_rec_server_url: str,
- **predict_kwargs
- ):
- """
- 重试
- - 失败后重建 pipeline
- - 等待 vllm-server 重启
- """
- global pipeline
- last_exception = None
- for attempt in range(1, MAX_RETRY + 1):
- try:
- ocr_pipeline = get_pipeline(vl_rec_server_url)
- return ocr_pipeline.predict(
- input=pdf_path,
- **predict_kwargs
- )
- except Exception as e:
- last_exception = e
- print(f"PaddleOCR predict 失败,第 {attempt}/{MAX_RETRY} 次")
- traceback.print_exc()
- pipeline = None
- if attempt < MAX_RETRY:
- print(f"[INFO] 等待 {RETRY_INTERVAL}s 后重试...")
- time.sleep(RETRY_INTERVAL)
- # 所有重试失败
- raise RuntimeError(
- f"PaddleOCR predict 连续失败 {MAX_RETRY} 次,最后错误: {last_exception}"
- )
- def get_pipeline(vl_rec_server_url: str):
- """获取或创建 PaddleOCRVL pipeline"""
- global pipeline
- if pipeline is None:
- pipeline = PaddleOCRVL(
- vl_rec_backend="vllm-server",
- vl_rec_server_url=vl_rec_server_url
- )
- return pipeline
- @app.post("/parse", response_model=ParseResponse)
- async def parse_pdf(request: ParseRequest):
- """
- 解析 PDF 文件
-
- 返回:
- - json_data: 合并后的 JSON 数据
- - md_path: MD 文件保存路径
- - pdf_file_name: PDF 文件名(不含扩展名)
- """
- try:
- pdf_path = request.pdf_path
- output_dir = Path(request.output_dir)
-
- if not os.path.exists(pdf_path):
- raise HTTPException(status_code=400, detail=f"PDF 文件不存在: {pdf_path}")
-
- pdf_file_name = Path(pdf_path).stem
-
- # 创建输出目录
- output_path = output_dir / pdf_file_name
- output_path.mkdir(parents=True, exist_ok=True)
-
- # 获取 pipeline
- ocr_pipeline = get_pipeline(request.vl_rec_server_url)
-
- # 执行解析
- # output = ocr_pipeline.predict(
- # input=pdf_path,
- # format_block_content=True,
- # use_chart_recognition=True
- # )
- output = safe_predict_with_retry(
- pdf_path=pdf_path,
- vl_rec_server_url=request.vl_rec_server_url,
- format_block_content=True,
- use_chart_recognition=True
- )
- # 收集所有页面数据
- json_list = []
- markdown_list = []
- markdown_images = {}
-
- for res in output:
- json_data = res.json
- json_list.append({"res": json_data})
-
- md_info = res.markdown
- markdown_list.append(md_info)
-
- page_images = md_info.get("markdown_images", {})
- if page_images:
- markdown_images.update(page_images)
-
- # 合并 MD 数据
- markdown_texts = ocr_pipeline.concatenate_markdown_pages(markdown_list)
-
- # 保存 MD 文件
- md_file_path = output_path / f"{pdf_file_name}.md"
- with open(md_file_path, "w", encoding="utf-8") as f:
- f.write(markdown_texts)
-
- # 保存图片
- images_dir = output_path / "imgs"
- for img_path, image in markdown_images.items():
- file_path = output_path / img_path
- file_path.parent.mkdir(parents=True, exist_ok=True)
- image.save(file_path)
-
- # 合并 JSON 数据
- merged_json = {
- "input_file": str(pdf_path),
- "total_pages": len(json_list),
- "pages": json_list
- }
-
- # 保存 JSON 文件
- json_file_path = output_path / f"{pdf_file_name}.json"
- with open(json_file_path, "w", encoding="utf-8") as f:
- json.dump(merged_json, f, ensure_ascii=False, indent=2)
-
- return ParseResponse(
- code=200,
- message="解析成功",
- data={
- "json_data": merged_json,
- "md_path": str(md_file_path),
- "json_path": str(json_file_path),
- "pdf_file_name": pdf_file_name,
- "output_dir": str(output_path)
- }
- )
-
- except Exception as e:
- return ParseResponse(
- code=500,
- message=f"解析失败: {str(e)}",
- data=None
- )
- @app.get("/health")
- async def health_check():
- """健康检查"""
- return {"status": "ok"}
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--host", default="0.0.0.0", help="服务地址")
- parser.add_argument("--port", type=int, default=8119, help="服务端口")
- args = parser.parse_args()
-
- uvicorn.run(app, host=args.host, port=args.port)
|