diff --git a/virtual-patient-model/src/main/java/com/supervision/model/ConfigAncillaryItem.java b/virtual-patient-model/src/main/java/com/supervision/model/ConfigAncillaryItem.java index bbb36886..9e72f9f7 100644 --- a/virtual-patient-model/src/main/java/com/supervision/model/ConfigAncillaryItem.java +++ b/virtual-patient-model/src/main/java/com/supervision/model/ConfigAncillaryItem.java @@ -26,6 +26,8 @@ public class ConfigAncillaryItem implements Serializable { @TableId private String id; + private String code; + /** * 类别 */ diff --git a/virtual-patient-model/src/main/resources/mapper/ConfigAncillaryItemMapper.xml b/virtual-patient-model/src/main/resources/mapper/ConfigAncillaryItemMapper.xml index d7b56462..ff83bd49 100644 --- a/virtual-patient-model/src/main/resources/mapper/ConfigAncillaryItemMapper.xml +++ b/virtual-patient-model/src/main/resources/mapper/ConfigAncillaryItemMapper.xml @@ -6,6 +6,7 @@ + diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java index 9305a017..3d9cda84 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java @@ -13,14 +13,13 @@ import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate; import com.supervision.service.*; import freemarker.template.Configuration; import freemarker.template.Template; -import freemarker.template.TemplateException; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.IOException; import java.io.PrintWriter; import java.util.*; import java.util.stream.Collectors; @@ -46,23 +45,26 @@ public class RasaServiceImpl implements RasaService { private final ConfigPhysicalToolService configPhysicalToolService; + private final ConfigAncillaryItemService configAncillaryItemService; + @Override public void generateRasaYml(String diseaseId) { Map ymalFileMap = new HashMap<>(); - + // 默认问答MAP Map defaultIntentCodeAndIdMap = new HashMap<>(); + // 疾病对应的问答MAP Map questionCodeAndIdMap = new HashMap<>(); + // 问诊工具MAP Map toolCodeIdMap = new HashMap<>(); + // 辅助检查MAP + Map ancillaryCodeIdMap = new HashMap<>(); List ruleList = new ArrayList<>(); + // 开始生成各种yaml文件 - try { - generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ymalFileMap); - } catch (Exception e) { - e.printStackTrace(); - } - generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList, ymalFileMap); + generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ymalFileMap); + generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ruleList, ymalFileMap); generateRule(ruleList, ymalFileMap); // 生成压缩文件 ByteArrayOutputStream bos = new ByteArrayOutputStream(); @@ -79,11 +81,15 @@ public class RasaServiceImpl implements RasaService { } // TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序 byte[] byteArray = bos.toByteArray(); + File file = new File("rasa.zip"); + IoUtil.copy(new ByteArrayInputStream(byteArray),FileUtil.getOutputStream(file)); } private void generateNlu(String diseaseId, Map defaultIntentCodeAndIdMap, Map intentCodeAndIdMap, - Map toolCodeIdMap, Map ymalFileMap) { + Map toolCodeIdMap, + Map itemCodeIdMap, + Map ymalFileMap) { // 首先生成根据意图查找到nlu文件 List nluList = new ArrayList<>(); // 默认意图 @@ -146,6 +152,22 @@ public class RasaServiceImpl implements RasaService { // 生成tool的map,key是code,value是工具对应的ID toolCodeIdMap.put(toolIntent, tool); } + // 生成呼出的辅助检查 + List ancillaryItemList = configAncillaryItemService.lambdaQuery() + .isNotNull(ConfigAncillaryItem::getCode) + .isNotNull(ConfigAncillaryItem::getCallOutQuestion).list(); + + for (ConfigAncillaryItem ancillary : ancillaryItemList) { + // 把呼出的问题全部加进去 + NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu(); + String itemIntent = "ancillary_" + ancillary.getCode(); + nlu.setIntent(itemIntent); + nlu.setExamples(ancillary.getCallOutQuestion()); + nluList.add(nlu); + // 生成tool的map,key是item,value是工具对应的ID + itemCodeIdMap.put(itemIntent, ancillary); + } + // 生成后生成yml对象 // createYmlFile(nluYml, "nlu.yml"); @@ -158,7 +180,9 @@ public class RasaServiceImpl implements RasaService { public void generateDomain(Map defaultQuestionCodeAndIdMap, - Map questionCodeAndIdMap, Map toolCodeIdMap, + Map questionCodeAndIdMap, + Map toolCodeIdMap, + Map ancillaryCodeIdMap, List ruleList, Map ymalFileMap) { LinkedHashMap> responses = new LinkedHashMap<>(); // 首先根据默认意图找到所有的意图ID @@ -201,6 +225,15 @@ public class RasaServiceImpl implements RasaService { responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter)); } + // 然后呼出辅助检查对应的回复 + for (Map.Entry entry : ancillaryCodeIdMap.entrySet()) { + String intentCode = entry.getKey(); + ConfigAncillaryItem ancillary = entry.getValue(); + String utter = "utter_" + intentCode; + String answer = "---ancillary---" + ancillary.getId(); + responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); + ruleList.add(new RuleYmlTemplate.Rule(ancillary.getItemName(), intentCode, utter)); + } DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();