Merge remote-tracking branch 'origin/dev_1.0.0' into dev_1.0.0

topo_dev
xueqingkun 9 months ago
commit 164f1567b2

@ -40,6 +40,12 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId> <artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency> </dependency>
<!--将http客户端引入httpclient,以提供给spring-ai使用,因为minio引入了okhttp,如果不引入httpclient,会导致自动使用okhttp(okhttp超时时间短,会导致大模型来不及回答消息)-->
<dependency>
<groupId>org.apache.httpcomponents.client5</groupId>
<artifactId>httpclient5</artifactId>
</dependency>
<dependency> <dependency>
<groupId>cn.hutool</groupId> <groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId> <artifactId>hutool-all</artifactId>

@ -20,6 +20,11 @@ public class CaseTaskRecord implements Serializable {
@TableId @TableId
private String id; private String id;
/**
* 1 2
*/
private Integer type;
/** /**
* ID * ID
*/ */
@ -40,6 +45,11 @@ public class CaseTaskRecord implements Serializable {
*/ */
private LocalDateTime submitTime; private LocalDateTime submitTime;
/**
*
*/
private LocalDateTime finishTime;
@TableField(exist = false) @TableField(exist = false)
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;

@ -27,6 +27,11 @@ public class ModelRecordType implements Serializable {
*/ */
private String recordType; private String recordType;
/**
*
*/
private String recordTypeExt;
/** /**
* *
*/ */

@ -30,6 +30,10 @@ public class NotePrompt implements Serializable {
*/ */
private String prompt; private String prompt;
private String startEntityType;
private String endEntityType;
/** /**
* ID * ID
*/ */

@ -53,12 +53,6 @@ public class NoteRecordSplit implements Serializable {
*/ */
private String recordType; private String recordType;
/**
* id
*/
@TableField(exist = false)
private String recordTypeId;
/** /**
* id * id
*/ */

@ -10,8 +10,8 @@ public interface NoteRecordSplitMapper extends BaseMapper<NoteRecordSplit> {
List<NoteRecordSplit> selectByRecordType(@Param("recordType") String recordType); List<NoteRecordSplit> selectByRecordType(@Param("recordType") String recordType);
List<NoteRecordSplit> selectRecord(@Param("caseId") String caseId, // List<NoteRecordSplit> selectRecord(@Param("caseId") String caseId,
@Param("name") String name, // @Param("name") String name,
@Param("recordId") String recordId); // @Param("recordId") String recordId);
} }

@ -0,0 +1,12 @@
package com.supervision.police.service;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import java.util.List;
public interface RecordSplitTypeService {
void type(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord);
}

@ -1,10 +1,9 @@
package com.supervision.police.service.impl; package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils; import com.alibaba.druid.util.StringUtils;
import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.*;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.TripleInfoMapper; import com.supervision.police.mapper.TripleInfoMapper;
@ -19,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
@ -27,14 +27,19 @@ import java.time.LocalDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final NoteRecordSplitMapper noteRecordSplitMapper; private final CaseTaskRecordService caseTaskRecordService;
private final ModelRecordTypeService modelRecordTypeService;
private final NotePromptService notePromptService; private final NotePromptService notePromptService;
@ -42,31 +47,69 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Async @Async
public void extractTripleInfo(String caseId, String name, String recordId) { public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录 // 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId); List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)
.eq(NoteRecordSplit::getCaseId, caseId).eq(NoteRecordSplit::getPersonName, name).list();
// 获取所有的分类
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
List<TripleInfo> tripleInfos = new ArrayList<>(); List<TripleInfo> tripleInfos = new ArrayList<>();
List<Future<TripleInfo>> futures = new ArrayList<>(); List<Future<TripleInfo>> futures = new ArrayList<>();
// 对切分后的笔录进行遍历 // 对切分后的笔录进行遍历
for (NoteRecordSplit recordSplit : recordSplitList) { for (NoteRecordSplit recordSplit : recordSplitList) {
// 根据笔录类型找到所有的提取三元组的提示词 String recordType = recordSplit.getRecordType();
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, recordSplit.getRecordTypeId()).list(); if (StrUtil.isBlank(recordType)) {
// 遍历提示词进行提取 log.info("{} 切分笔录不属于任何类型,跳过", recordSplit.getId());
for (NotePrompt prompt : prompts) { }
if (StringUtils.isEmpty(prompt.getPrompt())) { String[] split = recordType.split(";");
continue; for (String typeName : split) {
} String typeId = allTypeMap.get(typeName);
try { if (StrUtil.isBlank(typeId)) {
log.info("提交任务到线程池中进行三元组提取"); log.info("{} 切分笔录类型:{}未找到,跳过", recordSplit.getId(), typeName);
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt.getPrompt(), recordSplit.getQuestion(), recordSplit.getAnswer())); } else {
futures.add(submit); // 根据笔录类型找到所有的提取三元组的提示词
} catch (Exception e) { // 一个提示词可能关联多个类型,要进行拆分操作
log.error(e.getMessage(), e); List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list();
if (CollUtil.isEmpty(prompts)) {
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
} else {
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
log.info("{} 切分笔录类型:{}对应的提示词:{} 提示词模板为空,跳过", recordSplit.getId(), typeName, prompt.getId());
continue;
}
try {
log.info("提交任务到线程池中进行三元组提取");
TripleExtractThread tripleExtractThread = new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt, recordSplit.getQuestion(), recordSplit.getAnswer());
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(tripleExtractThread);
futures.add(submit);
log.info("任务提交成功");
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
} }
} }
} }
try {
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
// 计数器
AtomicInteger atomicInteger = new AtomicInteger(0);
while (futures.size() > 0) { while (futures.size() > 0) {
Iterator<Future<TripleInfo>> iterator = futures.iterator(); Iterator<Future<TripleInfo>> iterator = futures.iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
@ -86,15 +129,40 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
} }
} }
try { try {
log.info("检查一遍,休眠1s后继续检查"); int currentCount = atomicInteger.incrementAndGet();
Thread.sleep(1000); if (currentCount > 1000) {
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
// 将还在执行的线程中断
futures.forEach(future -> {
future.cancel(true);
});
break;
}
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) { } catch (Exception e) {
log.error(e.getMessage(), e); log.error(e.getMessage(), e);
} }
} }
// 首先清除 // 如果有提取到三元组信息
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); if (CollUtil.isNotEmpty(tripleInfos)) {
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 // 首先清除现在已经提取过的三元组信息
tripleInfoService.saveBatch(tripleInfos); tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
// TODO 这里,如果已经生成了图谱,怎么办?
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
tripleInfoService.saveBatch(tripleInfos);
}
if (CollUtil.isNotEmpty(futures)) {
// 将任务标记为成功
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 2).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
} else {
// 否则标记为失败
caseTaskRecordService.lambdaUpdate().set(CaseTaskRecord::getStatus, 3).set(CaseTaskRecord::getFinishTime, LocalDateTime.now())
.eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getRecordId, recordId)
.eq(CaseTaskRecord::getCaseId, caseId).update();
}
log.info("三元组提取任务执行完毕,结束");
} }
} }

@ -18,6 +18,7 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
@ -44,7 +45,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
private final CaseTaskRecordService caseTaskRecordService; private final CaseTaskRecordService caseTaskRecordService;
private final ExtractTripleInfoService extractTripleInfo; @Autowired
private ExtractTripleInfoService extractTripleInfo;
@Override @Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) { public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
@ -114,11 +116,11 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) { public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) {
//boolean taskStatus = taskExtractStatusCheck(caseId, recordId); boolean taskStatus = taskExtractStatusCheck(caseId, recordId);
// 如果校验结果为false,则说明需要进行提取三元组操作 // 如果校验结果为false,则说明需要进行提取三元组操作
//if (!taskStatus) { if (!taskStatus) {
// extractTripleInfo.extractTripleInfo(caseId, name, recordId); extractTripleInfo.extractTripleInfo(caseId, name, recordId);
//} }
// 这里进行查询 // 这里进行查询
return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list(); return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list();
} }
@ -128,9 +130,10 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
*/ */
private boolean taskExtractStatusCheck(String caseId, String recordId) { private boolean taskExtractStatusCheck(String caseId, String recordId) {
// 首先查询是否存在任务,如果不存在,就新建 // 首先查询是否存在任务,如果不存在,就新建
Optional<CaseTaskRecord> caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); Optional<CaseTaskRecord> caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt();
if (caseTaskRecordOpt.isEmpty()) { if (caseTaskRecordOpt.isEmpty()) {
CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord(); CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord();
newCaseTaskRecord.setType(2);
newCaseTaskRecord.setCaseId(caseId); newCaseTaskRecord.setCaseId(caseId);
newCaseTaskRecord.setRecordId(recordId); newCaseTaskRecord.setRecordId(recordId);
newCaseTaskRecord.setStatus(1); newCaseTaskRecord.setStatus(1);

@ -12,16 +12,21 @@ import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.mapper.MinioFileMapper; import com.supervision.minio.mapper.MinioFileMapper;
import com.supervision.minio.service.MinioService; import com.supervision.minio.service.MinioService;
import com.supervision.police.domain.ModelCase; import com.supervision.police.domain.ModelCase;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.NoteRecord; import com.supervision.police.domain.NoteRecord;
import com.supervision.police.mapper.ModelCaseMapper; import com.supervision.police.mapper.ModelCaseMapper;
import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.NoteRecordMapper; import com.supervision.police.mapper.NoteRecordMapper;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService; import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.RecordSplitTypeService;
import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil; import com.supervision.springaidemo.util.WordReadUtil;
import com.supervision.thread.RecordSplitTypeThread;
import com.supervision.thread.RecordSplitTypeThreadPool;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject; import org.json.JSONObject;
@ -29,7 +34,9 @@ import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
@ -43,8 +50,6 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor @RequiredArgsConstructor
public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMapper, NoteRecordSplit> implements NoteRecordSplitService { public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMapper, NoteRecordSplit> implements NoteRecordSplitService {
private final NoteRecordSplitMapper noteRecordSplitMapper;
private final NoteRecordMapper noteRecordMapper; private final NoteRecordMapper noteRecordMapper;
private final MinioService minioService; private final MinioService minioService;
@ -53,25 +58,15 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
private final MinioFileMapper minioFileMapper; private final MinioFileMapper minioFileMapper;
private final OllamaChatClient chatClient; @Autowired
private ModelRecordTypeService modelRecordTypeService;
private final ModelRecordTypeMapper modelRecordTypeMapper; @Autowired
private RecordSplitTypeService recordSplitTypeService;
private static final String TYPE_TEMPLATE = """
: : {allTypes}"
:
: :
便QQ
: { type: '' }
:
1.
2.
3.
{question} {answer}
""";
@Override @Override
// @Transactional(rollbackFor = Exception.class) @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public String uploadRecords(NoteRecord records, List<MultipartFile> fileList) throws IOException { public String uploadRecords(NoteRecord records, List<MultipartFile> fileList) throws IOException {
//上传文件获取文件ids //上传文件获取文件ids
List<String> fileIds = new ArrayList<>(); List<String> fileIds = new ArrayList<>();
@ -110,8 +105,7 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
i = noteRecordMapper.updateById(records); i = noteRecordMapper.updateById(records);
} }
//所有对话类型 //所有对话类型
List<String> allTypes = modelRecordTypeMapper.getAllType(); List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
if (i > 0) { if (i > 0) {
//拆分笔录 //拆分笔录
for (MultipartFile file : fileList) { for (MultipartFile file : fileList) {
@ -127,44 +121,12 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
noteRecord.setQuestion(qa.getQuestion()); noteRecord.setQuestion(qa.getQuestion());
noteRecord.setAnswer(qa.getAnswer()); noteRecord.setAnswer(qa.getAnswer());
noteRecord.setCreateTime(LocalDateTime.now()); noteRecord.setCreateTime(LocalDateTime.now());
// 开始对笔录进行分类 this.save(noteRecord);
Map<String, String> paramMap = new HashMap<>(); // 通过异步的形式提交分类
paramMap.put("allTypes", CollUtil.join(allTypes, ";")); recordSplitTypeService.type(allTypeList, qa, noteRecord);
paramMap.put("question", qa.getQuestion());
paramMap.put("answer", qa.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(TYPE_TEMPLATE, paramMap)));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer());
log.info("分析的结果是:{}", content);
JSONObject jsonObject = new JSONObject(content);
String type = jsonObject.getString("type").trim();
System.out.println("问:" + qa.getQuestion() + "答:" + qa.getAnswer());
System.out.println("分析的结果是:" + type);
/* // todo 写死测试
String type = "";
if (qa.getQuestion().contains("你为了骗取更多的钱都做了哪些准备")) {
type = "诈骗准备";
} else {
continue;
}*/
//保存笔录
noteRecord.setRecordType(type);
noteRecordSplitMapper.insert(noteRecord);
} catch (Exception e) { } catch (Exception e) {
log.error(e.getMessage(), e); log.error(e.getMessage(), e);
} }
// ModelRecordType exist = modelRecordTypeMapper.queryByName(type);
// if (exist == null) {
// ModelRecordType modelRecordType = new ModelRecordType();
// modelRecordType.setRecordType(type);
// modelRecordTypeMapper.insert(modelRecordType);
// }
} }
} }
return "保存成功"; return "保存成功";

@ -0,0 +1,47 @@
package com.supervision.police.service.impl;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.CaseTaskRecordService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.RecordSplitTypeService;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.thread.RecordSplitTypeThread;
import com.supervision.thread.RecordSplitTypeThreadPool;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
@RequiredArgsConstructor
public class RecordSplitTypeServiceImpl implements RecordSplitTypeService {
private final OllamaChatClient chatClient;
private final NoteRecordSplitService noteRecordSplitService;
@Async
@Override
public void type(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, NoteRecordSplit noteRecord){
// 这里线程休眠1秒,因为首先报保证消息记录能够插入完成,插入完成之后,再去提交大模型,让大模型去分类.防止分类太快,分类结果出来了,插入还没有插入完成
try {
Thread.sleep(1000);
}catch (Exception e){
log.error("线程休眠失败");
}
// 首先创建一个提取任务
// 进行分类
log.info("提交线程池进行分类");
RecordSplitTypeThread recordSplitTypeThread = new RecordSplitTypeThread(allTypeList, qa, chatClient, noteRecordSplitService, noteRecord);
RecordSplitTypeThreadPool.recordSplitTypeExecutor.submit(recordSplitTypeThread);
log.info("线程池提交分类成功");
// 这里应该对分类任务的执行过程进行监控,分类结束之后,才能提取三元组的关系.问了产品,暂时先不做,等后面在考虑
}
}

@ -0,0 +1,148 @@
package com.supervision.thread;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
/**
* 线
*/
@Slf4j
public class RecordSplitTypeThread implements Callable<Boolean> {
private final List<ModelRecordType> allTypeList;
private final QARecordNodeDTO qa;
private final OllamaChatClient chatClient;
private final NoteRecordSplitService noteRecordSplitService;
private final NoteRecordSplit noteRecord;
public RecordSplitTypeThread(List<ModelRecordType> allTypeList, QARecordNodeDTO qa, OllamaChatClient chatClient, NoteRecordSplitService noteRecordSplitService, NoteRecordSplit noteRecord) {
this.allTypeList = allTypeList;
this.qa = qa;
this.chatClient = chatClient;
this.noteRecordSplitService = noteRecordSplitService;
this.noteRecord = noteRecord;
}
private static final String TYPE_TEMPLATE = """
: : {allTypes}"
:
: :
便QQ
: { type: '' }
:
1.
2.
3.
{question} {answer}
""";
private static final String NEW_TEMPLATE = """
:
:
{typeContext}
:
---
:
XXX
: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]}
---
:
1. ,
2. ,
2. ,
3. {"result":[]}
---
:
{question}
{answer}
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
@Override
public Boolean call() throws Exception {
String type;
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", qa.getQuestion());
paramMap.put("answer", qa.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", qa.getQuestion(), qa.getAnswer());
log.info("分析的结果是:{}", content);
TypeResultDTO result = JSONUtil.toBean(content, TypeResultDTO.class);
List<TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
noteRecordSplitService.lambdaUpdate().set(NoteRecordSplit::getRecordType, type).eq(NoteRecordSplit::getId, noteRecord.getId()).update();
return true;
}
@Data
public static class TypeNodeDTO {
private String type;
private String explain;
}
@Data
public static class TypeResultDTO {
private List<TypeNodeDTO> result;
}
}

@ -0,0 +1,13 @@
package com.supervision.thread;
import cn.hutool.core.thread.ThreadUtil;
import java.util.concurrent.ExecutorService;
/**
* 线
*/
public class RecordSplitTypeThreadPool {
public static final ExecutorService recordSplitTypeExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "recordSplitType", false);
}

@ -2,6 +2,7 @@ package com.supervision.thread;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo; import com.supervision.police.domain.TripleInfo;
import com.supervision.springaidemo.domain.ModelMetric; import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord; import com.supervision.springaidemo.domain.NoteCheckRecord;
@ -18,6 +19,7 @@ import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@Slf4j @Slf4j
@ -25,7 +27,7 @@ public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
private final String prompt; private final NotePrompt prompt;
private final String question; private final String question;
@ -38,7 +40,8 @@ public class TripleExtractThread implements Callable<TripleInfo> {
private final String recordId; private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) { public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
NotePrompt prompt, String question, String answer) {
this.question = question; this.question = question;
this.chatClient = chatClient; this.chatClient = chatClient;
this.answer = answer; this.answer = answer;
@ -53,8 +56,10 @@ public class TripleExtractThread implements Callable<TripleInfo> {
try { try {
StopWatch stopWatch = new StopWatch(); StopWatch stopWatch = new StopWatch();
// 分析三元组 // 分析三元组
Prompt ask = new Prompt(new UserMessage(prompt + question + answer));
stopWatch.start(); stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("qaRecord", question + answer);
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
log.info("开始分析:"); log.info("开始分析:");
ChatResponse call = chatClient.call(ask); ChatResponse call = chatClient.call(ask);
stopWatch.stop(); stopWatch.stop();
@ -63,21 +68,27 @@ public class TripleExtractThread implements Callable<TripleInfo> {
log.info("分析的结果是:{}", content); log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息 // 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content); JSONObject jsonObject = new JSONObject(content);
JSONArray threeInfo = jsonObject.getJSONArray("result"); // 修改,经测试,一次提取多个三元组效果较差,改成一次只提取一个三元组
for (int i = 0; i < threeInfo.length(); i++) { //JSONArray threeInfo = jsonObject.getJSONArray("result");
JSONObject object = threeInfo.getJSONObject(i); //for (int i = 0; i < threeInfo.length(); i++) {
String startNodeType = object.getString("startNodeType"); //JSONObject object = threeInfo.getJSONObject(i);
String entity = object.getString("entity"); String entity = jsonObject.getString("主体");
String endNodeType = object.getString("endNodeType"); String relation = jsonObject.getString("关系");
String property = object.getString("property"); String value = jsonObject.getString("客体");
String value = object.getString("value"); // 类型信息从notePrompt对象中获取
// 去空,如果存在任何的空值,则忽略 // String startNodeType = object.getString("startNodeType");
if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { // String endNodeType = object.getString("endNodeType");
continue; // 去空,如果存在任何的空值,则忽略
} // if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
// 构建三元组信息 // continue;
return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType); // }
if (StrUtil.hasEmpty(entity, relation, value)) {
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", entity, relation, value);
return null;
} }
// 构建三元组信息
return new TripleInfo(entity, relation, value, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType());
//}
} catch (Exception e) { } catch (Exception e) {
log.error("提取三元组出现错误", e); log.error("提取三元组出现错误", e);
} }
@ -85,4 +96,6 @@ public class TripleExtractThread implements Callable<TripleInfo> {
} }
} }

@ -6,5 +6,5 @@ import java.util.concurrent.ExecutorService;
public class TripleExtractThreadPool { public class TripleExtractThreadPool {
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "tripleExtract", false); public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "tripleExtract", false);
} }

@ -3,6 +3,7 @@ spring:
active: dev active: dev
main: main:
allow-bean-definition-overriding: true allow-bean-definition-overriding: true
allow-circular-references: true
mvc: mvc:
path match: path match:
matching-strategy: ANT_PATH_MATCHER matching-strategy: ANT_PATH_MATCHER

@ -8,13 +8,13 @@
select * from note_record_split nr select * from note_record_split nr
where record_type = #{recordType} where record_type = #{recordType}
</select> </select>
<select id="selectRecord" resultType="com.supervision.police.domain.NoteRecordSplit"> <!-- <select id="selectRecord" resultType="com.supervision.police.domain.NoteRecordSplit">-->
select nr.*, mrt.id as recordTypeId <!-- select nr.*, mrt.id as recordTypeId-->
from note_record_split nr <!-- from note_record_split nr-->
left join model_record_type mrt on nr.record_type = mrt.record_type <!-- left join model_record_type mrt on nr.record_type = mrt.record_type-->
where nr.case_id = #{caseId} and nr.person_name = #{name} <!-- where nr.case_id = #{caseId} and nr.person_name = #{name}-->
<if test="recordId != null and recordId != ''"> <!-- <if test="recordId != null and recordId != ''">-->
and nr.note_records_id = #{recordId} <!-- and nr.note_records_id = #{recordId}-->
</if> <!-- </if>-->
</select> <!-- </select>-->
</mapper> </mapper>

Loading…
Cancel
Save