diff --git a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java index f98cdb2d..4abd7794 100644 --- a/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java +++ b/virtual-patient-web/src/test/java/com/supervision/AskTemplateIdTest.java @@ -1,7 +1,6 @@ package com.supervision; import cn.hutool.core.collection.CollUtil; -import cn.hutool.core.util.ArrayUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ReUtil; import cn.hutool.http.HttpUtil; @@ -30,6 +29,8 @@ import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; @Slf4j @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @@ -87,6 +88,7 @@ public class AskTemplateIdTest { } } + @Test public void generateQuestion() { // 使用文心一言,这里是付费的,不要随便用哦,注意数量 @@ -94,31 +96,36 @@ public class AskTemplateIdTest { System.out.println(post); JSONObject parse = JSONUtil.parseObj(post); String accessToken = parse.getStr("access_token"); - + List aqtList = commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNotNull(CommonDic::getParentId).ne(CommonDic::getParentId, 179).list(); + Map dictMap = aqtList.stream().collect(Collectors.toMap(CommonDic::getId, Function.identity())); List list = askTemplateQuestionLibraryService.list(); for (AskTemplateQuestionLibrary ask : list) { try { - String description = ask.getDescription(); - Map map = new HashMap<>(); - map.put("role", "user"); - map.put("content", "请把下面这句话以20种不同的方式提问,以JSONARRAY的形式输出:" + description); - HashMap param = new HashMap<>(); - param.put("messages", CollUtil.newArrayList(map)); - String askAnswer = HttpUtil.post("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + accessToken, JSONUtil.toJsonStr(param)); - JSONObject answerJSON = JSONUtil.parseObj(askAnswer); - String result = answerJSON.getStr("result"); - String s = ReUtil.get("\\[(.*?)\\]", result, 0); - List question = JSONUtil.toList(s, String.class); - question.add(0, description); - askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getQuestion, JSONUtil.toJsonStr(question)).eq(AskTemplateQuestionLibrary::getId, ask.getId()).update(); + CommonDic dic = dictMap.get(ask.getDictId()); + if (ObjectUtil.isNotEmpty(dic)) { + String description = ask.getDescription(); + Map map = new HashMap<>(); + map.put("role", "user"); + map.put("content", "假设你是一个精通RASA NLU调优的工程师,我现在有一个意图,有一个问题示例,请你根据这个意图,针对这个问题示例,提出30条与这个问题示例类似的问题,注意,问题不要超出这个意图的范围,回答请使用json array的格式,示例:[\"相似问题1\",\"相似问题2\"]\n" + + "### 下面是意图和问题示例\n" + dic.getNameZhPath() + ":" + ask.getQuestion()); + HashMap param = new HashMap<>(); + param.put("messages", CollUtil.newArrayList(map)); + String askAnswer = HttpUtil.post("https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + accessToken, JSONUtil.toJsonStr(param)); + JSONObject answerJSON = JSONUtil.parseObj(askAnswer); + String result = answerJSON.getStr("result"); + String s = ReUtil.get("\\[(.*?)\\]", result, 0); + List question = JSONUtil.toList(s, String.class); + question.add(0, description); + askTemplateQuestionLibraryService.lambdaUpdate().set(AskTemplateQuestionLibrary::getQuestion, JSONUtil.toJsonStr(question)).eq(AskTemplateQuestionLibrary::getId, ask.getId()).update(); + } + } catch (Exception e) { log.error("{}生成错误", ask.getDescription(), e); } - } - } + @Autowired private AskService askService;