|
|
|
package com.supervision.service.impl;
|
|
|
|
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
|
import cn.hutool.core.io.FileUtil;
|
|
|
|
import cn.hutool.core.io.IoUtil;
|
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
|
import com.supervision.exception.BusinessException;
|
|
|
|
import com.supervision.model.*;
|
|
|
|
import com.supervision.pojo.rasa.train.DomainYmlTemplate;
|
|
|
|
import com.supervision.pojo.rasa.train.NluYmlTemplate;
|
|
|
|
import com.supervision.pojo.rasa.train.QuestionAnswerDTO;
|
|
|
|
import com.supervision.pojo.rasa.train.RuleYmlTemplate;
|
|
|
|
import com.supervision.service.*;
|
|
|
|
import freemarker.template.Configuration;
|
|
|
|
import freemarker.template.Template;
|
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
|
|
import java.io.ByteArrayInputStream;
|
|
|
|
import java.io.ByteArrayOutputStream;
|
|
|
|
import java.io.File;
|
|
|
|
import java.io.PrintWriter;
|
|
|
|
import java.util.*;
|
|
|
|
import java.util.function.Function;
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
import java.util.zip.ZipEntry;
|
|
|
|
import java.util.zip.ZipOutputStream;
|
|
|
|
|
|
|
|
@Slf4j
|
|
|
|
@Service
|
|
|
|
@RequiredArgsConstructor
|
|
|
|
public class RasaServiceImpl implements RasaService {
|
|
|
|
|
|
|
|
private final AskDefaultQuestionAnswerService askDefaultQuestionAnswerService;
|
|
|
|
|
|
|
|
private final AskDiseaseQuestionAnswerService askDiseaseQuestionAnswerService;
|
|
|
|
|
|
|
|
private final AskTemplateQuestionService askTemplateQuestionService;
|
|
|
|
|
|
|
|
private final ConfigPhysicalToolService configPhysicalToolService;
|
|
|
|
|
|
|
|
private final ConfigAncillaryItemService configAncillaryItemService;
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public void generateRasaYml(String diseaseId) {
|
|
|
|
|
|
|
|
Map<String, File> ymalFileMap = new HashMap<>();
|
|
|
|
// 默认问答MAP
|
|
|
|
Map<String, QuestionAnswerDTO> questionCodeAndIdMap = new HashMap<>();
|
|
|
|
|
|
|
|
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
|
|
|
|
|
|
|
|
// 开始生成各种yaml文件
|
|
|
|
generateNlu(diseaseId, questionCodeAndIdMap, ymalFileMap);
|
|
|
|
generateDomain(questionCodeAndIdMap, ruleList, ymalFileMap);
|
|
|
|
generateRule(ruleList, ymalFileMap);
|
|
|
|
// 生成压缩文件
|
|
|
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
|
|
|
try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) {
|
|
|
|
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
|
|
|
|
zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey()));
|
|
|
|
IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream);
|
|
|
|
zipOutputStream.closeEntry();
|
|
|
|
}
|
|
|
|
zipOutputStream.finish();
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.error("生成ZIP文件失败", e);
|
|
|
|
throw new BusinessException("生成ZIP文件失败");
|
|
|
|
}
|
|
|
|
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
|
|
|
|
byte[] byteArray = bos.toByteArray();
|
|
|
|
File file = new File("rasa.zip");
|
|
|
|
IoUtil.copy(new ByteArrayInputStream(byteArray), FileUtil.getOutputStream(file));
|
|
|
|
}
|
|
|
|
|
|
|
|
private void generateNlu(String diseaseId,
|
|
|
|
Map<String, QuestionAnswerDTO> intentCodeAndIdMap,
|
|
|
|
Map<String, File> ymalFileMap) {
|
|
|
|
// 首先生成根据意图查找到nlu文件
|
|
|
|
List<NluYmlTemplate.Nlu> nluList = new ArrayList<>();
|
|
|
|
// 默认意图
|
|
|
|
List<AskDefaultQuestionAnswer> defaultQuestionAnswerList = askDefaultQuestionAnswerService.lambdaQuery().isNotNull(AskDefaultQuestionAnswer::getAnswer).list();
|
|
|
|
// 生成默认意图的nlu
|
|
|
|
for (AskDefaultQuestionAnswer defaultQA : defaultQuestionAnswerList) {
|
|
|
|
if (CollUtil.isNotEmpty(defaultQA.getQuestion()) && CollUtil.isNotEmpty(defaultQA.getAnswer())) {
|
|
|
|
// 开始生成
|
|
|
|
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
|
|
|
|
// 拼接格式:code_id(防止重复)
|
|
|
|
String intentCode = defaultQA.getCode() + "_" + defaultQA.getId();
|
|
|
|
nlu.setIntent(intentCode);
|
|
|
|
nlu.setExamples(defaultQA.getQuestion());
|
|
|
|
nluList.add(nlu);
|
|
|
|
// 添加到map中,key为意图编码,value为意图ID
|
|
|
|
intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(defaultQA.getQuestion(), CollUtil.newArrayList("default_" + defaultQA.getId()), defaultQA.getDescription()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// 然后处理该疾病对应的意图
|
|
|
|
List<AskDiseaseQuestionAnswer> diseaseQuestionAnswerList = askDiseaseQuestionAnswerService.lambdaQuery()
|
|
|
|
.eq(AskDiseaseQuestionAnswer::getDiseaseId, diseaseId).list();
|
|
|
|
// 使用通用模板的
|
|
|
|
Map<String, AskTemplateQuestion> templateQuestionMap = new HashMap<>();
|
|
|
|
// 根据默认意图找到所有的问题
|
|
|
|
if (CollUtil.isNotEmpty(diseaseQuestionAnswerList)) {
|
|
|
|
// 首先找到使用通用模板的问题
|
|
|
|
List<AskDiseaseQuestionAnswer> templateQuestionList = diseaseQuestionAnswerList.stream()
|
|
|
|
.filter(e -> StrUtil.isNotBlank(e.getTemplateQuestionId())).collect(Collectors.toList());
|
|
|
|
if (CollUtil.isNotEmpty(templateQuestionList)) {
|
|
|
|
Set<String> templateQuestionIdList = templateQuestionList.stream().map(AskDiseaseQuestionAnswer::getTemplateQuestionId).collect(Collectors.toSet());
|
|
|
|
List<AskTemplateQuestion> list = askTemplateQuestionService.lambdaQuery().in(AskTemplateQuestion::getId, templateQuestionIdList).list();
|
|
|
|
templateQuestionMap = list.stream().collect(Collectors.toMap(AskTemplateQuestion::getId, Function.identity()));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// 这里开始遍历
|
|
|
|
for (AskDiseaseQuestionAnswer askDiseaseQuestionAnswer : diseaseQuestionAnswerList) {
|
|
|
|
// 如果走模板的问题
|
|
|
|
if (StrUtil.isNotBlank(askDiseaseQuestionAnswer.getTemplateQuestionId()) && templateQuestionMap.containsKey(askDiseaseQuestionAnswer.getTemplateQuestionId())) {
|
|
|
|
AskTemplateQuestion askTemplateQuestion = templateQuestionMap.get(askDiseaseQuestionAnswer.getTemplateQuestionId());
|
|
|
|
// 开始生成
|
|
|
|
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
|
|
|
|
// 拼接格式:code_answerId(防止重复)
|
|
|
|
String intentCode = askTemplateQuestion.getCode() + "_" + askDiseaseQuestionAnswer.getId();
|
|
|
|
nlu.setIntent(intentCode);
|
|
|
|
nlu.setExamples(askTemplateQuestion.getQuestion());
|
|
|
|
nluList.add(nlu);
|
|
|
|
intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(askTemplateQuestion.getQuestion(), CollUtil.newArrayList("disease_" + askDiseaseQuestionAnswer.getId()), askTemplateQuestion.getDescription()));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 这里处理呼出的问题(code和问题不能为空)
|
|
|
|
List<ConfigPhysicalTool> physicalToolList = configPhysicalToolService.lambdaQuery()
|
|
|
|
.isNotNull(ConfigPhysicalTool::getCode)
|
|
|
|
.isNotNull(ConfigPhysicalTool::getCallOutQuestion).list();
|
|
|
|
|
|
|
|
for (ConfigPhysicalTool tool : physicalToolList) {
|
|
|
|
// 把呼出的问题全部加进去
|
|
|
|
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
|
|
|
|
String toolIntent = "tool_" + tool.getCode();
|
|
|
|
nlu.setIntent(toolIntent);
|
|
|
|
nlu.setExamples(tool.getCallOutQuestion());
|
|
|
|
nluList.add(nlu);
|
|
|
|
// answer格式为:---tool---工具ID
|
|
|
|
intentCodeAndIdMap.put(toolIntent,
|
|
|
|
new QuestionAnswerDTO(tool.getCallOutQuestion(),
|
|
|
|
CollUtil.newArrayList("tool_" + tool.getId()), "tool-" + tool.getToolName()));
|
|
|
|
}
|
|
|
|
|
|
|
|
// 生成呼出的辅助检查
|
|
|
|
List<ConfigAncillaryItem> ancillaryItemList = configAncillaryItemService.lambdaQuery()
|
|
|
|
.isNotNull(ConfigAncillaryItem::getCode)
|
|
|
|
.isNotNull(ConfigAncillaryItem::getCallOutQuestion).list();
|
|
|
|
|
|
|
|
for (ConfigAncillaryItem ancillary : ancillaryItemList) {
|
|
|
|
// 把辅助问诊的问题全部加进去
|
|
|
|
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
|
|
|
|
String itemIntent = "ancillary_" + ancillary.getCode();
|
|
|
|
nlu.setIntent(itemIntent);
|
|
|
|
nlu.setExamples(ancillary.getCallOutQuestion());
|
|
|
|
nluList.add(nlu);
|
|
|
|
// answer格式为:---ancillary---工具ID
|
|
|
|
intentCodeAndIdMap.put(itemIntent,
|
|
|
|
new QuestionAnswerDTO(ancillary.getCallOutQuestion(),
|
|
|
|
CollUtil.newArrayList("ancillary_" + ancillary.getId()), "呼出-ancillary-" + ancillary.getItemName()));
|
|
|
|
}
|
|
|
|
NluYmlTemplate nluYmlTemplate = new NluYmlTemplate();
|
|
|
|
nluYmlTemplate.setNlu(nluList);
|
|
|
|
|
|
|
|
// 生成后生成yml文件
|
|
|
|
createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public void generateDomain(Map<String, QuestionAnswerDTO> questionCodeAndIdMap,
|
|
|
|
List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
|
|
|
|
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
|
|
|
|
for (Map.Entry<String, QuestionAnswerDTO> entry : questionCodeAndIdMap.entrySet()) {
|
|
|
|
String intentCode = entry.getKey();
|
|
|
|
QuestionAnswerDTO value = entry.getValue();
|
|
|
|
String utter = "utter_" + intentCode;
|
|
|
|
responses.put(utter, CollUtil.newArrayList(value.getAnswerList()));
|
|
|
|
ruleList.add(new RuleYmlTemplate.Rule(value.getDesc(), intentCode, utter));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();
|
|
|
|
// 意图
|
|
|
|
List<String> intentList = new ArrayList<>(questionCodeAndIdMap.keySet());
|
|
|
|
domainYmlTemplate.setIntents(intentList);
|
|
|
|
// 回复
|
|
|
|
domainYmlTemplate.setResponses(responses);
|
|
|
|
// action
|
|
|
|
List<String> actionList = new ArrayList<>(responses.keySet());
|
|
|
|
domainYmlTemplate.setActions(actionList);
|
|
|
|
// 生成yml文件
|
|
|
|
createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml", ymalFileMap);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 生成rule
|
|
|
|
*/
|
|
|
|
public void generateRule(List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
|
|
|
|
RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate();
|
|
|
|
ruleYmlTemplate.setRules(ruleList);
|
|
|
|
// 生成yml文件
|
|
|
|
createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml", ymalFileMap);
|
|
|
|
}
|
|
|
|
|
|
|
|
private void createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName, Map<String, File> ymalFileMap) {
|
|
|
|
try {
|
|
|
|
// 这个版本和maven依赖的版本一致
|
|
|
|
Configuration configuration = new Configuration(Configuration.VERSION_2_3_31);
|
|
|
|
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
|
|
|
|
// 获取模板
|
|
|
|
Template template = configuration.getTemplate(ftlName);
|
|
|
|
File tempFile = FileUtil.createTempFile();
|
|
|
|
// 创建输出文件
|
|
|
|
try (PrintWriter out = new PrintWriter(tempFile);) {
|
|
|
|
// 填充并生成输出
|
|
|
|
template.process(data, out);
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.error("文件生成失败");
|
|
|
|
}
|
|
|
|
ymalFileMap.put(ymlName, tempFile);
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.error("导出模板失败", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|