提示词配置代码优化

topo_dev
liu 9 months ago
parent 4c316bdc60
commit 6ee70ac254

@ -6,6 +6,7 @@ import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecord;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.dto.ListDTO;
import com.supervision.police.dto.TypeDTO;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import io.swagger.annotations.ApiOperation;
@ -44,10 +45,10 @@ public class RecordController {
return R.ok(modelRecordTypeService.queryType(name, page, size));
}
// @PostMapping("saveType")
// public R<?> saveType(@RequestBody ModelRecordType type) {
// return modelRecordTypeService.saveType(type);
// }
@GetMapping("queryTypeListChoose")
public R<List<TypeDTO>> queryTypeListChoose() {
return R.ok(modelRecordTypeService.queryTypeListChoose());
}
/**
*
@ -88,7 +89,7 @@ public class RecordController {
}
@GetMapping("testExtractThreeInfo")
public void testExtractThreeInfo(){
public void testExtractThreeInfo() {
modelRecordTypeService.testExtractThreeInfo();
}
@ -138,5 +139,4 @@ public class RecordController {
}
}

@ -6,6 +6,7 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonFormat;
import com.supervision.police.dto.TripleInfoDTO;
import com.supervision.police.dto.TypeDTO;
import lombok.Data;
import java.io.Serializable;
@ -25,7 +26,12 @@ public class NotePrompt implements Serializable {
/**
* id
*/
private String typeId;
// private String typeId;
/**
* prompt
*/
@TableField(exist = false)
private List<String> typeList;
/**
*

@ -0,0 +1,15 @@
package com.supervision.police.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class TypeDTO {
private String id;
private String recordType;
}

@ -5,6 +5,7 @@ import com.supervision.common.domain.R;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.dto.TypeDTO;
import java.util.List;
@ -12,6 +13,8 @@ public interface ModelRecordTypeService extends IService<ModelRecordType> {
List<ModelRecordType> queryType(String name, Integer page, Integer size);
List<TypeDTO> queryTypeListChoose();
ModelRecordType queryByName(String content);
R<?> saveType(ModelRecordType type);

@ -51,9 +51,11 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
@Autowired
private NoteRecordSplitService noteRecordSplitService;
private final NotePromptTypeRelService notePromptTypeRelService;
@Async
@Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class)
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)
@ -78,7 +80,14 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
} else {
// 根据笔录类型找到所有的提取三元组的提示词
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list();
List<NotePromptTypeRel> promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list();
if (CollUtil.isEmpty(promptTypeRelList)) {
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
continue;
}
List<NotePrompt> prompts = notePromptService.lambdaQuery()
.in(NotePrompt::getId, promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet()))
.list();
if (CollUtil.isEmpty(prompts)) {
log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName);
} else {

@ -1,5 +1,6 @@
package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
@ -12,6 +13,7 @@ import com.supervision.neo4j.domain.Rel;
import com.supervision.neo4j.service.Neo4jService;
import com.supervision.police.domain.*;
import com.supervision.police.dto.TripleInfoDTO;
import com.supervision.police.dto.TypeDTO;
import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.service.*;
@ -27,9 +29,9 @@ import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
@Service
@ -50,24 +52,48 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
private final CaseTaskRecordService caseTaskRecordService;
private final NotePromptTypeRelService notePromptTypeRelService;
@Autowired
private ExtractTripleInfoService extractTripleInfo;
@Override
public List<TypeDTO> queryTypeListChoose() {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(null);
return list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toList());
}
@Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name);
Map<String, TypeDTO> modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity()));
// 获取类型和三元组提示词的关系表
List<NotePromptTypeRel> relList = notePromptTypeRelService.lambdaQuery().list();
Map<String, List<NotePromptTypeRel>> typeIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getTypeId));
Map<String, List<NotePromptTypeRel>> promptIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getPromptId));
for (ModelRecordType modelRecordType : list) {
//笔录内容
List<NoteRecordSplit> noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType());
modelRecordType.setRecords(noteRecords);
// grideOptions
//提示词
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list();
for (NotePrompt prompt : prompts) {
prompt.setTripleList(buildTripleInfo(prompt));
// 根据类型表获取所有的prompt
List<NotePromptTypeRel> promptRelList = typeIdMap.get(modelRecordType.getId());
if (CollUtil.isNotEmpty(promptRelList)) {
Set<String> promptIdSet = promptRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet());
List<NotePrompt> prompts = Optional.ofNullable(notePromptService.listByIds(promptIdSet)).orElse(new ArrayList<>());
for (NotePrompt prompt : prompts) {
List<String> modelRecordTypes = new ArrayList<>();
List<NotePromptTypeRel> promptTypeRelList = promptIdMap.get(prompt.getId());
for (NotePromptTypeRel notePromptTypeRel : promptTypeRelList) {
TypeDTO existType = modelRecordTypeMap.get(notePromptTypeRel.getTypeId());
modelRecordTypes.add(existType.getId());
}
prompt.setTypeList(modelRecordTypes);
prompt.setTripleList(buildTripleInfo(prompt));
}
modelRecordType.setPrompts(prompts);
}
modelRecordType.setPrompts(prompts);
}
return list;
}
@ -116,15 +142,56 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
}
@Override
@Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class)
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public R<?> addOrUpdPrompt(NotePrompt prompt) {
List<String> typeList = prompt.getTypeList();
if (CollUtil.isEmpty(typeList)) {
throw new RuntimeException("类型信息不能为空");
}
int i = 0;
boolean save;
if (StringUtils.isEmpty(prompt.getId())) {
save = notePromptService.save(prompt);
// 新增prompt绑定的分类信息
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(typeId);
notePromptTypeRelService.save(rel);
}
} else {
save = notePromptService.updateById(prompt);
// 更新prompt绑定的分类信息
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
List<NotePromptTypeRel> existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list();
if (CollUtil.isNotEmpty(existDatabaseRelList)) {
Set<String> existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet());
Set<String> frontRelIdList = new HashSet<>(typeList);
// 删除(数据库有,前端没有的)
List<String> deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList());
if (CollUtil.isNotEmpty(deleteIdList)) {
notePromptTypeRelService.removeByIds(deleteIdList);
}
// 新增(前端有数据库没有的)
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e ->{
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(e);
notePromptTypeRelService.save(rel);
});
}else {
// 如果数据库里面没查到,直接新增,一般不会走这一步
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(typeId);
notePromptTypeRelService.save(rel);
}
}
}
// 更新类型字段
List<TripleInfoDTO> tripleList = prompt.getTripleList();
for (TripleInfoDTO dto : tripleList) {
@ -132,11 +199,11 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getStartEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}else if ("关系".equals(dto.getType())){
} else if ("关系".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
.set(NotePrompt::getRelType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}else if ("尾节点".equals(dto.getType())){
} else if ("尾节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getEndEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
@ -154,6 +221,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
public R<?> delPrompt(NotePrompt prompt) {
String id = prompt.getId();
boolean removeById = notePromptService.removeById(id);
// 删除同时从所有分类里面都删除掉
notePromptTypeRelService.lambdaUpdate().eq(NotePromptTypeRel::getPromptId, id).remove();
if (removeById) {
return R.ok("删除成功");
} else {

Loading…
Cancel
Save