From e8f109b77c846946ff2f6bceee7bf39ea7d4e9ef Mon Sep 17 00:00:00 2001 From: liu Date: Thu, 1 Aug 2024 10:23:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../service/ExtractTripleInfoService.java | 2 +- .../impl/ExtractTripleInfoServiceImpl.java | 95 +++++++++---------- .../impl/ModelRecordTypeServiceImpl.java | 1 + .../impl/NoteRecordSplitServiceImpl.java | 43 +++++++-- .../impl/RecordSplitTypeServiceImpl.java | 14 ++- src/main/resources/application-dev.yml | 8 +- 6 files changed, 101 insertions(+), 62 deletions(-) diff --git a/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java b/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java index d9805e3..334393a 100644 --- a/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java +++ b/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java @@ -2,5 +2,5 @@ package com.supervision.police.service; public interface ExtractTripleInfoService { - void extractTripleInfo(String caseId, String name, String recordId); + void extractTripleInfo(String caseId, String name, String recordSplitId); } diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index f7504de..69350aa 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -2,6 +2,7 @@ package com.supervision.police.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.alibaba.druid.util.StringUtils; import com.supervision.police.domain.*; import com.supervision.police.mapper.NotePromptMapper; @@ -25,10 +26,7 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StopWatch; import java.time.LocalDateTime; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; @@ -56,10 +54,14 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { @Async @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class) - public void extractTripleInfo(String caseId, String name, String recordId) { + public void extractTripleInfo(String caseId, String name, String recordSplitId) { // 首先获取所有切分后的笔录 - List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordId, recordId) - .eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list(); + Optional optById = noteRecordSplitService.getOptById(recordSplitId); + if (optById.isEmpty()) { + log.info("{} 切分笔录不存在,跳过", recordSplitId); + return; + } + NoteRecordSplit recordSplit = optById.get(); // 获取所有的分类 List allTypeList = modelRecordTypeService.list(); Map allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1)); @@ -67,51 +69,49 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { List tripleInfos = new ArrayList<>(); List> futures = new ArrayList<>(); // 对切分后的笔录进行遍历 - for (NoteRecordSplit recordSplit : recordSplitList) { - String recordType = recordSplit.getRecordType(); - if (StrUtil.isBlank(recordType)) { - log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId()); - } - String[] split = recordType.split(";"); - for (String typeName : split) { - String typeId = allTypeMap.get(typeName); - if (StrUtil.isBlank(typeId)) { - log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName); + String recordType = recordSplit.getRecordType(); + if (StrUtil.isBlank(recordType)) { + log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId()); + } + String[] split = recordType.split(";"); + for (String typeName : split) { + String typeId = allTypeMap.get(typeName); + if (StrUtil.isBlank(typeId)) { + log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName); + } else { + // 根据笔录类型找到所有的提取三元组的提示词 + // 一个提示词可能关联多个类型,要进行拆分操作 + List promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list(); + if (CollUtil.isEmpty(promptTypeRelList)) { + log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); + continue; + } + List prompts = notePromptService.lambdaQuery() + .in(NotePrompt::getId, promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet())) + .list(); + if (CollUtil.isEmpty(prompts)) { + log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); } else { - // 根据笔录类型找到所有的提取三元组的提示词 - // 一个提示词可能关联多个类型,要进行拆分操作 - List promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list(); - if (CollUtil.isEmpty(promptTypeRelList)) { - log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); - continue; - } - List prompts = notePromptService.lambdaQuery() - .in(NotePrompt::getId, promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet())) - .list(); - if (CollUtil.isEmpty(prompts)) { - log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); - } else { - // 遍历提示词进行提取 - for (NotePrompt prompt : prompts) { - if (StringUtils.isEmpty(prompt.getPrompt())) { - log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); - continue; - } - try { - log.info("提交任务到线程池中进行三元组提取"); - TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); - Future submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread); - futures.add(submit); - log.info("任务提交成功"); - } catch (Exception e) { - log.error(e.getMessage(), e); - } + // 遍历提示词进行提取 + for (NotePrompt prompt : prompts) { + if (StringUtils.isEmpty(prompt.getPrompt())) { + log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); + continue; + } + try { + log.info("提交任务到线程池中进行三元组提取"); + TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordSplit.getNoteRecordId(), recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); + Future submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread); + futures.add(submit); + log.info("三元组提取任务提交成功"); + } catch (Exception e) { + log.error(e.getMessage(), e); } } - } } + } try { log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size()); @@ -157,10 +157,9 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { } // 如果有提取到三元组信息 if (CollUtil.isNotEmpty(tripleInfos)) { - // 首先清除现在已经提取过的三元组信息 - tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); for (TripleInfo tripleInfo : tripleInfos) { tripleInfoService.save(tripleInfo); + log.info("保存三元组信息{}", JSONUtil.toJsonStr(tripleInfo)); } } log.info("三元组提取任务执行完毕,结束"); diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index c3d2b33..cd276fb 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -277,6 +277,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl allTypeList = modelRecordTypeService.lambdaQuery().list(); // 根据recordId查询所有的分割后的笔录 List list = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list(); diff --git a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java index 003cd68..4c95cab 100644 --- a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java @@ -1,6 +1,7 @@ package com.supervision.police.service.impl; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; @@ -11,10 +12,8 @@ import com.supervision.config.BusinessException; import com.supervision.minio.domain.MinioFile; import com.supervision.minio.mapper.MinioFileMapper; import com.supervision.minio.service.MinioService; -import com.supervision.police.domain.CaseTaskRecord; -import com.supervision.police.domain.ModelRecordType; -import com.supervision.police.domain.NoteRecordSplit; -import com.supervision.police.domain.NoteRecord; +import com.supervision.neo4j.service.Neo4jService; +import com.supervision.police.domain.*; import com.supervision.police.dto.NoteRecordDTO; import com.supervision.police.dto.NoteRecordDetailDTO; import com.supervision.police.mapper.ModelCaseMapper; @@ -48,9 +47,14 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl list = tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, id).list(); + list.forEach(item -> { + // 如果已经入库,就删除已经入库的图 + if (StrUtil.equals("1", item.getAddNeo4j())) { + // 删除尾节点(不删除头节点),如果删除尾节点,关系会被自动删除 + try { + neo4jService.delNode(item.getEndNodeGraphId()); + } catch (Exception e) { + log.error("删除关系失败:{}", item.getRelation(), e); + } + + } + + }); + // TODO 是否需要把model_atomic_result的结果也删除了? } } diff --git a/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java index 6b00c5c..6d5a0c6 100644 --- a/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java @@ -2,6 +2,7 @@ package com.supervision.police.service.impl; import cn.hutool.core.collection.ConcurrentHashSet; import cn.hutool.core.util.StrUtil; +import com.supervision.police.domain.CaseTaskRecord; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.TripleInfo; @@ -38,7 +39,7 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { private final ExtractTripleInfoService extractTripleInfoService; - private final ConcurrentHashSet recordSplitIdSet = new ConcurrentHashSet(); + private final CaseTaskRecordService caseTaskRecordService; @Async @Override @@ -60,7 +61,8 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { futures.add(afterTypeSplitIdFuture); log.info("分类任务线程池提交分类成功"); } - // 如果分类完成了,那么就去提取三元组 + log.info("----------{}-----------", "分类任务全部提交成功了"); + // 校验分类任务是否完成,如果分类完成,那么就去提取三元组 AtomicInteger atomicInteger = new AtomicInteger(0); while (futures.size() > 0) { Iterator> iterator = futures.iterator(); @@ -101,7 +103,13 @@ public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { } } log.info("分类任务执行完毕"); - // 分类任务执行完成之后,就将任务进行更新 + Optional first = splitList.stream().findFirst(); + if (first.isPresent()) { + NoteRecordSplit recordSplit = first.get(); + // 分类任务执行完成之后,就将任务进行更新 + caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).eq(CaseTaskRecord::getCaseId, recordSplit.getCaseId()) + .eq(CaseTaskRecord::getRecordId, recordSplit.getNoteRecordId()).update(); + } } diff --git a/src/main/resources/application-dev.yml b/src/main/resources/application-dev.yml index 63243aa..3080e71 100644 --- a/src/main/resources/application-dev.yml +++ b/src/main/resources/application-dev.yml @@ -4,15 +4,15 @@ spring: ai: # 文档地址 https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/ollama-chat.html ollama: - base-url: http://113.128.242.110:11434 -# base-url: http://192.168.10.70:11434 +# base-url: http://113.128.242.110:11434 + base-url: http://192.168.10.70:11434 # base-url: http://124.220.94.55:8060 chat: enabled: true options: #model: qwen2:7b - #model: llama3-chinese:8b - model: qwen2:72b + model: llama3-chinese:8b +# model: qwen2:72b # 控制模型在请求后加载到内存中的时间(稍微长一点的时间,避免重复加载浪费性能,加快处理速度) keep_alive: 30m # 例如0.3