Merge remote-tracking branch 'origin/dev_2.1.0' into dev_2.1.0

dev_2.1.0
liu 1 year ago
commit 7c4fb9dc63

@ -3,11 +3,13 @@ package com.supervision.manage.controller.config;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.manage.service.AskQuestionLibraryManageService;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic;
import com.supervision.vo.manage.AskQuestionLibraryReqVo;
import com.supervision.vo.manage.AskQuestionLibraryResVo;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import io.swagger.v3.oas.annotations.parameters.RequestBody;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
@ -39,6 +41,31 @@ public class AskQuestionLibraryManageController {
}
@ApiOperation("保存问题库信息")
@PostMapping("/saveQuestionLibrary")
public String saveQuestionLibrary(@RequestBody AskTemplateQuestionLibrary askTemplateQuestionLibrary) {
return askQuestionLibraryManageService.saveQuestionLibrary(askTemplateQuestionLibrary);
}
@ApiOperation("更新问题库信息")
@PostMapping("/updateQuestionLibrary")
public boolean updateQuestionLibrary(@RequestBody AskTemplateQuestionLibrary askTemplateQuestionLibrary) {
return askQuestionLibraryManageService.updateQuestionLibrary(askTemplateQuestionLibrary);
}
@ApiOperation("删除问题库信息")
@PostMapping("/deleteQuestionLibrary")
public boolean deleteQuestionLibrary(@RequestParam("id") String id) {
return askQuestionLibraryManageService.deleteQuestionLibrary(id);
}
@ApiOperation("查询问题类目编码列表")
@GetMapping("/queryItemList")
public List<CommonDic> queryItemList() {

@ -1,6 +1,7 @@
package com.supervision.manage.service;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic;
import com.supervision.vo.manage.AskQuestionLibraryReqVo;
import com.supervision.vo.manage.AskQuestionLibraryResVo;
@ -14,4 +15,9 @@ public interface AskQuestionLibraryManageService {
List<CommonDic> queryItemList();
String saveQuestionLibrary(AskTemplateQuestionLibrary askTemplateQuestionLibrary);
boolean updateQuestionLibrary(AskTemplateQuestionLibrary askTemplateQuestionLibrary);
boolean deleteQuestionLibrary(String id);
}

@ -1,8 +1,9 @@
package com.supervision.manage.service.impl;
import com.baomidou.mybatisplus.extension.conditions.query.LambdaQueryChainWrapper;
import cn.hutool.core.lang.Assert;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.manage.service.AskQuestionLibraryManageService;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.CommonDic;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.CommonDicService;
@ -38,4 +39,33 @@ public class AskQuestionLibraryManageServiceImpl implements AskQuestionLibraryMa
public List<CommonDic> queryItemList() {
return commonDicService.lambdaQuery().eq(CommonDic::getGroupCode, "AQT").isNull(CommonDic::getParentId).list();
}
@Override
public String saveQuestionLibrary(AskTemplateQuestionLibrary askTemplateQuestionLibrary) {
assertSave(askTemplateQuestionLibrary);
askTemplateQuestionLibraryService.save(askTemplateQuestionLibrary);
return askTemplateQuestionLibrary.getId();
}
@Override
public boolean updateQuestionLibrary(AskTemplateQuestionLibrary askTemplateQuestionLibrary) {
Assert.notEmpty(askTemplateQuestionLibrary.getId(),"id不能为空");
assertSave(askTemplateQuestionLibrary);
return askTemplateQuestionLibraryService.updateById(askTemplateQuestionLibrary);
}
@Override
public boolean deleteQuestionLibrary(String id) {
Assert.notEmpty(id,"id不能为空");
return askTemplateQuestionLibraryService.removeById(id);
}
private void assertSave(AskTemplateQuestionLibrary askTemplateQuestionLibrary){
Assert.notEmpty(askTemplateQuestionLibrary.getCode(),"编码不能为空");
Assert.notEmpty(askTemplateQuestionLibrary.getQuestion(),"问题不能为空");
Assert.notEmpty(askTemplateQuestionLibrary.getDefaultAnswer(),"默认回答不能为空");
}
}

@ -12,4 +12,7 @@ public class AskQuestionLibraryReqVo {
@ApiModelProperty("疾病id")
private String diseaseId;
@ApiModelProperty("问题")
private String question;
}

@ -28,4 +28,7 @@ public class AskQuestionLibraryResVo {
@ApiModelProperty("问题类目名")
private String nameZhPath;
@ApiModelProperty("类目名")
private String nameZh;
}

@ -46,14 +46,19 @@
cd.code as code,
atql.dict_id as dictId,
atql.question as question,
cd.name_zh_path as nameZhPath
cd.name_zh_path as nameZhPath,
cd.name_zh as nameZh
from vp_ask_template_question_library atql
left join vp_common_dic cd on atql.dict_id = cd.id
<where>
<if test="askQuestionLibrary.code != null and askQuestionLibrary.code != '' ">
AND cd.code = #{askQuestionLibrary.code}
cd.code = #{askQuestionLibrary.code}
</if>
<if test="askQuestionLibrary.question != null and askQuestionLibrary.question != '' ">
JSON_EXTRACT(question, '$[*]') like CONCAT('%', #{askQuestionLibrary.question}, '%')
</if>
</where>
order by atql.create_time desc
</sql>
<select id="queryList" resultMap="askQuestionLibraryResultMap" parameterType="com.supervision.vo.manage.AskQuestionLibraryReqVo">

@ -11,7 +11,7 @@ WORKDIR /data/vp
COPY target/virtual-patient-rasa-1.0-SNAPSHOT.jar /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar
# 复制rasa配置文件到 rasa目录下
COPY docs/rasa /rasa
COPY docs/1 /data/vp/rasa/models/1
COPY docs/1 /data/vp/rasa/models/
# 暴漏服务端口
EXPOSE 8890

@ -3,7 +3,7 @@ FROM rasa_dev:1.0.0
COPY ./bert_chinese /usr/local/text2vec/bert_chinese
COPY ./app.py /usr/local/text2vec/
COPY ./question.json /usr/local/text2vec/
#COPY ./question.json /usr/local/text2vec/
RUN source /root/anaconda3/etc/profile.d/conda.sh && \
conda create --name text2vec_env python=3.9 -y && \

@ -58,6 +58,11 @@
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
</dependency>
<!--用来生成yml文件,jakson的模板库满足不了多行文本|管道符的需求-->
<dependency>
<groupId>org.freemarker</groupId>
<artifactId>freemarker</artifactId>
</dependency>
</dependencies>

@ -2,6 +2,7 @@ package com.supervision.rasa;
import com.supervision.config.WebConfig;
import com.supervision.rasa.service.RasaModelManager;
import com.supervision.rasa.service.Text2vecService;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@ -20,8 +21,17 @@ public class VirtualPatientRasaApplication {
public static void main(String[] args) {
ConfigurableApplicationContext context = SpringApplication.run(VirtualPatientRasaApplication.class, args);
// 启动rasa服务
RasaModelManager rasaModelManager = context.getBean(RasaModelManager.class);
rasaModelManager.wakeUpInterruptServerScheduled();
try {
rasaModelManager.wakeUpInterruptServerScheduled();
} catch (Exception e) {
throw new RuntimeException(e);
}
// 初始化文本匹配数据
Text2vecService text2vecService = context.getBean(Text2vecService.class);
text2vecService.initText2vecDataset();
}
}

@ -13,7 +13,7 @@ import org.springframework.web.bind.annotation.*;
import java.io.*;
import java.util.concurrent.*;
@Api(tags = "rasa文件保存")
@Api(tags = "rasa管理")
@RestController
@RequestMapping("rasaCmd")
@RequiredArgsConstructor
@ -25,7 +25,6 @@ public class RasaCmdController {
@PostMapping("/trainExec")
public String trainExec(@RequestBody RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
argument.setModelId("1");
return rasaCmdService.trainExec(argument);
}
@ -34,7 +33,6 @@ public class RasaCmdController {
@PostMapping("/runExec")
public String runExec(@RequestBody RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
argument.setModelId("1");
String outString = rasaCmdService.runExec(argument);
if (StrUtil.isEmptyIfStr(outString) || !outString.contains(RasaConstant.RUN_SUCCESS_MESSAGE)){
throw new BusinessException("任务执行异常。详细日志:"+outString);
@ -43,5 +41,12 @@ public class RasaCmdController {
}
@ApiOperation("部署rasa")
@PostMapping("/deploy")
public boolean deployRasa() throws Exception {
return rasaCmdService.deployRasa();
}
}

@ -0,0 +1,28 @@
package com.supervision.rasa.pojo.dto;
import lombok.Data;
import java.util.LinkedHashMap;
import java.util.List;
@Data
public class DomainYmlTemplate {
private List<String> intents;
private LinkedHashMap<String,List<String>> responses;
private List<String> actions;
private SessionConfig session_config = new SessionConfig();
@Data
public static class SessionConfig{
private final int session_expiration_time = 60;
private final Boolean carry_over_slots_to_new_session = true;
}
}

@ -0,0 +1,28 @@
package com.supervision.rasa.pojo.dto;
import lombok.Data;
import java.util.List;
@Data
public class NluYmlTemplate {
private List<Nlu> nlu;
@Data
public static class Nlu{
private String intent;
private List<String> examples;
public Nlu(String intent, List<String> examples) {
this.intent = intent;
this.examples = examples;
}
public Nlu() {
}
}
}

@ -0,0 +1,20 @@
package com.supervision.rasa.pojo.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class QuestionAnswerDTO {
private List<String> questionList;
private List<String> answerList;
private String desc;
}

@ -0,0 +1,44 @@
package com.supervision.rasa.pojo.dto;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Data
public class RuleYmlTemplate {
private List<Rule> rules;
@Data
public static class Rule {
private String rule;
private List<Step> steps;
public Rule() {
}
public Rule(String rule, String intent, String action) {
this.rule = rule;
steps = new ArrayList<>();
Step step = new Step(intent, action);
steps.add(step);
}
}
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class Step {
private String intent;
private String action;
}
}

@ -11,4 +11,12 @@ public class Text2vecDataVo {
@ApiModelProperty("问题")
private String question;
public Text2vecDataVo(String id, String question) {
this.id = id;
this.question = question;
}
public Text2vecDataVo() {
}
}

@ -1,9 +1,11 @@
package com.supervision.rasa.service;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
import java.util.function.Predicate;
@ -19,4 +21,10 @@ public interface RasaCmdService {
String getShellPath(String shell);
boolean deployRasa() throws Exception;
Map<String, QuestionAnswerDTO> generateRasaYml(String path);
public Map<String, QuestionAnswerDTO> getIntentCodeAndIdMap();
}

@ -51,39 +51,40 @@ public class RasaModelManager {
// 2. 重新启动中断的服务
for (RasaModelInfo rasaModelInfo : activeRasaList) {
if (!PortUtil.portIsActive(rasaModelInfo.getPort())){
try {
RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd());
rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort()));
rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL));
String rasaModelPath = rasaRunParam.getRasaModelPath();
if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)){
log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...",rasaModelPath);
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, rasaModelInfo.getModelId()));
String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
Assert.notEmpty(fixedModePath,"wakeUpInterruptService: no rasa model in path {} ",modeParentPath);
rasaRunParam.setRasaModelPath(fixedModePath);
}
log.info("wakeUpInterruptServer : use fixedModePath :{}",rasaRunParam.getRasaModelPath());
List<String> outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(),
s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300);
rasaModelInfo.setRunLog(String.join("\r\n",outMessageList));
rasaModelInfo.setRunCmd(rasaRunParam.toList());
rasaModeService.updateById(rasaModelInfo);
if (!runIsSuccess(outMessageList)){
log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}",rasaModelInfo.getPort(),String.join("\r\n",outMessageList));
}
} catch (InterruptedException | ExecutionException | TimeoutException e ) {
log.info("wakeUpInterruptServer: restart server port for {} failed",rasaModelInfo.getPort());
throw new RuntimeException(e);
}
log.info("wakeUpInterruptServer: restart server port for {} success ",rasaModelInfo.getPort());
}else {
log.info("wakeUpInterruptServer: port:{} is run..",rasaModelInfo.getPort());
}
}
if (PortUtil.portIsActive(rasaModelInfo.getPort())) {
log.info("wakeUpInterruptServer: port:{} is run..", rasaModelInfo.getPort());
continue;
}
try {
RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd());
rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort()));
rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL));
String rasaModelPath = rasaRunParam.getRasaModelPath();
if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)) {
log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...", rasaModelPath);
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath));
String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
Assert.notEmpty(fixedModePath, "wakeUpInterruptService: no rasa model in path {} ", modeParentPath);
rasaRunParam.setRasaModelPath(fixedModePath);
}
log.info("wakeUpInterruptServer : use fixedModePath :{}", rasaRunParam.getRasaModelPath());
List<String> outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(),
s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300);
rasaModelInfo.setRunLog(String.join("\r\n", outMessageList));
rasaModelInfo.setRunCmd(rasaRunParam.toList());
rasaModeService.updateById(rasaModelInfo);
if (!runIsSuccess(outMessageList)) {
log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}", rasaModelInfo.getPort(), String.join("\r\n", outMessageList));
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
log.info("wakeUpInterruptServer: restart server port for {} failed", rasaModelInfo.getPort());
throw new RuntimeException(e);
}
log.info("wakeUpInterruptServer: restart server port for {} success ", rasaModelInfo.getPort());
}
}

@ -21,4 +21,10 @@ public interface Text2vecService {
* @return
*/
List<Text2vecMatchesRes> matches(Text2vecMatchesReq text2vecMatchesReq);
/**
*
*/
void initText2vecDataset();
}

@ -1,5 +1,6 @@
package com.supervision.rasa.service;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
@ -7,21 +8,25 @@ import cn.hutool.json.JSON;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class Text2vecServiceImpl implements Text2vecService {
@Value("${text2vec.service.domain}")
@ -30,6 +35,13 @@ public class Text2vecServiceImpl implements Text2vecService {
private final String UPDATE_DATASET_PATH = "update_dataset";
private final String MATCHES_PATH = "matches";
private final String GET_ALL_SIMILARITIES_PATH = "get_all_similarities";
private final RasaCmdService rasaCmdService;
public Text2vecServiceImpl(@Autowired @Lazy RasaCmdService rasaCmdService) {
this.rasaCmdService = rasaCmdService;
}
@Override
public boolean updateDataset(List<Text2vecDataVo> text2vecDataVoList) {
@ -65,4 +77,19 @@ public class Text2vecServiceImpl implements Text2vecService {
return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class);
}
public void initText2vecDataset(){
log.info("initText2vecDataset ...");
if (FileUtil.exist("/usr/local/text2vec/question.json")){
log.info("question.json文件已经存在不进行text2vec数据初始化操作....");
return;
}
Map<String, QuestionAnswerDTO> intentCodeAndIdMap = rasaCmdService.getIntentCodeAndIdMap();
// 更新text2vec数据信息
List<Text2vecDataVo> text2vecDataVoList = intentCodeAndIdMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
this.updateDataset(text2vecDataVoList);
}
}

@ -1,17 +1,30 @@
package com.supervision.rasa.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Pair;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.AskTemplateQuestionLibrary;
import com.supervision.model.ConfigAncillaryItem;
import com.supervision.model.ConfigPhysicalTool;
import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.config.ThreadPoolExecutorConfig;
import com.supervision.rasa.constant.RasaConstant;
import com.supervision.rasa.pojo.dto.*;
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.service.Text2vecService;
import com.supervision.rasa.util.PortUtil;
import com.supervision.service.AskTemplateQuestionLibraryService;
import com.supervision.service.ConfigAncillaryItemService;
import com.supervision.service.ConfigPhysicalToolService;
import com.supervision.service.RasaModeService;
import freemarker.template.Configuration;
import freemarker.template.Template;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
@ -23,6 +36,7 @@ import java.io.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@Service
@Slf4j
@ -46,26 +60,36 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Value("${rasa.shell-env:/bin/bash}")
private String shellEnv;
@Value("${rasa.data-path:/home/rasa/model_resource/}")
private String rasaFilePath;
private final RasaModeService rasaModeService;
private final ConfigPhysicalToolService configPhysicalToolService;
private final ConfigAncillaryItemService configAncillaryItemService;
private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
private final Text2vecService text2vecService;
private final ConcurrentHashMap<String,String> shellPathCache = new ConcurrentHashMap<>();
@Override
@Transactional
public String trainExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
log.info("trainExec:start train rasa model ....argument:{}", JSONUtil.toJsonStr(argument));
argument.setFixedModelNameIfAbsent();
// /rasa/v3_jiazhuangxian/domain.yml domain的路径应该是从zip文件中加压出来的文件的路径后面拼上/domain.yml
String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId(),"domain.yml"));
String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,"domain.yml"));
// /rasa/v3_jiazhuangxian/ yml文件的路径应该是从zip文件中加压出来的文件的路径,在配置文件中配置
String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId()));
String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath));
// /rasa/models 生成出来的模型的存放路径,也写在配置文件里面
String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath,argument.getModelId()));
String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath));
List<String> cmds = ListUtil.toList(shellEnv, getShellPath(RasaConstant.TRAIN_SHELL),config,localDataPath,domain,localModelsPath);
@ -84,8 +108,10 @@ public class RasaCmdServiceImpl implements RasaCmdService {
cmds.set(1,null);
rasaModelInfo.setTrainCmd(cmds);
rasaModelInfo.setTrainLog(outMessageString);
rasaModelInfo.setModelId("1");
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
log.info("trainExec:end train rasa model ....");
return outMessageString;
}
@ -94,7 +120,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Override
public String runExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
log.info("runExec:start runExec rasa model ....args:{}",JSONUtil.toJsonStr(argument));
// 1. 查找可用端口
int port = PortUtil.findUnusedPort(5050, 100000,rasaModeService.listActivePort());
log.info("runExec findUnusedPort is : {}",port);
@ -104,7 +130,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
// aaa1111.tar.gz这个前面的文件名应该是--fixed-model-name指定的.tar.gz是文件后缀代码拼接
String fixedModePath;
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, argument.getModelId()));
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath));
if (StrUtil.isEmpty(argument.getFixedModelName())){
fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
}else {
@ -142,6 +168,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
rasaModelInfo.setRunLog(outMessageString);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
log.info("runExec:runExec end ....");
return outMessageString;
}
@ -176,56 +203,63 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Override
public boolean deployRasa() throws Exception {
private boolean trainIsSuccess(List<String> messageList){
// 1.生成rasa模型语料文件
Map<String, QuestionAnswerDTO> questionAnswerDTOMap = generateRasaYml(String.join(File.separator, rasaFilePath));
return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE);
}
// 2.训练模型
trainExec(new RasaCmdArgumentVo());
//3.运行模型
RasaCmdArgumentVo rasaCmdArgumentVo = new RasaCmdArgumentVo();
rasaCmdArgumentVo.setModelId("1");
runExec(rasaCmdArgumentVo);
private boolean runIsSuccess(List<String> messageList){
// 更新text2vec数据信息
List<Text2vecDataVo> text2vecDataVoList = questionAnswerDTOMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
text2vecService.updateDataset(text2vecDataVoList);
return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE);
return true;
}
private boolean containKey(List<String> messageList,String keyWord){
@Override
public Map<String, QuestionAnswerDTO> generateRasaYml(String path) {
if (CollectionUtil.isEmpty(messageList)){
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
log.info("generateRasaYml:start generateRasaYml ....");
private String replaceDuplicateSeparator(String path){
// 默认问答MAP
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
if (StrUtil.isEmpty(path)){
return path;
}
Map<String, File> ymalFileMap = new HashMap<>();
Map<String, QuestionAnswerDTO> intentCodeAndIdMap = getIntentCodeAndIdMap();
return path.replace(File.separator + File.separator, File.separator);
}
// 开始生成各种yaml文件
Pair<String, File> nulFilePair = generateNlu(intentCodeAndIdMap);
ymalFileMap.put(nulFilePair.getKey(),nulFilePair.getValue());
Pair<String, File> domainFilePair = generateDomain(intentCodeAndIdMap, ruleList);
ymalFileMap.put(domainFilePair.getKey(),domainFilePair.getValue());
Pair<String, File> ruleFilePair = generateRule(ruleList);
ymalFileMap.put(ruleFilePair.getKey(),ruleFilePair.getValue());
// 把文件复制到指定位置
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
try {
FileUtil.copy(fileEntry.getValue(), new File(StrUtil.join(File.separator, path,fileEntry.getKey())), true);
}finally {
FileUtil.del(fileEntry.getValue());
}
private String listLastFilePath(String path, FileFilter filter){
File file = listLastFile(path, filter);
if (null == file){
return null;
}
return file.getPath();
log.info("generateRasaYml:end generateRasaYml ....");
return intentCodeAndIdMap;
}
private File listLastFile(String path,FileFilter filter){
File file = new File(path);
File[] files = file.listFiles(filter);
if (null == files){
return null;
}
return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null);
}
public String getShellPath(String shell){
@ -266,6 +300,172 @@ public class RasaCmdServiceImpl implements RasaCmdService {
throw new RuntimeException(e);
}
}
private Pair<String, File> generateNlu(Map<String, QuestionAnswerDTO> intentCodeAndIdMap) {
// 首先生成根据意图查找到nlu文件
List<NluYmlTemplate.Nlu> nluList = intentCodeAndIdMap.entrySet()
.stream().map(entry ->
new NluYmlTemplate.Nlu(entry.getKey(), entry.getValue().getQuestionList()))
.collect(Collectors.toList());
NluYmlTemplate nluYmlTemplate = new NluYmlTemplate();
nluYmlTemplate.setNlu(nluList);
// 生成后生成yml文件
return createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml");
}
public Map<String, QuestionAnswerDTO> getIntentCodeAndIdMap() {
Map<String, QuestionAnswerDTO> intentCodeAndIdMap = new HashMap<>();
// 默认意图
List<AskTemplateQuestionLibrary> askTemplateQuestionLibraryList = askTemplateQuestionLibraryService.lambdaQuery().list();
// 生成默认意图的nlu
for (AskTemplateQuestionLibrary questionLibrary : askTemplateQuestionLibraryList) {
// 开始生成
// 拼接格式:code_id(防止重复)
String intentCode = questionLibrary.getCode() + "_" + questionLibrary.getId();
intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(questionLibrary.getQuestion(), CollUtil.newArrayList( questionLibrary.getId()), questionLibrary.getDescription()));
}
// 这里处理呼出的问题(code和问题不能为空)
List<ConfigPhysicalTool> physicalToolList = configPhysicalToolService.lambdaQuery()
.isNotNull(ConfigPhysicalTool::getCode)
.isNotNull(ConfigPhysicalTool::getCallOutQuestion).list();
for (ConfigPhysicalTool tool : physicalToolList) {
// 把呼出的问题全部加进去
String toolIntent = "tool_" + tool.getCode();
// answer格式为:---tool---工具ID
intentCodeAndIdMap.put(toolIntent,
new QuestionAnswerDTO(tool.getCallOutQuestion(),
CollUtil.newArrayList("tool_" + tool.getId()), "tool-" + tool.getToolName()));
}
// 生成呼出的辅助检查
List<ConfigAncillaryItem> ancillaryItemList = configAncillaryItemService.lambdaQuery()
.isNotNull(ConfigAncillaryItem::getCode)
.isNotNull(ConfigAncillaryItem::getCallOutQuestion).list();
for (ConfigAncillaryItem ancillary : ancillaryItemList) {
// 把辅助问诊的问题全部加进去
String itemIntent = "ancillary_" + ancillary.getCode();
// answer格式为:---ancillary---工具ID
intentCodeAndIdMap.put(itemIntent,
new QuestionAnswerDTO(ancillary.getCallOutQuestion(),
CollUtil.newArrayList("ancillary_" + ancillary.getId()), "呼出-ancillary-" + ancillary.getItemName()));
}
return intentCodeAndIdMap;
}
public Pair<String, File> generateDomain(Map<String, QuestionAnswerDTO> questionCodeAndIdMap, List<RuleYmlTemplate.Rule> ruleList) {
LinkedHashMap<String, List<String>> responses = new LinkedHashMap<>();
for (Map.Entry<String, QuestionAnswerDTO> entry : questionCodeAndIdMap.entrySet()) {
String intentCode = entry.getKey();
QuestionAnswerDTO value = entry.getValue();
String utter = "utter_" + intentCode;
responses.put(utter, CollUtil.newArrayList(value.getAnswerList()));
ruleList.add(new RuleYmlTemplate.Rule(value.getDesc(), intentCode, utter));
}
DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate();
// 意图
List<String> intentList = new ArrayList<>(questionCodeAndIdMap.keySet());
domainYmlTemplate.setIntents(intentList);
// 回复
domainYmlTemplate.setResponses(responses);
// action
List<String> actionList = new ArrayList<>(responses.keySet());
domainYmlTemplate.setActions(actionList);
// 生成yml文件
return createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml");
}
/**
* rule
*/
public Pair<String, File> generateRule(List<RuleYmlTemplate.Rule> ruleList) {
RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate();
ruleYmlTemplate.setRules(ruleList);
// 生成yml文件
return createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml");
}
private Pair<String,File> createYmlFile(Class<?> clazz, String ftlName, Object data, String ymlName) {
try {
// 这个版本和maven依赖的版本一致
Configuration configuration = new Configuration(Configuration.VERSION_2_3_31);
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
// 获取模板
Template template = configuration.getTemplate(ftlName);
File tempFile = FileUtil.createTempFile(".yml", true);
// 创建输出文件
try (PrintWriter out = new PrintWriter(tempFile);) {
// 填充并生成输出
template.process(data, out);
} catch (Exception e) {
log.error("文件生成失败");
}
return Pair.of(ymlName,tempFile);
} catch (Exception e) {
log.error("导出模板失败", e);
throw new RuntimeException("文件生成失败", e);
}
}
private boolean trainIsSuccess(List<String> messageList){
return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE);
}
private boolean runIsSuccess(List<String> messageList){
return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE);
}
private boolean containKey(List<String> messageList,String keyWord){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
private String replaceDuplicateSeparator(String path){
if (StrUtil.isEmpty(path)){
return path;
}
return path.replace(File.separator + File.separator, File.separator);
}
private String listLastFilePath(String path, FileFilter filter){
File file = listLastFile(path, filter);
if (null == file){
return null;
}
return file.getPath();
}
private File listLastFile(String path,FileFilter filter){
File file = new File(path);
File[] files = file.listFiles(filter);
if (null == files){
return null;
}
return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null);
}
}

@ -0,0 +1,28 @@
recipe: default.v1
language: zh
pipeline:
- name: JiebaTokenizer
- name: LanguageModelFeaturizer
model_name: bert
model_weights: bert-base-chinese
- name: RegexFeaturizer
- name: DIETClassifier
epochs: 100
learning_rate: 0.001
tensorboard_log_directory: ./log
- name: ResponseSelector
epochs: 100
learning_rate: 0.001
- name: FallbackClassifier
threshold: 0.87
ambiguity_threshold: 0.1
- name: EntitySynonymMapper
policies:
- name: MemoizationPolicy
- name: TEDPolicy
- name: RulePolicy
core_fallback_threshold: 0.87
core_fallback_action_name: "action_default_fallback"
enable_fallback_prediction: True

@ -0,0 +1,23 @@
version: "3.1"
intents:
<#list intents as intent>
- ${intent}
</#list>
responses:
<#list responses?keys as response>
${response}:
<#list responses[response] as item>
- text: "${item}"
</#list>
</#list>
actions:
<#list actions as action>
- ${action}
</#list>
session_config:
session_expiration_time: 60
carry_over_slots_to_new_session: true

@ -0,0 +1,10 @@
version: "3.1"
nlu:
<#list nlu as item>
- intent: ${item.intent}
examples: |
<#list item.examples as example>
- ${example}
</#list>
</#list>

@ -0,0 +1,12 @@
version: "3.1"
rules:
<#list rules as item>
- rule: ${item.rule}
steps:
<#list item.steps as ss>
- intent: ${ss.intent}
- action: ${ss.action}
</#list>
</#list>

@ -1,13 +1,36 @@
package com.supervision.rasa;
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.service.Text2vecService;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@SpringBootTest
class VirtualPatientRasaApplicationTests {
@Autowired
private RasaCmdService rasaCmdService;
@Autowired
private Text2vecService text2vecService;
@Test
void contextLoads() {
/*Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml("F:\\tmp\\rasa");
System.out.println(questionAnswerDTOMap);*/
Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml(String.join(File.separator, "F:\\tmp\\rasa"));
List<Text2vecDataVo> text2vecDataVoList = questionAnswerDTOMap.entrySet().stream()
.flatMap(entry -> entry.getValue().getQuestionList().stream()
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
text2vecService.updateDataset(text2vecDataVoList);
}
}

Loading…
Cancel
Save