Merge remote-tracking branch 'origin/dev_2.1.0' into dev_2.1.0

dev_2.1.0
xueqingkun 1 year ago
commit 50442fcb4f

@ -27,15 +27,15 @@ public class GraphNebulaController {
@ApiOperation("查询图谱") @ApiOperation("查询图谱")
@GetMapping("queryGraph") @GetMapping("queryGraph")
public GraphVO queryGraph(String processId) { public GraphVO queryGraph(String processId, Integer level) {
return graphNebulaService.queryGraph(processId); return graphNebulaService.queryGraph(processId, level);
} }
@ApiOperation("查询树形结构图") @ApiOperation("查询树形结构图")
@GetMapping("queryTreeGraph") @GetMapping("queryTreeGraph")
public List<TreeNodeVO> queryTreeGraph(String processId) { public List<TreeNodeVO> queryTreeGraph(String processId, Integer level) {
return graphNebulaService.queryTreeGraph(processId); return graphNebulaService.queryTreeGraph(processId, level);
} }

@ -9,8 +9,8 @@ public interface GraphNebulaService {
void creatGraphByNebula(String processId); void creatGraphByNebula(String processId);
GraphVO queryGraph(String processId); GraphVO queryGraph(String processId, Integer level);
List<TreeNodeVO> queryTreeGraph(String processId); List<TreeNodeVO> queryTreeGraph(String processId, Integer level);
} }

@ -7,6 +7,7 @@ import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.alibaba.nacos.common.utils.UuidUtils;
import com.supervision.dao.*; import com.supervision.dao.*;
import com.supervision.domain.*; import com.supervision.domain.*;
import com.supervision.enums.TagEnum; import com.supervision.enums.TagEnum;
@ -245,7 +246,7 @@ public class GraphNebulaServiceImpl implements GraphNebulaService {
@Override @Override
public GraphVO queryGraph(String processId) { public GraphVO queryGraph(String processId, Integer level) {
Process process = Optional.ofNullable(processService.getById(processId)).orElseThrow(() -> new BusinessException("未找到对应的问诊流程")); Process process = Optional.ofNullable(processService.getById(processId)).orElseThrow(() -> new BusinessException("未找到对应的问诊流程"));
// 如果图谱ID为空,则创建图谱 // 如果图谱ID为空,则创建图谱
if (StrUtil.isEmpty(process.getGraphId())) { if (StrUtil.isEmpty(process.getGraphId())) {
@ -287,29 +288,43 @@ public class GraphNebulaServiceImpl implements GraphNebulaService {
}); });
} }
} }
// 校验级别(根据参数的级别来进行判断)
if (ObjectUtil.isNotEmpty(level)) {
if (level >= nodeVO.getNodeLevel()) {
nodeList.add(nodeVO); nodeList.add(nodeVO);
} }
} else {
nodeList.add(nodeVO);
}
}
// 构建边 // 构建边
List<NgEdge<String>> edges = subgraph.getEdges(); List<NgEdge<String>> edges = subgraph.getEdges();
for (NgEdge<String> edge : edges) { for (NgEdge<String> edge : edges) {
EdgeVO edgeVO = new EdgeVO(); EdgeVO edgeVO = new EdgeVO();
edgeVO.setSource(edge.getSrcID()); edgeVO.setSource(edge.getSrcID());
edgeVO.setTarget(edge.getDstID()); edgeVO.setTarget(edge.getDstID());
Map<String, Object> properties = edge.getProperties(); Map<String, Object> properties = edge.getProperties();
Object nameObject = properties.get("edgeValue"); Object nameObject = properties.get("edgeValue");
if (ObjectUtil.isNotEmpty(nameObject)) { if (ObjectUtil.isNotEmpty(nameObject)) {
edgeVO.setName(String.valueOf(nameObject)); edgeVO.setName(String.valueOf(nameObject));
edgeVO.setLabel(edgeVO.getName());
} }
edgeVO.setParams(properties); edgeVO.setParams(properties);
edgeList.add(edgeVO); edgeList.add(edgeVO);
} }
} }
// 这里,需要遍历,把重点不存在的节点连线给删掉
Set<String> nodeIdSet = nodeList.stream().map(NodeVO::getId).collect(Collectors.toSet());
// 如果指向的节点不存在,那么这个边也不存在
edgeList.removeIf(edgeVO -> !nodeIdSet.contains(edgeVO.getTarget()));
return new GraphVO(nodeList, edgeList); return new GraphVO(nodeList, edgeList);
} }
@Override @Override
public List<TreeNodeVO> queryTreeGraph(String processId) { public List<TreeNodeVO> queryTreeGraph(String processId, Integer level) {
GraphVO graphVO = queryGraph(processId); GraphVO graphVO = queryGraph(processId, level);
List<TreeNodeVO> treeNodeList = graphVO.getNodes().stream().map(node -> BeanUtil.toBean(node, TreeNodeVO.class)).collect(Collectors.toList()); List<TreeNodeVO> treeNodeList = graphVO.getNodes().stream().map(node -> BeanUtil.toBean(node, TreeNodeVO.class)).collect(Collectors.toList());
// 首先找到第一级节点 // 首先找到第一级节点
List<TreeNodeVO> firstNodeList = treeNodeList.stream().filter(node -> node.getNodeLevel() == 1).collect(Collectors.toList()); List<TreeNodeVO> firstNodeList = treeNodeList.stream().filter(node -> node.getNodeLevel() == 1).collect(Collectors.toList());
@ -321,9 +336,25 @@ public class GraphNebulaServiceImpl implements GraphNebulaService {
for (TreeNodeVO nodeVO : firstNodeList) { for (TreeNodeVO nodeVO : firstNodeList) {
recursionBuildTree(nodeVO, treeNodeMap, graphVO.getEdges()); recursionBuildTree(nodeVO, treeNodeMap, graphVO.getEdges());
} }
// 为所有节点分配新的唯一ID(前端需要ID字段为唯一ID)
recursionGenerateSingleId(firstNodeList);
return firstNodeList; return firstNodeList;
} }
/**
* ID,IDGraphId
*/
private void recursionGenerateSingleId(List<TreeNodeVO> firstNodeList) {
for (TreeNodeVO treeNodeVO : firstNodeList) {
String uuid = UuidUtils.generateUuid();
treeNodeVO.setGraphId(treeNodeVO.getId());
treeNodeVO.setId(uuid);
if (CollUtil.isNotEmpty(treeNodeVO.getChildren())) {
recursionGenerateSingleId(treeNodeVO.getChildren());
}
}
}
private void recursionBuildTree(TreeNodeVO preNode, Map<String, TreeNodeVO> treeNodeMap, List<EdgeVO> edgeList) { private void recursionBuildTree(TreeNodeVO preNode, Map<String, TreeNodeVO> treeNodeMap, List<EdgeVO> edgeList) {
// 通过preNode的ID找到所有的子节点 // 通过preNode的ID找到所有的子节点
List<TreeNodeVO> childNode = new ArrayList<>(); List<TreeNodeVO> childNode = new ArrayList<>();
@ -336,7 +367,7 @@ public class GraphNebulaServiceImpl implements GraphNebulaService {
; ;
} }
} }
preNode.setChildNodeList(childNode); preNode.setChildren(childNode);
for (TreeNodeVO treeNodeVO : childNode) { for (TreeNodeVO treeNodeVO : childNode) {
recursionBuildTree(treeNodeVO, treeNodeMap, edgeList); recursionBuildTree(treeNodeVO, treeNodeMap, edgeList);
} }

@ -16,6 +16,8 @@ public class EdgeVO {
@ApiModelProperty("连线展示的名称,可能为空") @ApiModelProperty("连线展示的名称,可能为空")
private String name; private String name;
@ApiModelProperty("连线展示的名称,=name,前端需要")
private String label;
@ApiModelProperty("连线所拥有的属性") @ApiModelProperty("连线所拥有的属性")
private Map<String, Object> params; private Map<String, Object> params;

@ -12,8 +12,10 @@ import java.util.Map;
@ApiModel @ApiModel
public class TreeNodeVO { public class TreeNodeVO {
@ApiModelProperty("ID") @ApiModelProperty("重新分配的唯一ID")
private String id; private String id;
@ApiModelProperty("图谱ID")
private String graphId;
@ApiModelProperty("节点值") @ApiModelProperty("节点值")
private String nodeValue; private String nodeValue;
@ApiModelProperty("节点颜色") @ApiModelProperty("节点颜色")
@ -31,5 +33,5 @@ public class TreeNodeVO {
* *
*/ */
@ApiModelProperty("子节点") @ApiModelProperty("子节点")
private List<TreeNodeVO> childNodeList; private List<TreeNodeVO> children;
} }

@ -1,8 +1,11 @@
package com.supervision; package com.supervision;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReUtil; import cn.hutool.core.util.ReUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject; import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
@ -12,12 +15,11 @@ import com.baomidou.mybatisplus.core.incrementer.DefaultIdentifierGenerator;
import com.supervision.model.AskPatientAnswer; import com.supervision.model.AskPatientAnswer;
import com.supervision.model.AskTemplateQuestionLibrary; import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic; import com.supervision.model.CommonDic;
import com.supervision.pojo.vo.TalkResultResVO;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.service.AskPatientAnswerService; import com.supervision.service.AskPatientAnswerService;
import com.supervision.service.AskService; import com.supervision.service.AskService;
import com.supervision.service.AskTemplateQuestionLibraryService; import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.CommonDicService; import com.supervision.service.CommonDicService;
import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -25,7 +27,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import java.io.IOException;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -46,6 +47,9 @@ public class AskTemplateIdTest {
@Autowired @Autowired
private CommonDicService commonDicService; private CommonDicService commonDicService;
@Autowired
private AskService askService;
@Test @Test
public void creatAskId() { public void creatAskId() {
Object o = new Object(); Object o = new Object();
@ -125,18 +129,84 @@ public class AskTemplateIdTest {
} }
} }
@Test
@Autowired public void generateByGpt() {
private AskService askService; String api_key = "sk-FDNQ1bhd7007e62e714eT3BLbKFJ004fcC3ebDeA4542a516";
String url = "https://aigptx.top/v1/chat/completions";
String question = "假设你是一个精通RASA NLU调优的工程师且具备丰富医疗经验;" +
"我现在有一个意图请你根据这个意图针对这个问题示例提出10条医生在问诊时,可能根据这个意图来提问患者的问题.\n" +
"注意,问题不要超出这个意图的范围,始终契合意图的关键词\n" +
"回答请使用json array的格式示例[\"相似问题1\",\"相似问题2\"]\n" +
"### 下面是问题示例\n" +
"这种感觉持续多久了?";
GptParam gptParam = new GptParam();
GptMessage gptMessage = new GptMessage();
gptMessage.setContent(question);
gptParam.setMessages(CollUtil.newArrayList(gptMessage));
HttpResponse response = HttpRequest.post(url)
.header("Authorization", "Bearer " + api_key)
.body(JSONUtil.toJsonStr(gptParam))
.execute();
String body = response.body();
System.out.println(body);
}
@Test @Test
public void testRasa() throws IOException { public void testRasa() {
TalkVideoReqVO talkVideoReqVO = new TalkVideoReqVO(); List<CommonDic> aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list();
talkVideoReqVO.setText("你现在感觉怎么样?"); Map<Long, CommonDic> dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity()));
talkVideoReqVO.setProcessId("1749312510591934465"); List<AskPatientAnswer> list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getAnswerType, 1).list();
TalkResultResVO talkResultResVO = askService.talkByVideo(talkVideoReqVO); List<AskTemplateQuestionLibrary> libraryList = askTemplateQuestionLibraryService.list();
Map<String, AskTemplateQuestionLibrary> libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, Function.identity()));
for (AskPatientAnswer answer : list) {
Map<Object, Object> build = MapUtil.builder().put("text", answer.getQuestion()).build();
String post = HttpUtil.post("http://localhost:5005/model/parse", JSONUtil.toJsonStr(build));
RasaResult bean = JSONUtil.toBean(post, RasaResult.class);
ResaIntentResult intent = bean.getIntent();
if (intent.getName().startsWith("Q")) {
String id = intent.getName().split("_")[1];
if (!id.equals(answer.getLibraryQuestionId())) {
log.info("问题:{}匹配不正确,走了其他回答,期望ID为:{},实际ID为:{},实际分类为:{},期望分类为:{}", bean.getText(), answer.getLibraryQuestionId(), id,
dictMap.get(libraryMap.get(id).getDictId()).getNameZhPath(),
dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath()
);
}
} else {
log.info("问题:{}匹配不正确,走了默认回答", bean.getText());
}
}
}
@Data
private static class GptParam {
private List<GptMessage> messages;
// # 如果需要切换模型,在这里修改
private String model = "gpt-3.5-turbo";
}
@Data
private static class GptMessage {
private String role = "user";
private String content;
}
@Data
private static class RasaResult {
private String text;
private ResaIntentResult intent;
private List<ResaIntentResult> intent_ranking;
}
System.out.println(JSONUtil.toJsonStr(talkResultResVO)); @Data
private static class ResaIntentResult {
private String name;
private Double confidence;
} }

Loading…
Cancel
Save