package com.supervision.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.IoUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; import com.supervision.exception.BusinessException; import com.supervision.model.*; import com.supervision.pojo.rasa.train.DomainYmlTemplate; import com.supervision.pojo.rasa.train.NluYmlTemplate; import com.supervision.pojo.rasa.train.QuestionAnswerDTO; import com.supervision.pojo.rasa.train.RuleYmlTemplate; import com.supervision.service.*; import freemarker.template.Configuration; import freemarker.template.Template; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.io.PrintWriter; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; @Slf4j @Service @RequiredArgsConstructor public class RasaServiceImpl implements RasaService { private final AskDefaultQuestionAnswerService askDefaultQuestionAnswerService; private final AskDiseaseQuestionAnswerService askDiseaseQuestionAnswerService; private final AskTemplateQuestionService askTemplateQuestionService; private final ConfigPhysicalToolService configPhysicalToolService; private final ConfigAncillaryItemService configAncillaryItemService; @Override public void generateRasaYml(String diseaseId) { Map ymalFileMap = new HashMap<>(); // 默认问答MAP Map questionCodeAndIdMap = new HashMap<>(); List ruleList = new ArrayList<>(); // 开始生成各种yaml文件 generateNlu(diseaseId, questionCodeAndIdMap, ymalFileMap); generateDomain(questionCodeAndIdMap, ruleList, ymalFileMap); generateRule(ruleList, ymalFileMap); // 生成压缩文件 ByteArrayOutputStream bos = new ByteArrayOutputStream(); try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) { for (Map.Entry fileEntry : ymalFileMap.entrySet()) { zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey())); IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream); zipOutputStream.closeEntry(); } zipOutputStream.finish(); } catch (Exception e) { log.error("生成ZIP文件失败", e); throw new BusinessException("生成ZIP文件失败"); } // TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序 byte[] byteArray = bos.toByteArray(); File file = new File("rasa.zip"); IoUtil.copy(new ByteArrayInputStream(byteArray), FileUtil.getOutputStream(file)); } private void generateNlu(String diseaseId, Map intentCodeAndIdMap, Map ymalFileMap) { // 首先生成根据意图查找到nlu文件 List nluList = new ArrayList<>(); // 默认意图 List defaultQuestionAnswerList = askDefaultQuestionAnswerService.lambdaQuery().isNotNull(AskDefaultQuestionAnswer::getAnswer).list(); // 生成默认意图的nlu for (AskDefaultQuestionAnswer defaultQA : defaultQuestionAnswerList) { if (CollUtil.isNotEmpty(defaultQA.getQuestion()) && CollUtil.isNotEmpty(defaultQA.getAnswer())) { // 开始生成 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); // 拼接格式:code_id(防止重复) String intentCode = defaultQA.getCode() + "_" + defaultQA.getId(); nlu.setIntent(intentCode); nlu.setExamples(defaultQA.getQuestion()); nluList.add(nlu); // 添加到map中,key为意图编码,value为意图ID intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(defaultQA.getQuestion(), defaultQA.getAnswer(), defaultQA.getDescription())); } } // 然后处理该疾病对应的意图 List diseaseQuestionAnswerList = askDiseaseQuestionAnswerService.lambdaQuery() .eq(AskDiseaseQuestionAnswer::getDiseaseId, diseaseId).list(); // 使用通用模板的 Map templateQuestionMap = new HashMap<>(); // 根据默认意图找到所有的问题 if (CollUtil.isNotEmpty(diseaseQuestionAnswerList)) { // 首先找到使用通用模板的问题 List templateQuestionList = diseaseQuestionAnswerList.stream() .filter(e -> StrUtil.isNotBlank(e.getTemplateQuestionId())).collect(Collectors.toList()); if (CollUtil.isNotEmpty(templateQuestionList)) { Set templateQuestionIdList = templateQuestionList.stream().map(AskDiseaseQuestionAnswer::getTemplateQuestionId).collect(Collectors.toSet()); List list = askTemplateQuestionService.lambdaQuery().in(AskTemplateQuestion::getId, templateQuestionIdList).list(); templateQuestionMap = list.stream().collect(Collectors.toMap(AskTemplateQuestion::getId, Function.identity())); } } // 这里开始遍历 for (AskDiseaseQuestionAnswer askDiseaseQuestionAnswer : diseaseQuestionAnswerList) { // 如果走模板的问题 if (StrUtil.isNotBlank(askDiseaseQuestionAnswer.getTemplateQuestionId()) && templateQuestionMap.containsKey(askDiseaseQuestionAnswer.getTemplateQuestionId())) { AskTemplateQuestion askTemplateQuestion = templateQuestionMap.get(askDiseaseQuestionAnswer.getTemplateQuestionId()); // 开始生成 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); // 拼接格式:code_answerId(防止重复) String intentCode = askTemplateQuestion.getCode() + "_" + askDiseaseQuestionAnswer.getId(); nlu.setIntent(intentCode); nlu.setExamples(askTemplateQuestion.getQuestion()); nluList.add(nlu); intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(askTemplateQuestion.getQuestion(), askDiseaseQuestionAnswer.getAnswer(), askTemplateQuestion.getDescription())); } } // 这里处理呼出的问题(code和问题不能为空) List physicalToolList = configPhysicalToolService.lambdaQuery() .isNotNull(ConfigPhysicalTool::getCode) .isNotNull(ConfigPhysicalTool::getCallOutQuestion).list(); for (ConfigPhysicalTool tool : physicalToolList) { // 把呼出的问题全部加进去 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); String toolIntent = "tool_" + tool.getCode(); nlu.setIntent(toolIntent); nlu.setExamples(tool.getCallOutQuestion()); nluList.add(nlu); // answer格式为:---tool---工具ID intentCodeAndIdMap.put(toolIntent, new QuestionAnswerDTO(tool.getCallOutQuestion(), CollUtil.newArrayList("---tool---" + tool.getId()), "呼出-tool-" + tool.getToolName())); } // 生成呼出的辅助检查 List ancillaryItemList = configAncillaryItemService.lambdaQuery() .isNotNull(ConfigAncillaryItem::getCode) .isNotNull(ConfigAncillaryItem::getCallOutQuestion).list(); for (ConfigAncillaryItem ancillary : ancillaryItemList) { // 把辅助问诊的问题全部加进去 NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); String itemIntent = "ancillary_" + ancillary.getCode(); nlu.setIntent(itemIntent); nlu.setExamples(ancillary.getCallOutQuestion()); nluList.add(nlu); // answer格式为:---ancillary---工具ID intentCodeAndIdMap.put(itemIntent, new QuestionAnswerDTO(ancillary.getCallOutQuestion(), CollUtil.newArrayList("---ancillary---" + ancillary.getId()), "呼出-ancillary-" + ancillary.getItemName())); } NluYmlTemplate nluYmlTemplate = new NluYmlTemplate(); nluYmlTemplate.setNlu(nluList); // 生成后生成yml文件 createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap); } public void generateDomain(Map questionCodeAndIdMap, List ruleList, Map ymalFileMap) { LinkedHashMap> responses = new LinkedHashMap<>(); for (Map.Entry entry : questionCodeAndIdMap.entrySet()) { String intentCode = entry.getKey(); QuestionAnswerDTO value = entry.getValue(); String utter = "utter_" + intentCode; responses.put(utter, CollUtil.newArrayList(value.getAnswerList())); ruleList.add(new RuleYmlTemplate.Rule(value.getDesc(), intentCode, utter)); } DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate(); // 意图 List intentList = new ArrayList<>(questionCodeAndIdMap.keySet()); domainYmlTemplate.setIntents(intentList); // 回复 domainYmlTemplate.setResponses(responses); // action List actionList = new ArrayList<>(responses.keySet()); domainYmlTemplate.setActions(actionList); // 生成yml文件 createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap); } /** * 生成rule */ public void generateRule(List ruleList, Map ymalFileMap) { RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate(); ruleYmlTemplate.setRules(ruleList); // 生成yml文件 createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap); } private void createYmlFile(Class clazz, String ftlName, Object data, String ymlName, Map ymalFileMap) { try { // 这个版本和maven依赖的版本一致 Configuration configuration = new Configuration(Configuration.VERSION_2_3_31); configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录 // 获取模板 Template template = configuration.getTemplate(ftlName); File tempFile = FileUtil.createTempFile(); // 创建输出文件 try (PrintWriter out = new PrintWriter(tempFile);) { // 填充并生成输出 template.process(data, out); } catch (Exception e) { log.error("文件生成失败"); } ymalFileMap.put(ymlName, tempFile); } catch (Exception e) { log.error("导出模板失败", e); } } }