From a8c3125446edc6e01528158d5cd2331d129b2b24 Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Tue, 20 May 2025 13:15:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- helper/page_detection/pdf_detection.py | 64 ++++++++------------------ server.py | 14 ++++-- server.sh | 3 ++ 3 files changed, 33 insertions(+), 48 deletions(-) create mode 100755 server.sh diff --git a/helper/page_detection/pdf_detection.py b/helper/page_detection/pdf_detection.py index 8c31e01..bfd41c9 100644 --- a/helper/page_detection/pdf_detection.py +++ b/helper/page_detection/pdf_detection.py @@ -770,13 +770,10 @@ def load_predictor(model_dir, precision_mode=precision_map[run_mode], use_static=False, use_calib_mode=trt_calib_mode) - if FLAGS.collect_trt_shape_info: - config.collect_shape_range_info(FLAGS.tuned_trt_shape_file) - elif os.path.exists(FLAGS.tuned_trt_shape_file): - print(f'Use dynamic shape file: ' - f'{FLAGS.tuned_trt_shape_file} for TRT...') - config.enable_tuned_tensorrt_dynamic_shape( - FLAGS.tuned_trt_shape_file, True) + if os.path.exists('shape_range_info.pbtxt'): + print('Use dynamic shape file: shape_range_info.pbtxt for TRT...') + config.enable_tuned_tensorrt_dynamic_shape( + 'shape_range_info.pbtxt', True) if use_dynamic_shape: min_input_shape = { @@ -849,20 +846,10 @@ def visualize(image_list, result, labels, output_dir='output/', threshold=0.5): print("save result to: " + out_path) -def print_arguments(args): - print('----------- Running Arguments -----------') - for arg, value in sorted(vars(args).items()): - print('%s: %s' % (arg, value)) - print('------------------------------------------') - - class Pipeline(object): def __init__(self, model_dir): - if FLAGS.use_fd_format: - deploy_file = os.path.join(model_dir, 'inference.yml') - else: - deploy_file = os.path.join(model_dir, 'infer_cfg.yml') + deploy_file = os.path.join(model_dir, 'infer_cfg.yml') with open(deploy_file) as f: yml_conf = yaml.safe_load(f) arch = yml_conf['arch'] @@ -876,40 +863,27 @@ class Pipeline(object): self.detector = eval(detector_func)( model_dir, - device=FLAGS.device, - run_mode=FLAGS.run_mode, - batch_size=FLAGS.batch_size, - trt_min_shape=FLAGS.trt_min_shape, - trt_max_shape=FLAGS.trt_max_shape, - trt_opt_shape=FLAGS.trt_opt_shape, - trt_calib_mode=FLAGS.trt_calib_mode, - cpu_threads=FLAGS.cpu_threads, - enable_mkldnn=FLAGS.enable_mkldnn, - enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16, - threshold=FLAGS.threshold, - output_dir=FLAGS.output_dir, - use_fd_format=FLAGS.use_fd_format) + device='GPU', + run_mode='paddle', + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + enable_mkldnn_bfloat16=False, + threshold=0.5, + output_dir='output', + use_fd_format=False) def __call__(self, image): if isinstance(image, np.ndarray): image = [image] results = self.detector.predict_image( image, - visual=FLAGS.save_images) + visual=False) return results paddle.enable_static() -parser = argsparser() -FLAGS = parser.parse_args() -# print_arguments(FLAGS) -FLAGS.device = 'GPU' -FLAGS.save_images = False -FLAGS.device = FLAGS.device.upper() -assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU' - ], "device should be CPU, GPU, XPU, MLU, NPU or GCU" -assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device" - -assert not ( - FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True -), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16' diff --git a/server.py b/server.py index 0d4cc20..7e07b8a 100644 --- a/server.py +++ b/server.py @@ -1,19 +1,27 @@ from flask import Flask, request, jsonify import requests from pipeline import pdf2markdown_pipeline +import concurrent.futures +from loguru import logger app = Flask(__name__) +executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + +def pdf2markdown_task(pdf_paths, callback_url): + for pdf_path in pdf_paths: + process_status, pdf_id = pdf2markdown_pipeline(pdf_path) + requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status}) @app.route('/pdf-qa-server/pdf-to-md', methods=['POST']) def pdf2markdown(): data = request.json + logger.info(f'request params: {data}') pdf_paths = data['pathList'] callback_url = data['webhookUrl'] - for pdf_path in pdf_paths: - process_status, pdf_id = pdf2markdown_pipeline(pdf_path) - requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status}) + executor.submit(pdf2markdown_task, pdf_paths, callback_url) return jsonify({}) diff --git a/server.sh b/server.sh new file mode 100755 index 0000000..cbf0f25 --- /dev/null +++ b/server.sh @@ -0,0 +1,3 @@ +#! /bin/bash + +gunicorn -w 1 -b 0.0.0.0:8000 server:app