提交生成rasa的yml文件的代码

dev_v1.0.1
liu 2 years ago
parent 8edb5cab16
commit d56351d606

@ -1,33 +0,0 @@
version: "3.1"
intents:
- self_introduction
- goodbye
- greet
- ask_bushufu
responses:
utter_self_introduction:
- text: "再见"
utter_goodbye:
- text: "你好医生"
utter_greet:
- text: "你好"
utter_ask_bushufu:
- text: "我最近感觉心跳特别快,喘不上气。"
utter_tool_tool_shizhen:
- text: "---tool---1"
utter_tool_tool_huxi:
- text: "---tool---10"
actions:
- utter_self_introduction
- utter_goodbye
- utter_greet
- utter_ask_bushufu
- utter_tool_tool_shizhen
- utter_tool_tool_huxi
session_config:
session_expiration_time: 60
carry_over_slots_to_new_session: true

@ -1,28 +0,0 @@
version: "3.1"
nlu:
- intent: greet
examples: |
- 你好
- 你好啊
- 你好你好
- intent: goodbye
examples: |
- 再见
- 拜拜
- intent: self_introduction
examples: |
- 我是张医生
- intent: ask_bushufu
examples: |
- 今天您有什么不舒服?哪里不舒服?
- intent: tool_tool_shizhen
examples: |
- 1
- 22
- 333
- intent: tool_tool_huxi
examples: |
- 1
- 22
- 333

@ -1,28 +0,0 @@
version: "3.1"
rules:
- rule: 自我介绍
steps:
- intent: self_introduction
- action: utter_self_introduction
- rule: 再见
steps:
- intent: goodbye
- action: utter_goodbye
- rule: 问候
steps:
- intent: greet
- action: utter_greet
- rule: 问不舒服
steps:
- intent: ask_bushufu
- action: utter_ask_bushufu
- rule: 视诊
steps:
- intent: tool_tool_shizhen
- action: utter_tool_tool_shizhen
- rule: 呼吸
steps:
- intent: tool_tool_huxi
- action: utter_tool_tool_huxi

@ -1,14 +1,15 @@
package com.supervision.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import com.fasterxml.jackson.dataformat.yaml.YAMLGenerator;
import com.supervision.exception.BusinessException;
import com.supervision.model.*;
import com.supervision.pojo.rasa.train.intent.*;
import com.supervision.pojo.rasa.train.intent.DomainYmlTemplate;
import com.supervision.pojo.rasa.train.intent.Intent;
import com.supervision.pojo.rasa.train.intent.NluYmlTemplate;
import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate;
import com.supervision.service.*;
import freemarker.template.Configuration;
import freemarker.template.Template;
@ -17,11 +18,14 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@Slf4j
@Service
@ -44,23 +48,42 @@ public class RasaServiceImpl implements RasaService {
@Override
public void generateRasaYml(String diseaseId) {
Map<String, File> ymalFileMap = new HashMap<>();
Map<String, Intent> defaultIntentCodeAndIdMap = new HashMap<>();
Map<String, Intent> questionCodeAndIdMap = new HashMap<>();
Map<String, ConfigPhysicalTool> toolCodeIdMap = new HashMap<>();
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
// 开始生成各种yaml文件
try {
generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap);
generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ymalFileMap);
} catch (Exception e) {
e.printStackTrace();
}
generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList);
generateRule(ruleList);
generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList, ymalFileMap);
generateRule(ruleList, ymalFileMap);
// 生成压缩文件
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) {
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey()));
IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream);
zipOutputStream.closeEntry();
}
zipOutputStream.finish();
} catch (Exception e) {
log.error("生成ZIP文件失败", e);
throw new BusinessException("生成ZIP文件失败");
}
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
byte[] byteArray = bos.toByteArray();
}
private void generateNlu(String diseaseId, Map<String, Intent> defaultIntentCodeAndIdMap,
Map<String, Intent> intentCodeAndIdMap,
Map<String, ConfigPhysicalTool> toolCodeIdMap) throws IOException, TemplateException {
Map<String, ConfigPhysicalTool> toolCodeIdMap, Map<String, File> ymalFileMap) {
// 首先生成根据意图查找到nlu文件
List<NluYmlTemplate.Nlu> nluList = new ArrayList<>();
// 默认意图
@ -129,14 +152,14 @@ public class RasaServiceImpl implements RasaService {
// 加载模板配置
NluYmlTemplate nluYmlTemplate = new NluYmlTemplate();
nluYmlTemplate.setNlu(nluList);
createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml");
createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap);
}
public void generateDomain(Map<String, Intent> defaultQuestionCodeAndIdMap,
Map<String, Intent> questionCodeAndIdMap, Map<String, ConfigPhysicalTool> toolCodeIdMap,
List<RuleYmlTemplate.Rule> ruleList) {
List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
// 首先根据默认意图找到所有的意图ID
Collection<Intent> defaultIntentIdColl = defaultQuestionCodeAndIdMap.values();
@ -174,7 +197,7 @@ public class RasaServiceImpl implements RasaService {
String intentCode = entry.getKey();
ConfigPhysicalTool tool = entry.getValue();
String utter = "utter_" + intentCode;
String answer = "---tool---" + tool.getId() ;
String answer = "---tool---" + tool.getId();
responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer)));
ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter));
}
@ -192,33 +215,35 @@ public class RasaServiceImpl implements RasaService {
List<String> actionList = new ArrayList<>(responses.keySet());
domainYmlTemplate.setActions(actionList);
// 生成yml文件
createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml");
createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap);
}
/**
* rule
*/
public void generateRule(List<RuleYmlTemplate.Rule> ruleList) {
public void generateRule(List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate();
ruleYmlTemplate.setRules(ruleList);
// 生成yml文件
createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml");
createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap);
}
private void createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName) {
private void createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName, Map<String, File> ymalFileMap) {
try {
// 这个版本和maven依赖的版本一致
Configuration configuration = new Configuration(Configuration.VERSION_2_3_31);
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
// 获取模板
Template template = configuration.getTemplate(ftlName);
File tempFile = FileUtil.createTempFile();
// 创建输出文件
PrintWriter out = new PrintWriter(ymlName);
// 填充并生成输出
template.process(data, out);
// 关闭资源
out.close();
try (PrintWriter out = new PrintWriter(tempFile);) {
// 填充并生成输出
template.process(data, out);
} catch (Exception e) {
log.error("文件生成失败");
}
ymalFileMap.put(ymlName, tempFile);
} catch (Exception e) {
log.error("导出模板失败", e);
}

Loading…
Cancel
Save