169 lines
8.0 KiB
Java
169 lines
8.0 KiB
Java
package com.supervision.police.service.impl;
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
import cn.hutool.core.util.StrUtil;
|
|
import com.alibaba.druid.util.StringUtils;
|
|
import com.supervision.police.domain.*;
|
|
import com.supervision.police.mapper.NotePromptMapper;
|
|
import com.supervision.police.mapper.NoteRecordSplitMapper;
|
|
import com.supervision.police.mapper.TripleInfoMapper;
|
|
import com.supervision.police.service.*;
|
|
import com.supervision.thread.TripleExtractThread;
|
|
import com.supervision.thread.TripleExtractThreadPool;
|
|
import lombok.RequiredArgsConstructor;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.json.JSONArray;
|
|
import org.json.JSONObject;
|
|
import org.springframework.ai.chat.ChatResponse;
|
|
import org.springframework.ai.chat.messages.UserMessage;
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
import org.springframework.ai.ollama.OllamaChatClient;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.scheduling.annotation.Async;
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.util.StopWatch;
|
|
|
|
import java.time.LocalDateTime;
|
|
import java.util.ArrayList;
|
|
import java.util.Iterator;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.concurrent.Future;
|
|
import java.util.concurrent.atomic.AtomicInteger;
|
|
import java.util.stream.Collectors;
|
|
|
|
@Slf4j
|
|
@Service
|
|
@RequiredArgsConstructor
|
|
public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
|
|
|
|
private final CaseTaskRecordService caseTaskRecordService;
|
|
|
|
private final ModelRecordTypeService modelRecordTypeService;
|
|
|
|
private final NotePromptService notePromptService;
|
|
|
|
private final TripleInfoService tripleInfoService;
|
|
|
|
private final OllamaChatClient chatClient;
|
|
|
|
@Autowired
|
|
private NoteRecordSplitService noteRecordSplitService;
|
|
|
|
|
|
@Async
|
|
public void extractTripleInfo(String caseId, String name, String recordId) {
|
|
// 首先获取所有切分后的笔录
|
|
List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)
|
|
.eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list();
|
|
// 获取所有的分类
|
|
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
|
|
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
|
|
|
|
List<TripleInfo> tripleInfos = new ArrayList<>();
|
|
List<Future<TripleInfo>> futures = new ArrayList<>();
|
|
// 对切分后的笔录进行遍历
|
|
for (NoteRecordSplit recordSplit : recordSplitList) {
|
|
String recordType = recordSplit.getRecordType();
|
|
if (StrUtil.isBlank(recordType)) {
|
|
log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId());
|
|
}
|
|
String[] split = recordType.split(";");
|
|
for (String typeName : split) {
|
|
String typeId = allTypeMap.get(typeName);
|
|
if (StrUtil.isBlank(typeId)) {
|
|
log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName);
|
|
} else {
|
|
// 根据笔录类型找到所有的提取三元组的提示词
|
|
// 一个提示词可能关联多个类型,要进行拆分操作
|
|
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list();
|
|
if (CollUtil.isEmpty(prompts)) {
|
|
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
|
|
} else {
|
|
// 遍历提示词进行提取
|
|
for (NotePrompt prompt : prompts) {
|
|
if (StringUtils.isEmpty(prompt.getPrompt())) {
|
|
log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId());
|
|
continue;
|
|
}
|
|
try {
|
|
log.info("提交任务到线程池中进行三元组提取");
|
|
TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer());
|
|
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread);
|
|
futures.add(submit);
|
|
log.info("任务提交成功");
|
|
} catch (Exception e) {
|
|
log.error(e.getMessage(), e);
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
try {
|
|
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
|
|
Thread.sleep(1000 * 5);
|
|
} catch (Exception e) {
|
|
log.error(e.getMessage(), e);
|
|
}
|
|
// 计数器
|
|
AtomicInteger atomicInteger = new AtomicInteger(0);
|
|
while (futures.size() > 0) {
|
|
Iterator<Future<TripleInfo>> iterator = futures.iterator();
|
|
while (iterator.hasNext()) {
|
|
Future<TripleInfo> future = iterator.next();
|
|
try {
|
|
// 如果提取到结果,且不为空,就进行保存
|
|
if (future.isDone()) {
|
|
TripleInfo tripleInfo = future.get();
|
|
if (tripleInfo != null) {
|
|
tripleInfos.add(tripleInfo);
|
|
}
|
|
iterator.remove();
|
|
}
|
|
} catch (Exception e) {
|
|
log.info("从线程中获取任务失败");
|
|
iterator.remove();
|
|
}
|
|
}
|
|
try {
|
|
int currentCount = atomicInteger.incrementAndGet();
|
|
if (currentCount > 1000) {
|
|
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
|
|
// 将还在执行的线程中断
|
|
futures.forEach(future -> {
|
|
future.cancel(true);
|
|
});
|
|
break;
|
|
}
|
|
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
|
|
Thread.sleep(1000 * 5);
|
|
} catch (Exception e) {
|
|
log.error(e.getMessage(), e);
|
|
}
|
|
}
|
|
// 如果有提取到三元组信息
|
|
if (CollUtil.isNotEmpty(tripleInfos)) {
|
|
// 首先清除现在已经提取过的三元组信息
|
|
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
|
|
// TODO 这里,如果已经生成了图谱,怎么办?
|
|
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
|
|
tripleInfoService.saveBatch(tripleInfos);
|
|
}
|
|
if (CollUtil.isNotEmpty(futures)) {
|
|
// 将任务标记为成功
|
|
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
|
|
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
|
|
.eq(CaseTaskRecord::getCaseId, caseId).update();
|
|
} else {
|
|
// 否则标记为失败
|
|
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
|
|
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
|
|
.eq(CaseTaskRecord::getCaseId, caseId).update();
|
|
}
|
|
log.info("三元组提取任务执行完毕,结束");
|
|
}
|
|
}
|