全量提示词一键提取初版,追加日志

master
yaxin 3 months ago
parent 5997fef737
commit a7cb70e44f

@ -12,6 +12,7 @@ public class TaskRecordConstants {
public static final String TASK_TYPE_SPECIFIED_CASE = "1"; public static final String TASK_TYPE_SPECIFIED_CASE = "1";
public static final String TASK_TYPE_SPECIFIED_RECORD = "2"; public static final String TASK_TYPE_SPECIFIED_RECORD = "2";
public static final String TASK_TYPE_SPECIFIED_EVIDENCE = "3"; public static final String TASK_TYPE_SPECIFIED_EVIDENCE = "3";
public static final String TASK_TYPE_ONE_CLICK = "4";
// 任务状态 // 任务状态
public static final String TASK_STATUS_WAITING = "0"; public static final String TASK_STATUS_WAITING = "0";
public static final String TASK_STATUS_PROCESSING = "1"; public static final String TASK_STATUS_PROCESSING = "1";

@ -13,6 +13,7 @@ import com.xxl.job.core.handler.annotation.XxlJob;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.StopWatch;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -20,6 +21,7 @@ import java.util.Map;
import static com.supervision.common.constant.NotePromptConstants.TYPE_GRAPH_REASONING; import static com.supervision.common.constant.NotePromptConstants.TYPE_GRAPH_REASONING;
import static com.supervision.common.constant.NotePromptConstants.TYPE_STRUCTURAL_REASONING; import static com.supervision.common.constant.NotePromptConstants.TYPE_STRUCTURAL_REASONING;
import static com.supervision.common.constant.TaskRecordConstants.*; import static com.supervision.common.constant.TaskRecordConstants.*;
import static com.supervision.common.constant.XxlJobConstants.TASK_NAME_PROMPT_EXTRACT_TASK;
@Slf4j @Slf4j
@Component @Component
@ -55,17 +57,19 @@ public class XxlJobTask {
/** /**
* *
*/ */
@XxlJob(XxlJobConstants.TASK_NAME_PROMPT_EXTRACT_TASK) @XxlJob(TASK_NAME_PROMPT_EXTRACT_TASK)
public void promptExtractTask() { public void promptExtractTask() {
String jobParam = XxlJobHelper.getJobParam(); String jobParam = XxlJobHelper.getJobParam();
log.info("【提取任务】任务开始。参数: {}", jobParam); log.info("【提取任务】任务开始。参数: {}", jobParam);
StopWatch stopWatch = new StopWatch(TASK_NAME_PROMPT_EXTRACT_TASK + " stopwatch");
stopWatch.start("Data preparation and task status check");
Map<String, String> map = JSON.parseObject(XxlJobHelper.getJobParam(), Map.class); Map<String, String> map = JSON.parseObject(XxlJobHelper.getJobParam(), Map.class);
String taskId = map.get("taskId");
String caseId = map.get("caseId");
String promptId = map.get("promptId");
String executeId = map.get("executeId");
try { try {
String taskId = map.get("taskId"); NotePrompt prompt = notePromptService.getBaseMapper().selectById(promptId);
String caseId = map.get("caseId");
String promptId = map.get("promptId");
String executeId = map.get("executeId");
NotePrompt prompt = notePromptService.getById(promptId);
boolean executable = true; boolean executable = true;
TaskRecord taskRecord = taskRecordService.getById(taskId); TaskRecord taskRecord = taskRecordService.getById(taskId);
switch (taskRecord.getStatus()) { switch (taskRecord.getStatus()) {
@ -92,15 +96,18 @@ public class XxlJobTask {
default: default:
break; break;
} }
stopWatch.stop();
if (executable) { if (executable) {
stopWatch.start("Model case analysis status check");
this.modelCaseAnalyzingStatusCheck(caseId); this.modelCaseAnalyzingStatusCheck(caseId);
TaskCaseRecord taskCaseRecord = taskCaseRecordService.lambdaQuery().eq(TaskCaseRecord::getCaseId, caseId).eq(TaskCaseRecord::getTaskRecordId, taskId).one(); TaskCaseRecord taskCaseRecord = taskCaseRecordService.lambdaQuery().eq(TaskCaseRecord::getCaseId, caseId).eq(TaskCaseRecord::getTaskRecordId, taskId).eq(TaskCaseRecord::getPromptId, promptId).one();
if (TASK_STATUS_WAITING.equals(taskCaseRecord.getStatus())) { if (TASK_STATUS_WAITING.equals(taskCaseRecord.getStatus())) {
log.info("任务状态为等待中任务状态更新为处理中任务案件ID: 【{}】", taskCaseRecord.getId()); log.info("任务状态为等待中任务状态更新为处理中任务案件ID: 【{}】", taskCaseRecord.getId());
taskCaseRecord.setStatus(TASK_STATUS_PROCESSING); taskCaseRecord.setStatus(TASK_STATUS_PROCESSING);
taskCaseRecordService.updateById(taskCaseRecord); taskCaseRecordService.updateById(taskCaseRecord);
} }
stopWatch.stop();
stopWatch.start("Extract task execution");
switch (prompt.getType()) { switch (prompt.getType()) {
case TYPE_GRAPH_REASONING: case TYPE_GRAPH_REASONING:
log.info("【图推理】任务开始。任务ID: 【{}】", taskId); log.info("【图推理】任务开始。任务ID: 【{}】", taskId);
@ -114,14 +121,17 @@ public class XxlJobTask {
log.error("未知的任务类型"); log.error("未知的任务类型");
break; break;
} }
stopWatch.stop();
stopWatch.start("Complete task");
taskRecordService.completeTask(taskId, map.get("executeId"), true); taskRecordService.completeTask(taskId, map.get("executeId"), true);
log.info("【提取任务】任务结束。任务ID: 【{}】", taskId); stopWatch.stop();
} }
} catch (Exception e) { } catch (Exception e) {
log.error("任务执行失败", e); stopWatch.stop();
log.error("任务执行失败Task ID:{}", taskId, e);
taskRecordService.completeTask(map.get("taskId"), map.get("executeId"), false); taskRecordService.completeTask(map.get("taskId"), map.get("executeId"), false);
} finally { } finally {
log.info("【提取任务】任务结束。"); log.info("【提取任务】任务结束。任务ID: 【{}】。耗时:{}", taskId, stopWatch.prettyPrint());
} }
} }

@ -22,7 +22,6 @@ public class TaskRecordController {
private final TaskRecordService taskRecordService; private final TaskRecordService taskRecordService;
@Operation(summary = "执行提示词提取任务") @Operation(summary = "执行提示词提取任务")
@PostMapping("/executePromptExtractTask") @PostMapping("/executePromptExtractTask")
public R<?> executePromptExtractTask(@RequestBody TaskRecordVo taskRecordVo) { public R<?> executePromptExtractTask(@RequestBody TaskRecordVo taskRecordVo) {
@ -30,6 +29,13 @@ public class TaskRecordController {
return R.ok(); return R.ok();
} }
@Operation(summary = "一键提取任务")
@PostMapping("/executeAllPromptExtractTask")
public R<?> executeAllPromptExtractTask() {
taskRecordService.executeAllPromptExtractTask();
return R.ok();
}
@Operation(summary = "查询任务列表") @Operation(summary = "查询任务列表")
@PostMapping("/taskList") @PostMapping("/taskList")

@ -32,6 +32,11 @@ public class TaskCaseRecord implements Serializable {
*/ */
private String caseId; private String caseId;
/**
* ID
*/
private String promptId;
/** /**
* ID * ID
*/ */

@ -20,6 +20,8 @@ public interface TaskRecordService extends IService<TaskRecord> {
void executePromptExtractTask(TaskRecordVo taskRecordVo); void executePromptExtractTask(TaskRecordVo taskRecordVo);
void executeAllPromptExtractTask();
void graphExtract(NotePrompt prompt, String caseId, String executeId); void graphExtract(NotePrompt prompt, String caseId, String executeId);
/** /**

@ -22,6 +22,7 @@ import org.jetbrains.annotations.NotNull;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
@ -70,14 +71,14 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
taskRecord.setName(this.generateTaskName(taskRecord.getType())); taskRecord.setName(this.generateTaskName(taskRecord.getType()));
super.save(taskRecord); super.save(taskRecord);
try { try {
NotePrompt prompt = notePromptService.getById(taskRecordVo.getPromptId()); NotePrompt prompt = notePromptService.getById(taskRecord.getPromptId());
List<ModelCase> modelCases = this.getModelCases(taskRecordVo, taskRecord); List<ModelCase> modelCases = this.getModelCases(taskRecord);
if (!CollUtil.isEmpty(modelCases)) { if (!CollUtil.isEmpty(modelCases)) {
for (ModelCase modelCase : modelCases) { for (ModelCase modelCase : modelCases) {
String caseId = modelCase.getId(); String caseId = modelCase.getId();
List<String> ids = this.getIds(taskRecordVo, taskRecord, prompt, caseId); List<String> ids = this.getIds(taskRecord, caseId, prompt.getType());
if (!ids.isEmpty()) { if (!ids.isEmpty()) {
this.invokeXxlJob(taskRecordVo, taskRecord, caseId, ids); this.invokeXxlJob(taskRecord, prompt.getId(), caseId, ids);
} else { } else {
log.info("案件【{}】没有笔录", caseId); log.info("案件【{}】没有笔录", caseId);
} }
@ -92,6 +93,48 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
} }
} }
@Override
public void executeAllPromptExtractTask() {
//查出所有图谱推理和结构化推理的提示词
log.info("开始触发一键提取任务");
StopWatch stopWatch = new StopWatch("One-click extraction stopwatch");
stopWatch.start();
TaskRecord taskRecord = new TaskRecord();
taskRecord.setType(TASK_TYPE_ONE_CLICK);
taskRecord.setName(this.generateTaskName(taskRecord.getType()));
super.save(taskRecord);
notePromptService.list(new LambdaQueryWrapper<NotePrompt>().eq(NotePrompt::getType, TYPE_GRAPH_REASONING).or().eq(NotePrompt::getType, TYPE_STRUCTURAL_REASONING)).forEach(notePrompt -> {
try {
NotePrompt prompt = notePromptService.getById(notePrompt.getId());
List<ModelCase> modelCases = this.getModelCases(taskRecord);
if (!CollUtil.isEmpty(modelCases)) {
//异步调用xxl-job执行任务
Thread thread = new Thread(() -> {
for (ModelCase modelCase : modelCases) {
String caseId = modelCase.getId();
List<String> ids = this.getIds(taskRecord, caseId, prompt.getType());
if (!ids.isEmpty()) {
taskRecord.setPromptId(prompt.getId());
this.invokeXxlJob(taskRecord, prompt.getId(), caseId, ids);
} else {
log.info("案件【{}】没有笔录或证据", caseId);
}
}
});
thread.start();
} else {
log.info("查无案件");
}
} catch (Exception e) {
taskRecord.setStatus(TASK_STATUS_FAIL);
super.updateById(taskRecord);
log.error("任务执行失败", e);
}
});
stopWatch.stop();
log.info("一键提取任务触发完成。耗时:{}", stopWatch.getTotalTimeSeconds());
}
@Override @Override
public void graphExtract(NotePrompt prompt, String caseId, String executeId) { public void graphExtract(NotePrompt prompt, String caseId, String executeId) {
List<TripleInfo> tripleInfos = extractTripleInfoService.extractTripleInfo(prompt, caseId, executeId); List<TripleInfo> tripleInfos = extractTripleInfoService.extractTripleInfo(prompt, caseId, executeId);
@ -113,25 +156,24 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
/** /**
* ID * ID
* *
* @param taskRecordVo * @param taskRecord
* @param taskRecord * @param caseId ID
* @param prompt * @param promptType
* @param caseId ID
* @return ID * @return ID
*/ */
private @NotNull List<String> getIds(TaskRecordVo taskRecordVo, TaskRecord taskRecord, NotePrompt prompt, String caseId) { private @NotNull List<String> getIds(TaskRecord taskRecord, String caseId, String promptType) {
//查出当前案件相关笔录或证据 //查出当前案件相关笔录或证据
List<String> ids = new ArrayList<>(); List<String> ids = new ArrayList<>();
//如果类型为指定笔录或证据直接取传入的id //如果类型为指定笔录或证据直接取传入的id
if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType())) { if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType())) {
ids = List.of(taskRecordVo.getRecordId().split(",")); ids = List.of(taskRecord.getRecordId().split(","));
} else if (TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) { } else if (TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) {
ids = List.of(taskRecordVo.getEvidenceId().split(",")); ids = List.of(taskRecord.getEvidenceId().split(","));
} else { } else {
//如果是案件维度根据案件ID查找笔录或证据 //如果是案件维度根据案件ID查找笔录或证据
if (TYPE_GRAPH_REASONING.equals(prompt.getType())) { if (TYPE_GRAPH_REASONING.equals(promptType)) {
ids = noteRecordService.lambdaQuery().eq(NoteRecord::getCaseId, caseId).eq(NoteRecord::getDataStatus, DataStatus.AVAILABLE.getCode()).list().stream().map(NoteRecord::getId).toList(); ids = noteRecordService.lambdaQuery().eq(NoteRecord::getCaseId, caseId).eq(NoteRecord::getDataStatus, DataStatus.AVAILABLE.getCode()).list().stream().map(NoteRecord::getId).toList();
} else if (TYPE_STRUCTURAL_REASONING.equals(prompt.getType())) { } else if (TYPE_STRUCTURAL_REASONING.equals(promptType)) {
ids = caseEvidenceService.lambdaQuery().eq(CaseEvidence::getCaseId, caseId).list().stream().map(CaseEvidence::getId).toList(); ids = caseEvidenceService.lambdaQuery().eq(CaseEvidence::getCaseId, caseId).list().stream().map(CaseEvidence::getId).toList();
} }
} }
@ -141,15 +183,15 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
/** /**
* xxl-job * xxl-job
* *
* @param taskRecordVo * @param taskRecord
* @param taskRecord * @param caseId ID
* @param caseId ID * @param ids ID
* @param ids ID
*/ */
private void invokeXxlJob(TaskRecordVo taskRecordVo, TaskRecord taskRecord, String caseId, List<String> ids) { private void invokeXxlJob(TaskRecord taskRecord, String promptId, String caseId, List<String> ids) {
TaskCaseRecord taskCaseRecord = new TaskCaseRecord(); TaskCaseRecord taskCaseRecord = new TaskCaseRecord();
taskCaseRecord.setTaskRecordId(taskRecord.getId()); taskCaseRecord.setTaskRecordId(taskRecord.getId());
taskCaseRecord.setCaseId(caseId); taskCaseRecord.setCaseId(caseId);
taskCaseRecord.setPromptId(promptId);
taskCaseRecord.setWaitingId(ids.stream().reduce((a, b) -> a + "," + b).orElse("")); taskCaseRecord.setWaitingId(ids.stream().reduce((a, b) -> a + "," + b).orElse(""));
taskCaseRecord.setStatus(TASK_STATUS_WAITING); taskCaseRecord.setStatus(TASK_STATUS_WAITING);
taskCaseRecordService.save(taskCaseRecord); taskCaseRecordService.save(taskCaseRecord);
@ -158,7 +200,7 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
params.put("taskId", taskRecord.getId()); params.put("taskId", taskRecord.getId());
params.put("caseId", caseId); params.put("caseId", caseId);
params.put("executeId", id); params.put("executeId", id);
params.put("promptId", taskRecordVo.getPromptId()); params.put("promptId", taskRecord.getPromptId());
//map转String作为参数 //map转String作为参数
xxlJobService.executeTaskByJobHandler(TASK_NAME_PROMPT_EXTRACT_TASK, new JSONObject(params).toString()); xxlJobService.executeTaskByJobHandler(TASK_NAME_PROMPT_EXTRACT_TASK, new JSONObject(params).toString());
} }
@ -167,17 +209,16 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
/** /**
* *
* *
* @param taskRecordVo * @param taskRecord
* @param taskRecord
* @return * @return
*/ */
private List<ModelCase> getModelCases(TaskRecordVo taskRecordVo, TaskRecord taskRecord) { private List<ModelCase> getModelCases(TaskRecord taskRecord) {
List<ModelCase> modelCases; List<ModelCase> modelCases;
LambdaQueryWrapper<ModelCase> queryWrapper = new LambdaQueryWrapper<>(); LambdaQueryWrapper<ModelCase> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(ModelCase::getDataStatus, DataStatus.AVAILABLE.getCode()); queryWrapper.eq(ModelCase::getDataStatus, DataStatus.AVAILABLE.getCode());
//根据任务类型查找案件 //根据任务类型查找案件
if (TASK_TYPE_SPECIFIED_CASE.equals(taskRecord.getType())) { if (TASK_TYPE_SPECIFIED_CASE.equals(taskRecord.getType())) {
queryWrapper.in(ModelCase::getId, List.of(taskRecordVo.getCaseId().split(","))); queryWrapper.in(ModelCase::getId, List.of(taskRecord.getCaseId().split(",")));
} else if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType()) || TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) { } else if (TASK_TYPE_SPECIFIED_RECORD.equals(taskRecord.getType()) || TASK_TYPE_SPECIFIED_EVIDENCE.equals(taskRecord.getType())) {
queryWrapper.eq(ModelCase::getId, taskRecord.getCaseId()); queryWrapper.eq(ModelCase::getId, taskRecord.getCaseId());
} }
@ -298,13 +339,16 @@ public class TaskRecordServiceImpl extends ServiceImpl<TaskRecordMapper, TaskRec
Boolean success = taskCaseRecordService.updateStatus(taskId, List.of(TASK_STATUS_WAITING, TASK_STATUS_PROCESSING), TASK_STATUS_CANCELED); Boolean success = taskCaseRecordService.updateStatus(taskId, List.of(TASK_STATUS_WAITING, TASK_STATUS_PROCESSING), TASK_STATUS_CANCELED);
log.info("completeTask:任务状态更新完成,task_case数据任务状态【{}】变动任务ID: 【{}】", taskId, success ? "产生" : "无"); log.info("completeTask:任务状态更新完成,task_case数据任务状态【{}】变动任务ID: 【{}】", taskId, success ? "产生" : "无");
taskRecord.setStatus(TASK_STATUS_CANCELED); taskRecord.setStatus(TASK_STATUS_CANCELED);
this.updateById(taskRecord);
if (!StrUtil.equals(TASK_STATUS_CANCELED, taskRecord.getStatus())) {
this.updateById(taskRecord);
}
return; return;
} }
if (StrUtil.equalsAny(taskRecord.getStatus(), TASK_STATUS_WAITING, TASK_STATUS_PROCESSING)) { if (StrUtil.equalsAny(taskRecord.getStatus(), TASK_STATUS_WAITING, TASK_STATUS_PROCESSING)) {
List<TaskCaseRecord> taskCaseRecords = taskCaseRecordService.queryByTaskId(taskId); List<TaskCaseRecord> taskCaseRecords = taskCaseRecordService.queryByTaskId(taskId);
String taskStatus = this.determineStatus(taskCaseRecords); String taskStatus = this.determineStatus(taskCaseRecords);
log.info("completeTask:任务ID:【{}】,初始任务状态:【{}】,计算后任务状态:【{}】", taskId, taskCaseRecord.getStatus(), taskStatus); log.info("completeTask:任务ID:【{}】,初始任务案件状态:【{}】,计算后任务状态:【{}】", taskId, taskCaseRecord.getStatus(), taskStatus);
if (!StrUtil.equals(taskStatus, taskRecord.getStatus())) { if (!StrUtil.equals(taskStatus, taskRecord.getStatus())) {
taskRecord.setStatus(taskStatus); taskRecord.setStatus(taskStatus);
super.updateById(taskRecord); super.updateById(taskRecord);

Loading…
Cancel
Save