package com.supervision.police.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.supervision.police.domain.*; import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.TripleInfoMapper; import com.supervision.police.service.*; import com.supervision.thread.TripleExtractThread; import com.supervision.thread.TripleExtractThreadPool; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONArray; import org.json.JSONObject; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.util.StopWatch; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @Slf4j @Service @RequiredArgsConstructor public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { private final CaseTaskRecordService caseTaskRecordService; private final ModelRecordTypeService modelRecordTypeService; private final NotePromptService notePromptService; private final TripleInfoService tripleInfoService; private final OllamaChatClient chatClient; @Autowired private NoteRecordSplitService noteRecordSplitService; @Async public void extractTripleInfo(String caseId, String name, String recordId) { // 首先获取所有切分后的笔录 List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) .eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list(); // 获取所有的分类 List allTypeList = modelRecordTypeService.list(); Map allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1)); List tripleInfos = new ArrayList<>(); List> futures = new ArrayList<>(); // 对切分后的笔录进行遍历 for (NoteRecordSplit recordSplit : recordSplitList) { String recordType = recordSplit.getRecordType(); if (StrUtil.isBlank(recordType)) { log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId()); } String[] split = recordType.split(";"); for (String typeName : split) { String typeId = allTypeMap.get(typeName); if (StrUtil.isBlank(typeId)) { log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName); } else { // 根据笔录类型找到所有的提取三元组的提示词 // 一个提示词可能关联多个类型,要进行拆分操作 List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list(); if (CollUtil.isEmpty(prompts)) { log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); } else { // 遍历提示词进行提取 for (NotePrompt prompt : prompts) { if (StringUtils.isEmpty(prompt.getPrompt())) { log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); continue; } try { log.info("提交任务到线程池中进行三元组提取"); TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); Future submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread); futures.add(submit); log.info("任务提交成功"); } catch (Exception e) { log.error(e.getMessage(), e); } } } } } } try { log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size()); Thread.sleep(1000 * 5); } catch (Exception e) { log.error(e.getMessage(), e); } // 计数器 AtomicInteger atomicInteger = new AtomicInteger(0); while (futures.size() > 0) { Iterator> iterator = futures.iterator(); while (iterator.hasNext()) { Future future = iterator.next(); try { // 如果提取到结果,且不为空,就进行保存 if (future.isDone()) { TripleInfo tripleInfo = future.get(); if (tripleInfo != null) { tripleInfos.add(tripleInfo); } iterator.remove(); } } catch (Exception e) { log.info("从线程中获取任务失败"); iterator.remove(); } } try { int currentCount = atomicInteger.incrementAndGet(); if (currentCount > 1000) { log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size()); // 将还在执行的线程中断 futures.forEach(future -> { future.cancel(true); }); break; } log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size()); Thread.sleep(1000 * 5); } catch (Exception e) { log.error(e.getMessage(), e); } } // 如果有提取到三元组信息 if (CollUtil.isNotEmpty(tripleInfos)) { // 首先清除现在已经提取过的三元组信息 tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); // TODO 这里,如果已经生成了图谱,怎么办? // 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 tripleInfoService.saveBatch(tripleInfos); } if (CollUtil.isNotEmpty(futures)) { // 将任务标记为成功 caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now()) .eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId) .eq(CaseTaskRecord::getCaseId, caseId).update(); } else { // 否则标记为失败 caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now()) .eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId) .eq(CaseTaskRecord::getCaseId, caseId).update(); } log.info("三元组提取任务执行完毕,结束"); } }