Merge remote-tracking branch 'origin/dev_1.0.0' into dev_1.0.0

topo_dev
xueqingkun 9 months ago
commit c2f384c8a4

@ -2,6 +2,8 @@ package com.supervision.neo4j.service.impl;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.json.JSON;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelReader; import cn.hutool.poi.excel.ExcelReader;
import cn.hutool.poi.excel.ExcelUtil; import cn.hutool.poi.excel.ExcelUtil;
import com.supervision.common.domain.R; import com.supervision.common.domain.R;
@ -20,6 +22,9 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.*; import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/** /**
* @author qmy * @author qmy

@ -35,6 +35,10 @@ public class CaseTaskRecord implements Serializable {
*/ */
private Integer status; private Integer status;
private Integer taskCount;
private Integer finishCount;
/** /**
* *
*/ */

@ -2,6 +2,7 @@ package com.supervision.police.mapper;
import com.supervision.police.domain.CaseTaskRecord; import com.supervision.police.domain.CaseTaskRecord;
import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Param;
/** /**
* @author Administrator * @author Administrator
@ -11,6 +12,10 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper;
*/ */
public interface CaseTaskRecordMapper extends BaseMapper<CaseTaskRecord> { public interface CaseTaskRecordMapper extends BaseMapper<CaseTaskRecord> {
void taskCountIncrement(@Param("caseId") String caseId,@Param("recordId") String recordId);
void finishCountIncrement(@Param("caseId") String caseId,@Param("recordId") String recordId);
} }

@ -10,4 +10,8 @@ import com.baomidou.mybatisplus.extension.service.IService;
*/ */
public interface CaseTaskRecordService extends IService<CaseTaskRecord> { public interface CaseTaskRecordService extends IService<CaseTaskRecord> {
void taskCountIncrement(String caseId,String recordId);
void finishCountIncrement(String caseId, String recordId);
} }

@ -5,6 +5,8 @@ import com.supervision.police.domain.CaseTaskRecord;
import com.supervision.police.service.CaseTaskRecordService; import com.supervision.police.service.CaseTaskRecordService;
import com.supervision.police.mapper.CaseTaskRecordMapper; import com.supervision.police.mapper.CaseTaskRecordMapper;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
/** /**
* @author Administrator * @author Administrator
@ -15,6 +17,17 @@ import org.springframework.stereotype.Service;
public class CaseTaskRecordServiceImpl extends ServiceImpl<CaseTaskRecordMapper, CaseTaskRecord> public class CaseTaskRecordServiceImpl extends ServiceImpl<CaseTaskRecordMapper, CaseTaskRecord>
implements CaseTaskRecordService{ implements CaseTaskRecordService{
@Override
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class,propagation = Propagation.NOT_SUPPORTED)
public synchronized void taskCountIncrement(String caseId, String recordId) {
this.baseMapper.taskCountIncrement(caseId, recordId);
}
@Override
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class,propagation = Propagation.NOT_SUPPORTED)
public synchronized void finishCountIncrement(String caseId, String recordId) {
this.baseMapper.finishCountIncrement(caseId, recordId);
}
} }

@ -98,6 +98,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId());
continue; continue;
} }
// task+1
caseTaskRecordService.taskCountIncrement(caseId, recordSplit.getNoteRecordId());
try { try {
log.info("提交任务到线程池中进行三元组提取"); log.info("提交任务到线程池中进行三元组提取");
TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordSplit.getNoteRecordId(), recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordSplit.getNoteRecordId(), recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer());
@ -128,6 +130,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
try { try {
// 如果提取到结果,且不为空,就进行保存 // 如果提取到结果,且不为空,就进行保存
if (future.isDone()) { if (future.isDone()) {
// 完成+1
caseTaskRecordService.finishCountIncrement(caseId, recordSplit.getNoteRecordId());
TripleInfo tripleInfo = future.get(); TripleInfo tripleInfo = future.get();
if (tripleInfo != null) { if (tripleInfo != null) {
tripleInfos.add(tripleInfo); tripleInfos.add(tripleInfo);
@ -146,6 +150,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
// 将还在执行的线程中断 // 将还在执行的线程中断
futures.forEach(future -> { futures.forEach(future -> {
future.cancel(true); future.cancel(true);
// 完成+1
caseTaskRecordService.finishCountIncrement(caseId, recordSplit.getNoteRecordId());
}); });
break; break;
} }

@ -52,9 +52,12 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
log.error("分类任务线程休眠失败"); log.error("分类任务线程休眠失败");
} }
List<Future<String>> futures = new ArrayList<>(); List<Future<String>> futures = new ArrayList<>();
for (NoteRecordSplit recordSplit : splitList) { for (NoteRecordSplit recordSplit : splitList) {
// 进行分类 // 进行分类
log.info("分类任务提交线程池进行分类"); log.info("分类任务提交线程池进行分类");
// 任务+1
caseTaskRecordService.taskCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId());
RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, recordSplit, chatClient, noteRecordSplitService); RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, recordSplit, chatClient, noteRecordSplitService);
// 分类之后的id // 分类之后的id
Future<String> afterTypeSplitIdFuture = RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread); Future<String> afterTypeSplitIdFuture = RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread);
@ -71,6 +74,8 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
try { try {
// 如果分类成功,就开始提取三元组 // 如果分类成功,就开始提取三元组
if (future.isDone()) { if (future.isDone()) {
// 完成+1
splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId()));
String afterTypeSplitId = future.get(); String afterTypeSplitId = future.get();
if (StrUtil.isNotBlank(afterTypeSplitId)) { if (StrUtil.isNotBlank(afterTypeSplitId)) {
Optional<NoteRecordSplit> optById = noteRecordSplitService.getOptById(afterTypeSplitId); Optional<NoteRecordSplit> optById = noteRecordSplitService.getOptById(afterTypeSplitId);
@ -93,6 +98,9 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
// 将还在执行的线程中断 // 将还在执行的线程中断
futures.forEach(future -> { futures.forEach(future -> {
future.cancel(true); future.cancel(true);
// 完成+1
splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId()));
}); });
break; break;
} }

@ -6,6 +6,7 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.CaseTaskRecordService;
import com.supervision.police.service.NoteRecordSplitService; import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.dto.QARecordNodeDTO;
import lombok.Data; import lombok.Data;
@ -41,7 +42,9 @@ public class RecordSplitTypeThread implements Callable<String> {
private final NoteRecordSplitService noteRecordSplitService; private final NoteRecordSplitService noteRecordSplitService;
public RecordSplitTypeThread(List<ModelRecordType> allTypeList, NoteRecordSplit noteRecordSplit, OllamaChatClient chatClient, NoteRecordSplitService noteRecordSplitService) {
public RecordSplitTypeThread(List<ModelRecordType> allTypeList, NoteRecordSplit noteRecordSplit, OllamaChatClient chatClient,
NoteRecordSplitService noteRecordSplitService) {
this.allTypeList = allTypeList; this.allTypeList = allTypeList;
this.chatClient = chatClient; this.chatClient = chatClient;
this.noteRecordSplitService = noteRecordSplitService; this.noteRecordSplitService = noteRecordSplitService;
@ -90,6 +93,7 @@ public class RecordSplitTypeThread implements Callable<String> {
public String call() throws Exception { public String call() throws Exception {
String type; String type;
try { try {
StopWatch stopWatch = new StopWatch(); StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板 // 首先拼接分类模板
List<String> typeContextList = new ArrayList<>(); List<String> typeContextList = new ArrayList<>();
@ -127,6 +131,7 @@ public class RecordSplitTypeThread implements Callable<String> {
type = "无"; type = "无";
} }
noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecordSplit.getId()).update(); noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecordSplit.getId()).update();
return noteRecordSplit.getId(); return noteRecordSplit.getId();
} }

@ -16,4 +16,17 @@
id,case_id,record_id, id,case_id,record_id,
status,submit_time status,submit_time
</sql> </sql>
<update id="taskCountIncrement">
UPDATE case_task_record
SET task_count = task_count + 1
WHERE case_id = #{caseId} and record_id = #{recordId}
</update>
<update id="finishCountIncrement">
UPDATE case_task_record
SET finish_count = finish_count + 1
WHERE case_id = #{caseId} and record_id = #{recordId}
</update>
</mapper> </mapper>

Loading…
Cancel
Save