|
|
package com.supervision;
|
|
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.map.MapUtil;
|
|
|
import cn.hutool.core.util.ObjectUtil;
|
|
|
import cn.hutool.core.util.ReUtil;
|
|
|
import cn.hutool.http.HttpRequest;
|
|
|
import cn.hutool.http.HttpResponse;
|
|
|
import cn.hutool.http.HttpUtil;
|
|
|
import cn.hutool.json.JSONObject;
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
import cn.hutool.poi.excel.ExcelReader;
|
|
|
import cn.hutool.poi.excel.ExcelUtil;
|
|
|
import com.baomidou.mybatisplus.core.incrementer.DefaultIdentifierGenerator;
|
|
|
import com.supervision.model.AskPatientAnswer;
|
|
|
import com.supervision.model.AskTemplateQuestionLibrary;
|
|
|
import com.supervision.model.CommonDic;
|
|
|
import com.supervision.service.AskPatientAnswerService;
|
|
|
import com.supervision.service.AskService;
|
|
|
import com.supervision.service.AskTemplateQuestionLibraryService;
|
|
|
import com.supervision.service.CommonDicService;
|
|
|
import lombok.Data;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.junit.Test;
|
|
|
import org.junit.runner.RunWith;
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
import org.springframework.boot.test.context.SpringBootTest;
|
|
|
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
|
|
|
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.function.Function;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
@Slf4j
|
|
|
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
|
|
@RunWith(SpringJUnit4ClassRunner.class)
|
|
|
public class AskTemplateIdTest {
|
|
|
|
|
|
@Autowired
|
|
|
private AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
|
|
|
|
|
|
@Autowired
|
|
|
private AskPatientAnswerService askPatientAnswerService;
|
|
|
|
|
|
@Autowired
|
|
|
private CommonDicService commonDicService;
|
|
|
|
|
|
@Autowired
|
|
|
private AskService askService;
|
|
|
|
|
|
@Test
|
|
|
public void creatAskId() {
|
|
|
Object o = new Object();
|
|
|
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
|
|
|
|
|
|
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
|
|
|
for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
|
|
|
askTemplateQuestionLibraryService.lambdaUpdate()
|
|
|
.set(AskTemplateQuestionLibrary::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
|
|
|
.eq(AskTemplateQuestionLibrary::getId, askTemplateQuestionLibrary.getId()).update();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void creatAnswerId() {
|
|
|
Object o = new Object();
|
|
|
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
|
|
|
|
|
|
List<AskPatientAnswer> list = askPatientAnswerService.list();
|
|
|
for (AskPatientAnswer answer : list) {
|
|
|
askPatientAnswerService.lambdaUpdate()
|
|
|
.set(AskPatientAnswer::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
|
|
|
.eq(AskPatientAnswer::getId, answer.getId()).update();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void insertDict() {
|
|
|
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/虚拟病人/语料库/标准病人语料1226-2.xlsx");
|
|
|
List<List<Object>> read = reader.read(1, 86);
|
|
|
for (List<Object> objects : read) {
|
|
|
String pathOne = String.valueOf(objects.get(0));
|
|
|
String pathTwo = String.valueOf(objects.get(1));
|
|
|
CommonDic dic = commonDicService.lambdaQuery().eq(CommonDic::getNameZhPath, pathOne + "/" + pathTwo).last("limit 1").one();
|
|
|
if (ObjectUtil.isNotEmpty(dic)) {
|
|
|
String code = String.valueOf(objects.get(5));
|
|
|
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getDictId, dic.getId())
|
|
|
.eq(AskTemplateQuestionLibrary::getCode, code).update();
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
@Test
|
|
|
public void generateQuestion() {
|
|
|
// 使用文心一言,这里是付费的,不要随便用哦,注意数量
|
|
|
String post = HttpUtil.get("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=8VdV7Hm4ZmPVzRovCQXaH8nN&client_secret=Bte6UKtrexydMuhWrbqciolzYzS8rEYm");
|
|
|
System.out.println(post);
|
|
|
JSONObject parse = JSONUtil.parseObj(post);
|
|
|
String accessToken = parse.getStr("access_token");
|
|
|
List<CommonDic> aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list();
|
|
|
Map<Long, CommonDic> dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity()));
|
|
|
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
|
|
|
for (AskTemplateQuestionLibrary ask : list) {
|
|
|
try {
|
|
|
CommonDic dic = dictMap.get(ask.getDictId());
|
|
|
if (ObjectUtil.isNotEmpty(dic)) {
|
|
|
String description = ask.getDescription();
|
|
|
Map<String, Object> map = new HashMap<>();
|
|
|
map.put("role", "user");
|
|
|
map.put("content", "假设你是一个精通RASA NLU调优的工程师,我现在有一个意图,有一个问题示例,请你根据这个意图,针对这个问题示例,提出30条与这个问题示例类似的问题,注意,问题不要超出这个意图的范围,回答请使用json array的格式,示例:[\"相似问题1\",\"相似问题2\"]\n" +
|
|
|
"### 下面是意图和问题示例\n" + dic.getNameZhPath() + ":" + ask.getQuestion());
|
|
|
HashMap<String, Object> param = new HashMap<>();
|
|
|
param.put("messages", CollUtil.newArrayList(map));
|
|
|
String askAnswer = HttpUtil.post("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + accessToken, JSONUtil.toJsonStr(param));
|
|
|
JSONObject answerJSON = JSONUtil.parseObj(askAnswer);
|
|
|
String result = answerJSON.getStr("result");
|
|
|
String s = ReUtil.get("\\[(.*?)\\]", result, 0);
|
|
|
List<String> question = JSONUtil.toList(s, String.class);
|
|
|
question.add(0, description);
|
|
|
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getQuestion, JSONUtil.toJsonStr(question)).eq(AskTemplateQuestionLibrary::getId, ask.getId()).update();
|
|
|
}
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
log.error("{}生成错误", ask.getDescription(), e);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void generateByGpt() {
|
|
|
String api_key = "sk-FDNQ1bhd7007e62e714eT3BLbKFJ004fcC3ebDeA4542a516";
|
|
|
String url = "https://aigptx.top/v1/chat/completions";
|
|
|
|
|
|
String question = "假设你是一个精通RASA NLU调优的工程师且具备丰富医疗经验;" +
|
|
|
"我现在有一个意图,请你根据这个意图,针对这个问题示例,提出10条医生在问诊时,可能根据这个意图来提问患者的问题.\n" +
|
|
|
"注意,问题不要超出这个意图的范围,始终契合意图的关键词\n" +
|
|
|
"回答请使用json array的格式,示例:[\"相似问题1\",\"相似问题2\"]\n" +
|
|
|
"### 下面是问题示例\n" +
|
|
|
"这种感觉持续多久了?";
|
|
|
|
|
|
GptParam gptParam = new GptParam();
|
|
|
GptMessage gptMessage = new GptMessage();
|
|
|
gptMessage.setContent(question);
|
|
|
gptParam.setMessages(CollUtil.newArrayList(gptMessage));
|
|
|
|
|
|
HttpResponse response = HttpRequest.post(url)
|
|
|
.header("Authorization", "Bearer " + api_key)
|
|
|
.body(JSONUtil.toJsonStr(gptParam))
|
|
|
.execute();
|
|
|
String body = response.body();
|
|
|
System.out.println(body);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
public void testRasa() {
|
|
|
List<CommonDic> aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list();
|
|
|
Map<Long, CommonDic> dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity()));
|
|
|
List<AskPatientAnswer> list = askPatientAnswerService.lambdaQuery().isNotNull(AskPatientAnswer::getQuestion).eq(AskPatientAnswer::getAnswerType, 1).list();
|
|
|
List<AskTemplateQuestionLibrary> libraryList = askTemplateQuestionLibraryService.list();
|
|
|
Map<String, AskTemplateQuestionLibrary> libraryMap = libraryList.stream().collect(Collectors.toMap(AskTemplateQuestionLibrary::getId, Function.identity()));
|
|
|
for (AskPatientAnswer answer : list) {
|
|
|
Map<Object, Object> build = MapUtil.builder().put("text", answer.getQuestion()).build();
|
|
|
String post = HttpUtil.post("http://localhost:5005/model/parse", JSONUtil.toJsonStr(build));
|
|
|
RasaResult bean = JSONUtil.toBean(post, RasaResult.class);
|
|
|
ResaIntentResult intent = bean.getIntent();
|
|
|
if (intent.getName().startsWith("Q")) {
|
|
|
String id = intent.getName().split("_")[1];
|
|
|
if (!id.equals(answer.getLibraryQuestionId())) {
|
|
|
log.info("问题:{}匹配不正确,走了其他回答,期望ID为:{},实际ID为:{},实际分类为:{},期望分类为:{}", bean.getText(), answer.getLibraryQuestionId(), id,
|
|
|
dictMap.get(libraryMap.get(id).getDictId()).getNameZhPath(),
|
|
|
dictMap.get(libraryMap.get(answer.getLibraryQuestionId()).getDictId()).getNameZhPath()
|
|
|
|
|
|
);
|
|
|
}
|
|
|
} else {
|
|
|
log.info("问题:{}匹配不正确,走了默认回答", bean.getText());
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
private static class GptParam {
|
|
|
private List<GptMessage> messages;
|
|
|
// # 如果需要切换模型,在这里修改
|
|
|
private String model = "gpt-3.5-turbo";
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
private static class GptMessage {
|
|
|
private String role = "user";
|
|
|
private String content;
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
private static class RasaResult {
|
|
|
private String text;
|
|
|
private ResaIntentResult intent;
|
|
|
private List<ResaIntentResult> intent_ranking;
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
private static class ResaIntentResult {
|
|
|
private String name;
|
|
|
private Double confidence;
|
|
|
}
|
|
|
|
|
|
|
|
|
}
|