fu-hsi-service/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImp...

169 lines
8.0 KiB
Java

package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils;
import com.supervision.police.domain.*;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.TripleInfoMapper;
import com.supervision.police.service.*;
import com.supervision.thread.TripleExtractThread;
import com.supervision.thread.TripleExtractThreadPool;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final CaseTaskRecordService caseTaskRecordService;
private final ModelRecordTypeService modelRecordTypeService;
private final NotePromptService notePromptService;
private final TripleInfoService tripleInfoService;
private final OllamaChatClient chatClient;
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Async
public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)
.eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list();
// 获取所有的分类
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
List<TripleInfo> tripleInfos = new ArrayList<>();
List<Future<TripleInfo>> futures = new ArrayList<>();
// 对切分后的笔录进行遍历
for (NoteRecordSplit recordSplit : recordSplitList) {
String recordType = recordSplit.getRecordType();
if (StrUtil.isBlank(recordType)) {
log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId());
}
String[] split = recordType.split(";");
for (String typeName : split) {
String typeId = allTypeMap.get(typeName);
if (StrUtil.isBlank(typeId)) {
log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName);
} else {
// 根据笔录类型找到所有的提取三元组的提示词
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list();
if (CollUtil.isEmpty(prompts)) {
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
} else {
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId());
continue;
}
try {
log.info("提交任务到线程池中进行三元组提取");
TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer());
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread);
futures.add(submit);
log.info("任务提交成功");
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
}
}
}
try {
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
// 计数器
AtomicInteger atomicInteger = new AtomicInteger(0);
while (futures.size() > 0) {
Iterator<Future<TripleInfo>> iterator = futures.iterator();
while (iterator.hasNext()) {
Future<TripleInfo> future = iterator.next();
try {
// 如果提取到结果,且不为空,就进行保存
if (future.isDone()) {
TripleInfo tripleInfo = future.get();
if (tripleInfo != null) {
tripleInfos.add(tripleInfo);
}
iterator.remove();
}
} catch (Exception e) {
log.info("从线程中获取任务失败");
iterator.remove();
}
}
try {
int currentCount = atomicInteger.incrementAndGet();
if (currentCount > 1000) {
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
// 将还在执行的线程中断
futures.forEach(future -> {
future.cancel(true);
});
break;
}
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
// 如果有提取到三元组信息
if (CollUtil.isNotEmpty(tripleInfos)) {
// 首先清除现在已经提取过的三元组信息
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
// TODO 这里,如果已经生成了图谱,怎么办?
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
tripleInfoService.saveBatch(tripleInfos);
}
if (CollUtil.isNotEmpty(futures)) {
// 将任务标记为成功
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
} else {
// 否则标记为失败
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
}
log.info("三元组提取任务执行完毕,结束");
}
}