提交笔录拆分分类的相关代码

topo_dev
liu 9 months ago
parent be3946dbfd
commit 0ad8d0b73a

@ -40,6 +40,12 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<!--将http客户端引入httpclient,以提供给spring-ai使用,因为minio引入了okhttp,如果不引入httpclient,会导致自动使用okhttp(okhttp超时时间短,会导致大模型来不及回答消息)-->
<dependency>
<groupId>org.apache.httpcomponents.client5</groupId>
<artifactId>httpclient5</artifactId>
</dependency>
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>

@ -20,6 +20,11 @@ public class CaseTaskRecord implements Serializable {
@TableId
private String id;
/**
* 1 2
*/
private Integer type;
/**
* ID
*/
@ -40,6 +45,11 @@ public class CaseTaskRecord implements Serializable {
*/
private LocalDateTime submitTime;
/**
*
*/
private LocalDateTime finishTime;
@TableField(exist = false)
private static final long serialVersionUID = 1L;

@ -27,6 +27,11 @@ public class ModelRecordType implements Serializable {
*/
private String recordType;
/**
*
*/
private String recordTypeExt;
/**
*
*/

@ -30,6 +30,10 @@ public class NotePrompt implements Serializable {
*/
private String prompt;
private String startEntityType;
private String endEntityType;
/**
* ID
*/

@ -53,12 +53,6 @@ public class NoteRecordSplit implements Serializable {
*/
private String recordType;
/**
* id
*/
@TableField(exist = false)
private String recordTypeId;
/**
* id
*/

@ -10,8 +10,8 @@ public interface NoteRecordSplitMapper extends BaseMapper<NoteRecordSplit> {
List<NoteRecordSplit> selectByRecordType(@Param("recordType") String recordType);
List<NoteRecordSplit> selectRecord(@Param("caseId") String caseId,
@Param("name") String name,
@Param("recordId") String recordId);
// List<NoteRecordSplit> selectRecord(@Param("caseId") String caseId,
// @Param("name") String name,
// @Param("recordId") String recordId);
}

@ -0,0 +1,12 @@
package com.supervision.police.service;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import java.util.List;
public interface RecordSplitTypeService {
void type(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord);
}

@ -1,10 +1,9 @@
package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.domain.*;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.TripleInfoMapper;
@ -19,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
@ -27,14 +27,19 @@ import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final NoteRecordSplitMapper noteRecordSplitMapper;
private final CaseTaskRecordService caseTaskRecordService;
private final ModelRecordTypeService modelRecordTypeService;
private final NotePromptService notePromptService;
@ -42,31 +47,69 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final OllamaChatClient chatClient;
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Async
public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId);
List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)
.eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list();
// 获取所有的分类
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
List<TripleInfo> tripleInfos = new ArrayList<>();
List<Future<TripleInfo>> futures = new ArrayList<>();
// 对切分后的笔录进行遍历
for (NoteRecordSplit recordSplit : recordSplitList) {
// 根据笔录类型找到所有的提取三元组的提示词
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, recordSplit.getRecordTypeId()).list();
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
continue;
}
try {
log.info("提交任务到线程池中进行三元组提取");
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt.getPrompt(), recordSplit.getQuestion(), recordSplit.getAnswer()));
futures.add(submit);
} catch (Exception e) {
log.error(e.getMessage(), e);
String recordType = recordSplit.getRecordType();
if (StrUtil.isBlank(recordType)) {
log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId());
}
String[] split = recordType.split(";");
for (String typeName : split) {
String typeId = allTypeMap.get(typeName);
if (StrUtil.isBlank(typeId)) {
log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName);
} else {
// 根据笔录类型找到所有的提取三元组的提示词
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list();
if (CollUtil.isEmpty(prompts)) {
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
} else {
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId());
continue;
}
try {
log.info("提交任务到线程池中进行三元组提取");
TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer());
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread);
futures.add(submit);
log.info("任务提交成功");
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
}
}
}
try {
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
// 计数器
AtomicInteger atomicInteger = new AtomicInteger(0);
while (futures.size() > 0) {
Iterator<Future<TripleInfo>> iterator = futures.iterator();
while (iterator.hasNext()) {
@ -86,15 +129,40 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
}
}
try {
log.info("检查一遍,休眠1s后继续检查");
Thread.sleep(1000);
int currentCount = atomicInteger.incrementAndGet();
if (currentCount > 1000) {
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
// 将还在执行的线程中断
futures.forEach(future -> {
future.cancel(true);
});
break;
}
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
// 首先清除
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
tripleInfoService.saveBatch(tripleInfos);
// 如果有提取到三元组信息
if (CollUtil.isNotEmpty(tripleInfos)) {
// 首先清除现在已经提取过的三元组信息
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
// TODO 这里,如果已经生成了图谱,怎么办?
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
tripleInfoService.saveBatch(tripleInfos);
}
if (CollUtil.isNotEmpty(futures)) {
// 将任务标记为成功
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
} else {
// 否则标记为失败
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
}
log.info("三元组提取任务执行完毕,结束");
}
}

@ -18,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
@ -44,7 +45,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
private final CaseTaskRecordService caseTaskRecordService;
private final ExtractTripleInfoService extractTripleInfo;
@Autowired
private ExtractTripleInfoService extractTripleInfo;
@Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
@ -114,11 +116,11 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override
public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) {
//boolean taskStatus = taskExtractStatusCheck(caseId, recordId);
boolean taskStatus = taskExtractStatusCheck(caseId, recordId);
// 如果校验结果为false,则说明需要进行提取三元组操作
//if (!taskStatus) {
// extractTripleInfo.extractTripleInfo(caseId, name, recordId);
//}
if (!taskStatus) {
extractTripleInfo.extractTripleInfo(caseId, name, recordId);
}
// 这里进行查询
return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list();
}
@ -128,9 +130,10 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
*/
private boolean taskExtractStatusCheck(String caseId, String recordId) {
// 首先查询是否存在任务,如果不存在,就新建
Optional<CaseTaskRecord> caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt();
Optional<CaseTaskRecord> caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt();
if (caseTaskRecordOpt.isEmpty()) {
CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord();
newCaseTaskRecord.setType(2);
newCaseTaskRecord.setCaseId(caseId);
newCaseTaskRecord.setRecordId(recordId);
newCaseTaskRecord.setStatus(1);

@ -12,16 +12,21 @@ import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.mapper.MinioFileMapper;
import com.supervision.minio.service.MinioService;
import com.supervision.police.domain.ModelCase;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.NoteRecord;
import com.supervision.police.mapper.ModelCaseMapper;
import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.NoteRecordMapper;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.RecordSplitTypeService;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil;
import com.supervision.thread.RecordSplitTypeThread;
import com.supervision.thread.RecordSplitTypeThreadPool;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
@ -29,7 +34,9 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;
import org.springframework.web.multipart.MultipartFile;
@ -43,8 +50,6 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor
public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMapper, NoteRecordSplit> implements NoteRecordSplitService {
private final NoteRecordSplitMapper noteRecordSplitMapper;
private final NoteRecordMapper noteRecordMapper;
private final MinioService minioService;
@ -53,25 +58,15 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
private final MinioFileMapper minioFileMapper;
private final OllamaChatClient chatClient;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
private final ModelRecordTypeMapper modelRecordTypeMapper;
@Autowired
private RecordSplitTypeService recordSplitTypeService;
private static final String TYPE_TEMPLATE = """
: : {allTypes}"
:
: :
便QQ
: { type: '' }
:
1.
2.
3.
{question} {answer}
""";
@Override
// @Transactional(rollbackFor = Exception.class)
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public String uploadRecords(NoteRecord records, List<MultipartFile> fileList) throws IOException {
//上传文件获取文件ids
List<String> fileIds = new ArrayList<>();
@ -110,8 +105,7 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
i = noteRecordMapper.updateById(records);
}
//所有对话类型
List<String> allTypes = modelRecordTypeMapper.getAllType();
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
if (i > 0) {
//拆分笔录
for (MultipartFile file : fileList) {
@ -127,44 +121,12 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
noteRecord.setQuestion(qa.getQuestion());
noteRecord.setAnswer(qa.getAnswer());
noteRecord.setCreateTime(LocalDateTime.now());
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("allTypes", CollUtil.join(allTypes, ";"));
paramMap.put("question", qa.getQuestion());
paramMap.put("answer", qa.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(TYPE_TEMPLATE, paramMap)));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer());
log.info("分析的结果是:{}", content);
JSONObject jsonObject = new JSONObject(content);
String type = jsonObject.getString("type").trim();
System.out.println("问:" + qa.getQuestion() + "答:" + qa.getAnswer());
System.out.println("分析的结果是:" + type);
/* // todo 写死测试
String type = "";
if (qa.getQuestion().contains("你为了骗取更多的钱都做了哪些准备")) {
type = "诈骗准备";
} else {
continue;
}*/
//保存笔录
noteRecord.setRecordType(type);
noteRecordSplitMapper.insert(noteRecord);
this.save(noteRecord);
// 通过异步的形式提交分类
recordSplitTypeService.type(allTypeList, qa, noteRecord);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
// ModelRecordType exist = modelRecordTypeMapper.queryByName(type);
// if (exist == null) {
// ModelRecordType modelRecordType = new ModelRecordType();
// modelRecordType.setRecordType(type);
// modelRecordTypeMapper.insert(modelRecordType);
// }
}
}
return "保存成功";

@ -0,0 +1,47 @@
package com.supervision.police.service.impl;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.CaseTaskRecordService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.RecordSplitTypeService;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.thread.RecordSplitTypeThread;
import com.supervision.thread.RecordSplitTypeThreadPool;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
@RequiredArgsConstructor
public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
private final OllamaChatClient chatClient;
private final NoteRecordSplitService noteRecordSplitService;
@Async
@Override
public void type(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord){
// 这里线程休眠1秒,因为首先报保证消息记录能够插入完成,插入完成之后,再去提交大模型,让大模型去分类.防止分类太快,分类结果出来了,插入还没有插入完成
try {
Thread.sleep(1000);
}catch (Exception e){
log.error("线程休眠失败");
}
// 首先创建一个提取任务
// 进行分类
log.info("提交线程池进行分类");
RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, qa, chatClient, noteRecordSplitService, noteRecord);
RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread);
log.info("线程池提交分类成功");
// 这里应该对分类任务的执行过程进行监控,分类结束之后,才能提取三元组的关系.问了产品,暂时先不做,等后面在考虑
}
}

@ -0,0 +1,148 @@
package com.supervision.thread;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
/**
* 线
*/
@Slf4j
public class RecordSplitTypeThread implements Callable<Boolean> {
private final List<ModelRecordType> allTypeList;
private final QARecordNodeDTO qa;
private final OllamaChatClient chatClient;
private final NoteRecordSplitService noteRecordSplitService;
private final NoteRecordSplit noteRecord;
public RecordSplitTypeThread(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, OllamaChatClient chatClient, NoteRecordSplitService noteRecordSplitService, NoteRecordSplit noteRecord) {
this.allTypeList = allTypeList;
this.qa = qa;
this.chatClient = chatClient;
this.noteRecordSplitService = noteRecordSplitService;
this.noteRecord = noteRecord;
}
private static final String TYPE_TEMPLATE = """
: : {allTypes}"
:
: :
便QQ
: { type: '' }
:
1.
2.
3.
{question} {answer}
""";
private static final String NEW_TEMPLATE = """
:
:
{typeContext}
:
---
:
XXX
: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]}
---
:
1. ,
2. ,
2. ,
3. {"result":[]}
---
:
{question}
{answer}
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
@Override
public Boolean call() throws Exception {
String type;
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", qa.getQuestion());
paramMap.put("answer", qa.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer());
log.info("分析的结果是:{}", content);
TypeResultDTO result = JSONUtil.toBean(content, TypeResultDTO.class);
List<TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecord.getId()).update();
return true;
}
@Data
public static class TypeNodeDTO {
private String type;
private String explain;
}
@Data
public static class TypeResultDTO {
private List<TypeNodeDTO> result;
}
}

@ -0,0 +1,13 @@
package com.supervision.thread;
import cn.hutool.core.thread.ThreadUtil;
import java.util.concurrent.ExecutorService;
/**
* 线
*/
public class RecordSplitTypeThreadPool {
public static final ExecutorService recordSplitTypeExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "recordSplitType", false);
}

@ -2,6 +2,7 @@ package com.supervision.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
@ -18,6 +19,7 @@ import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.Callable;
@Slf4j
@ -25,7 +27,7 @@ public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final String prompt;
private final NotePrompt prompt;
private final String question;
@ -38,7 +40,8 @@ public class TripleExtractThread implements Callable<TripleInfo> {
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) {
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
NotePrompt prompt, String question, String answer) {
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
@ -53,8 +56,10 @@ public class TripleExtractThread implements Callable<TripleInfo> {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
Prompt ask = new Prompt(new UserMessage(prompt + question + answer));
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("qaRecord", question + answer);
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
log.info("开始分析:");
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
@ -63,21 +68,27 @@ public class TripleExtractThread implements Callable<TripleInfo> {
log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content);
JSONArray threeInfo = jsonObject.getJSONArray("result");
for (int i = 0; i < threeInfo.length(); i++) {
JSONObject object = threeInfo.getJSONObject(i);
String startNodeType = object.getString("startNodeType");
String entity = object.getString("entity");
String endNodeType = object.getString("endNodeType");
String property = object.getString("property");
String value = object.getString("value");
// 去空,如果存在任何的空值,则忽略
if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
continue;
}
// 构建三元组信息
return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType);
// 修改,经测试,一次提取多个三元组效果较差,改成一次只提取一个三元组
//JSONArray threeInfo = jsonObject.getJSONArray("result");
//for (int i = 0; i < threeInfo.length(); i++) {
//JSONObject object = threeInfo.getJSONObject(i);
String entity = jsonObject.getString("主体");
String relation = jsonObject.getString("关系");
String value = jsonObject.getString("客体");
// 类型信息从notePrompt对象中获取
// String startNodeType = object.getString("startNodeType");
// String endNodeType = object.getString("endNodeType");
// 去空,如果存在任何的空值,则忽略
// if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
// continue;
// }
if (StrUtil.hasEmpty(entity, relation, value)) {
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", entity, relation, value);
return null;
}
// 构建三元组信息
return new TripleInfo(entity, relation, value, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType());
//}
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
@ -85,4 +96,6 @@ public class TripleExtractThread implements Callable<TripleInfo> {
}
}

@ -6,5 +6,5 @@ import java.util.concurrent.ExecutorService;
public class TripleExtractThreadPool {
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "tripleExtract", false);
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "tripleExtract", false);
}

@ -3,6 +3,7 @@ spring:
active: dev
main:
allow-bean-definition-overriding: true
allow-circular-references: true
mvc:
path match:
matching-strategy: ANT_PATH_MATCHER

@ -8,13 +8,13 @@
select * from note_record_split nr
where record_type = #{recordType}
</select>
<select id="selectRecord" resultType="com.supervision.police.domain.NoteRecordSplit">
select nr.*, mrt.id as recordTypeId
from note_record_split nr
left join model_record_type mrt on nr.record_type = mrt.record_type
where nr.case_id = #{caseId} and nr.person_name = #{name}
<if test="recordId != null and recordId != ''">
and nr.note_records_id = #{recordId}
</if>
</select>
<!-- <select id="selectRecord" resultType="com.supervision.police.domain.NoteRecordSplit">-->
<!-- select nr.*, mrt.id as recordTypeId-->
<!-- from note_record_split nr-->
<!-- left join model_record_type mrt on nr.record_type = mrt.record_type-->
<!-- where nr.case_id = #{caseId} and nr.person_name = #{name}-->
<!-- <if test="recordId != null and recordId != ''">-->
<!-- and nr.note_records_id = #{recordId}-->
<!-- </if>-->
<!-- </select>-->
</mapper>

Loading…
Cancel
Save