From 3e1a47d947c431b74b92f01f67baa735e084a468 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Tue, 28 May 2024 09:48:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/java/com/supervision/VecTest.java | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/virtual-patient-web/src/test/java/com/supervision/VecTest.java b/virtual-patient-web/src/test/java/com/supervision/VecTest.java index 57b8079b..955900ed 100644 --- a/virtual-patient-web/src/test/java/com/supervision/VecTest.java +++ b/virtual-patient-web/src/test/java/com/supervision/VecTest.java @@ -2,6 +2,7 @@ package com.supervision; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.lang.Assert; +import cn.hutool.core.lang.Pair; import cn.hutool.core.map.MapUtil; import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.ObjectUtil; @@ -14,8 +15,10 @@ import cn.hutool.json.JSONUtil; import cn.hutool.poi.excel.ExcelReader; import cn.hutool.poi.excel.ExcelUtil; import cn.hutool.poi.excel.ExcelWriter; +import com.supervision.model.AskPatientAnswer; import com.supervision.model.AskTemplateQuestionLibrary; import com.supervision.model.CommonDic; +import com.supervision.service.AskPatientAnswerService; import com.supervision.service.AskTemplateQuestionLibraryService; import com.supervision.service.CommonDicService; import lombok.extern.slf4j.Slf4j; @@ -25,6 +28,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import java.security.PublicKey; import java.util.*; import java.util.stream.Collectors; @@ -406,6 +410,47 @@ public class VecTest { saveVec(vecData); } + + + @Test + public void initDescVecData(){ + List questionLibraries = askTemplateQuestionLibraryService.list(); + // load vec data + List> vecData = questionLibraries.stream().map(library -> { + Map map = new HashMap<>(); + map.put("questionCode", library.getId()); + map.put("questionList", CollUtil.newArrayList(library.getDescription())); + return map; + }).collect(Collectors.toList()); + saveVec(vecData); + } + + @Test + public void initMedicalVecData(){ + String medicalId = "1"; + + List answerList = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, medicalId).list(); + Map answerMap = answerList.stream().collect(Collectors.toMap(AskPatientAnswer::getLibraryQuestionId, v -> v)); + List questionLibraries = askTemplateQuestionLibraryService.list(); + List> result = new ArrayList<>(); + for (AskTemplateQuestionLibrary questionLibrary : questionLibraries) { + if (!questionLibrary.getDescription().contains(questionLibrary.getDescription())){ + questionLibrary.getQuestion().add(questionLibrary.getDescription()); + } + + AskPatientAnswer askPatientAnswer = answerMap.get(questionLibrary.getId()); + if (ObjectUtil.isEmpty(askPatientAnswer)){ + log.info("问题:id:{} desc:{},未设置回复答案,跳过",questionLibrary.getId(),questionLibrary.getDescription()); + continue; + } + Map map = new HashMap<>(); + map.put("questionCode", questionLibrary.getId()); + map.put("questionList", questionLibrary.getQuestion()); + result.add(map); + } + saveVec(result); + } + private void saveVec(List> List){ HttpRequest request = HttpRequest.post(BASE_URL + "/updateDatabase") @@ -455,4 +500,107 @@ public class VecTest { return null; } + // todo: 找出相似度大于0.5的标准问的数据 + + + @Test + public void standardQuestionMatch(){ + // initDescVecData() first step + List questionLibraries = askTemplateQuestionLibraryService.list(); + List commonDics = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").list(); + Map dicMap = commonDics.stream().collect(Collectors.toMap(CommonDic::getId, v -> v)); + Map libraryMap = questionLibraries.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, library -> library)); + + List> result = new ArrayList<>(); + + for (AskTemplateQuestionLibrary questionLibrary : questionLibraries) { + String description = questionLibrary.getDescription(); + List> maps = questionMatch(description); + // 只获取前四条数据 + if (CollUtil.isEmpty(maps)){ + log.warn("questionMatch:问题:{}没有匹配到结果", description); + continue; + } + + if (maps.size() > 3){ + // 截取前四条数据 + maps = maps.subList(0, 4); + } + for (Map matchMap : maps) { + String targetId = MapUtil.getStr(matchMap, "matchQuestionCode"); + String matchQuestion = MapUtil.getStr(matchMap, "matchQuestion"); + String matchScore = MapUtil.getStr(matchMap, "matchScore"); + + if (StrUtil.equals(targetId,questionLibrary.getId())){ + log.info("匹配到自己,跳过该条数据,id:{}",targetId); + continue; + } + Map tmp = new HashMap<>(); + tmp.put("sourceId",questionLibrary.getId()); + tmp.put("sourceDesc",description); + String sourceNamePath = dicMap.get(questionLibrary.getDictId()).getNameZhPath(); + tmp.put("sourceDicPath",sourceNamePath); + tmp.put("targetId",targetId); + tmp.put("targetDesc",matchQuestion); + tmp.put("matchScore",matchScore); + AskTemplateQuestionLibrary library = libraryMap.get(targetId); + tmp.put("targetDicPath",dicMap.get(library.getDictId()).getNameZhPath()); + result.add(tmp); + } + + } + String filePath = "F:\\tmp\\1\\问题库问题对比-标准问.xlsx"; + ExcelWriter writer = ExcelUtil.getWriter(filePath, "标准问匹配"); + writer.setDefaultRowHeight(18); + writer.addHeaderAlias("sourceId", "源问题id"); + writer.addHeaderAlias("sourceDesc", "源问题"); + writer.addHeaderAlias("sourceDicPath", "源分类"); + writer.addHeaderAlias("targetId", "目标问题"); + writer.addHeaderAlias("targetDesc", "目标问题"); + writer.addHeaderAlias("targetDicPath", "目标分类"); + writer.addHeaderAlias("matchScore", "相似度"); + writer.write(result,true); + writer.close(); + } + + + @Autowired + private AskPatientAnswerService askPatientAnswerService; + + @Test + public void medicalQA(){ + String medicalId = "1"; + + List answerList = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, medicalId).list(); + Map answerMap = answerList.stream().collect(Collectors.toMap(AskPatientAnswer::getLibraryQuestionId, v -> v)); + + List questionLibraries = askTemplateQuestionLibraryService.list(); + + List> result = new ArrayList<>(); + for (AskTemplateQuestionLibrary questionLibrary : questionLibraries) { + if (!questionLibrary.getDescription().contains(questionLibrary.getDescription())){ + questionLibrary.getQuestion().add(questionLibrary.getDescription()); + } + + AskPatientAnswer askPatientAnswer = answerMap.get(questionLibrary.getId()); + if (ObjectUtil.isEmpty(askPatientAnswer)){ + log.info("问题:id:{} desc:{},未设置回复答案,跳过",questionLibrary.getId(),questionLibrary.getDescription()); + continue; + } + for (String question : questionLibrary.getQuestion()) { + HashMap map = new HashMap<>(); + map.put("Q",question); + map.put("A",askPatientAnswer.getAnswer()); + result.add(map); + } + } + + String filePath = "F:\\tmp\\1\\病历问答.xlsx"; + ExcelWriter writer = ExcelUtil.getWriter(filePath); + writer.addHeaderAlias("Q", "Q"); + writer.addHeaderAlias("A", "A"); + writer.write(result,false); + writer.close(); + } + }