diff --git a/src/main/java/com/supervision/common/constant/TaskRecordConstants.java b/src/main/java/com/supervision/common/constant/TaskRecordConstants.java index 5eabe64..96d677c 100644 --- a/src/main/java/com/supervision/common/constant/TaskRecordConstants.java +++ b/src/main/java/com/supervision/common/constant/TaskRecordConstants.java @@ -12,6 +12,7 @@ public class TaskRecordConstants { public static final String TASK_TYPE_SPECIFIED_CASE = "1"; public static final String TASK_TYPE_SPECIFIED_RECORD = "2"; public static final String TASK_TYPE_SPECIFIED_EVIDENCE = "3"; + public static final String TASK_TYPE_ONE_CLICK = "4"; // 任务状态 public static final String TASK_STATUS_WAITING = "0"; public static final String TASK_STATUS_PROCESSING = "1"; diff --git a/src/main/java/com/supervision/job/XxlJobTask.java b/src/main/java/com/supervision/job/XxlJobTask.java index 7b251e1..1e291e4 100644 --- a/src/main/java/com/supervision/job/XxlJobTask.java +++ b/src/main/java/com/supervision/job/XxlJobTask.java @@ -13,6 +13,7 @@ import com.xxl.job.core.handler.annotation.XxlJob; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; +import org.springframework.util.StopWatch; import java.util.List; import java.util.Map; @@ -20,6 +21,7 @@ import java.util.Map; import static com.supervision.common.constant.NotePromptConstants.TYPE_GRAPH_REASONING; import static com.supervision.common.constant.NotePromptConstants.TYPE_STRUCTURAL_REASONING; import static com.supervision.common.constant.TaskRecordConstants.*; +import static com.supervision.common.constant.XxlJobConstants.TASK_NAME_PROMPT_EXTRACT_TASK; @Slf4j @Component @@ -55,17 +57,19 @@ public class XxlJobTask { /** * 提示词提取任务 */ - @XxlJob(XxlJobConstants.TASK_NAME_PROMPT_EXTRACT_TASK) + @XxlJob(TASK_NAME_PROMPT_EXTRACT_TASK) public void promptExtractTask() { String jobParam = XxlJobHelper.getJobParam(); log.info("【提取任务】任务开始。参数: {}", jobParam); + StopWatch stopWatch = new StopWatch(TASK_NAME_PROMPT_EXTRACT_TASK + " stopwatch"); + stopWatch.start("Data preparation and task status check"); Map map = JSON.parseObject(XxlJobHelper.getJobParam(), Map.class); + String taskId = map.get("taskId"); + String caseId = map.get("caseId"); + String promptId = map.get("promptId"); + String executeId = map.get("executeId"); try { - String taskId = map.get("taskId"); - String caseId = map.get("caseId"); - String promptId = map.get("promptId"); - String executeId = map.get("executeId"); - NotePrompt prompt = notePromptService.getById(promptId); + NotePrompt prompt = notePromptService.getBaseMapper().selectById(promptId); boolean executable = true; TaskRecord taskRecord = taskRecordService.getById(taskId); switch (taskRecord.getStatus()) { @@ -92,15 +96,18 @@ public class XxlJobTask { default: break; } + stopWatch.stop(); if (executable) { + stopWatch.start("Model case analysis status check"); this.modelCaseAnalyzingStatusCheck(caseId); - TaskCaseRecord taskCaseRecord = taskCaseRecordService.lambdaQuery().eq(TaskCaseRecord::getCaseId, caseId).eq(TaskCaseRecord::getTaskRecordId, taskId).one(); + TaskCaseRecord taskCaseRecord = taskCaseRecordService.lambdaQuery().eq(TaskCaseRecord::getCaseId, caseId).eq(TaskCaseRecord::getTaskRecordId, taskId).eq(TaskCaseRecord::getPromptId, promptId).one(); if (TASK_STATUS_WAITING.equals(taskCaseRecord.getStatus())) { log.info("任务状态为等待中,任务状态更新为处理中,任务案件ID: 【{}】", taskCaseRecord.getId()); taskCaseRecord.setStatus(TASK_STATUS_PROCESSING); taskCaseRecordService.updateById(taskCaseRecord); } - + stopWatch.stop(); + stopWatch.start("Extract task execution"); switch (prompt.getType()) { case TYPE_GRAPH_REASONING: log.info("【图推理】任务开始。任务ID: 【{}】", taskId); @@ -114,14 +121,17 @@ public class XxlJobTask { log.error("未知的任务类型"); break; } + stopWatch.stop(); + stopWatch.start("Complete task"); taskRecordService.completeTask(taskId, map.get("executeId"), true); - log.info("【提取任务】任务结束。任务ID: 【{}】", taskId); + stopWatch.stop(); } } catch (Exception e) { - log.error("任务执行失败", e); + stopWatch.stop(); + log.error("任务执行失败!Task ID:{}", taskId, e); taskRecordService.completeTask(map.get("taskId"), map.get("executeId"), false); } finally { - log.info("【提取任务】任务结束。"); + log.info("【提取任务】任务结束。任务ID: 【{}】。耗时:{}", taskId, stopWatch.prettyPrint()); } } diff --git a/src/main/java/com/supervision/police/controller/TaskRecordController.java b/src/main/java/com/supervision/police/controller/TaskRecordController.java index 2fb333e..44a7d73 100644 --- a/src/main/java/com/supervision/police/controller/TaskRecordController.java +++ b/src/main/java/com/supervision/police/controller/TaskRecordController.java @@ -22,7 +22,6 @@ public class TaskRecordController { private final TaskRecordService taskRecordService; - @Operation(summary = "执行提示词提取任务") @PostMapping("/executePromptExtractTask") public R executePromptExtractTask(@RequestBody TaskRecordVo taskRecordVo) { @@ -30,6 +29,13 @@ public class TaskRecordController { return R.ok(); } + @Operation(summary = "一键提取任务") + @PostMapping("/executeAllPromptExtractTask") + public R executeAllPromptExtractTask() { + taskRecordService.executeAllPromptExtractTask(); + return R.ok(); + } + @Operation(summary = "查询任务列表") @PostMapping("/taskList") diff --git a/src/main/java/com/supervision/police/domain/TaskCaseRecord.java b/src/main/java/com/supervision/police/domain/TaskCaseRecord.java index a2ae6aa..847246f 100644 --- a/src/main/java/com/supervision/police/domain/TaskCaseRecord.java +++ b/src/main/java/com/supervision/police/domain/TaskCaseRecord.java @@ -32,6 +32,11 @@ public class TaskCaseRecord implements Serializable { */ private String caseId; + /** + * 提示词ID + */ + private String promptId; + /** * 等待处理的ID,逗号分隔 */ diff --git a/src/main/java/com/supervision/police/service/TaskRecordService.java b/src/main/java/com/supervision/police/service/TaskRecordService.java index d2524d9..972c016 100644 --- a/src/main/java/com/supervision/police/service/TaskRecordService.java +++ b/src/main/java/com/supervision/police/service/TaskRecordService.java @@ -20,6 +20,8 @@ public interface TaskRecordService extends IService { void executePromptExtractTask(TaskRecordVo taskRecordVo); + void executeAllPromptExtractTask(); + void graphExtract(NotePrompt prompt, String caseId, String executeId); /** diff --git a/src/main/java/com/supervision/police/service/impl/TaskRecordServiceImpl.java b/src/main/java/com/supervision/police/service/impl/TaskRecordServiceImpl.java index 1dfcaa0..dbccffb 100644 --- a/src/main/java/com/supervision/police/service/impl/TaskRecordServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/TaskRecordServiceImpl.java @@ -22,6 +22,7 @@ import org.jetbrains.annotations.NotNull; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import org.springframework.util.StopWatch; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; @@ -70,14 +71,14 @@ public class TaskRecordServiceImpl extends ServiceImpl modelCases = this.getModelCases(taskRecordVo, taskRecord); + NotePrompt prompt = notePromptService.getById(taskRecord.getPromptId()); + List modelCases = this.getModelCases(taskRecord); if (!CollUtil.isEmpty(modelCases)) { for (ModelCase modelCase : modelCases) { String caseId = modelCase.getId(); - List ids = this.getIds(taskRecordVo, taskRecord, prompt, caseId); + List ids = this.getIds(taskRecord, caseId, prompt.getType()); if (!ids.isEmpty()) { - this.invokeXxlJob(taskRecordVo, taskRecord, caseId, ids); + this.invokeXxlJob(taskRecord, prompt.getId(), caseId, ids); } else { log.info("案件【{}】没有笔录", caseId); } @@ -92,6 +93,48 @@ public class TaskRecordServiceImpl extends ServiceImpl().eq(NotePrompt::getType, TYPE_GRAPH_REASONING).or().eq(NotePrompt::getType, TYPE_STRUCTURAL_REASONING)).forEach(notePrompt -> { + try { + NotePrompt prompt = notePromptService.getById(notePrompt.getId()); + List modelCases = this.getModelCases(taskRecord); + if (!CollUtil.isEmpty(modelCases)) { + //异步调用xxl-job执行任务 + Thread thread = new Thread(() -> { + for (ModelCase modelCase : modelCases) { + String caseId = modelCase.getId(); + List ids = this.getIds(taskRecord, caseId, prompt.getType()); + if (!ids.isEmpty()) { + taskRecord.setPromptId(prompt.getId()); + this.invokeXxlJob(taskRecord, prompt.getId(), caseId, ids); + } else { + log.info("案件【{}】没有笔录或证据", caseId); + } + } + }); + thread.start(); + } else { + log.info("查无案件"); + } + } catch (Exception e) { + taskRecord.setStatus(TASK_STATUS_FAIL); + super.updateById(taskRecord); + log.error("任务执行失败", e); + } + }); + stopWatch.stop(); + log.info("一键提取任务触发完成。耗时:{}", stopWatch.getTotalTimeSeconds()); + } + @Override public void graphExtract(NotePrompt prompt, String caseId, String executeId) { List tripleInfos = extractTripleInfoService.extractTripleInfo(prompt, caseId, executeId); @@ -113,25 +156,24 @@ public class TaskRecordServiceImpl extends ServiceImpl getIds(TaskRecordVo taskRecordVo, TaskRecord taskRecord, NotePrompt prompt, String caseId) { + private @NotNull List getIds(TaskRecord taskRecord, String caseId, String promptType) { //查出当前案件相关笔录或证据 List ids = new ArrayList<>(); //如果类型为指定笔录或证据,直接取传入的id if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType())) { - ids = List.of(taskRecordVo.getRecordId().split(",")); + ids = List.of(taskRecord.getRecordId().split(",")); } else if (TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) { - ids = List.of(taskRecordVo.getEvidenceId().split(",")); + ids = List.of(taskRecord.getEvidenceId().split(",")); } else { //如果是案件维度,根据案件ID查找笔录或证据 - if (TYPE_GRAPH_REASONING.equals(prompt.getType())) { + if (TYPE_GRAPH_REASONING.equals(promptType)) { ids = noteRecordService.lambdaQuery().eq(NoteRecord::getCaseId, caseId).eq(NoteRecord::getDataStatus, DataStatus.AVAILABLE.getCode()).list().stream().map(NoteRecord::getId).toList(); - } else if (TYPE_STRUCTURAL_REASONING.equals(prompt.getType())) { + } else if (TYPE_STRUCTURAL_REASONING.equals(promptType)) { ids = caseEvidenceService.lambdaQuery().eq(CaseEvidence::getCaseId, caseId).list().stream().map(CaseEvidence::getId).toList(); } } @@ -141,15 +183,15 @@ public class TaskRecordServiceImpl extends ServiceImpl ids) { + private void invokeXxlJob(TaskRecord taskRecord, String promptId, String caseId, List ids) { TaskCaseRecord taskCaseRecord = new TaskCaseRecord(); taskCaseRecord.setTaskRecordId(taskRecord.getId()); taskCaseRecord.setCaseId(caseId); + taskCaseRecord.setPromptId(promptId); taskCaseRecord.setWaitingId(ids.stream().reduce((a, b) -> a + "," + b).orElse("")); taskCaseRecord.setStatus(TASK_STATUS_WAITING); taskCaseRecordService.save(taskCaseRecord); @@ -158,7 +200,7 @@ public class TaskRecordServiceImpl extends ServiceImpl getModelCases(TaskRecordVo taskRecordVo, TaskRecord taskRecord) { + private List getModelCases(TaskRecord taskRecord) { List modelCases; LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); queryWrapper.eq(ModelCase::getDataStatus, DataStatus.AVAILABLE.getCode()); //根据任务类型查找案件 if (TASK_TYPE_SPECIFIED_CASE.equals(taskRecord.getType())) { - queryWrapper.in(ModelCase::getId, List.of(taskRecordVo.getCaseId().split(","))); + queryWrapper.in(ModelCase::getId, List.of(taskRecord.getCaseId().split(","))); } else if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType()) || TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) { queryWrapper.eq(ModelCase::getId, taskRecord.getCaseId()); } @@ -298,13 +339,16 @@ public class TaskRecordServiceImpl extends ServiceImpl taskCaseRecords = taskCaseRecordService.queryByTaskId(taskId); String taskStatus = this.determineStatus(taskCaseRecords); - log.info("completeTask:任务ID:【{}】,初始任务状态:【{}】,计算后任务状态:【{}】", taskId, taskCaseRecord.getStatus(), taskStatus); + log.info("completeTask:任务ID:【{}】,初始任务案件状态:【{}】,计算后任务状态:【{}】", taskId, taskCaseRecord.getStatus(), taskStatus); if (!StrUtil.equals(taskStatus, taskRecord.getStatus())) { taskRecord.setStatus(taskStatus); super.updateById(taskRecord);