提交生成rasa的yml文件的代码

dev_v1.0.1
liu 2 years ago
parent d56351d606
commit f491ec674f

@ -26,6 +26,8 @@ public class ConfigAncillaryItem implements Serializable {
@TableId @TableId
private String id; private String id;
private String code;
/** /**
* *
*/ */

@ -6,6 +6,7 @@
<resultMap id="BaseResultMap" type="com.supervision.model.ConfigAncillaryItem"> <resultMap id="BaseResultMap" type="com.supervision.model.ConfigAncillaryItem">
<id property="id" column="id" jdbcType="VARCHAR"/> <id property="id" column="id" jdbcType="VARCHAR"/>
<id property="code" column="code" jdbcType="VARCHAR"/>
<result property="itemClass" column="item_class" jdbcType="VARCHAR"/> <result property="itemClass" column="item_class" jdbcType="VARCHAR"/>
<result property="itemName" column="item_name" jdbcType="VARCHAR"/> <result property="itemName" column="item_name" jdbcType="VARCHAR"/>
<result property="info" column="info" jdbcType="VARCHAR"/> <result property="info" column="info" jdbcType="VARCHAR"/>

@ -13,14 +13,13 @@ import com.supervision.pojo.rasa.train.intent.RuleYmlTemplate;
import com.supervision.service.*; import com.supervision.service.*;
import freemarker.template.Configuration; import freemarker.template.Configuration;
import freemarker.template.Template; import freemarker.template.Template;
import freemarker.template.TemplateException;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -46,23 +45,26 @@ public class RasaServiceImpl implements RasaService {
private final ConfigPhysicalToolService configPhysicalToolService; private final ConfigPhysicalToolService configPhysicalToolService;
private final ConfigAncillaryItemService configAncillaryItemService;
@Override @Override
public void generateRasaYml(String diseaseId) { public void generateRasaYml(String diseaseId) {
Map<String, File> ymalFileMap = new HashMap<>(); Map<String, File> ymalFileMap = new HashMap<>();
// 默认问答MAP
Map<String, Intent> defaultIntentCodeAndIdMap = new HashMap<>(); Map<String, Intent> defaultIntentCodeAndIdMap = new HashMap<>();
// 疾病对应的问答MAP
Map<String, Intent> questionCodeAndIdMap = new HashMap<>(); Map<String, Intent> questionCodeAndIdMap = new HashMap<>();
// 问诊工具MAP
Map<String, ConfigPhysicalTool> toolCodeIdMap = new HashMap<>(); Map<String, ConfigPhysicalTool> toolCodeIdMap = new HashMap<>();
// 辅助检查MAP
Map<String, ConfigAncillaryItem> ancillaryCodeIdMap = new HashMap<>();
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>(); List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
// 开始生成各种yaml文件 // 开始生成各种yaml文件
try { generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ymalFileMap);
generateNlu(diseaseId, defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ymalFileMap); generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ancillaryCodeIdMap, ruleList, ymalFileMap);
} catch (Exception e) {
e.printStackTrace();
}
generateDomain(defaultIntentCodeAndIdMap, questionCodeAndIdMap, toolCodeIdMap, ruleList, ymalFileMap);
generateRule(ruleList, ymalFileMap); generateRule(ruleList, ymalFileMap);
// 生成压缩文件 // 生成压缩文件
ByteArrayOutputStream bos = new ByteArrayOutputStream(); ByteArrayOutputStream bos = new ByteArrayOutputStream();
@ -79,11 +81,15 @@ public class RasaServiceImpl implements RasaService {
} }
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序 // TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
byte[] byteArray = bos.toByteArray(); byte[] byteArray = bos.toByteArray();
File file = new File("rasa.zip");
IoUtil.copy(new ByteArrayInputStream(byteArray),FileUtil.getOutputStream(file));
} }
private void generateNlu(String diseaseId, Map<String, Intent> defaultIntentCodeAndIdMap, private void generateNlu(String diseaseId, Map<String, Intent> defaultIntentCodeAndIdMap,
Map<String, Intent> intentCodeAndIdMap, Map<String, Intent> intentCodeAndIdMap,
Map<String, ConfigPhysicalTool> toolCodeIdMap, Map<String, File> ymalFileMap) { Map<String, ConfigPhysicalTool> toolCodeIdMap,
Map<String, ConfigAncillaryItem> itemCodeIdMap,
Map<String, File> ymalFileMap) {
// 首先生成根据意图查找到nlu文件 // 首先生成根据意图查找到nlu文件
List<NluYmlTemplate.Nlu> nluList = new ArrayList<>(); List<NluYmlTemplate.Nlu> nluList = new ArrayList<>();
// 默认意图 // 默认意图
@ -146,6 +152,22 @@ public class RasaServiceImpl implements RasaService {
// 生成tool的map,key是code,value是工具对应的ID // 生成tool的map,key是code,value是工具对应的ID
toolCodeIdMap.put(toolIntent, tool); toolCodeIdMap.put(toolIntent, tool);
} }
// 生成呼出的辅助检查
List<ConfigAncillaryItem> ancillaryItemList = configAncillaryItemService.lambdaQuery()
.isNotNull(ConfigAncillaryItem::getCode)
.isNotNull(ConfigAncillaryItem::getCallOutQuestion).list();
for (ConfigAncillaryItem ancillary : ancillaryItemList) {
// 把呼出的问题全部加进去
NluYmlTemplate.Nlu nlu = new NluYmlTemplate.Nlu();
String itemIntent = "ancillary_" + ancillary.getCode();
nlu.setIntent(itemIntent);
nlu.setExamples(ancillary.getCallOutQuestion());
nluList.add(nlu);
// 生成tool的map,key是item,value是工具对应的ID
itemCodeIdMap.put(itemIntent, ancillary);
}
// 生成后生成yml对象 // 生成后生成yml对象
// createYmlFile(nluYml, "nlu.yml"); // createYmlFile(nluYml, "nlu.yml");
@ -158,7 +180,9 @@ public class RasaServiceImpl implements RasaService {
public void generateDomain(Map<String, Intent> defaultQuestionCodeAndIdMap, public void generateDomain(Map<String, Intent> defaultQuestionCodeAndIdMap,
Map<String, Intent> questionCodeAndIdMap, Map<String, ConfigPhysicalTool> toolCodeIdMap, Map<String, Intent> questionCodeAndIdMap,
Map<String, ConfigPhysicalTool> toolCodeIdMap,
Map<String, ConfigAncillaryItem> ancillaryCodeIdMap,
List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) { List<RuleYmlTemplate.Rule> ruleList, Map<String, File> ymalFileMap) {
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>(); LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
// 首先根据默认意图找到所有的意图ID // 首先根据默认意图找到所有的意图ID
@ -201,6 +225,15 @@ public class RasaServiceImpl implements RasaService {
responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer))); responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer)));
ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter)); ruleList.add(new RuleYmlTemplate.Rule(tool.getToolName(), intentCode, utter));
} }
// 然后呼出辅助检查对应的回复
for (Map.Entry<String, ConfigAncillaryItem> entry : ancillaryCodeIdMap.entrySet()) {
String intentCode = entry.getKey();
ConfigAncillaryItem ancillary = entry.getValue();
String utter = "utter_" + intentCode;
String answer = "---ancillary---" + ancillary.getId();
responses.put(utter, CollUtil.newArrayList(JSONUtil.toJsonStr(answer)));
ruleList.add(new RuleYmlTemplate.Rule(ancillary.getItemName(), intentCode, utter));
}
DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate(); DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();

Loading…
Cancel
Save