You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

99 lines
3.2 KiB
Python

import os
import uuid
import shutil
import tempfile
import gc
import fitz
import torch
import base64
import filetype
import litserve as ls
from pathlib import Path
from fastapi import HTTPException
class MinerUAPI(ls.LitAPI):
def __init__(self, output_dir='/tmp'):
self.output_dir = Path(output_dir)
def setup(self, device):
if device.startswith('cuda'):
os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
if torch.cuda.device_count() > 1:
raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
self.do_parse = do_parse
self.convert_file_to_pdf = convert_file_to_pdf
model_manager = ModelSingleton()
model_manager.get_model(True, False)
model_manager.get_model(False, False)
print(f'Model initialization complete on {device}!')
def decode_request(self, request):
file = request['file']
file = self.cvt2pdf(file)
opts = request.get('kwargs', {})
opts.setdefault('debug_able', False)
opts.setdefault('parse_method', 'auto')
return file, opts
def predict(self, inputs):
try:
pdf_name = str(uuid.uuid4())
output_dir = self.output_dir.joinpath(pdf_name)
self.do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
return output_dir
except Exception as e:
shutil.rmtree(output_dir, ignore_errors=True)
raise HTTPException(status_code=500, detail=str(e))
finally:
self.clean_memory()
def encode_response(self, response):
return {'output_dir': response}
def clean_memory(self):
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def cvt2pdf(self, file_base64):
try:
temp_dir = Path(tempfile.mkdtemp())
temp_file = temp_dir.joinpath('tmpfile')
file_bytes = base64.b64decode(file_base64)
file_ext = filetype.guess_extension(file_bytes)
if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
if file_ext == 'pdf':
return file_bytes
elif file_ext in ['jpg', 'png']:
with fitz.open(stream=file_bytes, filetype=file_ext) as f:
return f.convert_to_pdf()
else:
temp_file.write_bytes(file_bytes)
self.convert_file_to_pdf(temp_file, temp_dir)
return temp_file.with_suffix('.pdf').read_bytes()
else:
raise Exception('Unsupported file format')
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
if __name__ == '__main__':
server = ls.LitServer(
MinerUAPI(output_dir='/tmp'),
accelerator='cuda',
devices='auto',
workers_per_device=1,
timeout=False
)
server.run(port=8000)