|
|
|
@ -8,6 +8,7 @@ import com.baomidou.mybatisplus.core.conditions.Wrapper;
|
|
|
|
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
|
|
|
|
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
|
|
|
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
|
|
|
|
import com.supervision.common.constant.NotePromptConstants;
|
|
|
|
|
import com.supervision.common.domain.R;
|
|
|
|
|
import com.supervision.config.BusinessException;
|
|
|
|
|
import com.supervision.neo4j.domain.CaseNode;
|
|
|
|
@ -21,14 +22,10 @@ import com.supervision.police.mapper.NoteRecordSplitMapper;
|
|
|
|
|
import com.supervision.police.service.*;
|
|
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
import org.springframework.ai.chat.ChatResponse;
|
|
|
|
|
import org.springframework.ai.chat.messages.UserMessage;
|
|
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
|
|
import org.springframework.ai.ollama.OllamaChatClient;
|
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
import org.springframework.transaction.annotation.Transactional;
|
|
|
|
|
import org.springframework.util.StopWatch;
|
|
|
|
|
|
|
|
|
|
import java.util.*;
|
|
|
|
|
import java.util.function.Function;
|
|
|
|
@ -154,95 +151,99 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
|
|
|
|
|
@Override
|
|
|
|
|
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
|
|
|
|
|
public R<?> addOrUpdPrompt(NotePrompt prompt) {
|
|
|
|
|
List<String> typeList = prompt.getTypeList();
|
|
|
|
|
if (CollUtil.isEmpty(typeList)) {
|
|
|
|
|
throw new RuntimeException("类型信息不能为空");
|
|
|
|
|
}
|
|
|
|
|
boolean save;
|
|
|
|
|
if (StringUtils.isEmpty(prompt.getId())) {
|
|
|
|
|
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
|
|
|
|
|
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), null);
|
|
|
|
|
save = notePromptService.save(prompt);
|
|
|
|
|
// 新增prompt绑定的分类信息
|
|
|
|
|
for (String typeId : typeList) {
|
|
|
|
|
NotePromptTypeRel rel = new NotePromptTypeRel();
|
|
|
|
|
rel.setPromptId(prompt.getId());
|
|
|
|
|
rel.setTypeId(typeId);
|
|
|
|
|
notePromptTypeRelService.save(rel);
|
|
|
|
|
String type = prompt.getType();
|
|
|
|
|
if (NotePromptConstants.TYPE_GRAPH_REASONING.equals(type)) {
|
|
|
|
|
List<String> typeList = prompt.getTypeList();
|
|
|
|
|
if (CollUtil.isEmpty(typeList)) {
|
|
|
|
|
throw new RuntimeException("类型信息不能为空");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), prompt.getId());
|
|
|
|
|
save = notePromptService.updateById(prompt);
|
|
|
|
|
// 更新prompt绑定的分类信息
|
|
|
|
|
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
|
|
|
|
|
List<NotePromptTypeRel> existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list();
|
|
|
|
|
|
|
|
|
|
if (CollUtil.isNotEmpty(existDatabaseRelList)) {
|
|
|
|
|
Set<String> existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet());
|
|
|
|
|
Set<String> frontRelIdList = new HashSet<>(typeList);
|
|
|
|
|
// 删除(数据库有,前端没有的)
|
|
|
|
|
List<String> deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList());
|
|
|
|
|
if (CollUtil.isNotEmpty(deleteIdList)) {
|
|
|
|
|
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
|
|
|
|
|
}
|
|
|
|
|
// 新增(前端有数据库没有的)
|
|
|
|
|
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
|
|
|
|
|
NotePromptTypeRel rel = new NotePromptTypeRel();
|
|
|
|
|
rel.setPromptId(prompt.getId());
|
|
|
|
|
rel.setTypeId(e);
|
|
|
|
|
notePromptTypeRelService.save(rel);
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
// 如果数据库里面没查到,直接新增,一般不会走这一步
|
|
|
|
|
boolean save;
|
|
|
|
|
if (StringUtils.isEmpty(prompt.getId())) {
|
|
|
|
|
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
|
|
|
|
|
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), null);
|
|
|
|
|
save = notePromptService.save(prompt);
|
|
|
|
|
// 新增prompt绑定的分类信息
|
|
|
|
|
for (String typeId : typeList) {
|
|
|
|
|
NotePromptTypeRel rel = new NotePromptTypeRel();
|
|
|
|
|
rel.setPromptId(prompt.getId());
|
|
|
|
|
rel.setTypeId(typeId);
|
|
|
|
|
notePromptTypeRelService.save(rel);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), prompt.getId());
|
|
|
|
|
save = notePromptService.updateById(prompt);
|
|
|
|
|
// 更新prompt绑定的分类信息
|
|
|
|
|
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
|
|
|
|
|
List<NotePromptTypeRel> existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list();
|
|
|
|
|
|
|
|
|
|
if (CollUtil.isNotEmpty(existDatabaseRelList)) {
|
|
|
|
|
Set<String> existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet());
|
|
|
|
|
Set<String> frontRelIdList = new HashSet<>(typeList);
|
|
|
|
|
// 删除(数据库有,前端没有的)
|
|
|
|
|
List<String> deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList());
|
|
|
|
|
if (CollUtil.isNotEmpty(deleteIdList)) {
|
|
|
|
|
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
|
|
|
|
|
}
|
|
|
|
|
// 新增(前端有数据库没有的)
|
|
|
|
|
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
|
|
|
|
|
NotePromptTypeRel rel = new NotePromptTypeRel();
|
|
|
|
|
rel.setPromptId(prompt.getId());
|
|
|
|
|
rel.setTypeId(e);
|
|
|
|
|
notePromptTypeRelService.save(rel);
|
|
|
|
|
});
|
|
|
|
|
} else {
|
|
|
|
|
// 如果数据库里面没查到,直接新增,一般不会走这一步
|
|
|
|
|
for (String typeId : typeList) {
|
|
|
|
|
NotePromptTypeRel rel = new NotePromptTypeRel();
|
|
|
|
|
rel.setPromptId(prompt.getId());
|
|
|
|
|
rel.setTypeId(typeId);
|
|
|
|
|
notePromptTypeRelService.save(rel);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 更新类型字段
|
|
|
|
|
List<TripleInfoDTO> tripleList = prompt.getTripleList();
|
|
|
|
|
for (TripleInfoDTO dto : tripleList) {
|
|
|
|
|
if ("头节点".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getStartEntityType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
} else if ("关系".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getRelType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
} else if ("尾节点".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getEndEntityType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
// 更新类型字段
|
|
|
|
|
List<TripleInfoDTO> tripleList = prompt.getTripleList();
|
|
|
|
|
for (TripleInfoDTO dto : tripleList) {
|
|
|
|
|
if ("头节点".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getStartEntityType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
} else if ("关系".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getRelType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
} else if ("尾节点".equals(dto.getType())) {
|
|
|
|
|
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
|
|
|
|
|
.set(NotePrompt::getEndEntityType, dto.getValue())
|
|
|
|
|
.eq(NotePrompt::getId, prompt.getId()).update();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取所有的类型
|
|
|
|
|
List<ModelRecordType> modelRecordTypes = list();
|
|
|
|
|
// 根据提示词id获取类型和提示词的关系表
|
|
|
|
|
List<NotePromptTypeRel> relList = notePromptTypeRelService.list(new QueryWrapper<NotePromptTypeRel>().eq("prompt_id", prompt.getId()));
|
|
|
|
|
//根据typeId集合过滤出对应的modelRecordType的name
|
|
|
|
|
List<String> typeNames = modelRecordTypes.stream().filter(e -> relList.stream().map(NotePromptTypeRel::getTypeId).toList().contains(e.getId())).map(ModelRecordType::getRecordType).toList();
|
|
|
|
|
//根据typeNames模糊匹配查询note_record_split
|
|
|
|
|
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.list().stream()
|
|
|
|
|
.filter(record -> record != null && record.getRecordType() != null && typeNames.stream().anyMatch(typeName -> Arrays.asList(record.getRecordType().split(",")).contains(typeName)))
|
|
|
|
|
.toList();
|
|
|
|
|
//过滤并去重涉及到的的note_record_id
|
|
|
|
|
Set<String> recordIds = noteRecordSplits.stream().map(NoteRecordSplit::getNoteRecordId).collect(Collectors.toSet());
|
|
|
|
|
//根据note_record_id更新note_record表的isPromptUpdate字段
|
|
|
|
|
log.info("开始更新笔录表提示词更新状态【is_prompt_update】,涉及到的笔录有:{}", recordIds);
|
|
|
|
|
boolean updated = noteRecordService.update(new UpdateWrapper<NoteRecord>().set("is_prompt_update", true).in("id", recordIds));
|
|
|
|
|
|
|
|
|
|
if (save && updated) {
|
|
|
|
|
return R.ok("保存成功");
|
|
|
|
|
// 获取所有的类型
|
|
|
|
|
List<ModelRecordType> modelRecordTypes = list();
|
|
|
|
|
// 根据提示词id获取类型和提示词的关系表
|
|
|
|
|
List<NotePromptTypeRel> relList = notePromptTypeRelService.list(new QueryWrapper<NotePromptTypeRel>().eq("prompt_id", prompt.getId()));
|
|
|
|
|
//根据typeId集合过滤出对应的modelRecordType的name
|
|
|
|
|
List<String> typeNames = modelRecordTypes.stream().filter(e -> relList.stream().map(NotePromptTypeRel::getTypeId).toList().contains(e.getId())).map(ModelRecordType::getRecordType).toList();
|
|
|
|
|
//根据typeNames模糊匹配查询note_record_split
|
|
|
|
|
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.list().stream()
|
|
|
|
|
.filter(record -> record != null && record.getRecordType() != null && typeNames.stream().anyMatch(typeName -> Arrays.asList(record.getRecordType().split(",")).contains(typeName)))
|
|
|
|
|
.toList();
|
|
|
|
|
//过滤并去重涉及到的的note_record_id
|
|
|
|
|
Set<String> recordIds = noteRecordSplits.stream().map(NoteRecordSplit::getNoteRecordId).collect(Collectors.toSet());
|
|
|
|
|
//根据note_record_id更新note_record表的isPromptUpdate字段
|
|
|
|
|
log.info("开始更新笔录表提示词更新状态【is_prompt_update】,涉及到的笔录有:{}", recordIds);
|
|
|
|
|
boolean updated = noteRecordService.update(new UpdateWrapper<NoteRecord>().set("is_prompt_update", true).in("id", recordIds));
|
|
|
|
|
if (!save || !updated) {
|
|
|
|
|
return R.fail("保存失败");
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
return R.fail("保存失败");
|
|
|
|
|
notePromptService.saveOrUpdate(prompt);
|
|
|
|
|
}
|
|
|
|
|
return R.ok("保存成功");
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private void checkHasSameTriple(String startEntityType, String relType, String endEntityType, String promptId) {
|
|
|
|
@ -252,6 +253,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
|
|
|
|
|
if (StrUtil.isBlank(promptId)) {
|
|
|
|
|
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
|
|
// 校验list查出来的是不是和promptId相等,如果不想等,也报错
|
|
|
|
|
if (!list.get(0).getId().equals(promptId)) {
|
|
|
|
|
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
|
|
|
|
@ -307,7 +309,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
|
|
|
|
|
// 更新为1执行中
|
|
|
|
|
throw new BusinessException("笔录解析任务未完成,请等待");
|
|
|
|
|
}
|
|
|
|
|
}else {
|
|
|
|
|
} else {
|
|
|
|
|
throw new BusinessException("请先进行笔录提取");
|
|
|
|
|
}
|
|
|
|
|
// 这里进行查询
|
|
|
|
|