You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
virtual-patient/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java

287 lines
15 KiB
Java

package com.supervision.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.*;
import com.supervision.pojo.rasa.train.intent.DomainYmlTemplate;
import com.supervision.pojo.rasa.train.intent.Intent;
import com.supervision.pojo.rasa.train.intent.NluYmlTemplate;
import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate;
import com.supervision.service.*;
import freemarker.template.Configuration;
import freemarker.template.Template;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
@Slf4j
@Service
@RequiredArgsConstructor
public class RasaServiceImpl implements RasaService {
private final AskIntentService askIntentService;
private final AskQuestionService askQuestionService;
private final AskDefaultIntentService askDefaultIntentService;
private final AskDefaultQuestionService askDefaultQuestionService;
private final AskDefaultAnswerService askDefaultAnswerService;
private final AskAnswerService askAnswerService;
private final ConfigPhysicalToolService configPhysicalToolService;
private final ConfigAncillaryItemService configAncillaryItemService;
@Override
public void generateRasaYml(String diseaseId) {
Map<String, File> ymalFileMap = new HashMap<>();
// 默认问答MAP
Map<String, Intent> defaultIntentCodeAndIdMap = new HashMap<>();
// 疾病对应的问答MAP
Map<String, Intent> questionCodeAndIdMap = new HashMap<>();
// 问诊工具MAP
Map<String, ConfigPhysicalTool> toolCodeIdMap = new HashMap<>();
// 辅助检查MAP
Map<String, ConfigAncillaryItem> ancillaryCodeIdMap = new HashMap<>();
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
// 开始生成各种yaml文件
generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ymalFileMap);
generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ruleList, ymalFileMap);
generateRule(ruleList, ymalFileMap);
// 生成压缩文件
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) {
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey()));
IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream);
zipOutputStream.closeEntry();
}
zipOutputStream.finish();
} catch (Exception e) {
log.error("生成ZIP文件失败", e);
throw new BusinessException("生成ZIP文件失败");
}
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
byte[] byteArray = bos.toByteArray();
File file = new File("rasa.zip");
IoUtil.copy(new ByteArrayInputStream(byteArray),FileUtil.getOutputStream(file));
}
private void generateNlu(String diseaseId, Map<String, Intent> defaultIntentCodeAndIdMap,
Map<String, Intent> intentCodeAndIdMap,
Map<String, ConfigPhysicalTool> toolCodeIdMap,
Map<String, ConfigAncillaryItem> itemCodeIdMap,
Map<String, File> ymalFileMap) {
// 首先生成根据意图查找到nlu文件
List<NluYmlTemplate.Nlu> nluList = new ArrayList<>();
// 默认意图
List<AskDefaultIntent> defaultIntentList = askDefaultIntentService.list();
// 根据默认意图找到所有的问题
if (CollUtil.isNotEmpty(defaultIntentList)) {
Set<String> defaultIntentIdSet = defaultIntentList.stream().map(AskDefaultIntent::getId).collect(Collectors.toSet());
// 去默认问题表找问题
List<AskDefaultQuestion> questionList = askDefaultQuestionService.lambdaQuery().in(AskDefaultQuestion::getDefaultIntentId, defaultIntentIdSet).list();
Map<String, List<AskDefaultQuestion>> defaultQuestionByDefaultIntentId = questionList.stream().collect(Collectors.groupingBy(AskDefaultQuestion::getDefaultIntentId));
// 生成nlu的节点
for (AskDefaultIntent askDefaultIntent : defaultIntentList) {
List<AskDefaultQuestion> questions = defaultQuestionByDefaultIntentId.get(askDefaultIntent.getId());
if (CollUtil.isNotEmpty(questions)) {
// 开始生成
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
nlu.setIntent(askDefaultIntent.getCode());
// 注意,这里的格式应该是 "- 你好\n- 你好啊\n- 你好你好" 这种的格式,所以我们要拼接
nlu.setExamples(questions.stream().map(AskDefaultQuestion::getQuestion).collect(Collectors.toList()));
nluList.add(nlu);
// 添加到map中,key为意图编码,value为意图ID
defaultIntentCodeAndIdMap.put(askDefaultIntent.getCode(), new Intent(askDefaultIntent.getId(), askDefaultIntent.getCode(), askDefaultIntent.getDescription()));
}
}
}
// 然后处理该疾病对应的意图
List<AskIntent> askIntentList = askIntentService.lambdaQuery().eq(AskIntent::getDiseaseId, diseaseId).list();
// 根据默认意图找到所有的问题
if (CollUtil.isNotEmpty(askIntentList)) {
Set<String> askIntentListSet = askIntentList.stream().map(AskIntent::getId).collect(Collectors.toSet());
// 去默认问题表找问题
List<AskQuestion> questionList = askQuestionService.lambdaQuery().in(AskQuestion::getIntentId, askIntentListSet).list();
Map<String, List<AskQuestion>> questionByDefaultIntentId = questionList.stream().collect(Collectors.groupingBy(AskQuestion::getIntentId));
// 生成nlu的节点
for (AskIntent askIntent : askIntentList) {
List<AskQuestion> questions = questionByDefaultIntentId.get(askIntent.getId());
if (CollUtil.isNotEmpty(questions)) {
// 开始生成
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
nlu.setIntent(askIntent.getCode());
nlu.setExamples(questions.stream().map(AskQuestion::getQuestion).collect(Collectors.toList()));
nluList.add(nlu);
// 添加到map中,key为意图编码,value为意图ID
intentCodeAndIdMap.put(askIntent.getCode(), new Intent(askIntent.getId(), askIntent.getCode(), askIntent.getDescription()));
}
}
}
// 这里处理呼出的问题(code和问题不能为空)
List<ConfigPhysicalTool> physicalToolList = configPhysicalToolService.lambdaQuery()
.isNotNull(ConfigPhysicalTool::getCode)
.isNotNull(ConfigPhysicalTool::getCallOutQuestion).list();
for (ConfigPhysicalTool tool : physicalToolList) {
// 把呼出的问题全部加进去
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
String toolIntent = "tool_" + tool.getCode();
nlu.setIntent(toolIntent);
nlu.setExamples(tool.getCallOutQuestion());
nluList.add(nlu);
// 生成tool的map,key是code,value是工具对应的ID
toolCodeIdMap.put(toolIntent, tool);
}
// 生成呼出的辅助检查
List<ConfigAncillaryItem> ancillaryItemList = configAncillaryItemService.lambdaQuery()
.isNotNull(ConfigAncillaryItem::getCode)
.isNotNull(ConfigAncillaryItem::getCallOutQuestion).list();
for (ConfigAncillaryItem ancillary : ancillaryItemList) {
// 把呼出的问题全部加进去
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
String itemIntent = "ancillary_" + ancillary.getCode();
nlu.setIntent(itemIntent);
nlu.setExamples(ancillary.getCallOutQuestion());
nluList.add(nlu);
// 生成tool的map,key是item,value是工具对应的ID
itemCodeIdMap.put(itemIntent, ancillary);
}
// 生成后生成yml对象
// createYmlFile(nluYml, "nlu.yml");
// 加载模板配置
NluYmlTemplate nluYmlTemplate = new NluYmlTemplate();
nluYmlTemplate.setNlu(nluList);
createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap);
}
public void generateDomain(Map<String, Intent> defaultQuestionCodeAndIdMap,
Map<String, Intent> questionCodeAndIdMap,
Map<String, ConfigPhysicalTool> toolCodeIdMap,
Map<String, ConfigAncillaryItem> ancillaryCodeIdMap,
List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
// 首先根据默认意图找到所有的意图ID
Collection<Intent> defaultIntentIdColl = defaultQuestionCodeAndIdMap.values();
// 找到默认意图对应的回复
if (CollUtil.isNotEmpty(defaultIntentIdColl)) {
Set<String> defaultIntentIdSet = defaultIntentIdColl.stream().map(Intent::getId).collect(Collectors.toSet());
List<AskDefaultAnswer> defaultAnswerList = askDefaultAnswerService.lambdaQuery().in(AskDefaultAnswer::getDefaultIntentId, defaultIntentIdSet).list();
Map<String, String> answerMap = defaultAnswerList.stream().collect(Collectors.toMap(AskDefaultAnswer::getId, AskDefaultAnswer::getAnswer, (k1, k2) -> k1));
for (Map.Entry<String, Intent> entry : defaultQuestionCodeAndIdMap.entrySet()) {
String defaultIntentCode = entry.getKey();
Intent defaultIntent = entry.getValue();
String answer = answerMap.get(defaultIntent.getId());
String utter = "utter_" + defaultIntentCode;
responses.put(utter, CollUtil.newArrayList(answer));
ruleList.add(new RuleYmlTemplate.Rule(defaultIntent.getDesc(), defaultIntent.getCode(), utter));
}
}
// 然后疾病意图对应的回复
Collection<Intent> intentIdColl = questionCodeAndIdMap.values();
if (CollUtil.isNotEmpty(intentIdColl)) {
Set<String> intentIdSet = intentIdColl.stream().map(Intent::getId).collect(Collectors.toSet());
List<AskAnswer> answerList = askAnswerService.lambdaQuery().in(AskAnswer::getIntentId, intentIdSet).list();
Map<String, String> answerMap = answerList.stream().collect(Collectors.toMap(AskAnswer::getId, AskAnswer::getAnswer, (k1, k2) -> k1));
for (Map.Entry<String, Intent> entry : questionCodeAndIdMap.entrySet()) {
String intentCode = entry.getKey();
Intent intent = entry.getValue();
String answer = answerMap.get(intent.getId());
String utter = "utter_" + intentCode;
responses.put(utter, CollUtil.newArrayList(answer));
ruleList.add(new RuleYmlTemplate.Rule(intent.getDesc(), intent.getCode(), utter));
}
}
// 生成呼出tool对应的回复
for (Map.Entry<String, ConfigPhysicalTool> entry : toolCodeIdMap.entrySet()) {
String intentCode = entry.getKey();
ConfigPhysicalTool tool = entry.getValue();
String utter = "utter_" + intentCode;
String answer = "---tool---" + tool.getId();
responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer)));
ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter));
}
// 然后呼出辅助检查对应的回复
for (Map.Entry<String, ConfigAncillaryItem> entry : ancillaryCodeIdMap.entrySet()) {
String intentCode = entry.getKey();
ConfigAncillaryItem ancillary = entry.getValue();
String utter = "utter_" + intentCode;
String answer = "---ancillary---" + ancillary.getId();
responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer)));
ruleList.add(new RuleYmlTemplate.Rule(ancillary.getItemName(), intentCode, utter));
}
DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();
// 意图
List<String> intentList = new ArrayList<>();
intentList.addAll(defaultQuestionCodeAndIdMap.keySet());
intentList.addAll(questionCodeAndIdMap.keySet());
domainYmlTemplate.setIntents(intentList);
// 回复
domainYmlTemplate.setResponses(responses);
// action
List<String> actionList = new ArrayList<>(responses.keySet());
domainYmlTemplate.setActions(actionList);
// 生成yml文件
createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap);
}
/**
* rule
*/
public void generateRule(List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate();
ruleYmlTemplate.setRules(ruleList);
// 生成yml文件
createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap);
}
private void createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName, Map<String, File> ymalFileMap) {
try {
// 这个版本和maven依赖的版本一致
Configuration configuration = new Configuration(Configuration.VERSION_2_3_31);
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
// 获取模板
Template template = configuration.getTemplate(ftlName);
File tempFile = FileUtil.createTempFile();
// 创建输出文件
try (PrintWriter out = new PrintWriter(tempFile);) {
// 填充并生成输出
template.process(data, out);
} catch (Exception e) {
log.error("文件生成失败");
}
ymalFileMap.put(ymlName, tempFile);
} catch (Exception e) {
log.error("导出模板失败", e);
}
}
}