From 0ad8d0b73a25a07cb5c618f4555f1b575593eb50 Mon Sep 17 00:00:00 2001 From: liu Date: Tue, 23 Jul 2024 16:50:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E7=AC=94=E5=BD=95=E6=8B=86?= =?UTF-8?q?=E5=88=86=E5=88=86=E7=B1=BB=E7=9A=84=E7=9B=B8=E5=85=B3=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 6 + .../police/domain/CaseTaskRecord.java | 10 ++ .../police/domain/ModelRecordType.java | 5 + .../supervision/police/domain/NotePrompt.java | 4 + .../police/domain/NoteRecordSplit.java | 6 - .../police/mapper/NoteRecordSplitMapper.java | 6 +- .../service/RecordSplitTypeService.java | 12 ++ .../impl/ExtractTripleInfoServiceImpl.java | 116 +++++++++++--- .../impl/ModelRecordTypeServiceImpl.java | 15 +- .../impl/NoteRecordSplitServiceImpl.java | 70 ++------- .../impl/RecordSplitTypeServiceImpl.java | 47 ++++++ .../thread/RecordSplitTypeThread.java | 148 ++++++++++++++++++ .../thread/RecordSplitTypeThreadPool.java | 13 ++ .../thread/TripleExtractThread.java | 47 ++++-- .../thread/TripleExtractThreadPool.java | 2 +- src/main/resources/application.yml | 1 + .../mapper/NoteRecordSplitMapper.xml | 18 +-- 17 files changed, 406 insertions(+), 120 deletions(-) create mode 100644 src/main/java/com/supervision/police/service/RecordSplitTypeService.java create mode 100644 src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java create mode 100644 src/main/java/com/supervision/thread/RecordSplitTypeThread.java create mode 100644 src/main/java/com/supervision/thread/RecordSplitTypeThreadPool.java diff --git a/pom.xml b/pom.xml index 0037d8e..bd0e1c2 100644 --- a/pom.xml +++ b/pom.xml @@ -40,6 +40,12 @@ spring-ai-ollama-spring-boot-starter + + + org.apache.httpcomponents.client5 + httpclient5 + + cn.hutool hutool-all diff --git a/src/main/java/com/supervision/police/domain/CaseTaskRecord.java b/src/main/java/com/supervision/police/domain/CaseTaskRecord.java index ff1e4cd..d55439c 100644 --- a/src/main/java/com/supervision/police/domain/CaseTaskRecord.java +++ b/src/main/java/com/supervision/police/domain/CaseTaskRecord.java @@ -20,6 +20,11 @@ public class CaseTaskRecord implements Serializable { @TableId private String id; + /** + * 类型 1笔录分类 2提取三元组 + */ + private Integer type; + /** * 案件ID */ @@ -40,6 +45,11 @@ public class CaseTaskRecord implements Serializable { */ private LocalDateTime submitTime; + /** + * 完成时间 + */ + private LocalDateTime finishTime; + @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/police/domain/ModelRecordType.java b/src/main/java/com/supervision/police/domain/ModelRecordType.java index 5701c40..795b943 100644 --- a/src/main/java/com/supervision/police/domain/ModelRecordType.java +++ b/src/main/java/com/supervision/police/domain/ModelRecordType.java @@ -27,6 +27,11 @@ public class ModelRecordType implements Serializable { */ private String recordType; + /** + * 区别点 + */ + private String recordTypeExt; + /** * 提示词 */ diff --git a/src/main/java/com/supervision/police/domain/NotePrompt.java b/src/main/java/com/supervision/police/domain/NotePrompt.java index 1edda19..922a54c 100644 --- a/src/main/java/com/supervision/police/domain/NotePrompt.java +++ b/src/main/java/com/supervision/police/domain/NotePrompt.java @@ -30,6 +30,10 @@ public class NotePrompt implements Serializable { */ private String prompt; + private String startEntityType; + + private String endEntityType; + /** * 创建人ID */ diff --git a/src/main/java/com/supervision/police/domain/NoteRecordSplit.java b/src/main/java/com/supervision/police/domain/NoteRecordSplit.java index 27461f1..2818312 100644 --- a/src/main/java/com/supervision/police/domain/NoteRecordSplit.java +++ b/src/main/java/com/supervision/police/domain/NoteRecordSplit.java @@ -53,12 +53,6 @@ public class NoteRecordSplit implements Serializable { */ private String recordType; - /** - * 笔录类型id - */ - @TableField(exist = false) - private String recordTypeId; - /** * 完整笔录id */ diff --git a/src/main/java/com/supervision/police/mapper/NoteRecordSplitMapper.java b/src/main/java/com/supervision/police/mapper/NoteRecordSplitMapper.java index ccf0a22..5c4202f 100644 --- a/src/main/java/com/supervision/police/mapper/NoteRecordSplitMapper.java +++ b/src/main/java/com/supervision/police/mapper/NoteRecordSplitMapper.java @@ -10,8 +10,8 @@ public interface NoteRecordSplitMapper extends BaseMapper { List selectByRecordType(@Param("recordType") String recordType); - List selectRecord(@Param("caseId") String caseId, - @Param("name") String name, - @Param("recordId") String recordId); +// List selectRecord(@Param("caseId") String caseId, +// @Param("name") String name, +// @Param("recordId") String recordId); } diff --git a/src/main/java/com/supervision/police/service/RecordSplitTypeService.java b/src/main/java/com/supervision/police/service/RecordSplitTypeService.java new file mode 100644 index 0000000..cf5bd13 --- /dev/null +++ b/src/main/java/com/supervision/police/service/RecordSplitTypeService.java @@ -0,0 +1,12 @@ +package com.supervision.police.service; + +import com.supervision.police.domain.ModelRecordType; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.springaidemo.dto.QARecordNodeDTO; + +import java.util.List; + +public interface RecordSplitTypeService { + + void type(List allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord); +} diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index db021af..088a7eb 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -1,10 +1,9 @@ package com.supervision.police.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; -import com.supervision.police.domain.NotePrompt; -import com.supervision.police.domain.NoteRecordSplit; -import com.supervision.police.domain.TripleInfo; +import com.supervision.police.domain.*; import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.TripleInfoMapper; @@ -19,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.util.StopWatch; @@ -27,14 +27,19 @@ import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; @Slf4j @Service @RequiredArgsConstructor public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { - private final NoteRecordSplitMapper noteRecordSplitMapper; + private final CaseTaskRecordService caseTaskRecordService; + + private final ModelRecordTypeService modelRecordTypeService; private final NotePromptService notePromptService; @@ -42,31 +47,69 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { private final OllamaChatClient chatClient; + @Autowired + private NoteRecordSplitService noteRecordSplitService; + @Async public void extractTripleInfo(String caseId, String name, String recordId) { // 首先获取所有切分后的笔录 - List recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId); + List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) + .eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list(); + // 获取所有的分类 + List allTypeList = modelRecordTypeService.list(); + Map allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1)); + List tripleInfos = new ArrayList<>(); List> futures = new ArrayList<>(); // 对切分后的笔录进行遍历 for (NoteRecordSplit recordSplit : recordSplitList) { - // 根据笔录类型找到所有的提取三元组的提示词 - List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, recordSplit.getRecordTypeId()).list(); - // 遍历提示词进行提取 - for (NotePrompt prompt : prompts) { - if (StringUtils.isEmpty(prompt.getPrompt())) { - continue; - } - try { - log.info("提交任务到线程池中进行三元组提取"); - Future submit = TripleExtractThreadPool.chatExecutor.submit(new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt.getPrompt(), recordSplit.getQuestion(), recordSplit.getAnswer())); - futures.add(submit); - } catch (Exception e) { - log.error(e.getMessage(), e); + String recordType = recordSplit.getRecordType(); + if (StrUtil.isBlank(recordType)) { + log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId()); + } + String[] split = recordType.split(";"); + for (String typeName : split) { + String typeId = allTypeMap.get(typeName); + if (StrUtil.isBlank(typeId)) { + log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName); + } else { + // 根据笔录类型找到所有的提取三元组的提示词 + // 一个提示词可能关联多个类型,要进行拆分操作 + List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list(); + if (CollUtil.isEmpty(prompts)) { + log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); + } else { + // 遍历提示词进行提取 + for (NotePrompt prompt : prompts) { + if (StringUtils.isEmpty(prompt.getPrompt())) { + log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId()); + continue; + } + try { + log.info("提交任务到线程池中进行三元组提取"); + TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer()); + Future submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread); + futures.add(submit); + log.info("任务提交成功"); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } + } + } + } } + try { + log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size()); + Thread.sleep(1000 * 5); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + // 计数器 + AtomicInteger atomicInteger = new AtomicInteger(0); while (futures.size() > 0) { Iterator> iterator = futures.iterator(); while (iterator.hasNext()) { @@ -86,15 +129,40 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { } } try { - log.info("检查一遍,休眠1s后继续检查"); - Thread.sleep(1000); + int currentCount = atomicInteger.incrementAndGet(); + if (currentCount > 1000) { + log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size()); + // 将还在执行的线程中断 + futures.forEach(future -> { + future.cancel(true); + }); + break; + } + log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size()); + Thread.sleep(1000 * 5); } catch (Exception e) { log.error(e.getMessage(), e); } } - // 首先清除 - tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); - // 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 - tripleInfoService.saveBatch(tripleInfos); + // 如果有提取到三元组信息 + if (CollUtil.isNotEmpty(tripleInfos)) { + // 首先清除现在已经提取过的三元组信息 + tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); + // TODO 这里,如果已经生成了图谱,怎么办? + // 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 + tripleInfoService.saveBatch(tripleInfos); + } + if (CollUtil.isNotEmpty(futures)) { + // 将任务标记为成功 + caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now()) + .eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId) + .eq(CaseTaskRecord::getCaseId, caseId).update(); + } else { + // 否则标记为失败 + caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now()) + .eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId) + .eq(CaseTaskRecord::getCaseId, caseId).update(); + } + log.info("三元组提取任务执行完毕,结束"); } } diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index cb577a6..d9e063e 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -18,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.StopWatch; @@ -44,7 +45,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl queryType(String name, Integer page, Integer size) { @@ -114,11 +116,11 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { - //boolean taskStatus = taskExtractStatusCheck(caseId, recordId); + boolean taskStatus = taskExtractStatusCheck(caseId, recordId); // 如果校验结果为false,则说明需要进行提取三元组操作 - //if (!taskStatus) { - // extractTripleInfo.extractTripleInfo(caseId, name, recordId); - //} + if (!taskStatus) { + extractTripleInfo.extractTripleInfo(caseId, name, recordId); + } // 这里进行查询 return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list(); } @@ -128,9 +130,10 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); + Optional caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); if (caseTaskRecordOpt.isEmpty()) { CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord(); + newCaseTaskRecord.setType(2); newCaseTaskRecord.setCaseId(caseId); newCaseTaskRecord.setRecordId(recordId); newCaseTaskRecord.setStatus(1); diff --git a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java index 013aa1f..287e63e 100644 --- a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java @@ -12,16 +12,21 @@ import com.supervision.minio.domain.MinioFile; import com.supervision.minio.mapper.MinioFileMapper; import com.supervision.minio.service.MinioService; import com.supervision.police.domain.ModelCase; +import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.NoteRecord; import com.supervision.police.mapper.ModelCaseMapper; import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordMapper; +import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.NoteRecordSplitService; +import com.supervision.police.service.RecordSplitTypeService; import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.WordReadUtil; +import com.supervision.thread.RecordSplitTypeThread; +import com.supervision.thread.RecordSplitTypeThreadPool; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONObject; @@ -29,7 +34,9 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StopWatch; import org.springframework.web.multipart.MultipartFile; @@ -43,8 +50,6 @@ import java.util.stream.Collectors; @RequiredArgsConstructor public class NoteRecordSplitServiceImpl extends ServiceImpl implements NoteRecordSplitService { - private final NoteRecordSplitMapper noteRecordSplitMapper; - private final NoteRecordMapper noteRecordMapper; private final MinioService minioService; @@ -53,25 +58,15 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl fileList) throws IOException { //上传文件,获取文件ids List fileIds = new ArrayList<>(); @@ -110,8 +105,7 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl allTypes = modelRecordTypeMapper.getAllType(); - + List allTypeList = modelRecordTypeService.lambdaQuery().list(); if (i > 0) { //拆分笔录 for (MultipartFile file : fileList) { @@ -127,44 +121,12 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl paramMap = new HashMap<>(); - paramMap.put("allTypes", CollUtil.join(allTypes, ";")); - paramMap.put("question", qa.getQuestion()); - paramMap.put("answer", qa.getAnswer()); - Prompt prompt = new Prompt(new UserMessage(StrUtil.format(TYPE_TEMPLATE, paramMap))); - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - log.info("开始分析:"); - ChatResponse call = chatClient.call(prompt); - stopWatch.stop(); - log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - String content = call.getResult().getOutput().getContent(); - log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer()); - log.info("分析的结果是:{}", content); - JSONObject jsonObject = new JSONObject(content); - String type = jsonObject.getString("type").trim(); - System.out.println("问:" + qa.getQuestion() + "答:" + qa.getAnswer()); - System.out.println("分析的结果是:" + type); -/* // todo 写死测试 - String type = ""; - if (qa.getQuestion().contains("你为了骗取更多的钱都做了哪些准备")) { - type = "诈骗准备"; - } else { - continue; - }*/ - //保存笔录 - noteRecord.setRecordType(type); - noteRecordSplitMapper.insert(noteRecord); + this.save(noteRecord); + // 通过异步的形式提交分类 + recordSplitTypeService.type(allTypeList, qa, noteRecord); } catch (Exception e) { log.error(e.getMessage(), e); } -// ModelRecordType exist = modelRecordTypeMapper.queryByName(type); -// if (exist == null) { -// ModelRecordType modelRecordType = new ModelRecordType(); -// modelRecordType.setRecordType(type); -// modelRecordTypeMapper.insert(modelRecordType); -// } } } return "保存成功"; diff --git a/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java new file mode 100644 index 0000000..6ae390d --- /dev/null +++ b/src/main/java/com/supervision/police/service/impl/RecordSplitTypeServiceImpl.java @@ -0,0 +1,47 @@ +package com.supervision.police.service.impl; + +import com.supervision.police.domain.ModelRecordType; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.service.CaseTaskRecordService; +import com.supervision.police.service.NoteRecordSplitService; +import com.supervision.police.service.RecordSplitTypeService; +import com.supervision.springaidemo.dto.QARecordNodeDTO; +import com.supervision.thread.RecordSplitTypeThread; +import com.supervision.thread.RecordSplitTypeThreadPool; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@Slf4j +@RequiredArgsConstructor +public class RecordSplitTypeServiceImpl implements RecordSplitTypeService { + + private final OllamaChatClient chatClient; + + private final NoteRecordSplitService noteRecordSplitService; + + @Async + @Override + public void type(List allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord){ + // 这里线程休眠1秒,因为首先报保证消息记录能够插入完成,插入完成之后,再去提交大模型,让大模型去分类.防止分类太快,分类结果出来了,插入还没有插入完成 + try { + Thread.sleep(1000); + }catch (Exception e){ + log.error("线程休眠失败"); + } + // 首先创建一个提取任务 + + // 进行分类 + log.info("提交线程池进行分类"); + RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, qa, chatClient, noteRecordSplitService, noteRecord); + RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread); + log.info("线程池提交分类成功"); + // 这里应该对分类任务的执行过程进行监控,分类结束之后,才能提取三元组的关系.问了产品,暂时先不做,等后面在考虑 + + } +} diff --git a/src/main/java/com/supervision/thread/RecordSplitTypeThread.java b/src/main/java/com/supervision/thread/RecordSplitTypeThread.java new file mode 100644 index 0000000..9706fff --- /dev/null +++ b/src/main/java/com/supervision/thread/RecordSplitTypeThread.java @@ -0,0 +1,148 @@ +package com.supervision.thread; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.thread.ThreadUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; +import com.supervision.police.domain.ModelRecordType; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.service.NoteRecordSplitService; +import com.supervision.springaidemo.dto.QARecordNodeDTO; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONObject; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.util.StopWatch; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.stream.Collectors; + +/** + * 笔录分类线程 + */ +@Slf4j +public class RecordSplitTypeThread implements Callable { + + + private final List allTypeList; + + private final QARecordNodeDTO qa; + + private final OllamaChatClient chatClient; + + private final NoteRecordSplitService noteRecordSplitService; + + private final NoteRecordSplit noteRecord; + + + public RecordSplitTypeThread(List allTypeList, QARecordNodeDTO qa, OllamaChatClient chatClient, NoteRecordSplitService noteRecordSplitService, NoteRecordSplit noteRecord) { + this.allTypeList = allTypeList; + this.qa = qa; + this.chatClient = chatClient; + this.noteRecordSplitService = noteRecordSplitService; + this.noteRecord = noteRecord; + } + + private static final String TYPE_TEMPLATE = """ + 分类任务: 对话笔记录文本分类。目标: 将给定的对话笔记录分配到预定义的类别中,这些类别包括但不限于:{allTypes}。" + 说明: 提供一段对话笔记录文本,分类器应该识别出对话的主题,并将其归类到上述类别中的一个。 + 示例输入: 文本: + 办案警官问:你为了骗取更多的钱都做了哪些准备?。裴金禄回答:我刚开始我就是自己想了一些关于骗钱的点子,后面为了更不容易让别人识破我为了更佳逼真,我就从网上随便搜了一家租赁公司,我就搜到了兰州胜利机械租赁有限公司,我又想到了我管理的中铁北京局和中铁电气化局施工公司。我先是通过百度搜索了“办证”之后就在网页上面弹出了一个页面上面有一个QQ号,我就加上了。加上之后我就将我的要求给他说了,要求他给我刻两个假的公章,一个是兰州胜利机械租赁有限公司合同专用章,另一个是中铁北京局集团有限公司合同专用章。我还要求他给我伪造了一张兰州胜利机械租赁有限公司的营业执照 + 预期输出: { type: '诈骗准备' } + 任务要求: + 1. 分类器应当准确地识别对话的主题。 + 2. 如果一段对话笔记录包含多个主题,请选择最相关的类别。 + 3. 必须考虑上下文语境和专业术语来确定正确的分类。 + 对话内容为:{question} {answer} + """; + + private static final String NEW_TEMPLATE = """ + 分类任务: 将对话笔录文本进行分类。 + 目标: 将给定的对话分配到预定义的类别中 + 预定义类别为:{typeContext}。 + 说明: 提供一段对话笔记录文本,分类器应该识别出对话的主题,并将其归类到上述类别中的一个或多个。 + --- + 示例输入: + 办案警官问:你和上述的这些受害人签订协议之后是否实际履行合同? + 行为人XXX回答:我和他们签订合同就是为了骗他们相信我,我就是伪造的一些假合同,等我把他们的钱骗到手之后我就不会履行合同。 + 预期输出: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]} + --- + 任务要求: + 1. 分类器应当准确地识别对话的主题,分类来自于预定义的类别。 + 2. 分类器应该实事求是按照对话进行分类,不要有过多的推测。 + 2. 如果一段对话笔记录包含多个主题,请选择最相关的类别,最多可选择三个分类。 + 3. 如果不涉及任何分类则回复{"result":[]} + --- + 以下为问答对内容: + {question} + {answer} + --- + 返回格式为json,字段名要严格一致:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]} + """; + + private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}"; + + @Override + public Boolean call() throws Exception { + String type; + try { + StopWatch stopWatch = new StopWatch(); + // 首先拼接分类模板 + List typeContextList = new ArrayList<>(); + for (ModelRecordType modelRecordType : allTypeList) { + String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt())); + typeContextList.add(format); + } + // 开始对笔录进行分类 + Map paramMap = new HashMap<>(); + paramMap.put("typeContext", CollUtil.join(typeContextList, ";")); + paramMap.put("question", qa.getQuestion()); + paramMap.put("answer", qa.getAnswer()); + Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap))); + stopWatch.start(); + log.info("开始分析:"); + ChatResponse call = chatClient.call(prompt); + stopWatch.stop(); + log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); + String content = call.getResult().getOutput().getContent(); + log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer()); + log.info("分析的结果是:{}", content); + TypeResultDTO result = JSONUtil.toBean(content, TypeResultDTO.class); + List typeList = result.getResult(); + + if (CollUtil.isNotEmpty(typeList)) { + // 将type进行拼接,并以分号进行分割 + type = CollUtil.join(typeList.stream().map(TypeNodeDTO::getType).collect(Collectors.toSet()), ";"); + } else { + // 如果没有提取到,就是无 + type = "无"; + } + + } catch (Exception e) { + log.error("分类任务执行失败:{}", e.getMessage(), e); + type = "无"; + } + noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecord.getId()).update(); + return true; + } + + + @Data + public static class TypeNodeDTO { + private String type; + private String explain; + } + + @Data + public static class TypeResultDTO { + private List result; + } +} diff --git a/src/main/java/com/supervision/thread/RecordSplitTypeThreadPool.java b/src/main/java/com/supervision/thread/RecordSplitTypeThreadPool.java new file mode 100644 index 0000000..6f93433 --- /dev/null +++ b/src/main/java/com/supervision/thread/RecordSplitTypeThreadPool.java @@ -0,0 +1,13 @@ +package com.supervision.thread; + +import cn.hutool.core.thread.ThreadUtil; + +import java.util.concurrent.ExecutorService; + +/** + * 笔录分类线程池 + */ +public class RecordSplitTypeThreadPool { + + public static final ExecutorService recordSplitTypeExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "recordSplitType", false); +} diff --git a/src/main/java/com/supervision/thread/TripleExtractThread.java b/src/main/java/com/supervision/thread/TripleExtractThread.java index 7b1c038..0b97a8a 100644 --- a/src/main/java/com/supervision/thread/TripleExtractThread.java +++ b/src/main/java/com/supervision/thread/TripleExtractThread.java @@ -2,6 +2,7 @@ package com.supervision.thread; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; +import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.TripleInfo; import com.supervision.springaidemo.domain.ModelMetric; import com.supervision.springaidemo.domain.NoteCheckRecord; @@ -18,6 +19,7 @@ import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.util.StopWatch; import java.time.LocalDateTime; +import java.util.HashMap; import java.util.concurrent.Callable; @Slf4j @@ -25,7 +27,7 @@ public class TripleExtractThread implements Callable { private final OllamaChatClient chatClient; - private final String prompt; + private final NotePrompt prompt; private final String question; @@ -38,7 +40,8 @@ public class TripleExtractThread implements Callable { private final String recordId; - public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) { + public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, + NotePrompt prompt, String question, String answer) { this.question = question; this.chatClient = chatClient; this.answer = answer; @@ -53,8 +56,10 @@ public class TripleExtractThread implements Callable { try { StopWatch stopWatch = new StopWatch(); // 分析三元组 - Prompt ask = new Prompt(new UserMessage(prompt + question + answer)); stopWatch.start(); + HashMap paramMap = new HashMap<>(); + paramMap.put("qaRecord", question + answer); + Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap))); log.info("开始分析:"); ChatResponse call = chatClient.call(ask); stopWatch.stop(); @@ -63,21 +68,27 @@ public class TripleExtractThread implements Callable { log.info("分析的结果是:{}", content); // 获取从提示词中提取到的三元组信息 JSONObject jsonObject = new JSONObject(content); - JSONArray threeInfo = jsonObject.getJSONArray("result"); - for (int i = 0; i < threeInfo.length(); i++) { - JSONObject object = threeInfo.getJSONObject(i); - String startNodeType = object.getString("startNodeType"); - String entity = object.getString("entity"); - String endNodeType = object.getString("endNodeType"); - String property = object.getString("property"); - String value = object.getString("value"); - // 去空,如果存在任何的空值,则忽略 - if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { - continue; - } - // 构建三元组信息 - return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType); + // 修改,经测试,一次提取多个三元组效果较差,改成一次只提取一个三元组 + //JSONArray threeInfo = jsonObject.getJSONArray("result"); + //for (int i = 0; i < threeInfo.length(); i++) { + //JSONObject object = threeInfo.getJSONObject(i); + String entity = jsonObject.getString("主体"); + String relation = jsonObject.getString("关系"); + String value = jsonObject.getString("客体"); + // 类型信息从notePrompt对象中获取 + // String startNodeType = object.getString("startNodeType"); + // String endNodeType = object.getString("endNodeType"); + // 去空,如果存在任何的空值,则忽略 +// if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { +// continue; +// } + if (StrUtil.hasEmpty(entity, relation, value)) { + log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", entity, relation, value); + return null; } + // 构建三元组信息 + return new TripleInfo(entity, relation, value, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType()); + //} } catch (Exception e) { log.error("提取三元组出现错误", e); } @@ -85,4 +96,6 @@ public class TripleExtractThread implements Callable { } + + } diff --git a/src/main/java/com/supervision/thread/TripleExtractThreadPool.java b/src/main/java/com/supervision/thread/TripleExtractThreadPool.java index b75ae5f..53f1bef 100644 --- a/src/main/java/com/supervision/thread/TripleExtractThreadPool.java +++ b/src/main/java/com/supervision/thread/TripleExtractThreadPool.java @@ -6,5 +6,5 @@ import java.util.concurrent.ExecutorService; public class TripleExtractThreadPool { - public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "tripleExtract", false); + public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "tripleExtract", false); } diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 5ff0d4d..f2acfa1 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -3,6 +3,7 @@ spring: active: dev main: allow-bean-definition-overriding: true + allow-circular-references: true mvc: path match: matching-strategy: ANT_PATH_MATCHER diff --git a/src/main/resources/mapper/NoteRecordSplitMapper.xml b/src/main/resources/mapper/NoteRecordSplitMapper.xml index d10a4d3..b20ed2e 100644 --- a/src/main/resources/mapper/NoteRecordSplitMapper.xml +++ b/src/main/resources/mapper/NoteRecordSplitMapper.xml @@ -8,13 +8,13 @@ select * from note_record_split nr where record_type = #{recordType} - + + + + + + + + +