virtual-patient/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java

124 lines
5.5 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSON;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelReader;
import cn.hutool.poi.excel.ExcelUtil;
import com.baomidou.mybatisplus.core.incrementer.DefaultIdentifierGenerator;
import com.supervision.model.AskPatientAnswer;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic;
import com.supervision.service.AskPatientAnswerService;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.CommonDicService;
import com.supervision.util.RedisSequenceUtil;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@RunWith(SpringJUnit4ClassRunner.class)
public class AskTemplateIdTest {
@Autowired
private AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
@Autowired
private AskPatientAnswerService askPatientAnswerService;
@Autowired
private CommonDicService commonDicService;
@Test
public void creatAskId() {
Object o = new Object();
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary askTemplateQuestionLibrary : list) {
askTemplateQuestionLibraryService.lambdaUpdate()
.set(AskTemplateQuestionLibrary::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
.eq(AskTemplateQuestionLibrary::getId, askTemplateQuestionLibrary.getId()).update();
}
}
@Test
public void creatAnswerId() {
Object o = new Object();
DefaultIdentifierGenerator defaultIdentifierGenerator = new DefaultIdentifierGenerator();
List<AskPatientAnswer> list = askPatientAnswerService.list();
for (AskPatientAnswer answer : list) {
askPatientAnswerService.lambdaUpdate()
.set(AskPatientAnswer::getId, String.valueOf(defaultIdentifierGenerator.nextId(o)))
.eq(AskPatientAnswer::getId, answer.getId()).update();
}
}
@Test
public void insertDict() {
ExcelReader reader = ExcelUtil.getReader("/Users/flevance/Desktop/虚拟病人/语料库/标准病人语料1226-2.xlsx");
List<List<Object>> read = reader.read(1, 86);
for (List<Object> objects : read) {
String pathOne = String.valueOf(objects.get(0));
String pathTwo = String.valueOf(objects.get(1));
CommonDic dic = commonDicService.lambdaQuery().eq(CommonDic::getNameZhPath, pathOne + "/" + pathTwo).last("limit 1").one();
if (ObjectUtil.isNotEmpty(dic)) {
String code = String.valueOf(objects.get(5));
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getDictId, dic.getId())
.eq(AskTemplateQuestionLibrary::getCode, code).update();
}
}
}
@Test
public void generateQuestion() {
// 使用文心一言,这里是付费的,不要随便用哦,注意数量
String post = HttpUtil.get("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=8VdV7Hm4ZmPVzRovCQXaH8nN&client_secret=Bte6UKtrexydMuhWrbqciolzYzS8rEYm");
System.out.println(post);
JSONObject parse = JSONUtil.parseObj(post);
String accessToken = parse.getStr("access_token");
List<AskTemplateQuestionLibrary> list = askTemplateQuestionLibraryService.list();
for (AskTemplateQuestionLibrary ask : list) {
try {
String description = ask.getDescription();
Map<String, Object> map = new HashMap<>();
map.put("role", "user");
map.put("content", "请把下面这句话以20种不同的方式提问以JSONARRAY的形式输出:" + description);
HashMap<String, Object> param = new HashMap<>();
param.put("messages", CollUtil.newArrayList(map));
String askAnswer = HttpUtil.post("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + accessToken, JSONUtil.toJsonStr(param));
JSONObject answerJSON = JSONUtil.parseObj(askAnswer);
String result = answerJSON.getStr("result");
String s = ReUtil.get("\\[(.*?)\\]", result, 0);
List<String> question = JSONUtil.toList(s, String.class);
question.add(0, description);
askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getQuestion, JSONUtil.toJsonStr(question)).eq(AskTemplateQuestionLibrary::getId, ask.getId()).update();
} catch (Exception e) {
log.error("{}生成错误", ask.getDescription(), e);
}
}
}
}