diff --git a/virtual-patient-graph/src/main/java/com/supervision/controller/GraphNebulaController.java b/virtual-patient-graph/src/main/java/com/supervision/controller/GraphNebulaController.java index 6abfb6d2..3a6d40fc 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/controller/GraphNebulaController.java +++ b/virtual-patient-graph/src/main/java/com/supervision/controller/GraphNebulaController.java @@ -27,15 +27,15 @@ public class GraphNebulaController { @ApiOperation("查询图谱") @GetMapping("queryGraph") - public GraphVO queryGraph(String processId) { - return graphNebulaService.queryGraph(processId); + public GraphVO queryGraph(String processId, Integer level) { + return graphNebulaService.queryGraph(processId, level); } @ApiOperation("查询树形结构图") @GetMapping("queryTreeGraph") - public List queryTreeGraph(String processId) { - return graphNebulaService.queryTreeGraph(processId); + public List queryTreeGraph(String processId, Integer level) { + return graphNebulaService.queryTreeGraph(processId, level); } diff --git a/virtual-patient-graph/src/main/java/com/supervision/service/GraphNebulaService.java b/virtual-patient-graph/src/main/java/com/supervision/service/GraphNebulaService.java index 0891d658..009af60c 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/service/GraphNebulaService.java +++ b/virtual-patient-graph/src/main/java/com/supervision/service/GraphNebulaService.java @@ -9,8 +9,8 @@ public interface GraphNebulaService { void creatGraphByNebula(String processId); - GraphVO queryGraph(String processId); + GraphVO queryGraph(String processId, Integer level); - List queryTreeGraph(String processId); + List queryTreeGraph(String processId, Integer level); } diff --git a/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java b/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java index b9fe54c7..7c423b03 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java +++ b/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java @@ -7,6 +7,7 @@ import cn.hutool.core.lang.Assert; import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; +import com.alibaba.nacos.common.utils.UuidUtils; import com.supervision.dao.*; import com.supervision.domain.*; import com.supervision.enums.TagEnum; @@ -245,7 +246,7 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { @Override - public GraphVO queryGraph(String processId) { + public GraphVO queryGraph(String processId, Integer level) { Process process = Optional.ofNullable(processService.getById(processId)).orElseThrow(() -> new BusinessException("未找到对应的问诊流程")); // 如果图谱ID为空,则创建图谱 if (StrUtil.isEmpty(process.getGraphId())) { @@ -287,7 +288,14 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { }); } } - nodeList.add(nodeVO); + // 校验级别(根据参数的级别来进行判断) + if (ObjectUtil.isNotEmpty(level)) { + if (level >= nodeVO.getNodeLevel()) { + nodeList.add(nodeVO); + } + } else { + nodeList.add(nodeVO); + } } // 构建边 List> edges = subgraph.getEdges(); @@ -295,21 +303,28 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { EdgeVO edgeVO = new EdgeVO(); edgeVO.setSource(edge.getSrcID()); edgeVO.setTarget(edge.getDstID()); + Map properties = edge.getProperties(); Object nameObject = properties.get("edgeValue"); if (ObjectUtil.isNotEmpty(nameObject)) { edgeVO.setName(String.valueOf(nameObject)); + edgeVO.setLabel(edgeVO.getName()); } edgeVO.setParams(properties); edgeList.add(edgeVO); } } + // 这里,需要遍历,把重点不存在的节点连线给删掉 + Set nodeIdSet = nodeList.stream().map(NodeVO::getId).collect(Collectors.toSet()); + // 如果指向的节点不存在,那么这个边也不存在 + edgeList.removeIf(edgeVO -> !nodeIdSet.contains(edgeVO.getTarget())); + return new GraphVO(nodeList, edgeList); } @Override - public List queryTreeGraph(String processId) { - GraphVO graphVO = queryGraph(processId); + public List queryTreeGraph(String processId, Integer level) { + GraphVO graphVO = queryGraph(processId, level); List treeNodeList = graphVO.getNodes().stream().map(node -> BeanUtil.toBean(node, TreeNodeVO.class)).collect(Collectors.toList()); // 首先找到第一级节点 List firstNodeList = treeNodeList.stream().filter(node -> node.getNodeLevel() == 1).collect(Collectors.toList()); @@ -321,9 +336,25 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { for (TreeNodeVO nodeVO : firstNodeList) { recursionBuildTree(nodeVO, treeNodeMap, graphVO.getEdges()); } + // 为所有节点分配新的唯一ID(前端需要ID字段为唯一ID) + recursionGenerateSingleId(firstNodeList); return firstNodeList; } + /** + * 为属性结构构建新的唯一ID,把原先的ID迁移到GraphId + */ + private void recursionGenerateSingleId(List firstNodeList) { + for (TreeNodeVO treeNodeVO : firstNodeList) { + String uuid = UuidUtils.generateUuid(); + treeNodeVO.setGraphId(treeNodeVO.getId()); + treeNodeVO.setId(uuid); + if (CollUtil.isNotEmpty(treeNodeVO.getChildren())) { + recursionGenerateSingleId(treeNodeVO.getChildren()); + } + } + } + private void recursionBuildTree(TreeNodeVO preNode, Map treeNodeMap, List edgeList) { // 通过preNode的ID找到所有的子节点 List childNode = new ArrayList<>(); @@ -336,7 +367,7 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { ; } } - preNode.setChildNodeList(childNode); + preNode.setChildren(childNode); for (TreeNodeVO treeNodeVO : childNode) { recursionBuildTree(treeNodeVO, treeNodeMap, edgeList); } diff --git a/virtual-patient-graph/src/main/java/com/supervision/vo/EdgeVO.java b/virtual-patient-graph/src/main/java/com/supervision/vo/EdgeVO.java index 453522b1..b384ddac 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/vo/EdgeVO.java +++ b/virtual-patient-graph/src/main/java/com/supervision/vo/EdgeVO.java @@ -16,6 +16,8 @@ public class EdgeVO { @ApiModelProperty("连线展示的名称,可能为空") private String name; + @ApiModelProperty("连线展示的名称,=name,前端需要") + private String label; @ApiModelProperty("连线所拥有的属性") private Map params; diff --git a/virtual-patient-graph/src/main/java/com/supervision/vo/TreeNodeVO.java b/virtual-patient-graph/src/main/java/com/supervision/vo/TreeNodeVO.java index 39384497..522db22d 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/vo/TreeNodeVO.java +++ b/virtual-patient-graph/src/main/java/com/supervision/vo/TreeNodeVO.java @@ -12,8 +12,10 @@ import java.util.Map; @ApiModel public class TreeNodeVO { - @ApiModelProperty("ID") + @ApiModelProperty("重新分配的唯一ID") private String id; + @ApiModelProperty("图谱ID") + private String graphId; @ApiModelProperty("节点值") private String nodeValue; @ApiModelProperty("节点颜色") @@ -31,5 +33,5 @@ public class TreeNodeVO { * 子节点 */ @ApiModelProperty("子节点") - private List childNodeList; + private List children; } diff --git a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java index 4abd7794..0433e612 100644 --- a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java +++ b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java @@ -1,8 +1,11 @@ package com.supervision; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.map.MapUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ReUtil; +import cn.hutool.http.HttpRequest; +import cn.hutool.http.HttpResponse; import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; @@ -12,12 +15,11 @@ import com.baomidou.mybatisplus.core.incrementer.DefaultIdentifierGenerator; import com.supervision.model.AskPatientAnswer; import com.supervision.model.AskTemplateQuestionLibrary; import com.supervision.model.CommonDic; -import com.supervision.pojo.vo.TalkResultResVO; -import com.supervision.pojo.vo.TalkVideoReqVO; import com.supervision.service.AskPatientAnswerService; import com.supervision.service.AskService; import com.supervision.service.AskTemplateQuestionLibraryService; import com.supervision.service.CommonDicService; +import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.junit.Test; import org.junit.runner.RunWith; @@ -25,7 +27,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -46,6 +47,9 @@ public class AskTemplateIdTest { @Autowired private CommonDicService commonDicService; + @Autowired + private AskService askService; + @Test public void creatAskId() { Object o = new Object(); @@ -125,18 +129,84 @@ public class AskTemplateIdTest { } } - - @Autowired - private AskService askService; + @Test + public void generateByGpt() { + String api_key = "sk-FDNQ1bhd7007e62e714eT3BLbKFJ004fcC3ebDeA4542a516"; + String url = "https://aigptx.top/v1/chat/completions"; + + String question = "假设你是一个精通RASA NLU调优的工程师且具备丰富医疗经验;" + + "我现在有一个意图,请你根据这个意图,针对这个问题示例,提出10条医生在问诊时,可能根据这个意图来提问患者的问题.\n" + + "注意,问题不要超出这个意图的范围,始终契合意图的关键词\n" + + "回答请使用json array的格式,示例:[\"相似问题1\",\"相似问题2\"]\n" + + "### 下面是问题示例\n" + + "这种感觉持续多久了?"; + + GptParam gptParam = new GptParam(); + GptMessage gptMessage = new GptMessage(); + gptMessage.setContent(question); + gptParam.setMessages(CollUtil.newArrayList(gptMessage)); + + HttpResponse response = HttpRequest.post(url) + .header("Authorization", "Bearer " + api_key) + .body(JSONUtil.toJsonStr(gptParam)) + .execute(); + String body = response.body(); + System.out.println(body); + } @Test - public void testRasa() throws IOException { - TalkVideoReqVO talkVideoReqVO = new TalkVideoReqVO(); - talkVideoReqVO.setText("你现在感觉怎么样?"); - talkVideoReqVO.setProcessId("1749312510591934465"); - TalkResultResVO talkResultResVO = askService.talkByVideo(talkVideoReqVO); + public void testRasa() { + List aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list(); + Map dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity())); + List list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getAnswerType, 1).list(); + List libraryList = askTemplateQuestionLibraryService.list(); + Map libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, Function.identity())); + for (AskPatientAnswer answer : list) { + Map build = MapUtil.builder().put("text", answer.getQuestion()).build(); + String post = HttpUtil.post("http://localhost:5005/model/parse", JSONUtil.toJsonStr(build)); + RasaResult bean = JSONUtil.toBean(post, RasaResult.class); + ResaIntentResult intent = bean.getIntent(); + if (intent.getName().startsWith("Q")) { + String id = intent.getName().split("_")[1]; + if (!id.equals(answer.getLibraryQuestionId())) { + log.info("问题:{}匹配不正确,走了其他回答,期望ID为:{},实际ID为:{},实际分类为:{},期望分类为:{}", bean.getText(), answer.getLibraryQuestionId(), id, + dictMap.get(libraryMap.get(id).getDictId()).getNameZhPath(), + dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath() + + ); + } + } else { + log.info("问题:{}匹配不正确,走了默认回答", bean.getText()); + } + + } + + } + + @Data + private static class GptParam { + private List messages; + // # 如果需要切换模型,在这里修改 + private String model = "gpt-3.5-turbo"; + } + + @Data + private static class GptMessage { + private String role = "user"; + private String content; + } + + @Data + private static class RasaResult { + private String text; + private ResaIntentResult intent; + private List intent_ranking; + } - System.out.println(JSONUtil.toJsonStr(talkResultResVO)); + @Data + private static class ResaIntentResult { + private String name; + private Double confidence; }