1. 笔录提取进度暂存
parent
0a94e03bad
commit
c9030869a6
@ -0,0 +1,16 @@
|
||||
package com.supervision.police.service;
|
||||
|
||||
import com.supervision.police.domain.ModelRecordType;
|
||||
import com.supervision.police.domain.NoteRecordSplit;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface RecordSplitClassifyService {
|
||||
|
||||
/**
|
||||
* 对问答对进行分类
|
||||
* @param allTypeList
|
||||
* @param splitList
|
||||
*/
|
||||
void classify(List<ModelRecordType> allTypeList, List<NoteRecordSplit> splitList);
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
package com.supervision.police.service;
|
||||
|
||||
import com.supervision.police.domain.ModelRecordType;
|
||||
import com.supervision.police.domain.NoteRecordSplit;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface RecordSplitTypeService {
|
||||
|
||||
void type(List<ModelRecordType> allTypeList, List<NoteRecordSplit> splitList);
|
||||
}
|
@ -0,0 +1,100 @@
|
||||
package com.supervision.police.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import com.supervision.police.domain.CaseTaskRecord;
|
||||
import com.supervision.police.domain.ModelRecordType;
|
||||
import com.supervision.police.domain.NoteRecordSplit;
|
||||
import com.supervision.police.service.CaseTaskRecordService;
|
||||
import com.supervision.police.service.ExtractTripleInfoService;
|
||||
import com.supervision.police.service.NoteRecordSplitService;
|
||||
import com.supervision.police.service.RecordSplitClassifyService;
|
||||
import com.supervision.thread.RecordSplitClassifyTask;
|
||||
import com.supervision.thread.RecordSplitClassifyThreadPool;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.ollama.OllamaChatClient;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class RecordSplitClassifyServiceImpl implements RecordSplitClassifyService {
|
||||
|
||||
private final OllamaChatClient chatClient;
|
||||
|
||||
private final NoteRecordSplitService noteRecordSplitService;
|
||||
|
||||
private final ExtractTripleInfoService extractTripleInfoService;
|
||||
|
||||
private final CaseTaskRecordService caseTaskRecordService;
|
||||
|
||||
@Async
|
||||
@Override
|
||||
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
|
||||
public void classify(List<ModelRecordType> allTypeList, List<NoteRecordSplit> splitList) {
|
||||
|
||||
if (CollUtil.isEmpty(splitList)){
|
||||
log.warn("classify:没有需要分类的笔录片段,停止分类操作...");
|
||||
return;
|
||||
}
|
||||
|
||||
log.info("classify:开始执行笔录分类任务,笔录片段个数:{}", splitList.size());
|
||||
List<RecordSplitClassifyTask> taskList = splitList.stream()
|
||||
.peek(recordSplit -> {
|
||||
caseTaskRecordService.taskCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId());
|
||||
log.info("classify:分类任务数量加1,笔录片段id:{}", recordSplit.getId());
|
||||
})
|
||||
.map(recordSplit -> new RecordSplitClassifyTask(allTypeList, recordSplit, chatClient, postClassifyRecord()))
|
||||
.toList();
|
||||
|
||||
if (CollUtil.isEmpty(taskList)){
|
||||
log.warn("classify:没有可用的分类任务,停止分类操作...");
|
||||
return;
|
||||
}
|
||||
try {
|
||||
log.info("classify:提交{}个分类任务....",taskList.size());
|
||||
RecordSplitClassifyThreadPool.executorService.invokeAll(taskList);
|
||||
} catch (Exception e) {
|
||||
log.error("classify:分类任务执行出现异常", e);
|
||||
}finally {
|
||||
setFinishStatus(CollUtil.getFirst(splitList));
|
||||
}
|
||||
}
|
||||
|
||||
private void setFinishStatus(NoteRecordSplit noteRecordSplit) {
|
||||
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).eq(CaseTaskRecord::getCaseId, noteRecordSplit.getCaseId())
|
||||
.eq(CaseTaskRecord::getRecordId, noteRecordSplit.getNoteRecordId()).update();
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private Consumer<NoteRecordSplit> postClassifyRecord() {
|
||||
return (noteRecordSplit) -> {
|
||||
// 更新笔录片段中的分类结果
|
||||
noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, noteRecordSplit.getRecordType())
|
||||
.eq(NoteRecordSplit::getId, noteRecordSplit.getId()).update();
|
||||
|
||||
caseTaskRecordService.taskCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId());
|
||||
|
||||
// 提取三元组信息
|
||||
try {
|
||||
extractTripleInfoService.extractTripleInfo(noteRecordSplit.getCaseId(), noteRecordSplit.getPersonName(), noteRecordSplit.getId());
|
||||
} catch (Exception e) {
|
||||
log.error("postClassifyRecord:提取三元组信息出现异常", e);
|
||||
}finally {
|
||||
// 更新分类任务完成数量加1
|
||||
log.info("postClassifyRecord:分类任务完成数量加1,笔录片段id:{}", noteRecordSplit.getId());
|
||||
caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId());
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
}
|
||||
}
|
@ -1,117 +0,0 @@
|
||||
package com.supervision.police.service.impl;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.supervision.police.domain.CaseTaskRecord;
|
||||
import com.supervision.police.domain.ModelRecordType;
|
||||
import com.supervision.police.domain.NoteRecordSplit;
|
||||
import com.supervision.police.service.CaseTaskRecordService;
|
||||
import com.supervision.police.service.ExtractTripleInfoService;
|
||||
import com.supervision.police.service.NoteRecordSplitService;
|
||||
import com.supervision.police.service.RecordSplitTypeService;
|
||||
import com.supervision.thread.RecordSplitTypeThread;
|
||||
import com.supervision.thread.RecordSplitTypeThreadPool;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.ollama.OllamaChatClient;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
|
||||
|
||||
private final OllamaChatClient chatClient;
|
||||
|
||||
private final NoteRecordSplitService noteRecordSplitService;
|
||||
|
||||
private final ExtractTripleInfoService extractTripleInfoService;
|
||||
|
||||
private final CaseTaskRecordService caseTaskRecordService;
|
||||
|
||||
@Async
|
||||
@Override
|
||||
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
|
||||
public void type(List<ModelRecordType> allTypeList, List<NoteRecordSplit> splitList) {
|
||||
// 这里线程休眠1秒,因为首先报保证消息记录能够插入完成,插入完成之后,再去提交大模型,让大模型去分类.防止分类太快,分类结果出来了,插入还没有插入完成
|
||||
try {
|
||||
Thread.sleep(1000);
|
||||
} catch (Exception e) {
|
||||
log.error("分类任务线程休眠失败");
|
||||
}
|
||||
List<Future<String>> futures = new ArrayList<>();
|
||||
log.info("开始执行笔录分类任务");
|
||||
for (NoteRecordSplit recordSplit : splitList) {
|
||||
// 进行分类
|
||||
log.info("分类任务提交线程池进行分类");
|
||||
// 任务+1
|
||||
caseTaskRecordService.taskCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId());
|
||||
RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, recordSplit, chatClient, noteRecordSplitService);
|
||||
// 分类之后的id
|
||||
Future<String> afterTypeSplitIdFuture = RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread);
|
||||
futures.add(afterTypeSplitIdFuture);
|
||||
log.info("分类任务线程池提交分类成功");
|
||||
}
|
||||
log.info("----------{}-----------", "分类任务全部提交成功了");
|
||||
// 校验分类任务是否完成,如果分类完成,那么就去提取三元组
|
||||
AtomicInteger atomicInteger = new AtomicInteger(0);
|
||||
while (futures.size() > 0) {
|
||||
Iterator<Future<String>> iterator = futures.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
Future<String> future = iterator.next();
|
||||
try {
|
||||
// 如果分类成功,就开始提取三元组
|
||||
if (future.isDone()) {
|
||||
// 完成+1
|
||||
splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId()));
|
||||
String afterTypeSplitId = future.get();
|
||||
if (StrUtil.isNotBlank(afterTypeSplitId)) {
|
||||
Optional<NoteRecordSplit> optById = noteRecordSplitService.getOptById(afterTypeSplitId);
|
||||
if (optById.isPresent()) {
|
||||
NoteRecordSplit recordSplit = optById.get();
|
||||
extractTripleInfoService.extractTripleInfo(recordSplit.getCaseId(), recordSplit.getPersonName(), afterTypeSplitId);
|
||||
}
|
||||
}
|
||||
iterator.remove();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.info("分类任务从线程中获取任务失败");
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
try {
|
||||
int currentCount = atomicInteger.incrementAndGet();
|
||||
if (currentCount > 1000) {
|
||||
log.info("分类任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
|
||||
// 将还在执行的线程中断
|
||||
futures.forEach(future -> {
|
||||
future.cancel(true);
|
||||
// 完成+1
|
||||
splitList.stream().findAny().ifPresent(noteRecordSplit -> caseTaskRecordService.finishCountIncrement(noteRecordSplit.getCaseId(), noteRecordSplit.getNoteRecordId()));
|
||||
|
||||
});
|
||||
break;
|
||||
}
|
||||
log.info("分类任务已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
|
||||
Thread.sleep(1000 * 5);
|
||||
} catch (Exception e) {
|
||||
log.error(e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
log.info("分类任务执行完毕");
|
||||
Optional<NoteRecordSplit> first = splitList.stream().findFirst();
|
||||
if (first.isPresent()) {
|
||||
NoteRecordSplit recordSplit = first.get();
|
||||
// 分类任务执行完成之后,就将任务进行更新
|
||||
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).eq(CaseTaskRecord::getCaseId, recordSplit.getCaseId())
|
||||
.eq(CaseTaskRecord::getRecordId, recordSplit.getNoteRecordId()).update();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue