diff --git a/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java b/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java index 7c423b03..d94e33b5 100644 --- a/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java +++ b/virtual-patient-graph/src/main/java/com/supervision/service/impl/GraphNebulaServiceImpl.java @@ -347,7 +347,9 @@ public class GraphNebulaServiceImpl implements GraphNebulaService { private void recursionGenerateSingleId(List firstNodeList) { for (TreeNodeVO treeNodeVO : firstNodeList) { String uuid = UuidUtils.generateUuid(); - treeNodeVO.setGraphId(treeNodeVO.getId()); + if (StrUtil.isBlank(treeNodeVO.getGraphId())) { + treeNodeVO.setGraphId(treeNodeVO.getId()); + } treeNodeVO.setId(uuid); if (CollUtil.isNotEmpty(treeNodeVO.getChildren())) { recursionGenerateSingleId(treeNodeVO.getChildren()); diff --git a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java index 0433e612..c8b6fea4 100644 --- a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java +++ b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java @@ -1,12 +1,14 @@ package com.supervision; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.io.FileUtil; import cn.hutool.core.map.MapUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ReUtil; import cn.hutool.http.HttpRequest; import cn.hutool.http.HttpResponse; import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; import cn.hutool.poi.excel.ExcelReader; @@ -27,9 +29,9 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -156,9 +158,10 @@ public class AskTemplateIdTest { @Test public void testRasa() { + String medicalId = "1"; List aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list(); Map dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity())); - List list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getAnswerType, 1).list(); + List list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getMedicalId, medicalId).list(); List libraryList = askTemplateQuestionLibraryService.list(); Map libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, Function.identity())); for (AskPatientAnswer answer : list) { @@ -169,20 +172,91 @@ public class AskTemplateIdTest { if (intent.getName().startsWith("Q")) { String id = intent.getName().split("_")[1]; if (!id.equals(answer.getLibraryQuestionId())) { - log.info("问题:{}匹配不正确,走了其他回答,期望ID为:{},实际ID为:{},实际分类为:{},期望分类为:{}", bean.getText(), answer.getLibraryQuestionId(), id, + log.info("问题:{}匹配不正确,走了其他回答,实际分类为:{},期望分类为:{},期望ID为:{},实际ID为:{}", + bean.getText(), dictMap.get(libraryMap.get(id).getDictId()).getNameZhPath(), - dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath() + dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath(), + answer.getLibraryQuestionId(), + id ); } } else { - log.info("问题:{}匹配不正确,走了默认回答", bean.getText()); + log.info("问题:{}匹配不正确,走了默认回答,期望分类为:{},期望ID为:{}", + bean.getText(), + dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath(), + answer.getLibraryQuestionId()); } } } + /** + * 相似度匹配数据清洗 + */ + @Test + public void similarityMatchingDataCleaning() { + List aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list(); + Map dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getNameZhPath, Function.identity())); + List libraryList = askTemplateQuestionLibraryService.list(); + Map libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getDictId, Function.identity())); + // 遍历match_result.json文件 + JSONArray jsonArray = JSONUtil.readJSONArray(FileUtil.file("/Users/flevance/Desktop/知识图谱调优/matching_results.json"), StandardCharsets.UTF_8); + // 根据测试问题判断,匹配度最高的问题,是不是包含在question列表内,如果不是,就提示 + for (Object o : jsonArray) { + JSONObject jsonObject = (JSONObject) o; + String testQuestion = jsonObject.get("测试问题", String.class); + // 测试数据对应的意图 + String intent = jsonObject.get("对应意图", String.class); + JSONArray matchList = jsonObject.get("匹配问题", JSONArray.class); + // 如果不为空,找到最大的那一个 + find: + if (CollUtil.isNotEmpty(matchList)) { + for (Object o1 : matchList) { + JSONObject matchNode = (JSONObject) o1; + Double similarity = matchNode.get("Similarity", Double.class); + String question = matchNode.get("Question", String.class); + CommonDic dic = dictMap.get(intent); + if (ObjectUtil.isEmpty(dic)) { + log.info("测试问题:{},对应意图:{},期望的意图在数据库中未找到", testQuestion, intent); + } + AskTemplateQuestionLibrary askTemplateQuestionLibrary = libraryMap.get(dic.getId()); + if (ObjectUtil.isEmpty(askTemplateQuestionLibrary)) { + log.info("测试问题:{},对应意图:{},期望的意图找到了,但是期望意图对应的知识库数据未找到", testQuestion, intent); + } + Set dataBaseQuestionSet = new HashSet<>(askTemplateQuestionLibrary.getQuestion()); + if (!dataBaseQuestionSet.contains(question)) { + log.info("测试问题:{},对应意图:{},实际匹配相似度最高的问题为:{},但并不是实际期待匹配的意图", testQuestion, intent, question); + } else { + break find; + } + } + } + + } + } + + @Test + public void testQuestionFromExcel() { + ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/知识图谱调优/副本T训练语料.xlsx"); + List> read = reader.read(1); + ArrayList> maps = new ArrayList<>(); + for (List objects : read) { + Map node = new HashMap<>(); + String s0 = (String) objects.get(0); + String s1 = (String) objects.get(1); + node.put("类目", s0.trim() + "/" + s1.trim()); + String s2 = (String) objects.get(2); + node.put("测试问题", s2.trim()); + maps.add(node); + } + JSONArray objects = JSONUtil.parseArray(maps); + String stringPretty = objects.toStringPretty(); + File file = FileUtil.newFile("/Users/flevance/Desktop/知识图谱调优/副本T训练语料.json"); + FileUtil.writeUtf8String(stringPretty, file); + } + @Data private static class GptParam { private List messages;