fu-hsi-service/src/main/java/com/supervision/police/service/impl/FileOcrProcessServiceImpl.java

304 lines
12 KiB
Java

package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.constant.OcrProcessStatus;
import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.service.MinioService;
import com.supervision.police.domain.FileOcrProcess;
import com.supervision.police.domain.NoteRecord;
import com.supervision.police.dto.OCRReqDTO;
import com.supervision.police.dto.OCRResDTO;
import com.supervision.police.dto.RecordFileDTO;
import com.supervision.police.service.FileOcrProcessService;
import com.supervision.police.mapper.FileOcrProcessMapper;
import com.supervision.police.service.NoteRecordService;
import com.supervision.police.service.OCRService;
import com.supervision.utils.PDFReadUtil;
import com.supervision.utils.WordReadUtil;
import io.swagger.v3.oas.annotations.OpenAPI31;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.aop.framework.AopContext;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* @author Administrator
* @description 针对表【file_ocr_process(文件ocr识别进度表)】的数据库操作Service实现
* @createDate 2024-08-30 17:35:23
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class FileOcrProcessServiceImpl extends ServiceImpl<FileOcrProcessMapper, FileOcrProcess>
implements FileOcrProcessService {
@Value("${ocr.pool.max-size:20}")
private Integer poolMaxSize;
private final OCRService ocrService;
private final MinioService minioService;
private final NoteRecordService noteRecordService;
@Override
@Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class)
public List<FileOcrProcess> asyncSubmitOCR(List<String> fileIdList) {
return submitOCR(fileIdList, ((FileOcrProcessService) AopContext.currentProxy())::asyncDoOCRTask);
}
@Override
public List<FileOcrProcess> syncSubmitOCR(List<String> fileIdList) {
return submitOCR(fileIdList, this::doOCRTask);
}
@Override
public List<FileOcrProcess> syncSubmitOCR(List<String> fileIdList, Consumer<List<FileOcrProcess>> consumer) {
return submitOCR(fileIdList, consumer);
}
private List<FileOcrProcess> submitOCR(List<String> fileIdList, Consumer<List<FileOcrProcess>> consumer){
if (CollUtil.isEmpty(fileIdList)){
log.info("submitOCR:fileIds为空。不提交ocr任务...");
return new ArrayList<>(1);
}
Map<String, FileOcrProcess> ocrProcessMap = super.lambdaQuery().in(FileOcrProcess::getFileId, fileIdList).list()
.stream().collect(
Collectors.toMap(FileOcrProcess::getFileId, fileOcrProcess -> fileOcrProcess, (k1, k2) -> k1));
List<FileOcrProcess> processList = new ArrayList<>();
for (String fileId : fileIdList) {
if (null != ocrProcessMap.get(fileId)){
// 重新识别识别失败的文件
FileOcrProcess fileOcrProcess = ocrProcessMap.get(fileId);
if (OcrProcessStatus.isFailCode(fileOcrProcess.getStatus())){
fileOcrProcess.setStatus(OcrProcessStatus.UNPROCESS.getCode());
super.updateById(fileOcrProcess);
processList.add(fileOcrProcess);
}
}else {
FileOcrProcess fileOcrProcess = new FileOcrProcess(fileId, OcrProcessStatus.UNPROCESS.getCode());
super.save(fileOcrProcess);
processList.add(fileOcrProcess);
}
}
log.debug("submitOCR:提交识别任务到异步处理器中...");
consumer.accept(processList);
List<FileOcrProcess> resultList = super.lambdaQuery().in(FileOcrProcess::getFileId, fileIdList).list();
return sortByIdOrder(fileIdList,resultList, FileOcrProcess::getFileId);
}
@Override
public synchronized void doOCRTask(List<FileOcrProcess> fileOcrProcesses) {
log.info("doOCRTask:开始识别文件...{}",JSONUtil.toJsonStr(fileOcrProcesses));
if (CollUtil.isEmpty(fileOcrProcesses)){
log.info("asyncOcr:当前暂无识别的任务,结束...");
return;
}
List<List<FileOcrProcess>> ocrTaskList = CollUtil.split(fileOcrProcesses, poolMaxSize);
for (List<FileOcrProcess> ocrProcesses : ocrTaskList) {
List<String> fileIdList = ocrProcesses.stream().map(FileOcrProcess::getFileId).collect(Collectors.toList());
List<OCRReqDTO> ocrReqDTOS = buildOCRReqDTO(fileIdList);
for (OCRReqDTO ocrReqDTO : ocrReqDTOS) {
log.info("ocr:开始识别文件:{}", JSONUtil.toJsonStr(ocrReqDTO));
this.updateOCrStatus(ocrReqDTO.getFile_ids(),OcrProcessStatus.PROCESSING.getCode());
try {
List<OCRResDTO> ocrRes = ocrService.ocr(ocrReqDTO);
log.info("ocr:识别结果:{}", JSONUtil.toJsonStr(ocrRes));
if (CollUtil.isNotEmpty(ocrRes)){
for (OCRResDTO ocrRe : ocrRes) {
if (Integer.valueOf(2).equals(ocrRe.getStatus())){
log.info("ocr:文件{}识别失败,原因:{}", ocrRe.getFile_id(), ocrRe.getError_msg());
}
}
}
ocrRes.forEach(this::updateByOcrRes);
} catch (Exception e) {
log.error("远程调用ocr识别接口失败",e);
this.updateOCrStatus(ocrReqDTO.getFile_ids(),2);
}
}
}
}
@Async
@Override
public void asyncDoOCRTask(List<FileOcrProcess> fileOcrProcesses) {
log.debug("asyncDoOCRTask:开始识别文件...{}", JSONUtil.toJsonStr(fileOcrProcesses));
doOCRTask(fileOcrProcesses);
}
@Override
public void doAllOCRTask() {
List<FileOcrProcess> allFileOcrProcesses = pageListByStatus(-1, 99999);
doOCRTask(allFileOcrProcesses);
}
@Override
public List<FileOcrProcess> pageListByStatus(Integer status, Integer size) {
return super.lambdaQuery().eq(FileOcrProcess::getStatus, status).page(new Page<>(1, size)).getRecords();
}
@Override
public Integer countByStatus(Integer status) {
return Math.toIntExact(super.lambdaQuery().eq(FileOcrProcess::getStatus, status).count());
}
@Override
public Boolean updateOCrStatus(List<String> ocrIdList, Integer ocrStatus) {
return super.lambdaUpdate().in(FileOcrProcess::getFileId, ocrIdList)
.set(FileOcrProcess::getStatus, ocrStatus).update();
}
@Override
public Boolean updateByOcrRes(OCRResDTO ocrResDTO) {
// code码转换
if (Integer.valueOf(0).equals(ocrResDTO.getStatus())){
ocrResDTO.setStatus(1);
}
return super.lambdaUpdate().eq(FileOcrProcess::getFileId, ocrResDTO.getFile_id())
.set(FileOcrProcess::getOcrText, ocrResDTO.getOcr_text())
.set(FileOcrProcess::getStatus, ocrResDTO.getStatus())
.set(FileOcrProcess::getDrawImgId, ocrResDTO.getDraw_img_id())
.update();
}
@Override
public List<RecordFileDTO> queryFileList(List<String> fileIdList) {
if (CollUtil.isEmpty(fileIdList)){
return new ArrayList<>(1);
}
return super.baseMapper.queryFileList(null,fileIdList);
}
@Override
public List<RecordFileDTO> queryFileListWithIdSort(List<String> fileIdList) {
List<RecordFileDTO> recordFileDTOS = this.queryFileList(fileIdList);
return sortByIdOrder(fileIdList, recordFileDTOS, RecordFileDTO::getFileId);
}
@Override
public List<RecordFileDTO> queryFileList(String status, List<String> fileIdList) {
if (CollUtil.isEmpty(fileIdList)){
return new ArrayList<>(1);
}
return super.baseMapper.queryFileList(status,fileIdList);
}
@Override
public List<RecordFileDTO> queryListByRecordId(String recordId) {
if (StrUtil.isEmpty(recordId)){
return new ArrayList<>(1);
}
NoteRecord noteRecord = noteRecordService.getById(recordId);
if (Objects.isNull(noteRecord) || StrUtil.isEmpty(noteRecord.getFileIds())){
return new ArrayList<>(1);
}
return queryFileList(Arrays.stream(noteRecord.getFileIds().split(",")).toList());
}
@Override
@Transactional(transactionManager = "dataSourceTransactionManager",propagation= Propagation.NOT_SUPPORTED)
public List<RecordFileDTO> queryFileListWithIdSortNoTransaction(List<String> fileIdList) {
return this.queryFileListWithIdSort(fileIdList);
}
@Override
public void doWordCRTask(List<FileOcrProcess> fileOcrProcesses) {
doMcr(fileOcrProcesses, (fileId)-> WordReadUtil.readWordInMinio(minioService, fileId));
}
@Override
public void doPdfCRTask(List<FileOcrProcess> fileOcrProcesses) {
doMcr(fileOcrProcesses, (fileId)-> PDFReadUtil.readPdfInMinio(minioService, fileId));
}
@Override
public List<FileOcrProcess> multipleTypeOcrProcess(List<String> fileIds, String fileType) {
if (StrUtil.equalsAny(fileType, "doc", "docx")){
return this.syncSubmitOCR(fileIds, this::doWordCRTask);
}else if (StrUtil.equalsAny(fileType, "pdf")){
return this.syncSubmitOCR(fileIds, this::doPdfCRTask);
}else {
return this.submitOCR(fileIds, this::doOCRTask);
}
}
private void doMcr(List<FileOcrProcess> fileOcrProcesses, Function<String, String> function) {
log.info("doMcr:开始识别文件...{}",JSONUtil.toJsonStr(fileOcrProcesses));
if (CollUtil.isEmpty(fileOcrProcesses)){
log.info("doMcr:当前暂无识别的任务,结束...");
return;
}
for (FileOcrProcess ocrProcess : fileOcrProcesses) {
log.info("ocr:开始识别文件:{}", JSONUtil.toJsonStr(ocrProcess));
this.updateOCrStatus(List.of(ocrProcess.getFileId()),OcrProcessStatus.PROCESSING.getCode());
try {
String ocrText = function.apply(ocrProcess.getFileId());
Assert.notNull(ocrText, "识别结果为空");
this.lambdaUpdate().eq(FileOcrProcess::getFileId, ocrProcess.getFileId())
.set(FileOcrProcess::getStatus, OcrProcessStatus.PROCESSING.getCode())
.set(FileOcrProcess::getOcrText, ocrText).update();
} catch (Exception e) {
log.error("doMcr识别失败",e);
this.updateOCrStatus(List.of(ocrProcess.getFileId()),OcrProcessStatus.FAIL.getCode());
}
}
}
private List<OCRReqDTO> buildOCRReqDTO(List<String> fileIdList){
List<MinioFile> minioFiles = minioService.listMinioFile(fileIdList);
return minioFiles.stream().collect(Collectors.groupingBy(MinioFile::getFileType))
.entrySet().stream().map(entry ->
new OCRReqDTO(entry.getValue().stream().map(MinioFile::getId).collect(Collectors.toList()),entry.getKey()))
.collect(Collectors.toList());
}
private <T> List<T> sortByIdOrder(List<String> idList,List<T> targetList, Function<T,String> function){
if (CollUtil.size(idList) < 2 || CollUtil.size(targetList) < 2) {
return targetList;
}
Map<String, T> targetMap = targetList.stream()
.collect(Collectors.toMap(function,target -> target, (k1, k2) -> k1));
return idList.stream().map(targetMap::get).filter(Objects::nonNull).collect(Collectors.toList());
}
}