package com.supervision.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.IoUtil; import cn.hutool.json.JSONUtil; import com.supervision.exception.BusinessException; import com.supervision.model.*; import com.supervision.pojo.rasa.train.intent.DomainYmlTemplate; import com.supervision.pojo.rasa.train.intent.Intent; import com.supervision.pojo.rasa.train.intent.NluYmlTemplate; import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate; import com.supervision.service.*; import freemarker.template.Configuration; import freemarker.template.Template; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.PrintWriter; import java.util.*; import java.util.stream.Collectors; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; @Slf4j @Service @RequiredArgsConstructor public class RasaServiceImpl implements RasaService { private final AskIntentService askIntentService; private final AskQuestionService askQuestionService; private final AskDefaultIntentService askDefaultIntentService; private final AskDefaultQuestionService askDefaultQuestionService; private final AskDefaultAnswerService askDefaultAnswerService; private final AskAnswerService askAnswerService; private final ConfigPhysicalToolService configPhysicalToolService; private final ConfigAncillaryItemService configAncillaryItemService; @Override public void generateRasaYml(String diseaseId) { Map ymalFileMap = new HashMap<>(); // 默认问答MAP Map defaultIntentCodeAndIdMap = new HashMap<>(); // 疾病对应的问答MAP Map questionCodeAndIdMap = new HashMap<>(); // 问诊工具MAP Map toolCodeIdMap = new HashMap<>(); // 辅助检查MAP Map ancillaryCodeIdMap = new HashMap<>(); List ruleList = new ArrayList<>(); // 开始生成各种yaml文件 generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ymalFileMap); generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ruleList, ymalFileMap); generateRule(ruleList, ymalFileMap); // 生成压缩文件 ByteArrayOutputStream bos = new ByteArrayOutputStream(); try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) { for (Map.Entry fileEntry : ymalFileMap.entrySet()) { zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey())); IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream); zipOutputStream.closeEntry(); } zipOutputStream.finish(); } catch (Exception e) { log.error("生成ZIP文件失败", e); throw new BusinessException("生成ZIP文件失败"); } // TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序 byte[] byteArray = bos.toByteArray(); File file = new File("rasa.zip"); IoUtil.copy(new ByteArrayInputStream(byteArray),FileUtil.getOutputStream(file)); } private void generateNlu(String diseaseId, Map defaultIntentCodeAndIdMap, Map intentCodeAndIdMap, Map toolCodeIdMap, Map itemCodeIdMap, Map ymalFileMap) { // 首先生成根据意图查找到nlu文件 List nluList = new ArrayList<>(); // 默认意图 List defaultIntentList = askDefaultIntentService.list(); // 根据默认意图找到所有的问题 if (CollUtil.isNotEmpty(defaultIntentList)) { Set defaultIntentIdSet = defaultIntentList.stream().map(AskDefaultIntent::getId).collect(Collectors.toSet()); // 去默认问题表找问题 List questionList = askDefaultQuestionService.lambdaQuery().in(AskDefaultQuestion::getDefaultIntentId, defaultIntentIdSet).list(); Map> defaultQuestionByDefaultIntentId = questionList.stream().collect(Collectors.groupingBy(AskDefaultQuestion::getDefaultIntentId)); // 生成nlu的节点 for (AskDefaultIntent askDefaultIntent : defaultIntentList) { List questions = defaultQuestionByDefaultIntentId.get(askDefaultIntent.getId()); if (CollUtil.isNotEmpty(questions)) { // 开始生成 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); nlu.setIntent(askDefaultIntent.getCode()); // 注意,这里的格式应该是 "- 你好\n- 你好啊\n- 你好你好" 这种的格式,所以我们要拼接 nlu.setExamples(questions.stream().map(AskDefaultQuestion::getQuestion).collect(Collectors.toList())); nluList.add(nlu); // 添加到map中,key为意图编码,value为意图ID defaultIntentCodeAndIdMap.put(askDefaultIntent.getCode(), new Intent(askDefaultIntent.getId(), askDefaultIntent.getCode(), askDefaultIntent.getDescription())); } } } // 然后处理该疾病对应的意图 List askIntentList = askIntentService.lambdaQuery().eq(AskIntent::getDiseaseId, diseaseId).list(); // 根据默认意图找到所有的问题 if (CollUtil.isNotEmpty(askIntentList)) { Set askIntentListSet = askIntentList.stream().map(AskIntent::getId).collect(Collectors.toSet()); // 去默认问题表找问题 List questionList = askQuestionService.lambdaQuery().in(AskQuestion::getIntentId, askIntentListSet).list(); Map> questionByDefaultIntentId = questionList.stream().collect(Collectors.groupingBy(AskQuestion::getIntentId)); // 生成nlu的节点 for (AskIntent askIntent : askIntentList) { List questions = questionByDefaultIntentId.get(askIntent.getId()); if (CollUtil.isNotEmpty(questions)) { // 开始生成 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); nlu.setIntent(askIntent.getCode()); nlu.setExamples(questions.stream().map(AskQuestion::getQuestion).collect(Collectors.toList())); nluList.add(nlu); // 添加到map中,key为意图编码,value为意图ID intentCodeAndIdMap.put(askIntent.getCode(), new Intent(askIntent.getId(), askIntent.getCode(), askIntent.getDescription())); } } } // 这里处理呼出的问题(code和问题不能为空) List physicalToolList = configPhysicalToolService.lambdaQuery() .isNotNull(ConfigPhysicalTool::getCode) .isNotNull(ConfigPhysicalTool::getCallOutQuestion).list(); for (ConfigPhysicalTool tool : physicalToolList) { // 把呼出的问题全部加进去 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); String toolIntent = "tool_" + tool.getCode(); nlu.setIntent(toolIntent); nlu.setExamples(tool.getCallOutQuestion()); nluList.add(nlu); // 生成tool的map,key是code,value是工具对应的ID toolCodeIdMap.put(toolIntent, tool); } // 生成呼出的辅助检查 List ancillaryItemList = configAncillaryItemService.lambdaQuery() .isNotNull(ConfigAncillaryItem::getCode) .isNotNull(ConfigAncillaryItem::getCallOutQuestion).list(); for (ConfigAncillaryItem ancillary : ancillaryItemList) { // 把呼出的问题全部加进去 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); String itemIntent = "ancillary_" + ancillary.getCode(); nlu.setIntent(itemIntent); nlu.setExamples(ancillary.getCallOutQuestion()); nluList.add(nlu); // 生成tool的map,key是item,value是工具对应的ID itemCodeIdMap.put(itemIntent, ancillary); } // 生成后生成yml对象 // createYmlFile(nluYml, "nlu.yml"); // 加载模板配置 NluYmlTemplate nluYmlTemplate = new NluYmlTemplate(); nluYmlTemplate.setNlu(nluList); createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap); } public void generateDomain(Map defaultQuestionCodeAndIdMap, Map questionCodeAndIdMap, Map toolCodeIdMap, Map ancillaryCodeIdMap, List ruleList, Map ymalFileMap) { LinkedHashMap> responses = new LinkedHashMap<>(); // 首先根据默认意图找到所有的意图ID Collection defaultIntentIdColl = defaultQuestionCodeAndIdMap.values(); // 找到默认意图对应的回复 if (CollUtil.isNotEmpty(defaultIntentIdColl)) { Set defaultIntentIdSet = defaultIntentIdColl.stream().map(Intent::getId).collect(Collectors.toSet()); List defaultAnswerList = askDefaultAnswerService.lambdaQuery().in(AskDefaultAnswer::getDefaultIntentId, defaultIntentIdSet).list(); Map answerMap = defaultAnswerList.stream().collect(Collectors.toMap(AskDefaultAnswer::getId, AskDefaultAnswer::getAnswer, (k1, k2) -> k1)); for (Map.Entry entry : defaultQuestionCodeAndIdMap.entrySet()) { String defaultIntentCode = entry.getKey(); Intent defaultIntent = entry.getValue(); String answer = answerMap.get(defaultIntent.getId()); String utter = "utter_" + defaultIntentCode; responses.put(utter, CollUtil.newArrayList(answer)); ruleList.add(new RuleYmlTemplate.Rule(defaultIntent.getDesc(), defaultIntent.getCode(), utter)); } } // 然后疾病意图对应的回复 Collection intentIdColl = questionCodeAndIdMap.values(); if (CollUtil.isNotEmpty(intentIdColl)) { Set intentIdSet = intentIdColl.stream().map(Intent::getId).collect(Collectors.toSet()); List answerList = askAnswerService.lambdaQuery().in(AskAnswer::getIntentId, intentIdSet).list(); Map answerMap = answerList.stream().collect(Collectors.toMap(AskAnswer::getId, AskAnswer::getAnswer, (k1, k2) -> k1)); for (Map.Entry entry : questionCodeAndIdMap.entrySet()) { String intentCode = entry.getKey(); Intent intent = entry.getValue(); String answer = answerMap.get(intent.getId()); String utter = "utter_" + intentCode; responses.put(utter, CollUtil.newArrayList(answer)); ruleList.add(new RuleYmlTemplate.Rule(intent.getDesc(), intent.getCode(), utter)); } } // 生成呼出tool对应的回复 for (Map.Entry entry : toolCodeIdMap.entrySet()) { String intentCode = entry.getKey(); ConfigPhysicalTool tool = entry.getValue(); String utter = "utter_" + intentCode; String answer = "---tool---" + tool.getId(); responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter)); } // 然后呼出辅助检查对应的回复 for (Map.Entry entry : ancillaryCodeIdMap.entrySet()) { String intentCode = entry.getKey(); ConfigAncillaryItem ancillary = entry.getValue(); String utter = "utter_" + intentCode; String answer = "---ancillary---" + ancillary.getId(); responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); ruleList.add(new RuleYmlTemplate.Rule(ancillary.getItemName(), intentCode, utter)); } DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate(); // 意图 List intentList = new ArrayList<>(); intentList.addAll(defaultQuestionCodeAndIdMap.keySet()); intentList.addAll(questionCodeAndIdMap.keySet()); domainYmlTemplate.setIntents(intentList); // 回复 domainYmlTemplate.setResponses(responses); // action List actionList = new ArrayList<>(responses.keySet()); domainYmlTemplate.setActions(actionList); // 生成yml文件 createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap); } /** * 生成rule */ public void generateRule(List ruleList, Map ymalFileMap) { RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate(); ruleYmlTemplate.setRules(ruleList); // 生成yml文件 createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap); } private void createYmlFile(Class clazz, String ftlName, Object data, String ymlName, Map ymalFileMap) { try { // 这个版本和maven依赖的版本一致 Configuration configuration = new Configuration(Configuration.VERSION_2_3_31); configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录 // 获取模板 Template template = configuration.getTemplate(ftlName); File tempFile = FileUtil.createTempFile(); // 创建输出文件 try (PrintWriter out = new PrintWriter(tempFile);) { // 填充并生成输出 template.process(data, out); } catch (Exception e) { log.error("文件生成失败"); } ymalFileMap.put(ymlName, tempFile); } catch (Exception e) { log.error("导出模板失败", e); } } }