manage: 添加rasa部署接口(集成生成yml文件、训练模型、运行模型)

dev_2.1.0
xueqingkun 1 year ago
parent 429087608c
commit 22b1b49dbd

@ -11,7 +11,7 @@ WORKDIR /data/vp
COPY target/virtual-patient-rasa-1.0-SNAPSHOT.jar /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar COPY target/virtual-patient-rasa-1.0-SNAPSHOT.jar /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar
# 复制rasa配置文件到 rasa目录下 # 复制rasa配置文件到 rasa目录下
COPY docs/rasa /rasa COPY docs/rasa /rasa
COPY docs/1 /data/vp/rasa/models/1 COPY docs/1 /data/vp/rasa/models/
# 暴漏服务端口 # 暴漏服务端口
EXPOSE 8890 EXPOSE 8890

@ -3,7 +3,7 @@ FROM rasa_dev:1.0.0
COPY ./bert_chinese /usr/local/text2vec/bert_chinese COPY ./bert_chinese /usr/local/text2vec/bert_chinese
COPY ./app.py /usr/local/text2vec/ COPY ./app.py /usr/local/text2vec/
COPY ./question.json /usr/local/text2vec/ #COPY ./question.json /usr/local/text2vec/
RUN source /root/anaconda3/etc/profile.d/conda.sh && \ RUN source /root/anaconda3/etc/profile.d/conda.sh && \
conda create --name text2vec_env python=3.9 -y && \ conda create --name text2vec_env python=3.9 -y && \

@ -58,6 +58,11 @@
<groupId>cn.hutool</groupId> <groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId> <artifactId>hutool-all</artifactId>
</dependency> </dependency>
<!--用来生成yml文件,jakson的模板库满足不了多行文本|管道符的需求-->
<dependency>
<groupId>org.freemarker</groupId>
<artifactId>freemarker</artifactId>
</dependency>
</dependencies> </dependencies>

@ -2,6 +2,7 @@ package com.supervision.rasa;
import com.supervision.config.WebConfig; import com.supervision.config.WebConfig;
import com.supervision.rasa.service.RasaModelManager; import com.supervision.rasa.service.RasaModelManager;
import com.supervision.rasa.service.Text2vecService;
import org.mybatis.spring.annotation.MapperScan; import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
@ -20,8 +21,17 @@ public class VirtualPatientRasaApplication {
public static void main(String[] args) { public static void main(String[] args) {
ConfigurableApplicationContext context = SpringApplication.run(VirtualPatientRasaApplication.class, args); ConfigurableApplicationContext context = SpringApplication.run(VirtualPatientRasaApplication.class, args);
// 启动rasa服务
RasaModelManager rasaModelManager = context.getBean(RasaModelManager.class); RasaModelManager rasaModelManager = context.getBean(RasaModelManager.class);
try {
rasaModelManager.wakeUpInterruptServerScheduled(); rasaModelManager.wakeUpInterruptServerScheduled();
} catch (Exception e) {
throw new RuntimeException(e);
}
// 初始化文本匹配数据
Text2vecService text2vecService = context.getBean(Text2vecService.class);
text2vecService.initText2vecDataset();
} }
} }

@ -13,7 +13,7 @@ import org.springframework.web.bind.annotation.*;
import java.io.*; import java.io.*;
import java.util.concurrent.*; import java.util.concurrent.*;
@Api(tags = "rasa文件保存") @Api(tags = "rasa管理")
@RestController @RestController
@RequestMapping("rasaCmd") @RequestMapping("rasaCmd")
@RequiredArgsConstructor @RequiredArgsConstructor
@ -25,7 +25,6 @@ public class RasaCmdController {
@PostMapping("/trainExec") @PostMapping("/trainExec")
public String trainExec(@RequestBody RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { public String trainExec(@RequestBody RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
argument.setModelId("1");
return rasaCmdService.trainExec(argument); return rasaCmdService.trainExec(argument);
} }
@ -34,7 +33,6 @@ public class RasaCmdController {
@PostMapping("/runExec") @PostMapping("/runExec")
public String runExec(@RequestBody RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { public String runExec(@RequestBody RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
argument.setModelId("1");
String outString = rasaCmdService.runExec(argument); String outString = rasaCmdService.runExec(argument);
if (StrUtil.isEmptyIfStr(outString) || !outString.contains(RasaConstant.RUN_SUCCESS_MESSAGE)){ if (StrUtil.isEmptyIfStr(outString) || !outString.contains(RasaConstant.RUN_SUCCESS_MESSAGE)){
throw new BusinessException("任务执行异常。详细日志:"+outString); throw new BusinessException("任务执行异常。详细日志:"+outString);
@ -43,5 +41,12 @@ public class RasaCmdController {
} }
@ApiOperation("部署rasa")
@PostMapping("/deploy")
public boolean deployRasa() throws Exception {
return rasaCmdService.deployRasa();
}
} }

@ -0,0 +1,28 @@
package com.supervision.rasa.pojo.dto;
import lombok.Data;
import java.util.LinkedHashMap;
import java.util.List;
@Data
public class DomainYmlTemplate {
private List<String> intents;
private LinkedHashMap<String,List<String>> responses;
private List<String> actions;
private SessionConfig session_config = new SessionConfig();
@Data
public static class SessionConfig{
private final int session_expiration_time = 60;
private final Boolean carry_over_slots_to_new_session = true;
}
}

@ -0,0 +1,28 @@
package com.supervision.rasa.pojo.dto;
import lombok.Data;
import java.util.List;
@Data
public class NluYmlTemplate {
private List<Nlu> nlu;
@Data
public static class Nlu{
private String intent;
private List<String> examples;
public Nlu(String intent, List<String> examples) {
this.intent = intent;
this.examples = examples;
}
public Nlu() {
}
}
}

@ -0,0 +1,20 @@
package com.supervision.rasa.pojo.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class QuestionAnswerDTO {
private List<String> questionList;
private List<String> answerList;
private String desc;
}

@ -0,0 +1,44 @@
package com.supervision.rasa.pojo.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Data
public class RuleYmlTemplate {
private List<Rule> rules;
@Data
public static class Rule {
private String rule;
private List<Step> steps;
public Rule() {
}
public Rule(String rule, String intent, String action) {
this.rule = rule;
steps = new ArrayList<>();
Step step = new Step(intent, action);
steps.add(step);
}
}
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class Step {
private String intent;
private String action;
}
}

@ -11,4 +11,12 @@ public class Text2vecDataVo {
@ApiModelProperty("问题") @ApiModelProperty("问题")
private String question; private String question;
public Text2vecDataVo(String id, String question) {
this.id = id;
this.question = question;
}
public Text2vecDataVo() {
}
} }

@ -1,9 +1,11 @@
package com.supervision.rasa.service; package com.supervision.rasa.service;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo; import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -19,4 +21,10 @@ public interface RasaCmdService {
String getShellPath(String shell); String getShellPath(String shell);
boolean deployRasa() throws Exception;
Map<String, QuestionAnswerDTO> generateRasaYml(String path);
public Map<String, QuestionAnswerDTO> getIntentCodeAndIdMap();
} }

@ -51,38 +51,39 @@ public class RasaModelManager {
// 2. 重新启动中断的服务 // 2. 重新启动中断的服务
for (RasaModelInfo rasaModelInfo : activeRasaList) { for (RasaModelInfo rasaModelInfo : activeRasaList) {
if (!PortUtil.portIsActive(rasaModelInfo.getPort())){ if (PortUtil.portIsActive(rasaModelInfo.getPort())) {
log.info("wakeUpInterruptServer: port:{} is run..", rasaModelInfo.getPort());
continue;
}
try { try {
RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd()); RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd());
rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort())); rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort()));
rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL)); rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL));
String rasaModelPath = rasaRunParam.getRasaModelPath(); String rasaModelPath = rasaRunParam.getRasaModelPath();
if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)){ if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)) {
log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...",rasaModelPath); log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...", rasaModelPath);
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, rasaModelInfo.getModelId())); String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath));
String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz")); String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
Assert.notEmpty(fixedModePath,"wakeUpInterruptService: no rasa model in path {} ",modeParentPath); Assert.notEmpty(fixedModePath, "wakeUpInterruptService: no rasa model in path {} ", modeParentPath);
rasaRunParam.setRasaModelPath(fixedModePath); rasaRunParam.setRasaModelPath(fixedModePath);
} }
log.info("wakeUpInterruptServer : use fixedModePath :{}",rasaRunParam.getRasaModelPath()); log.info("wakeUpInterruptServer : use fixedModePath :{}", rasaRunParam.getRasaModelPath());
List<String> outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(), List<String> outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(),
s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300); s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300);
rasaModelInfo.setRunLog(String.join("\r\n",outMessageList)); rasaModelInfo.setRunLog(String.join("\r\n", outMessageList));
rasaModelInfo.setRunCmd(rasaRunParam.toList()); rasaModelInfo.setRunCmd(rasaRunParam.toList());
rasaModeService.updateById(rasaModelInfo); rasaModeService.updateById(rasaModelInfo);
if (!runIsSuccess(outMessageList)){ if (!runIsSuccess(outMessageList)) {
log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}",rasaModelInfo.getPort(),String.join("\r\n",outMessageList)); log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}", rasaModelInfo.getPort(), String.join("\r\n", outMessageList));
} }
} catch (InterruptedException | ExecutionException | TimeoutException e ) { } catch (InterruptedException | ExecutionException | TimeoutException e) {
log.info("wakeUpInterruptServer: restart server port for {} failed",rasaModelInfo.getPort()); log.info("wakeUpInterruptServer: restart server port for {} failed", rasaModelInfo.getPort());
throw new RuntimeException(e); throw new RuntimeException(e);
} }
log.info("wakeUpInterruptServer: restart server port for {} success ",rasaModelInfo.getPort()); log.info("wakeUpInterruptServer: restart server port for {} success ", rasaModelInfo.getPort());
}else {
log.info("wakeUpInterruptServer: port:{} is run..",rasaModelInfo.getPort());
}
} }
} }

@ -21,4 +21,10 @@ public interface Text2vecService {
* @return * @return
*/ */
List<Text2vecMatchesRes> matches(Text2vecMatchesReq text2vecMatchesReq); List<Text2vecMatchesRes> matches(Text2vecMatchesReq text2vecMatchesReq);
/**
*
*/
void initText2vecDataset();
} }

@ -1,5 +1,6 @@
package com.supervision.rasa.service; package com.supervision.rasa.service;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
@ -7,21 +8,25 @@ import cn.hutool.json.JSON;
import cn.hutool.json.JSONArray; import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject; import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.dto.Text2vecDataVo; import com.supervision.rasa.pojo.dto.Text2vecDataVo;
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j; import lombok.extern.log4j.Log4j;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@Service @Service
@RequiredArgsConstructor
public class Text2vecServiceImpl implements Text2vecService { public class Text2vecServiceImpl implements Text2vecService {
@Value("${text2vec.service.domain}") @Value("${text2vec.service.domain}")
@ -30,6 +35,13 @@ public class Text2vecServiceImpl implements Text2vecService {
private final String UPDATE_DATASET_PATH = "update_dataset"; private final String UPDATE_DATASET_PATH = "update_dataset";
private final String MATCHES_PATH = "matches"; private final String MATCHES_PATH = "matches";
private final String GET_ALL_SIMILARITIES_PATH = "get_all_similarities"; private final String GET_ALL_SIMILARITIES_PATH = "get_all_similarities";
private final RasaCmdService rasaCmdService;
public Text2vecServiceImpl(@Autowired @Lazy RasaCmdService rasaCmdService) {
this.rasaCmdService = rasaCmdService;
}
@Override @Override
public boolean updateDataset(List<Text2vecDataVo> text2vecDataVoList) { public boolean updateDataset(List<Text2vecDataVo> text2vecDataVoList) {
@ -65,4 +77,19 @@ public class Text2vecServiceImpl implements Text2vecService {
return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class); return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class);
} }
public void initText2vecDataset(){
log.info("initText2vecDataset ...");
if (FileUtil.exist("/usr/local/text2vec/question.json")){
log.info("question.json文件已经存在不进行text2vec数据初始化操作....");
return;
}
Map<String, QuestionAnswerDTO> intentCodeAndIdMap = rasaCmdService.getIntentCodeAndIdMap();
// 更新text2vec数据信息
List<Text2vecDataVo> text2vecDataVoList = intentCodeAndIdMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
this.updateDataset(text2vecDataVoList);
}
} }

@ -1,17 +1,30 @@
package com.supervision.rasa.service.impl; package com.supervision.rasa.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.ListUtil; import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.io.FileUtil; import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Pair;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException; import com.supervision.exception.BusinessException;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.ConfigAncillaryItem;
import com.supervision.model.ConfigPhysicalTool;
import com.supervision.model.RasaModelInfo; import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.config.ThreadPoolExecutorConfig; import com.supervision.rasa.config.ThreadPoolExecutorConfig;
import com.supervision.rasa.constant.RasaConstant; import com.supervision.rasa.constant.RasaConstant;
import com.supervision.rasa.pojo.dto.*;
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo; import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
import com.supervision.rasa.service.RasaCmdService; import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.service.Text2vecService;
import com.supervision.rasa.util.PortUtil; import com.supervision.rasa.util.PortUtil;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.ConfigAncillaryItemService;
import com.supervision.service.ConfigPhysicalToolService;
import com.supervision.service.RasaModeService; import com.supervision.service.RasaModeService;
import freemarker.template.Configuration;
import freemarker.template.Template;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
@ -23,6 +36,7 @@ import java.io.*;
import java.util.*; import java.util.*;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
@ -46,26 +60,36 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Value("${rasa.shell-env:/bin/bash}") @Value("${rasa.shell-env:/bin/bash}")
private String shellEnv; private String shellEnv;
@Value("${rasa.data-path:/home/rasa/model_resource/}")
private String rasaFilePath;
private final RasaModeService rasaModeService; private final RasaModeService rasaModeService;
private final ConfigPhysicalToolService configPhysicalToolService;
private final ConfigAncillaryItemService configAncillaryItemService;
private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
private final Text2vecService text2vecService;
private final ConcurrentHashMap<String,String> shellPathCache = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String,String> shellPathCache = new ConcurrentHashMap<>();
@Override @Override
@Transactional @Transactional
public String trainExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { public String trainExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
log.info("trainExec:start train rasa model ....argument:{}", JSONUtil.toJsonStr(argument));
argument.setFixedModelNameIfAbsent(); argument.setFixedModelNameIfAbsent();
// /rasa/v3_jiazhuangxian/domain.yml domain的路径应该是从zip文件中加压出来的文件的路径后面拼上/domain.yml // /rasa/v3_jiazhuangxian/domain.yml domain的路径应该是从zip文件中加压出来的文件的路径后面拼上/domain.yml
String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId(),"domain.yml")); String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,"domain.yml"));
// /rasa/v3_jiazhuangxian/ yml文件的路径应该是从zip文件中加压出来的文件的路径,在配置文件中配置 // /rasa/v3_jiazhuangxian/ yml文件的路径应该是从zip文件中加压出来的文件的路径,在配置文件中配置
String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId())); String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath));
// /rasa/models 生成出来的模型的存放路径,也写在配置文件里面 // /rasa/models 生成出来的模型的存放路径,也写在配置文件里面
String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath,argument.getModelId())); String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath));
List<String> cmds = ListUtil.toList(shellEnv, getShellPath(RasaConstant.TRAIN_SHELL),config,localDataPath,domain,localModelsPath); List<String> cmds = ListUtil.toList(shellEnv, getShellPath(RasaConstant.TRAIN_SHELL),config,localDataPath,domain,localModelsPath);
@ -84,8 +108,10 @@ public class RasaCmdServiceImpl implements RasaCmdService {
cmds.set(1,null); cmds.set(1,null);
rasaModelInfo.setTrainCmd(cmds); rasaModelInfo.setTrainCmd(cmds);
rasaModelInfo.setTrainLog(outMessageString); rasaModelInfo.setTrainLog(outMessageString);
rasaModelInfo.setModelId("1");
rasaModeService.saveOrUpdateByModelId(rasaModelInfo); rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
log.info("trainExec:end train rasa model ....");
return outMessageString; return outMessageString;
} }
@ -94,7 +120,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Override @Override
public String runExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { public String runExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
log.info("runExec:start runExec rasa model ....args:{}",JSONUtil.toJsonStr(argument));
// 1. 查找可用端口 // 1. 查找可用端口
int port = PortUtil.findUnusedPort(5050, 100000,rasaModeService.listActivePort()); int port = PortUtil.findUnusedPort(5050, 100000,rasaModeService.listActivePort());
log.info("runExec findUnusedPort is : {}",port); log.info("runExec findUnusedPort is : {}",port);
@ -104,7 +130,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
// aaa1111.tar.gz这个前面的文件名应该是--fixed-model-name指定的.tar.gz是文件后缀代码拼接 // aaa1111.tar.gz这个前面的文件名应该是--fixed-model-name指定的.tar.gz是文件后缀代码拼接
String fixedModePath; String fixedModePath;
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, argument.getModelId())); String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath));
if (StrUtil.isEmpty(argument.getFixedModelName())){ if (StrUtil.isEmpty(argument.getFixedModelName())){
fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz")); fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
}else { }else {
@ -142,6 +168,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
rasaModelInfo.setRunLog(outMessageString); rasaModelInfo.setRunLog(outMessageString);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo); rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
log.info("runExec:runExec end ....");
return outMessageString; return outMessageString;
} }
@ -176,56 +203,63 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Override
public boolean deployRasa() throws Exception {
private boolean trainIsSuccess(List<String> messageList){ // 1.生成rasa模型语料文件
Map<String, QuestionAnswerDTO> questionAnswerDTOMap = generateRasaYml(String.join(File.separator, rasaFilePath));
return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE); // 2.训练模型
} trainExec(new RasaCmdArgumentVo());
//3.运行模型
RasaCmdArgumentVo rasaCmdArgumentVo = new RasaCmdArgumentVo();
rasaCmdArgumentVo.setModelId("1");
runExec(rasaCmdArgumentVo);
private boolean runIsSuccess(List<String> messageList){ // 更新text2vec数据信息
List<Text2vecDataVo> text2vecDataVoList = questionAnswerDTOMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
text2vecService.updateDataset(text2vecDataVoList);
return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE); return true;
} }
private boolean containKey(List<String> messageList,String keyWord){ @Override
public Map<String, QuestionAnswerDTO> generateRasaYml(String path) {
if (CollectionUtil.isEmpty(messageList)){ log.info("generateRasaYml:start generateRasaYml ....");
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
private String replaceDuplicateSeparator(String path){ // 默认问答MAP
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
if (StrUtil.isEmpty(path)){ Map<String, File> ymalFileMap = new HashMap<>();
return path; Map<String, QuestionAnswerDTO> intentCodeAndIdMap = getIntentCodeAndIdMap();
}
return path.replace(File.separator + File.separator, File.separator); // 开始生成各种yaml文件
} Pair<String, File> nulFilePair = generateNlu(intentCodeAndIdMap);
ymalFileMap.put(nulFilePair.getKey(),nulFilePair.getValue());
Pair<String, File> domainFilePair = generateDomain(intentCodeAndIdMap, ruleList);
ymalFileMap.put(domainFilePair.getKey(),domainFilePair.getValue());
Pair<String, File> ruleFilePair = generateRule(ruleList);
ymalFileMap.put(ruleFilePair.getKey(),ruleFilePair.getValue());
private String listLastFilePath(String path, FileFilter filter){ // 把文件复制到指定位置
File file = listLastFile(path, filter); for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
if (null == file){ try {
return null; FileUtil.copy(fileEntry.getValue(), new File(StrUtil.join(File.separator, path,fileEntry.getKey())), true);
}finally {
FileUtil.del(fileEntry.getValue());
} }
return file.getPath();
} }
private File listLastFile(String path,FileFilter filter){ log.info("generateRasaYml:end generateRasaYml ....");
File file = new File(path); return intentCodeAndIdMap;
File[] files = file.listFiles(filter);
if (null == files){
return null;
} }
return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null);
}
public String getShellPath(String shell){ public String getShellPath(String shell){
@ -266,6 +300,172 @@ public class RasaCmdServiceImpl implements RasaCmdService {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
private Pair<String, File> generateNlu(Map<String, QuestionAnswerDTO> intentCodeAndIdMap) {
// 首先生成根据意图查找到nlu文件
List<NluYmlTemplate.Nlu> nluList = intentCodeAndIdMap.entrySet()
.stream().map(entry ->
new NluYmlTemplate.Nlu(entry.getKey(), entry.getValue().getQuestionList()))
.collect(Collectors.toList());
NluYmlTemplate nluYmlTemplate = new NluYmlTemplate();
nluYmlTemplate.setNlu(nluList);
// 生成后生成yml文件
return createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml");
}
public Map<String, QuestionAnswerDTO> getIntentCodeAndIdMap() {
Map<String, QuestionAnswerDTO> intentCodeAndIdMap = new HashMap<>();
// 默认意图
List<AskTemplateQuestionLibrary> askTemplateQuestionLibraryList = askTemplateQuestionLibraryService.lambdaQuery().list();
// 生成默认意图的nlu
for (AskTemplateQuestionLibrary questionLibrary : askTemplateQuestionLibraryList) {
// 开始生成
// 拼接格式:code_id(防止重复)
String intentCode = questionLibrary.getCode() + "_" + questionLibrary.getId();
intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(questionLibrary.getQuestion(), CollUtil.newArrayList( questionLibrary.getId()), questionLibrary.getDescription()));
}
// 这里处理呼出的问题(code和问题不能为空)
List<ConfigPhysicalTool> physicalToolList = configPhysicalToolService.lambdaQuery()
.isNotNull(ConfigPhysicalTool::getCode)
.isNotNull(ConfigPhysicalTool::getCallOutQuestion).list();
for (ConfigPhysicalTool tool : physicalToolList) {
// 把呼出的问题全部加进去
String toolIntent = "tool_" + tool.getCode();
// answer格式为:---tool---工具ID
intentCodeAndIdMap.put(toolIntent,
new QuestionAnswerDTO(tool.getCallOutQuestion(),
CollUtil.newArrayList("tool_" + tool.getId()), "tool-" + tool.getToolName()));
}
// 生成呼出的辅助检查
List<ConfigAncillaryItem> ancillaryItemList = configAncillaryItemService.lambdaQuery()
.isNotNull(ConfigAncillaryItem::getCode)
.isNotNull(ConfigAncillaryItem::getCallOutQuestion).list();
for (ConfigAncillaryItem ancillary : ancillaryItemList) {
// 把辅助问诊的问题全部加进去
String itemIntent = "ancillary_" + ancillary.getCode();
// answer格式为:---ancillary---工具ID
intentCodeAndIdMap.put(itemIntent,
new QuestionAnswerDTO(ancillary.getCallOutQuestion(),
CollUtil.newArrayList("ancillary_" + ancillary.getId()), "呼出-ancillary-" + ancillary.getItemName()));
}
return intentCodeAndIdMap;
}
public Pair<String, File> generateDomain(Map<String, QuestionAnswerDTO> questionCodeAndIdMap, List<RuleYmlTemplate.Rule> ruleList) {
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
for (Map.Entry<String, QuestionAnswerDTO> entry : questionCodeAndIdMap.entrySet()) {
String intentCode = entry.getKey();
QuestionAnswerDTO value = entry.getValue();
String utter = "utter_" + intentCode;
responses.put(utter, CollUtil.newArrayList(value.getAnswerList()));
ruleList.add(new RuleYmlTemplate.Rule(value.getDesc(), intentCode, utter));
}
DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();
// 意图
List<String> intentList = new ArrayList<>(questionCodeAndIdMap.keySet());
domainYmlTemplate.setIntents(intentList);
// 回复
domainYmlTemplate.setResponses(responses);
// action
List<String> actionList = new ArrayList<>(responses.keySet());
domainYmlTemplate.setActions(actionList);
// 生成yml文件
return createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml");
}
/**
* rule
*/
public Pair<String, File> generateRule(List<RuleYmlTemplate.Rule> ruleList) {
RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate();
ruleYmlTemplate.setRules(ruleList);
// 生成yml文件
return createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml");
}
private Pair<String,File> createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName) {
try {
// 这个版本和maven依赖的版本一致
Configuration configuration = new Configuration(Configuration.VERSION_2_3_31);
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
// 获取模板
Template template = configuration.getTemplate(ftlName);
File tempFile = FileUtil.createTempFile(".yml", true);
// 创建输出文件
try (PrintWriter out = new PrintWriter(tempFile);) {
// 填充并生成输出
template.process(data, out);
} catch (Exception e) {
log.error("文件生成失败");
}
return Pair.of(ymlName,tempFile);
} catch (Exception e) {
log.error("导出模板失败", e);
throw new RuntimeException("文件生成失败", e);
}
}
private boolean trainIsSuccess(List<String> messageList){
return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE);
}
private boolean runIsSuccess(List<String> messageList){
return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE);
}
private boolean containKey(List<String> messageList,String keyWord){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
private String replaceDuplicateSeparator(String path){
if (StrUtil.isEmpty(path)){
return path;
}
return path.replace(File.separator + File.separator, File.separator);
}
private String listLastFilePath(String path, FileFilter filter){
File file = listLastFile(path, filter);
if (null == file){
return null;
}
return file.getPath();
}
private File listLastFile(String path,FileFilter filter){
File file = new File(path);
File[] files = file.listFiles(filter);
if (null == files){
return null;
}
return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null);
}
} }

@ -0,0 +1,28 @@
recipe: default.v1
language: zh
pipeline:
- name: JiebaTokenizer
- name: LanguageModelFeaturizer
model_name: bert
model_weights: bert-base-chinese
- name: RegexFeaturizer
- name: DIETClassifier
epochs: 100
learning_rate: 0.001
tensorboard_log_directory: ./log
- name: ResponseSelector
epochs: 100
learning_rate: 0.001
- name: FallbackClassifier
threshold: 0.87
ambiguity_threshold: 0.1
- name: EntitySynonymMapper
policies:
- name: MemoizationPolicy
- name: TEDPolicy
- name: RulePolicy
core_fallback_threshold: 0.87
core_fallback_action_name: "action_default_fallback"
enable_fallback_prediction: True

@ -0,0 +1,23 @@
version: "3.1"
intents:
<#list intents as intent>
- ${intent}
</#list>
responses:
<#list responses?keys as response>
${response}:
<#list responses[response] as item>
- text: "${item}"
</#list>
</#list>
actions:
<#list actions as action>
- ${action}
</#list>
session_config:
session_expiration_time: 60
carry_over_slots_to_new_session: true

@ -0,0 +1,10 @@
version: "3.1"
nlu:
<#list nlu as item>
- intent: ${item.intent}
examples: |
<#list item.examples as example>
- ${example}
</#list>
</#list>

@ -0,0 +1,12 @@
version: "3.1"
rules:
<#list rules as item>
- rule: ${item.rule}
steps:
<#list item.steps as ss>
- intent: ${ss.intent}
- action: ${ss.action}
</#list>
</#list>

@ -1,13 +1,36 @@
package com.supervision.rasa; package com.supervision.rasa;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.service.Text2vecService;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@SpringBootTest @SpringBootTest
class VirtualPatientRasaApplicationTests { class VirtualPatientRasaApplicationTests {
@Autowired
private RasaCmdService rasaCmdService;
@Autowired
private Text2vecService text2vecService;
@Test @Test
void contextLoads() { void contextLoads() {
/*Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml("F:\\tmp\\rasa");
System.out.println(questionAnswerDTOMap);*/
Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml(String.join(File.separator, "F:\\tmp\\rasa"));
List<Text2vecDataVo> text2vecDataVoList = questionAnswerDTOMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
text2vecService.updateDataset(text2vecDataVoList);
} }
} }

Loading…
Cancel
Save