parent
4de8e64451
commit
64cfbd851e
@ -0,0 +1,217 @@ |
||||
import os |
||||
import tempfile |
||||
import logging |
||||
import time |
||||
from fastapi import FastAPI, HTTPException, status, Body |
||||
from pydantic import BaseModel |
||||
import fitz # PyMuPDF |
||||
from paddleocr import PaddleOCR |
||||
from typing import Optional, Dict, Any |
||||
|
||||
# 日志配置保持不变 |
||||
logging.basicConfig( |
||||
level=logging.INFO, |
||||
format='%(asctime)s [%(levelname)s] %(message)s', |
||||
handlers=[ |
||||
logging.FileHandler('/var/log/ocr_service.log', encoding='utf-8'), |
||||
logging.StreamHandler() |
||||
] |
||||
) |
||||
logger = logging.getLogger('OCRService') |
||||
logger.setLevel(logging.DEBUG) |
||||
|
||||
# 安全配置保持不变 |
||||
ALLOWED_EXTENSIONS = {'pdf', 'jpg', 'jpeg', 'png'} |
||||
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB |
||||
BASE_STORAGE = '/data/files' |
||||
|
||||
# OCR引擎配置保持不变 |
||||
ocr_engine = PaddleOCR( |
||||
use_angle_cls=True, |
||||
lang='ch', |
||||
enable_mkldnn=True, |
||||
det_limit_side_len=2200, |
||||
det_db_box_thresh=0.4 |
||||
) |
||||
|
||||
app = FastAPI() |
||||
|
||||
class OCRRequest(BaseModel): |
||||
file_path: str |
||||
|
||||
class OCRResponse(BaseModel): |
||||
status: str |
||||
data: Optional[Dict[str, Any]] = None |
||||
message: Optional[str] = None |
||||
code: Optional[int] = None |
||||
elapsed: str |
||||
|
||||
class OCRServiceError(Exception): |
||||
def __init__(self, message: str, status_code: int = 400): |
||||
self.message = message |
||||
self.status_code = status_code |
||||
|
||||
# 路径解析和验证函数保持不变 |
||||
def safe_resolve_path(user_path: str) -> str: |
||||
try: |
||||
clean_path = os.path.normpath(user_path).lstrip('/') |
||||
if not clean_path: |
||||
raise ValueError("空路径参数") |
||||
|
||||
abs_path = os.path.abspath(os.path.join(BASE_STORAGE, clean_path)) |
||||
|
||||
if not abs_path.startswith(BASE_STORAGE): |
||||
raise ValueError("非法路径访问") |
||||
|
||||
return abs_path |
||||
except Exception as e: |
||||
logger.error(f"路径解析失败: {str(e)}") |
||||
raise OCRServiceError("无效文件路径") from e |
||||
|
||||
def validate_file(file_path: str) -> str: |
||||
try: |
||||
if not os.path.exists(file_path): |
||||
raise OCRServiceError("文件不存在") |
||||
if not os.path.isfile(file_path): |
||||
raise OCRServiceError("路径不是文件") |
||||
if os.path.getsize(file_path) > MAX_FILE_SIZE: |
||||
raise OCRServiceError(f"文件超过大小限制 ({MAX_FILE_SIZE//1024//1024}MB)") |
||||
|
||||
_, ext_with_dot = os.path.splitext(file_path) |
||||
ext = ext_with_dot.lstrip('.').lower() |
||||
if ext not in ALLOWED_EXTENSIONS: |
||||
raise OCRServiceError(f"不支持的文件类型: {ext}") |
||||
|
||||
return ext |
||||
except OCRServiceError: |
||||
raise |
||||
except Exception as e: |
||||
logger.error(f"文件验证异常: {str(e)}") |
||||
raise OCRServiceError("文件验证失败") from e |
||||
|
||||
# 结果格式化和处理函数保持不变 |
||||
def format_ocr_result(raw_data): |
||||
results = [] |
||||
try: |
||||
for group in raw_data[0]: |
||||
boxes, (text, score) = group |
||||
logger.error(f"识别内容: {str(text)}") |
||||
results.append({ |
||||
"text": text, |
||||
"score": float(score), |
||||
"boxes": [list(map(float, point)) for point in boxes] |
||||
}) |
||||
return results |
||||
except Exception as e: |
||||
logger.error("解析失败:%s", str(e), exc_info=True) |
||||
return [] |
||||
|
||||
def process_pdf(file_path: str): |
||||
try: |
||||
logger.info(f"开始处理PDF: {file_path}") |
||||
doc = fitz.open(file_path) |
||||
if doc.is_encrypted: |
||||
if not doc.authenticate(""): |
||||
raise OCRServiceError("加密PDF需要密码") |
||||
|
||||
pages = [] |
||||
for page_num in range(len(doc)): |
||||
page_start = time.time() |
||||
page = doc.load_page(page_num) |
||||
zoom = 900 / 72 |
||||
mat = fitz.Matrix(zoom, zoom) |
||||
pix = page.get_pixmap( |
||||
matrix=mat, |
||||
colorspace=fitz.csGRAY, |
||||
alpha=False, |
||||
dpi=900 |
||||
) |
||||
with tempfile.NamedTemporaryFile(suffix=".png") as tmp: |
||||
pix.save(tmp.name) |
||||
page_result = ocr_engine.ocr(tmp.name, cls=True) |
||||
pages.append({ |
||||
"page": page_num + 1, |
||||
"content": format_ocr_result(page_result), |
||||
"process_time": f"{time.time() - page_start:.2f}s" |
||||
}) |
||||
|
||||
return { |
||||
"type": "pdf", |
||||
"page_count": len(doc), |
||||
"pages": pages |
||||
} |
||||
except OCRServiceError: |
||||
raise |
||||
except Exception as e: |
||||
logger.error(f"PDF处理异常: {str(e)}") |
||||
raise OCRServiceError("PDF处理失败") from e |
||||
|
||||
def process_image(file_path: str): |
||||
try: |
||||
logger.info(f"开始处理图像: {file_path}") |
||||
start_time = time.time() |
||||
result = ocr_engine.ocr(file_path, cls=True) |
||||
logger.debug("原始OCR数据结构类型: %s", type(result)) |
||||
formatted = format_ocr_result(result) |
||||
return { |
||||
"type": "image", |
||||
"results": formatted |
||||
} |
||||
except Exception as e: |
||||
logger.error("图像处理异常") |
||||
raise |
||||
|
||||
@app.exception_handler(OCRServiceError) |
||||
async def ocr_exception_handler(request, exc: OCRServiceError): |
||||
return JSONResponse( |
||||
status_code=exc.status_code, |
||||
content={ |
||||
"status": "error", |
||||
"message": exc.message, |
||||
"code": exc.status_code, |
||||
"elapsed": "0.00s" |
||||
} |
||||
) |
||||
|
||||
@app.post("/ocr", response_model=OCRResponse) |
||||
async def ocr_service(request_data: OCRRequest = Body(...)): |
||||
start_time = time.time() |
||||
response = {"status": "success", "data": None, "message": None, "code": None} |
||||
|
||||
try: |
||||
abs_path = safe_resolve_path(request_data.file_path) |
||||
logger.debug(f"处理请求文件: {abs_path}") |
||||
ext = validate_file(abs_path) |
||||
|
||||
if ext == 'pdf': |
||||
result = process_pdf(abs_path) |
||||
else: |
||||
result = process_image(abs_path) |
||||
|
||||
response["data"] = result |
||||
|
||||
except OCRServiceError as e: |
||||
raise e |
||||
except Exception as e: |
||||
logger.error(f"系统异常: {str(e)}", exc_info=True) |
||||
raise HTTPException( |
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
||||
detail="系统内部错误" |
||||
) |
||||
finally: |
||||
response["elapsed"] = f"{time.time() - start_time:.2f}s" |
||||
|
||||
return response |
||||
|
||||
@app.get("/healthcheck", response_model=Dict[str, Any]) |
||||
async def health_check(): |
||||
return { |
||||
"status": "ok", |
||||
"timestamp": time.time(), |
||||
"service": "OCR" |
||||
} |
||||
|
||||
if __name__ == "__main__": |
||||
import uvicorn |
||||
os.makedirs(BASE_STORAGE, exist_ok=True) |
||||
uvicorn.run(app, host="0.0.0.0", port=5000) |
Loading…
Reference in new issue