From d56351d6067ee055073f61a83fb08e19303230a2 Mon Sep 17 00:00:00 2001 From: liu Date: Thu, 26 Oct 2023 10:30:28 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E7=94=9F=E6=88=90rasa?= =?UTF-8?q?=E7=9A=84yml=E6=96=87=E4=BB=B6=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- domain.yml | 33 --------- nlu.yml | 28 ------- rules.yml | 28 ------- .../service/impl/RasaServiceImpl.java | 73 +++++++++++++------ 4 files changed, 49 insertions(+), 113 deletions(-) delete mode 100644 domain.yml delete mode 100644 nlu.yml delete mode 100644 rules.yml diff --git a/domain.yml b/domain.yml deleted file mode 100644 index 7ca48fdc..00000000 --- a/domain.yml +++ /dev/null @@ -1,33 +0,0 @@ -version: "3.1" - -intents: - - self_introduction - - goodbye - - greet - - ask_bushufu - -responses: - utter_self_introduction: - - text: "再见" - utter_goodbye: - - text: "你好医生" - utter_greet: - - text: "你好" - utter_ask_bushufu: - - text: "我最近感觉心跳特别快,喘不上气。" - utter_tool_tool_shizhen: - - text: "---tool---1" - utter_tool_tool_huxi: - - text: "---tool---10" - -actions: - - utter_self_introduction - - utter_goodbye - - utter_greet - - utter_ask_bushufu - - utter_tool_tool_shizhen - - utter_tool_tool_huxi - -session_config: - session_expiration_time: 60 - carry_over_slots_to_new_session: true diff --git a/nlu.yml b/nlu.yml deleted file mode 100644 index fc043472..00000000 --- a/nlu.yml +++ /dev/null @@ -1,28 +0,0 @@ -version: "3.1" - -nlu: - - intent: greet - examples: | - - 你好 - - 你好啊 - - 你好你好 - - intent: goodbye - examples: | - - 再见 - - 拜拜 - - intent: self_introduction - examples: | - - 我是张医生 - - intent: ask_bushufu - examples: | - - 今天您有什么不舒服?哪里不舒服? - - intent: tool_tool_shizhen - examples: | - - 1 - - 22 - - 333 - - intent: tool_tool_huxi - examples: | - - 1 - - 22 - - 333 diff --git a/rules.yml b/rules.yml deleted file mode 100644 index fb074ea2..00000000 --- a/rules.yml +++ /dev/null @@ -1,28 +0,0 @@ -version: "3.1" - -rules: - - - rule: 自我介绍 - steps: - - intent: self_introduction - - action: utter_self_introduction - - rule: 再见 - steps: - - intent: goodbye - - action: utter_goodbye - - rule: 问候 - steps: - - intent: greet - - action: utter_greet - - rule: 问不舒服 - steps: - - intent: ask_bushufu - - action: utter_ask_bushufu - - rule: 视诊 - steps: - - intent: tool_tool_shizhen - - action: utter_tool_tool_shizhen - - rule: 呼吸 - steps: - - intent: tool_tool_huxi - - action: utter_tool_tool_huxi diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java index 2204b9e1..9305a017 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java @@ -1,14 +1,15 @@ package com.supervision.service.impl; import cn.hutool.core.collection.CollUtil; -import cn.hutool.core.map.MapUtil; -import cn.hutool.core.util.StrUtil; +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.io.IoUtil; import cn.hutool.json.JSONUtil; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator; +import com.supervision.exception.BusinessException; import com.supervision.model.*; -import com.supervision.pojo.rasa.train.intent.*; +import com.supervision.pojo.rasa.train.intent.DomainYmlTemplate; +import com.supervision.pojo.rasa.train.intent.Intent; +import com.supervision.pojo.rasa.train.intent.NluYmlTemplate; +import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate; import com.supervision.service.*; import freemarker.template.Configuration; import freemarker.template.Template; @@ -17,11 +18,14 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.PrintWriter; import java.util.*; import java.util.stream.Collectors; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; @Slf4j @Service @@ -44,23 +48,42 @@ public class RasaServiceImpl implements RasaService { @Override public void generateRasaYml(String diseaseId) { + + Map ymalFileMap = new HashMap<>(); + Map defaultIntentCodeAndIdMap = new HashMap<>(); Map questionCodeAndIdMap = new HashMap<>(); Map toolCodeIdMap = new HashMap<>(); + List ruleList = new ArrayList<>(); + // 开始生成各种yaml文件 try { - generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap); + generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ymalFileMap); } catch (Exception e) { e.printStackTrace(); } - generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList); - generateRule(ruleList); - + generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList, ymalFileMap); + generateRule(ruleList, ymalFileMap); + // 生成压缩文件 + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) { + for (Map.Entry fileEntry : ymalFileMap.entrySet()) { + zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey())); + IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream); + zipOutputStream.closeEntry(); + } + zipOutputStream.finish(); + } catch (Exception e) { + log.error("生成ZIP文件失败", e); + throw new BusinessException("生成ZIP文件失败"); + } + // TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序 + byte[] byteArray = bos.toByteArray(); } private void generateNlu(String diseaseId, Map defaultIntentCodeAndIdMap, Map intentCodeAndIdMap, - Map toolCodeIdMap) throws IOException, TemplateException { + Map toolCodeIdMap, Map ymalFileMap) { // 首先生成根据意图查找到nlu文件 List nluList = new ArrayList<>(); // 默认意图 @@ -129,14 +152,14 @@ public class RasaServiceImpl implements RasaService { // 加载模板配置 NluYmlTemplate nluYmlTemplate = new NluYmlTemplate(); nluYmlTemplate.setNlu(nluList); - createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml"); + createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap); } public void generateDomain(Map defaultQuestionCodeAndIdMap, Map questionCodeAndIdMap, Map toolCodeIdMap, - List ruleList) { + List ruleList, Map ymalFileMap) { LinkedHashMap> responses = new LinkedHashMap<>(); // 首先根据默认意图找到所有的意图ID Collection defaultIntentIdColl = defaultQuestionCodeAndIdMap.values(); @@ -174,7 +197,7 @@ public class RasaServiceImpl implements RasaService { String intentCode = entry.getKey(); ConfigPhysicalTool tool = entry.getValue(); String utter = "utter_" + intentCode; - String answer = "---tool---" + tool.getId() ; + String answer = "---tool---" + tool.getId(); responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter)); } @@ -192,33 +215,35 @@ public class RasaServiceImpl implements RasaService { List actionList = new ArrayList<>(responses.keySet()); domainYmlTemplate.setActions(actionList); // 生成yml文件 - createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml"); + createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap); } /** * 生成rule */ - public void generateRule(List ruleList) { + public void generateRule(List ruleList, Map ymalFileMap) { RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate(); ruleYmlTemplate.setRules(ruleList); // 生成yml文件 - createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml"); + createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap); } - private void createYmlFile(Class clazz, String ftlName, Object data, String ymlName) { + private void createYmlFile(Class clazz, String ftlName, Object data, String ymlName, Map ymalFileMap) { try { // 这个版本和maven依赖的版本一致 Configuration configuration = new Configuration(Configuration.VERSION_2_3_31); configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录 // 获取模板 Template template = configuration.getTemplate(ftlName); + File tempFile = FileUtil.createTempFile(); // 创建输出文件 - PrintWriter out = new PrintWriter(ymlName); - // 填充并生成输出 - template.process(data, out); - - // 关闭资源 - out.close(); + try (PrintWriter out = new PrintWriter(tempFile);) { + // 填充并生成输出 + template.process(data, out); + } catch (Exception e) { + log.error("文件生成失败"); + } + ymalFileMap.put(ymlName, tempFile); } catch (Exception e) { log.error("导出模板失败", e); }