diff --git a/src/main/java/com/supervision/police/controller/RecordController.java b/src/main/java/com/supervision/police/controller/RecordController.java index c19d86a..0f7a8ee 100644 --- a/src/main/java/com/supervision/police/controller/RecordController.java +++ b/src/main/java/com/supervision/police/controller/RecordController.java @@ -6,6 +6,7 @@ import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NoteRecord; import com.supervision.police.domain.TripleInfo; import com.supervision.police.dto.ListDTO; +import com.supervision.police.dto.TypeDTO; import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.NoteRecordSplitService; import io.swagger.annotations.ApiOperation; @@ -44,10 +45,10 @@ public class RecordController { return R.ok(modelRecordTypeService.queryType(name, page, size)); } -// @PostMapping("saveType") -// public R saveType(@RequestBody ModelRecordType type) { -// return modelRecordTypeService.saveType(type); -// } + @GetMapping("queryTypeListChoose") + public R> queryTypeListChoose() { + return R.ok(modelRecordTypeService.queryTypeListChoose()); + } /** * 保存提示词 @@ -88,7 +89,7 @@ public class RecordController { } @GetMapping("testExtractThreeInfo") - public void testExtractThreeInfo(){ + public void testExtractThreeInfo() { modelRecordTypeService.testExtractThreeInfo(); } @@ -138,5 +139,4 @@ public class RecordController { } - } diff --git a/src/main/java/com/supervision/police/domain/NotePrompt.java b/src/main/java/com/supervision/police/domain/NotePrompt.java index 96881f8..ff09ed7 100644 --- a/src/main/java/com/supervision/police/domain/NotePrompt.java +++ b/src/main/java/com/supervision/police/domain/NotePrompt.java @@ -6,6 +6,7 @@ import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.fasterxml.jackson.annotation.JsonFormat; import com.supervision.police.dto.TripleInfoDTO; +import com.supervision.police.dto.TypeDTO; import lombok.Data; import java.io.Serializable; @@ -25,7 +26,12 @@ public class NotePrompt implements Serializable { /** * 笔录类型id */ - private String typeId; + // private String typeId; + /** + * prompt关联的类型 + */ + @TableField(exist = false) + private List typeList; /** * 提示词 diff --git a/src/main/java/com/supervision/police/dto/TypeDTO.java b/src/main/java/com/supervision/police/dto/TypeDTO.java new file mode 100644 index 0000000..e048397 --- /dev/null +++ b/src/main/java/com/supervision/police/dto/TypeDTO.java @@ -0,0 +1,15 @@ +package com.supervision.police.dto; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class TypeDTO { + + private String id; + + private String recordType; +} diff --git a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java index 8dc52fe..d26fe87 100644 --- a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java +++ b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java @@ -5,6 +5,7 @@ import com.supervision.common.domain.R; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.TripleInfo; +import com.supervision.police.dto.TypeDTO; import java.util.List; @@ -12,6 +13,8 @@ public interface ModelRecordTypeService extends IService { List queryType(String name, Integer page, Integer size); + List queryTypeListChoose(); + ModelRecordType queryByName(String content); R saveType(ModelRecordType type); diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index 198f778..b0fd21b 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -51,9 +51,11 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { @Autowired private NoteRecordSplitService noteRecordSplitService; + private final NotePromptTypeRelService notePromptTypeRelService; + @Async - @Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class) + @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class) public void extractTripleInfo(String caseId, String name, String recordId) { // 首先获取所有切分后的笔录 List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) @@ -78,7 +80,14 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { } else { // 根据笔录类型找到所有的提取三元组的提示词 // 一个提示词可能关联多个类型,要进行拆分操作 - List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, typeId).list(); + List promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list(); + if (CollUtil.isEmpty(promptTypeRelList)) { + log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); + continue; + } + List prompts = notePromptService.lambdaQuery() + .in(NotePrompt::getId, promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet())) + .list(); if (CollUtil.isEmpty(prompts)) { log.info("{} 切分笔录类型:{}无对应的提示词,跳过", recordSplit.getId(), typeName); } else { diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index 80fb67d..e72c807 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.police.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.baomidou.mybatisplus.core.conditions.Wrapper; @@ -12,6 +13,7 @@ import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.service.Neo4jService; import com.supervision.police.domain.*; import com.supervision.police.dto.TripleInfoDTO; +import com.supervision.police.dto.TypeDTO; import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.service.*; @@ -27,9 +29,9 @@ import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StopWatch; import java.time.LocalDateTime; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; @Slf4j @Service @@ -50,24 +52,48 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl queryTypeListChoose() { + List list = modelRecordTypeMapper.selectByName(null); + return list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toList()); + } + @Override public List queryType(String name, Integer page, Integer size) { List list = modelRecordTypeMapper.selectByName(name); + Map modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity())); + // 获取类型和三元组提示词的关系表 + List relList = notePromptTypeRelService.lambdaQuery().list(); + Map> typeIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getTypeId)); + Map> promptIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getPromptId)); for (ModelRecordType modelRecordType : list) { //笔录内容 List noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType()); modelRecordType.setRecords(noteRecords); - // grideOptions //提示词 - List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list(); - for (NotePrompt prompt : prompts) { - prompt.setTripleList(buildTripleInfo(prompt)); + // 根据类型表获取所有的prompt + List promptRelList = typeIdMap.get(modelRecordType.getId()); + if (CollUtil.isNotEmpty(promptRelList)) { + Set promptIdSet = promptRelList.stream().map(NotePromptTypeRel::getPromptId).collect(Collectors.toSet()); + List prompts = Optional.ofNullable(notePromptService.listByIds(promptIdSet)).orElse(new ArrayList<>()); + for (NotePrompt prompt : prompts) { + List modelRecordTypes = new ArrayList<>(); + List promptTypeRelList = promptIdMap.get(prompt.getId()); + for (NotePromptTypeRel notePromptTypeRel : promptTypeRelList) { + TypeDTO existType = modelRecordTypeMap.get(notePromptTypeRel.getTypeId()); + modelRecordTypes.add(existType.getId()); + } + prompt.setTypeList(modelRecordTypes); + prompt.setTripleList(buildTripleInfo(prompt)); + } + modelRecordType.setPrompts(prompts); } - modelRecordType.setPrompts(prompts); } return list; } @@ -116,15 +142,56 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl addOrUpdPrompt(NotePrompt prompt) { + List typeList = prompt.getTypeList(); + if (CollUtil.isEmpty(typeList)) { + throw new RuntimeException("类型信息不能为空"); + } int i = 0; boolean save; if (StringUtils.isEmpty(prompt.getId())) { save = notePromptService.save(prompt); + // 新增prompt绑定的分类信息 + for (String typeId : typeList) { + NotePromptTypeRel rel = new NotePromptTypeRel(); + rel.setPromptId(prompt.getId()); + rel.setTypeId(typeId); + notePromptTypeRelService.save(rel); + } } else { save = notePromptService.updateById(prompt); + // 更新prompt绑定的分类信息 + // 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增 + List existDatabaseRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getPromptId, prompt.getId()).list(); + + if (CollUtil.isNotEmpty(existDatabaseRelList)) { + Set existTypeList = existDatabaseRelList.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toSet()); + Set frontRelIdList = new HashSet<>(typeList); + // 删除(数据库有,前端没有的) + List deleteIdList = existTypeList.stream().filter(id -> !frontRelIdList.contains(id)).collect(Collectors.toList()); + if (CollUtil.isNotEmpty(deleteIdList)) { + notePromptTypeRelService.removeByIds(deleteIdList); + } + // 新增(前端有数据库没有的) + frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e ->{ + NotePromptTypeRel rel = new NotePromptTypeRel(); + rel.setPromptId(prompt.getId()); + rel.setTypeId(e); + notePromptTypeRelService.save(rel); + }); + }else { + // 如果数据库里面没查到,直接新增,一般不会走这一步 + for (String typeId : typeList) { + NotePromptTypeRel rel = new NotePromptTypeRel(); + rel.setPromptId(prompt.getId()); + rel.setTypeId(typeId); + notePromptTypeRelService.save(rel); + } + } } + + // 更新类型字段 List tripleList = prompt.getTripleList(); for (TripleInfoDTO dto : tripleList) { @@ -132,11 +199,11 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl delPrompt(NotePrompt prompt) { String id = prompt.getId(); boolean removeById = notePromptService.removeById(id); + // 删除同时从所有分类里面都删除掉 + notePromptTypeRelService.lambdaUpdate().eq(NotePromptTypeRel::getPromptId, id).remove(); if (removeById) { return R.ok("删除成功"); } else {