提示词新增、修改、分页列表改造

topo_dev
DESKTOP-DDTUS3E\yaxin 7 months ago
parent 696affff79
commit 2061ac049d

@ -0,0 +1,25 @@
package com.supervision.common.constant;
public class NotePromptConstants {
/**
* -
*/
public static final String CASE_TYPE_ENGINEERING_CONTRACT_FRAUD = "1";
/**
* -
*/
public static final String TYPE_STRUCTURAL_REASONING = "1";
/**
* -
*/
public static final String TYPE_GRAPH_REASONING = "2";
/**
* -
*/
public static final String TYPE_CLASSIFICATION = "3";
}

@ -1,7 +1,9 @@
package com.supervision.police.controller;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.supervision.common.domain.R;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.dto.NotePromptDTO;
import com.supervision.police.service.NotePromptService;
import io.swagger.annotations.ApiOperation;
import io.swagger.v3.oas.annotations.Operation;
@ -9,8 +11,6 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@RestController
@Slf4j
@RequestMapping("/prompt")
@ -24,18 +24,16 @@ public class PromptController {
/**
*
* @param caseType 1
* @param type 34
*
* @param notePrompt
* @return
*/
@GetMapping("/list")
@PostMapping("/list")
@Operation(summary = "查询提示词列表")
public R<List<NotePrompt>> listPrompt(@RequestParam(name = "caseType",defaultValue = "1") String caseType,@RequestParam("type") String type) {
NotePrompt notePrompt = new NotePrompt();
notePrompt.setId("04ac794c-6457-11ef-a77c-0242ac11000d");
notePrompt.setCaseType(caseType);
notePrompt.setType(type);
List<NotePrompt> notePrompts = promptService.listPrompt(notePrompt);
public R<IPage<NotePromptDTO>> listPrompt(@RequestBody NotePrompt notePrompt,
@RequestParam(required = false, defaultValue = "1") Integer page,
@RequestParam(required = false, defaultValue = "20") Integer size) {
IPage<NotePromptDTO> notePrompts = promptService.listPrompt(page, size, notePrompt);
return R.ok(notePrompts);
}
}

@ -0,0 +1,11 @@
package com.supervision.police.dto;
import com.supervision.police.domain.NotePrompt;
import lombok.Data;
import lombok.EqualsAndHashCode;
@EqualsAndHashCode(callSuper = true)
@Data
public class NotePromptDTO extends NotePrompt {
private Integer matchNum;
}

@ -1,12 +1,13 @@
package com.supervision.police.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.dto.NotePromptDTO;
import org.apache.ibatis.annotations.Param;
import java.util.List;
public interface NotePromptMapper extends BaseMapper<NotePrompt> {
Page<NotePromptDTO> selectNotePromptWithMatchNum(Page<NotePromptDTO> page, @Param("notePrompt") NotePrompt notePrompt);
}

@ -1,7 +1,9 @@
package com.supervision.police.service;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.dto.NotePromptDTO;
import java.util.List;
@ -10,5 +12,5 @@ public interface NotePromptService extends IService<NotePrompt> {
List<NotePrompt> listPromptBySplitId(String recordSplitId);
List<NotePrompt> listPrompt(NotePrompt notePrompt);
IPage<NotePromptDTO> listPrompt(int page, int size, NotePrompt notePrompt);
}

@ -8,6 +8,7 @@ import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.common.constant.NotePromptConstants;
import com.supervision.common.domain.R;
import com.supervision.config.BusinessException;
import com.supervision.neo4j.domain.CaseNode;
@ -21,14 +22,10 @@ import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;
import java.util.*;
import java.util.function.Function;
@ -154,95 +151,99 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public R<?> addOrUpdPrompt(NotePrompt prompt) {
List<String> typeList = prompt.getTypeList();
if (CollUtil.isEmpty(typeList)) {
throw new RuntimeException("类型信息不能为空");
}
boolean save;
if (StringUtils.isEmpty(prompt.getId())) {
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), null);
save = notePromptService.save(prompt);
// 新增prompt绑定的分类信息
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(typeId);
notePromptTypeRelService.save(rel);
String type = prompt.getType();
if (NotePromptConstants.TYPE_GRAPH_REASONING.equals(type)) {
List<String> typeList = prompt.getTypeList();
if (CollUtil.isEmpty(typeList)) {
throw new RuntimeException("类型信息不能为空");
}
} else {
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), prompt.getId());
save = notePromptService.updateById(prompt);
// 更新prompt绑定的分类信息
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
List<NotePromptTypeRel> existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list();
if (CollUtil.isNotEmpty(existDatabaseRelList)) {
Set<String> existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet());
Set<String> frontRelIdList = new HashSet<>(typeList);
// 删除(数据库有,前端没有的)
List<String> deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList());
if (CollUtil.isNotEmpty(deleteIdList)) {
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
}
// 新增(前端有数据库没有的)
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(e);
notePromptTypeRelService.save(rel);
});
} else {
// 如果数据库里面没查到,直接新增,一般不会走这一步
boolean save;
if (StringUtils.isEmpty(prompt.getId())) {
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), null);
save = notePromptService.save(prompt);
// 新增prompt绑定的分类信息
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(typeId);
notePromptTypeRelService.save(rel);
}
} else {
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), prompt.getId());
save = notePromptService.updateById(prompt);
// 更新prompt绑定的分类信息
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
List<NotePromptTypeRel> existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list();
if (CollUtil.isNotEmpty(existDatabaseRelList)) {
Set<String> existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet());
Set<String> frontRelIdList = new HashSet<>(typeList);
// 删除(数据库有,前端没有的)
List<String> deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList());
if (CollUtil.isNotEmpty(deleteIdList)) {
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
}
// 新增(前端有数据库没有的)
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(e);
notePromptTypeRelService.save(rel);
});
} else {
// 如果数据库里面没查到,直接新增,一般不会走这一步
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(typeId);
notePromptTypeRelService.save(rel);
}
}
}
}
// 更新类型字段
List<TripleInfoDTO> tripleList = prompt.getTripleList();
for (TripleInfoDTO dto : tripleList) {
if ("头节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getStartEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
} else if ("关系".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
.set(NotePrompt::getRelType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
} else if ("尾节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getEndEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
// 更新类型字段
List<TripleInfoDTO> tripleList = prompt.getTripleList();
for (TripleInfoDTO dto : tripleList) {
if ("头节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getStartEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
} else if ("关系".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
.set(NotePrompt::getRelType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
} else if ("尾节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getEndEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}
}
}
// 获取所有的类型
List<ModelRecordType> modelRecordTypes = list();
// 根据提示词id获取类型和提示词的关系表
List<NotePromptTypeRel> relList = notePromptTypeRelService.list(new QueryWrapper<NotePromptTypeRel>().eq("prompt_id", prompt.getId()));
//根据typeId集合过滤出对应的modelRecordType的name
List<String> typeNames = modelRecordTypes.stream().filter(e -> relList.stream().map(NotePromptTypeRel::getTypeId).toList().contains(e.getId())).map(ModelRecordType::getRecordType).toList();
//根据typeNames模糊匹配查询note_record_split
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.list().stream()
.filter(record -> record != null && record.getRecordType() != null && typeNames.stream().anyMatch(typeName -> Arrays.asList(record.getRecordType().split(",")).contains(typeName)))
.toList();
//过滤并去重涉及到的的note_record_id
Set<String> recordIds = noteRecordSplits.stream().map(NoteRecordSplit::getNoteRecordId).collect(Collectors.toSet());
//根据note_record_id更新note_record表的isPromptUpdate字段
log.info("开始更新笔录表提示词更新状态【is_prompt_update】涉及到的笔录有{}", recordIds);
boolean updated = noteRecordService.update(new UpdateWrapper<NoteRecord>().set("is_prompt_update", true).in("id", recordIds));
if (save && updated) {
return R.ok("保存成功");
// 获取所有的类型
List<ModelRecordType> modelRecordTypes = list();
// 根据提示词id获取类型和提示词的关系表
List<NotePromptTypeRel> relList = notePromptTypeRelService.list(new QueryWrapper<NotePromptTypeRel>().eq("prompt_id", prompt.getId()));
//根据typeId集合过滤出对应的modelRecordType的name
List<String> typeNames = modelRecordTypes.stream().filter(e -> relList.stream().map(NotePromptTypeRel::getTypeId).toList().contains(e.getId())).map(ModelRecordType::getRecordType).toList();
//根据typeNames模糊匹配查询note_record_split
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.list().stream()
.filter(record -> record != null && record.getRecordType() != null && typeNames.stream().anyMatch(typeName -> Arrays.asList(record.getRecordType().split(",")).contains(typeName)))
.toList();
//过滤并去重涉及到的的note_record_id
Set<String> recordIds = noteRecordSplits.stream().map(NoteRecordSplit::getNoteRecordId).collect(Collectors.toSet());
//根据note_record_id更新note_record表的isPromptUpdate字段
log.info("开始更新笔录表提示词更新状态【is_prompt_update】涉及到的笔录有{}", recordIds);
boolean updated = noteRecordService.update(new UpdateWrapper<NoteRecord>().set("is_prompt_update", true).in("id", recordIds));
if (!save || !updated) {
return R.fail("保存失败");
}
} else {
return R.fail("保存失败");
notePromptService.saveOrUpdate(prompt);
}
return R.ok("保存成功");
}
private void checkHasSameTriple(String startEntityType, String relType, String endEntityType, String promptId) {
@ -252,6 +253,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
if (StrUtil.isBlank(promptId)) {
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
} else {
// 校验list查出来的是不是和promptId相等,如果不想等,也报错
if (!list.get(0).getId().equals(promptId)) {
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
@ -307,7 +309,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
// 更新为1执行中
throw new BusinessException("笔录解析任务未完成,请等待");
}
}else {
} else {
throw new BusinessException("请先进行笔录提取");
}
// 这里进行查询

@ -2,11 +2,14 @@ package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NotePromptTypeRel;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.dto.NotePromptDTO;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NotePromptService;
@ -32,11 +35,14 @@ public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePro
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Autowired
private NotePromptMapper notePromptMapper;
private final NotePromptTypeRelService notePromptTypeRelService;
@Override
public List<NotePrompt> listPromptBySplitId(String recordSplitId){
public List<NotePrompt> listPromptBySplitId(String recordSplitId) {
List<NotePrompt> notePromptList = new ArrayList<>();
// 首先获取所有切分后的笔录
@ -62,7 +68,7 @@ public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePro
for (String typeName : split) {
String typeId = allTypeMap.get(typeName);
if (StrUtil.isBlank(typeId)) {
log.info("listPromptBySplitId:笔录片段id:{} typeName:{}未在全局分类中找到数据...", recordSplit.getId(),typeName);
log.info("listPromptBySplitId:笔录片段id:{} typeName:{}未在全局分类中找到数据...", recordSplit.getId(), typeName);
continue;
}
@ -70,14 +76,14 @@ public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePro
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePromptTypeRel> promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list();
if (CollUtil.isEmpty(promptTypeRelList)) {
log.info("listPromptBySplitId:笔录片段:{}根据typeId:{},typeName:{},未找到对应的提示词信息...", recordSplit.getId(), typeId,typeName);
log.info("listPromptBySplitId:笔录片段:{}根据typeId:{},typeName:{},未找到对应的提示词信息...", recordSplit.getId(), typeId, typeName);
continue;
}
List<String> promptIdList = promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).toList();
List<NotePrompt> list = super.lambdaQuery().in(NotePrompt::getId, promptIdList).list();
if (CollUtil.isEmpty(list)){
log.info("listPromptBySplitId:根据 promptIdList:{},未找到对应的提示词信息...",CollUtil.join(promptIdList,","));
if (CollUtil.isEmpty(list)) {
log.info("listPromptBySplitId:根据 promptIdList:{},未找到对应的提示词信息...", CollUtil.join(promptIdList, ","));
continue;
}
notePromptList.addAll(list);
@ -87,8 +93,7 @@ public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePro
}
@Override
public List<NotePrompt> listPrompt(NotePrompt notePrompt) {
return super.lambdaQuery(notePrompt).list();
public IPage<NotePromptDTO> listPrompt(int page, int size, NotePrompt notePrompt) {
return notePromptMapper.selectNotePromptWithMatchNum(new Page<>(page, size), notePrompt);
}
}

@ -3,4 +3,20 @@
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.police.mapper.NotePromptMapper">
<select id="selectNotePromptWithMatchNum" resultType="com.supervision.police.dto.NotePromptDTO">
SELECT
np.*,
(SELECT COUNT(*) FROM model_atomic_index mai WHERE mai.prompt_id = np.id) AS match_num
FROM
note_prompt np
WHERE 1=1
<if test="notePrompt.name != null and notePrompt.name != ''">
AND np.name LIKE CONCAT('%', #{notePrompt.name}, '%')
</if>
<if test="notePrompt.type != null and notePrompt.type != ''">
AND np.type = #{notePrompt.type}
</if>
ORDER BY np.create_time DESC
</select>
</mapper>
Loading…
Cancel
Save