You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
virtual-patient/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java

214 lines
10 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelReader;
import cn.hutool.poi.excel.ExcelUtil;
import com.baomidou.mybatisplus.core.incrementer.DefaultIdentifierGenerator;
import com.supervision.model.AskPatientAnswer;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic;
import com.supervision.service.AskPatientAnswerService;
import com.supervision.service.AskService;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.CommonDicService;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@RunWith(SpringJUnit4ClassRunner.class)
public class AskTemplateIdTest {
@Autowired
private AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
@Autowired
private AskPatientAnswerService askPatientAnswerService;
@Autowired
private CommonDicService commonDicService;
@Autowired
private AskService askService;
@Test
public void creatAskId() {
Object o = new Object();
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
askTemplateQuestionLibraryService.lambdaUpdate()
.set(AskTemplateQuestionLibrary::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
.eq(AskTemplateQuestionLibrary::getId, askTemplateQuestionLibrary.getId()).update();
}
}
@Test
public void creatAnswerId() {
Object o = new Object();
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
List<AskPatientAnswer> list = askPatientAnswerService.list();
for (AskPatientAnswer answer : list) {
askPatientAnswerService.lambdaUpdate()
.set(AskPatientAnswer::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
.eq(AskPatientAnswer::getId, answer.getId()).update();
}
}
@Test
public void insertDict() {
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/虚拟病人/语料库/标准病人语料1226-2.xlsx");
List<List<Object>> read = reader.read(1, 86);
for (List<Object> objects : read) {
String pathOne = String.valueOf(objects.get(0));
String pathTwo = String.valueOf(objects.get(1));
CommonDic dic = commonDicService.lambdaQuery().eq(CommonDic::getNameZhPath, pathOne + "/" + pathTwo).last("limit 1").one();
if (ObjectUtil.isNotEmpty(dic)) {
String code = String.valueOf(objects.get(5));
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getDictId, dic.getId())
.eq(AskTemplateQuestionLibrary::getCode, code).update();
}
}
}
@Test
public void generateQuestion() {
// 使用文心一言,这里是付费的,不要随便用哦,注意数量
String post = HttpUtil.get("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=8VdV7Hm4ZmPVzRovCQXaH8nN&client_secret=Bte6UKtrexydMuhWrbqciolzYzS8rEYm");
System.out.println(post);
JSONObject parse = JSONUtil.parseObj(post);
String accessToken = parse.getStr("access_token");
List<CommonDic> aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list();
Map<Long, CommonDic> dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity()));
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary ask : list) {
try {
CommonDic dic = dictMap.get(ask.getDictId());
if (ObjectUtil.isNotEmpty(dic)) {
String description = ask.getDescription();
Map<String, Object> map = new HashMap<>();
map.put("role", "user");
map.put("content", "假设你是一个精通RASA NLU调优的工程师我现在有一个意图有一个问题示例请你根据这个意图针对这个问题示例提出30条与这个问题示例类似的问题注意问题不要超出这个意图的范围回答请使用json array的格式示例[\"相似问题1\",\"相似问题2\"]\n" +
"### 下面是意图和问题示例\n" + dic.getNameZhPath() + ":" + ask.getQuestion());
HashMap<String, Object> param = new HashMap<>();
param.put("messages", CollUtil.newArrayList(map));
String askAnswer = HttpUtil.post("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + accessToken, JSONUtil.toJsonStr(param));
JSONObject answerJSON = JSONUtil.parseObj(askAnswer);
String result = answerJSON.getStr("result");
String s = ReUtil.get("\\[(.*?)\\]", result, 0);
List<String> question = JSONUtil.toList(s, String.class);
question.add(0, description);
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getQuestion, JSONUtil.toJsonStr(question)).eq(AskTemplateQuestionLibrary::getId, ask.getId()).update();
}
} catch (Exception e) {
log.error("{}生成错误", ask.getDescription(), e);
}
}
}
@Test
public void generateByGpt() {
String api_key = "sk-FDNQ1bhd7007e62e714eT3BLbKFJ004fcC3ebDeA4542a516";
String url = "https://aigptx.top/v1/chat/completions";
String question = "假设你是一个精通RASA NLU调优的工程师且具备丰富医疗经验;" +
"我现在有一个意图请你根据这个意图针对这个问题示例提出10条医生在问诊时,可能根据这个意图来提问患者的问题.\n" +
"注意,问题不要超出这个意图的范围,始终契合意图的关键词\n" +
"回答请使用json array的格式示例[\"相似问题1\",\"相似问题2\"]\n" +
"### 下面是问题示例\n" +
"这种感觉持续多久了?";
GptParam gptParam = new GptParam();
GptMessage gptMessage = new GptMessage();
gptMessage.setContent(question);
gptParam.setMessages(CollUtil.newArrayList(gptMessage));
HttpResponse response = HttpRequest.post(url)
.header("Authorization", "Bearer " + api_key)
.body(JSONUtil.toJsonStr(gptParam))
.execute();
String body = response.body();
System.out.println(body);
}
@Test
public void testRasa() {
List<CommonDic> aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list();
Map<Long, CommonDic> dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity()));
List<AskPatientAnswer> list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getAnswerType, 1).list();
List<AskTemplateQuestionLibrary> libraryList = askTemplateQuestionLibraryService.list();
Map<String, AskTemplateQuestionLibrary> libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, Function.identity()));
for (AskPatientAnswer answer : list) {
Map<Object, Object> build = MapUtil.builder().put("text", answer.getQuestion()).build();
String post = HttpUtil.post("http://localhost:5005/model/parse", JSONUtil.toJsonStr(build));
RasaResult bean = JSONUtil.toBean(post, RasaResult.class);
ResaIntentResult intent = bean.getIntent();
if (intent.getName().startsWith("Q")) {
String id = intent.getName().split("_")[1];
if (!id.equals(answer.getLibraryQuestionId())) {
log.info("问题:{}匹配不正确,走了其他回答,期望ID为:{},实际ID为:{},实际分类为:{},期望分类为:{}", bean.getText(), answer.getLibraryQuestionId(), id,
dictMap.get(libraryMap.get(id).getDictId()).getNameZhPath(),
dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath()
);
}
} else {
log.info("问题:{}匹配不正确,走了默认回答", bean.getText());
}
}
}
@Data
private static class GptParam {
private List<GptMessage> messages;
// # 如果需要切换模型,在这里修改
private String model = "gpt-3.5-turbo";
}
@Data
private static class GptMessage {
private String role = "user";
private String content;
}
@Data
private static class RasaResult {
private String text;
private ResaIntentResult intent;
private List<ResaIntentResult> intent_ranking;
}
@Data
private static class ResaIntentResult {
private String name;
private Double confidence;
}
}