|
|
|
@ -4,9 +4,18 @@ import cn.hutool.core.collection.CollUtil;
|
|
|
|
|
import cn.hutool.core.io.FileUtil;
|
|
|
|
|
import cn.hutool.core.io.IoUtil;
|
|
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
|
|
import cn.hutool.http.HttpRequest;
|
|
|
|
|
import cn.hutool.http.HttpResponse;
|
|
|
|
|
import cn.hutool.http.HttpUtil;
|
|
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
|
|
|
import com.fasterxml.jackson.core.type.TypeReference;
|
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
|
import com.supervision.domain.GlobalResult;
|
|
|
|
|
import com.supervision.exception.BusinessException;
|
|
|
|
|
import com.supervision.model.*;
|
|
|
|
|
import com.supervision.pojo.paddlespeech.res.PaddleSpeechResDTO;
|
|
|
|
|
import com.supervision.pojo.paddlespeech.res.TtsResultDTO;
|
|
|
|
|
import com.supervision.pojo.rasa.train.DomainYmlTemplate;
|
|
|
|
|
import com.supervision.pojo.rasa.train.NluYmlTemplate;
|
|
|
|
|
import com.supervision.pojo.rasa.train.QuestionAnswerDTO;
|
|
|
|
@ -16,6 +25,7 @@ import freemarker.template.Configuration;
|
|
|
|
|
import freemarker.template.Template;
|
|
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
|
|
|
|
import java.io.ByteArrayInputStream;
|
|
|
|
@ -43,8 +53,19 @@ public class RasaServiceImpl implements RasaService {
|
|
|
|
|
|
|
|
|
|
private final ConfigAncillaryItemService configAncillaryItemService;
|
|
|
|
|
|
|
|
|
|
private static final ObjectMapper objectMapper = new ObjectMapper();
|
|
|
|
|
|
|
|
|
|
@Value("${rasa.base-url}${rasa.saveRasaFile}")
|
|
|
|
|
private String saveRasaFileUrl;
|
|
|
|
|
|
|
|
|
|
@Value("${rasa.base-url}${rasa.train}")
|
|
|
|
|
private String trainRasaUrl;
|
|
|
|
|
|
|
|
|
|
@Value("${rasa.base-url}${rasa.run}")
|
|
|
|
|
private String runRasaUrl;
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void generateRasaYml(String diseaseId) {
|
|
|
|
|
public GlobalResult<String> generateRasaYml(String patientId) {
|
|
|
|
|
|
|
|
|
|
Map<String, File> ymalFileMap = new HashMap<>();
|
|
|
|
|
// 默认问答MAP
|
|
|
|
@ -53,29 +74,42 @@ public class RasaServiceImpl implements RasaService {
|
|
|
|
|
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
|
|
|
|
|
|
|
|
|
|
// 开始生成各种yaml文件
|
|
|
|
|
generateNlu(diseaseId, questionCodeAndIdMap, ymalFileMap);
|
|
|
|
|
generateNlu(patientId, questionCodeAndIdMap, ymalFileMap);
|
|
|
|
|
generateDomain(questionCodeAndIdMap, ruleList, ymalFileMap);
|
|
|
|
|
generateRule(ruleList, ymalFileMap);
|
|
|
|
|
// 生成压缩文件
|
|
|
|
|
List<File> tempFile = new ArrayList<>();
|
|
|
|
|
File tempZipFile = FileUtil.createTempFile(".zip", true);
|
|
|
|
|
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
|
|
|
|
try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) {
|
|
|
|
|
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
|
|
|
|
|
zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey()));
|
|
|
|
|
IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream);
|
|
|
|
|
zipOutputStream.closeEntry();
|
|
|
|
|
tempFile.add(fileEntry.getValue());
|
|
|
|
|
}
|
|
|
|
|
zipOutputStream.finish();
|
|
|
|
|
// 调用接口传文件
|
|
|
|
|
HttpRequest request = HttpRequest.post(saveRasaFileUrl);
|
|
|
|
|
IoUtil.copy(new ByteArrayInputStream(bos.toByteArray()), FileUtil.getOutputStream(tempZipFile));
|
|
|
|
|
request.form("file", tempZipFile);
|
|
|
|
|
request.form("modelId", patientId);
|
|
|
|
|
HttpResponse response = request.execute();
|
|
|
|
|
String responseBody = response.body();
|
|
|
|
|
log.info(responseBody);
|
|
|
|
|
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
|
|
|
|
|
});
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("生成ZIP文件失败", e);
|
|
|
|
|
throw new BusinessException("生成ZIP文件失败");
|
|
|
|
|
} finally {
|
|
|
|
|
// 最后把临时文件删除
|
|
|
|
|
tempFile.forEach(FileUtil::del);
|
|
|
|
|
FileUtil.del(tempZipFile);
|
|
|
|
|
}
|
|
|
|
|
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
|
|
|
|
|
byte[] byteArray = bos.toByteArray();
|
|
|
|
|
File file = new File("rasa.zip");
|
|
|
|
|
IoUtil.copy(new ByteArrayInputStream(byteArray), FileUtil.getOutputStream(file));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private void generateNlu(String diseaseId,
|
|
|
|
|
private void generateNlu(String patientId,
|
|
|
|
|
Map<String, QuestionAnswerDTO> intentCodeAndIdMap,
|
|
|
|
|
Map<String, File> ymalFileMap) {
|
|
|
|
|
// 首先生成根据意图查找到nlu文件
|
|
|
|
@ -98,7 +132,7 @@ public class RasaServiceImpl implements RasaService {
|
|
|
|
|
}
|
|
|
|
|
// 然后处理该疾病对应的意图
|
|
|
|
|
List<AskDiseaseQuestionAnswer> diseaseQuestionAnswerList = askDiseaseQuestionAnswerService.lambdaQuery()
|
|
|
|
|
.eq(AskDiseaseQuestionAnswer::getDiseaseId, diseaseId).list();
|
|
|
|
|
.eq(AskDiseaseQuestionAnswer::getPatientId, patientId).list();
|
|
|
|
|
// 使用通用模板的
|
|
|
|
|
Map<String, AskTemplateQuestion> templateQuestionMap = new HashMap<>();
|
|
|
|
|
// 根据默认意图找到所有的问题
|
|
|
|
@ -215,7 +249,7 @@ public class RasaServiceImpl implements RasaService {
|
|
|
|
|
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
|
|
|
|
|
// 获取模板
|
|
|
|
|
Template template = configuration.getTemplate(ftlName);
|
|
|
|
|
File tempFile = FileUtil.createTempFile();
|
|
|
|
|
File tempFile = FileUtil.createTempFile(".yml", true);
|
|
|
|
|
// 创建输出文件
|
|
|
|
|
try (PrintWriter out = new PrintWriter(tempFile);) {
|
|
|
|
|
// 填充并生成输出
|
|
|
|
@ -230,4 +264,22 @@ public class RasaServiceImpl implements RasaService {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public GlobalResult<String> trainRasa(String patientId) throws JsonProcessingException {
|
|
|
|
|
Map<String, Object> param = new HashMap<>();
|
|
|
|
|
param.put("modelId", patientId);
|
|
|
|
|
String responseBody = HttpUtil.post(trainRasaUrl, param);
|
|
|
|
|
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public GlobalResult<String> runRasa(String patientId) throws JsonProcessingException {
|
|
|
|
|
Map<String, Object> param = new HashMap<>();
|
|
|
|
|
param.put("modelId", patientId);
|
|
|
|
|
String responseBody = HttpUtil.post(runRasaUrl, param);
|
|
|
|
|
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|