From 732a1303b19f4f5ec6b45c6cd2f0507a8712036e Mon Sep 17 00:00:00 2001 From: liu Date: Mon, 5 Aug 2024 10:43:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/police/domain/CaseTaskRecord.java | 4 ++++ .../police/mapper/CaseTaskRecordMapper.java | 5 +++++ .../police/service/CaseTaskRecordService.java | 4 ++++ .../service/impl/CaseTaskRecordServiceImpl.java | 13 +++++++++++++ .../service/impl/ExtractTripleInfoServiceImpl.java | 6 ++++++ .../service/impl/RecordSplitTypeServiceImpl.java | 8 ++++++++ .../supervision/thread/RecordSplitTypeThread.java | 7 ++++++- src/main/resources/mapper/CaseTaskRecordMapper.xml | 13 +++++++++++++ 8 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/supervision/police/domain/CaseTaskRecord.java b/src/main/java/com/supervision/police/domain/CaseTaskRecord.java index 4c63c0f..7b57acf 100644 --- a/src/main/java/com/supervision/police/domain/CaseTaskRecord.java +++ b/src/main/java/com/supervision/police/domain/CaseTaskRecord.java @@ -35,6 +35,10 @@ public class CaseTaskRecord implements Serializable { */ private Integer status; + private Integer taskCount; + + private Integer finishCount; + /** * 提交日期 */ diff --git a/src/main/java/com/supervision/police/mapper/CaseTaskRecordMapper.java b/src/main/java/com/supervision/police/mapper/CaseTaskRecordMapper.java index ed94c7a..a03dfe6 100644 --- a/src/main/java/com/supervision/police/mapper/CaseTaskRecordMapper.java +++ b/src/main/java/com/supervision/police/mapper/CaseTaskRecordMapper.java @@ -2,6 +2,7 @@ package com.supervision.police.mapper; import com.supervision.police.domain.CaseTaskRecord; import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.apache.ibatis.annotations.Param; /** * @author Administrator @@ -11,6 +12,10 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; */ public interface CaseTaskRecordMapper extends BaseMapper { + void taskCountIncrement(@Param("caseId") String caseId,@Param("recordId") String recordId); + + void finishCountIncrement(@Param("caseId") String caseId,@Param("recordId") String recordId); + } diff --git a/src/main/java/com/supervision/police/service/CaseTaskRecordService.java b/src/main/java/com/supervision/police/service/CaseTaskRecordService.java index 9a06c7f..ffac810 100644 --- a/src/main/java/com/supervision/police/service/CaseTaskRecordService.java +++ b/src/main/java/com/supervision/police/service/CaseTaskRecordService.java @@ -10,4 +10,8 @@ import com.baomidou.mybatisplus.extension.service.IService; */ public interface CaseTaskRecordService extends IService { + void taskCountIncrement(String caseId,String recordId); + + void finishCountIncrement(String caseId, String recordId); + } diff --git a/src/main/java/com/supervision/police/service/impl/CaseTaskRecordServiceImpl.java b/src/main/java/com/supervision/police/service/impl/CaseTaskRecordServiceImpl.java index 5762e4b..e0d3830 100644 --- a/src/main/java/com/supervision/police/service/impl/CaseTaskRecordServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/CaseTaskRecordServiceImpl.java @@ -5,6 +5,8 @@ import com.supervision.police.domain.CaseTaskRecord; import com.supervision.police.service.CaseTaskRecordService; import com.supervision.police.mapper.CaseTaskRecordMapper; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Propagation; +import org.springframework.transaction.annotation.Transactional; /** * @author Administrator @@ -15,6 +17,17 @@ import org.springframework.stereotype.Service; public class CaseTaskRecordServiceImpl extends ServiceImpl implements CaseTaskRecordService{ + @Override + @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class,propagation = Propagation.NOT_SUPPORTED) + public synchronized void taskCountIncrement(String caseId, String recordId) { + this.baseMapper.taskCountIncrement(caseId, recordId); + } + + @Override + @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class,propagation = Propagation.NOT_SUPPORTED) + public synchronized void finishCountIncrement(String caseId, String recordId) { + this.baseMapper.finishCountIncrement(caseId, recordId); + } } diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index 69350aa..b0d3510 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -98,6 +98,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); continue; } + // task+1 + caseTaskRecordService.taskCountIncrement(caseId, recordSplit.getNoteRecordId()); try { log.info("提交任务到线程池中进行三元组提取"); TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordSplit.getNoteRecordId(), recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); @@ -128,6 +130,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { try { // 如果提取到结果,且不为空,就进行保存 if (future.isDone()) { + // 完成+1 + caseTaskRecordService.finishCountIncrement(caseId, recordSplit.getNoteRecordId()); TripleInfo tripleInfo = future.get(); if (tripleInfo != null) { tripleInfos.add(tripleInfo); @@ -146,6 +150,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { // 将还在执行的线程中断 futures.forEach(future -> { future.cancel(true); + // 完成+1 + caseTaskRecordService.finishCountIncrement(caseId, recordSplit.getNoteRecordId()); }); break; } diff --git a/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java index 6d5a0c6..6b97e29 100644 --- a/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java @@ -52,9 +52,12 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { log.error("分类任务线程休眠失败"); } List> futures = new ArrayList<>(); + for (NoteRecordSplit recordSplit : splitList) { // 进行分类 log.info("分类任务提交线程池进行分类"); + // 任务+1 + caseTaskRecordService.taskCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId()); RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, recordSplit, chatClient, noteRecordSplitService); // 分类之后的id Future afterTypeSplitIdFuture = RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread); @@ -71,6 +74,8 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { try { // 如果分类成功,就开始提取三元组 if (future.isDone()) { + // 完成+1 + splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId())); String afterTypeSplitId = future.get(); if (StrUtil.isNotBlank(afterTypeSplitId)) { Optional optById = noteRecordSplitService.getOptById(afterTypeSplitId); @@ -93,6 +98,9 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { // 将还在执行的线程中断 futures.forEach(future -> { future.cancel(true); + // 完成+1 + splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId())); + }); break; } diff --git a/src/main/java/com/supervision/thread/RecordSplitTypeThread.java b/src/main/java/com/supervision/thread/RecordSplitTypeThread.java index 6db7712..d489461 100644 --- a/src/main/java/com/supervision/thread/RecordSplitTypeThread.java +++ b/src/main/java/com/supervision/thread/RecordSplitTypeThread.java @@ -6,6 +6,7 @@ import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.service.CaseTaskRecordService; import com.supervision.police.service.NoteRecordSplitService; import com.supervision.springaidemo.dto.QARecordNodeDTO; import lombok.Data; @@ -41,7 +42,9 @@ public class RecordSplitTypeThread implements Callable { private final NoteRecordSplitService noteRecordSplitService; - public RecordSplitTypeThread(List allTypeList, NoteRecordSplit noteRecordSplit, OllamaChatClient chatClient, NoteRecordSplitService noteRecordSplitService) { + + public RecordSplitTypeThread(List allTypeList, NoteRecordSplit noteRecordSplit, OllamaChatClient chatClient, + NoteRecordSplitService noteRecordSplitService) { this.allTypeList = allTypeList; this.chatClient = chatClient; this.noteRecordSplitService = noteRecordSplitService; @@ -90,6 +93,7 @@ public class RecordSplitTypeThread implements Callable { public String call() throws Exception { String type; try { + StopWatch stopWatch = new StopWatch(); // 首先拼接分类模板 List typeContextList = new ArrayList<>(); @@ -127,6 +131,7 @@ public class RecordSplitTypeThread implements Callable { type = "无"; } noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecordSplit.getId()).update(); + return noteRecordSplit.getId(); } diff --git a/src/main/resources/mapper/CaseTaskRecordMapper.xml b/src/main/resources/mapper/CaseTaskRecordMapper.xml index a9821fc..5fc0a68 100644 --- a/src/main/resources/mapper/CaseTaskRecordMapper.xml +++ b/src/main/resources/mapper/CaseTaskRecordMapper.xml @@ -16,4 +16,17 @@ id,case_id,record_id, status,submit_time + + UPDATE case_task_record + SET task_count = task_count + 1 + WHERE case_id = #{caseId} and record_id = #{recordId} + + + + UPDATE case_task_record + SET finish_count = finish_count + 1 + WHERE case_id = #{caseId} and record_id = #{recordId} + + +