From 22b1b49dbd7c025cf47156c58554e87891501006 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Mon, 15 Jan 2024 17:47:29 +0800 Subject: [PATCH] =?UTF-8?q?manage:=20=E6=B7=BB=E5=8A=A0rasa=E9=83=A8?= =?UTF-8?q?=E7=BD=B2=E6=8E=A5=E5=8F=A3(=E9=9B=86=E6=88=90=E7=94=9F?= =?UTF-8?q?=E6=88=90yml=E6=96=87=E4=BB=B6=E3=80=81=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E3=80=81=E8=BF=90=E8=A1=8C=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- virtual-patient-rasa/Dockerfile | 2 +- virtual-patient-rasa/docker_1_1_0/Dockerfile | 2 +- virtual-patient-rasa/pom.xml | 5 + .../rasa/VirtualPatientRasaApplication.java | 12 +- .../rasa/controller/RasaCmdController.java | 11 +- .../rasa/pojo/dto/DomainYmlTemplate.java | 28 ++ .../rasa/pojo/dto/NluYmlTemplate.java | 28 ++ .../rasa/pojo/dto/QuestionAnswerDTO.java | 20 ++ .../rasa/pojo/dto/RuleYmlTemplate.java | 44 +++ .../rasa/pojo/dto/Text2vecDataVo.java | 8 + .../rasa/service/RasaCmdService.java | 8 + .../rasa/service/RasaModelManager.java | 67 ++--- .../rasa/service/Text2vecService.java | 6 + .../rasa/service/Text2vecServiceImpl.java | 29 +- .../rasa/service/impl/RasaCmdServiceImpl.java | 276 +++++++++++++++--- .../src/main/resources/templates/config.ftl | 28 ++ .../src/main/resources/templates/domain.ftl | 23 ++ .../src/main/resources/templates/nlu.ftl | 10 + .../src/main/resources/templates/rules.ftl | 12 + .../VirtualPatientRasaApplicationTests.java | 23 ++ 20 files changed, 564 insertions(+), 78 deletions(-) create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/DomainYmlTemplate.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/NluYmlTemplate.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/QuestionAnswerDTO.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RuleYmlTemplate.java create mode 100644 virtual-patient-rasa/src/main/resources/templates/config.ftl create mode 100644 virtual-patient-rasa/src/main/resources/templates/domain.ftl create mode 100644 virtual-patient-rasa/src/main/resources/templates/nlu.ftl create mode 100644 virtual-patient-rasa/src/main/resources/templates/rules.ftl diff --git a/virtual-patient-rasa/Dockerfile b/virtual-patient-rasa/Dockerfile index df15199d..2ee034f0 100644 --- a/virtual-patient-rasa/Dockerfile +++ b/virtual-patient-rasa/Dockerfile @@ -11,7 +11,7 @@ WORKDIR /data/vp COPY target/virtual-patient-rasa-1.0-SNAPSHOT.jar /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar # 复制rasa配置文件到 rasa目录下 COPY docs/rasa /rasa -COPY docs/1 /data/vp/rasa/models/1 +COPY docs/1 /data/vp/rasa/models/ # 暴漏服务端口 EXPOSE 8890 diff --git a/virtual-patient-rasa/docker_1_1_0/Dockerfile b/virtual-patient-rasa/docker_1_1_0/Dockerfile index 512049a0..1c2e9d78 100644 --- a/virtual-patient-rasa/docker_1_1_0/Dockerfile +++ b/virtual-patient-rasa/docker_1_1_0/Dockerfile @@ -3,7 +3,7 @@ FROM rasa_dev:1.0.0 COPY ./bert_chinese /usr/local/text2vec/bert_chinese COPY ./app.py /usr/local/text2vec/ -COPY ./question.json /usr/local/text2vec/ +#COPY ./question.json /usr/local/text2vec/ RUN source /root/anaconda3/etc/profile.d/conda.sh && \ conda create --name text2vec_env python=3.9 -y && \ diff --git a/virtual-patient-rasa/pom.xml b/virtual-patient-rasa/pom.xml index 8b290d6d..f934f761 100644 --- a/virtual-patient-rasa/pom.xml +++ b/virtual-patient-rasa/pom.xml @@ -58,6 +58,11 @@ cn.hutool hutool-all + + + org.freemarker + freemarker + diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/VirtualPatientRasaApplication.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/VirtualPatientRasaApplication.java index 1eb8428a..041efa1d 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/VirtualPatientRasaApplication.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/VirtualPatientRasaApplication.java @@ -2,6 +2,7 @@ package com.supervision.rasa; import com.supervision.config.WebConfig; import com.supervision.rasa.service.RasaModelManager; +import com.supervision.rasa.service.Text2vecService; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; @@ -20,8 +21,17 @@ public class VirtualPatientRasaApplication { public static void main(String[] args) { ConfigurableApplicationContext context = SpringApplication.run(VirtualPatientRasaApplication.class, args); + // 启动rasa服务 RasaModelManager rasaModelManager = context.getBean(RasaModelManager.class); - rasaModelManager.wakeUpInterruptServerScheduled(); + try { + rasaModelManager.wakeUpInterruptServerScheduled(); + } catch (Exception e) { + throw new RuntimeException(e); + } + + // 初始化文本匹配数据 + Text2vecService text2vecService = context.getBean(Text2vecService.class); + text2vecService.initText2vecDataset(); } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java index d1708be7..fda52353 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java @@ -13,7 +13,7 @@ import org.springframework.web.bind.annotation.*; import java.io.*; import java.util.concurrent.*; -@Api(tags = "rasa文件保存") +@Api(tags = "rasa管理") @RestController @RequestMapping("rasaCmd") @RequiredArgsConstructor @@ -25,7 +25,6 @@ public class RasaCmdController { @PostMapping("/trainExec") public String trainExec(@RequestBody RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { - argument.setModelId("1"); return rasaCmdService.trainExec(argument); } @@ -34,7 +33,6 @@ public class RasaCmdController { @PostMapping("/runExec") public String runExec(@RequestBody RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { - argument.setModelId("1"); String outString = rasaCmdService.runExec(argument); if (StrUtil.isEmptyIfStr(outString) || !outString.contains(RasaConstant.RUN_SUCCESS_MESSAGE)){ throw new BusinessException("任务执行异常。详细日志:"+outString); @@ -43,5 +41,12 @@ public class RasaCmdController { } + @ApiOperation("部署rasa") + @PostMapping("/deploy") + public boolean deployRasa() throws Exception { + return rasaCmdService.deployRasa(); + + } + } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/DomainYmlTemplate.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/DomainYmlTemplate.java new file mode 100644 index 00000000..e4a5394e --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/DomainYmlTemplate.java @@ -0,0 +1,28 @@ +package com.supervision.rasa.pojo.dto; + +import lombok.Data; + +import java.util.LinkedHashMap; +import java.util.List; + +@Data +public class DomainYmlTemplate { + + private List intents; + + private LinkedHashMap> responses; + + private List actions; + + private SessionConfig session_config = new SessionConfig(); + + + @Data + public static class SessionConfig{ + + private final int session_expiration_time = 60; + + private final Boolean carry_over_slots_to_new_session = true; + } + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/NluYmlTemplate.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/NluYmlTemplate.java new file mode 100644 index 00000000..d7ef1e5a --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/NluYmlTemplate.java @@ -0,0 +1,28 @@ +package com.supervision.rasa.pojo.dto; + +import lombok.Data; + +import java.util.List; + +@Data +public class NluYmlTemplate { + + private List nlu; + + @Data + public static class Nlu{ + private String intent; + + private List examples; + + public Nlu(String intent, List examples) { + this.intent = intent; + this.examples = examples; + } + + public Nlu() { + } + } + + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/QuestionAnswerDTO.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/QuestionAnswerDTO.java new file mode 100644 index 00000000..375d1390 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/QuestionAnswerDTO.java @@ -0,0 +1,20 @@ +package com.supervision.rasa.pojo.dto; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class QuestionAnswerDTO { + + private List questionList; + + private List answerList; + + private String desc; + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RuleYmlTemplate.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RuleYmlTemplate.java new file mode 100644 index 00000000..43e2284f --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RuleYmlTemplate.java @@ -0,0 +1,44 @@ +package com.supervision.rasa.pojo.dto; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.ArrayList; +import java.util.List; + +@Data +public class RuleYmlTemplate { + + private List rules; + + + @Data + public static class Rule { + private String rule; + + private List steps; + + public Rule() { + } + + public Rule(String rule, String intent, String action) { + this.rule = rule; + steps = new ArrayList<>(); + Step step = new Step(intent, action); + steps.add(step); + } + } + + @Data + @AllArgsConstructor + @NoArgsConstructor + public static class Step { + private String intent; + + private String action; + + + } + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java index 15b748da..6b50ba9a 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java @@ -11,4 +11,12 @@ public class Text2vecDataVo { @ApiModelProperty("问题") private String question; + + public Text2vecDataVo(String id, String question) { + this.id = id; + this.question = question; + } + + public Text2vecDataVo() { + } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java index ede109b1..92e1256e 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java @@ -1,9 +1,11 @@ package com.supervision.rasa.service; +import com.supervision.rasa.pojo.dto.QuestionAnswerDTO; import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo; import java.io.IOException; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeoutException; import java.util.function.Predicate; @@ -19,4 +21,10 @@ public interface RasaCmdService { String getShellPath(String shell); + boolean deployRasa() throws Exception; + + Map generateRasaYml(String path); + + public Map getIntentCodeAndIdMap(); + } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java index a05c34da..ba3fbacb 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java @@ -51,39 +51,40 @@ public class RasaModelManager { // 2. 重新启动中断的服务 for (RasaModelInfo rasaModelInfo : activeRasaList) { - if (!PortUtil.portIsActive(rasaModelInfo.getPort())){ - try { - RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd()); - rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort())); - rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL)); - String rasaModelPath = rasaRunParam.getRasaModelPath(); - if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)){ - log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...",rasaModelPath); - String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, rasaModelInfo.getModelId())); - String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz")); - Assert.notEmpty(fixedModePath,"wakeUpInterruptService: no rasa model in path {} ",modeParentPath); - rasaRunParam.setRasaModelPath(fixedModePath); - } - log.info("wakeUpInterruptServer : use fixedModePath :{}",rasaRunParam.getRasaModelPath()); - List outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(), - s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300); - - rasaModelInfo.setRunLog(String.join("\r\n",outMessageList)); - rasaModelInfo.setRunCmd(rasaRunParam.toList()); - rasaModeService.updateById(rasaModelInfo); - - if (!runIsSuccess(outMessageList)){ - log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}",rasaModelInfo.getPort(),String.join("\r\n",outMessageList)); - } - } catch (InterruptedException | ExecutionException | TimeoutException e ) { - log.info("wakeUpInterruptServer: restart server port for {} failed",rasaModelInfo.getPort()); - throw new RuntimeException(e); - } - log.info("wakeUpInterruptServer: restart server port for {} success ",rasaModelInfo.getPort()); - }else { - log.info("wakeUpInterruptServer: port:{} is run..",rasaModelInfo.getPort()); - } - } + if (PortUtil.portIsActive(rasaModelInfo.getPort())) { + log.info("wakeUpInterruptServer: port:{} is run..", rasaModelInfo.getPort()); + continue; + } + + try { + RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd()); + rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort())); + rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL)); + String rasaModelPath = rasaRunParam.getRasaModelPath(); + if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)) { + log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...", rasaModelPath); + String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath)); + String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz")); + Assert.notEmpty(fixedModePath, "wakeUpInterruptService: no rasa model in path {} ", modeParentPath); + rasaRunParam.setRasaModelPath(fixedModePath); + } + log.info("wakeUpInterruptServer : use fixedModePath :{}", rasaRunParam.getRasaModelPath()); + List outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(), + s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300); + + rasaModelInfo.setRunLog(String.join("\r\n", outMessageList)); + rasaModelInfo.setRunCmd(rasaRunParam.toList()); + rasaModeService.updateById(rasaModelInfo); + + if (!runIsSuccess(outMessageList)) { + log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}", rasaModelInfo.getPort(), String.join("\r\n", outMessageList)); + } + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.info("wakeUpInterruptServer: restart server port for {} failed", rasaModelInfo.getPort()); + throw new RuntimeException(e); + } + log.info("wakeUpInterruptServer: restart server port for {} success ", rasaModelInfo.getPort()); + } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java index 9d064b67..c7b8677e 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java @@ -21,4 +21,10 @@ public interface Text2vecService { * @return */ List matches(Text2vecMatchesReq text2vecMatchesReq); + + + /** + * 初始化语料库 + */ + void initText2vecDataset(); } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java index 35c67986..4058c7dd 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.rasa.service; +import cn.hutool.core.io.FileUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; @@ -7,21 +8,25 @@ import cn.hutool.json.JSON; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; import cn.hutool.json.JSONUtil; +import com.supervision.rasa.pojo.dto.QuestionAnswerDTO; import com.supervision.rasa.pojo.dto.Text2vecDataVo; import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; @Slf4j @Service -@RequiredArgsConstructor public class Text2vecServiceImpl implements Text2vecService { @Value("${text2vec.service.domain}") @@ -30,6 +35,13 @@ public class Text2vecServiceImpl implements Text2vecService { private final String UPDATE_DATASET_PATH = "update_dataset"; private final String MATCHES_PATH = "matches"; private final String GET_ALL_SIMILARITIES_PATH = "get_all_similarities"; + + private final RasaCmdService rasaCmdService; + + public Text2vecServiceImpl(@Autowired @Lazy RasaCmdService rasaCmdService) { + this.rasaCmdService = rasaCmdService; + } + @Override public boolean updateDataset(List text2vecDataVoList) { @@ -65,4 +77,19 @@ public class Text2vecServiceImpl implements Text2vecService { return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class); } + + public void initText2vecDataset(){ + log.info("initText2vecDataset ..."); + if (FileUtil.exist("/usr/local/text2vec/question.json")){ + log.info("question.json文件已经存在,不进行text2vec数据初始化操作...."); + return; + } + Map intentCodeAndIdMap = rasaCmdService.getIntentCodeAndIdMap(); + // 更新text2vec数据信息 + List text2vecDataVoList = intentCodeAndIdMap.entrySet().stream() + .flatMap(entry -> entry.getValue().getQuestionList().stream() + .map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList()); + this.updateDataset(text2vecDataVoList); + + } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java index 9153c549..807900c0 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java @@ -1,17 +1,30 @@ package com.supervision.rasa.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.ListUtil; import cn.hutool.core.io.FileUtil; +import cn.hutool.core.lang.Pair; import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.supervision.exception.BusinessException; +import com.supervision.model.AskTemplateQuestionLibrary; +import com.supervision.model.ConfigAncillaryItem; +import com.supervision.model.ConfigPhysicalTool; import com.supervision.model.RasaModelInfo; import com.supervision.rasa.config.ThreadPoolExecutorConfig; import com.supervision.rasa.constant.RasaConstant; +import com.supervision.rasa.pojo.dto.*; import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo; import com.supervision.rasa.service.RasaCmdService; +import com.supervision.rasa.service.Text2vecService; import com.supervision.rasa.util.PortUtil; +import com.supervision.service.AskTemplateQuestionLibraryService; +import com.supervision.service.ConfigAncillaryItemService; +import com.supervision.service.ConfigPhysicalToolService; import com.supervision.service.RasaModeService; +import freemarker.template.Configuration; +import freemarker.template.Template; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; @@ -23,6 +36,7 @@ import java.io.*; import java.util.*; import java.util.concurrent.*; import java.util.function.Predicate; +import java.util.stream.Collectors; @Service @Slf4j @@ -46,26 +60,36 @@ public class RasaCmdServiceImpl implements RasaCmdService { @Value("${rasa.shell-env:/bin/bash}") private String shellEnv; + @Value("${rasa.data-path:/home/rasa/model_resource/}") + private String rasaFilePath; + private final RasaModeService rasaModeService; + private final ConfigPhysicalToolService configPhysicalToolService; + + private final ConfigAncillaryItemService configAncillaryItemService; + private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService; + + private final Text2vecService text2vecService; private final ConcurrentHashMap shellPathCache = new ConcurrentHashMap<>(); @Override @Transactional public String trainExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { + log.info("trainExec:start train rasa model ....argument:{}", JSONUtil.toJsonStr(argument)); argument.setFixedModelNameIfAbsent(); // /rasa/v3_jiazhuangxian/domain.yml domain的路径,应该是从zip文件中加压出来的文件的路径后面拼上/domain.yml - String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId(),"domain.yml")); + String domain = replaceDuplicateSeparator(String.join(File.separator,dataPath,"domain.yml")); // /rasa/v3_jiazhuangxian/ yml文件的路径,应该是从zip文件中加压出来的文件的路径,在配置文件中配置 - String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath,argument.getModelId())); + String localDataPath = replaceDuplicateSeparator(String.join(File.separator,dataPath)); // /rasa/models 生成出来的模型的存放路径,也写在配置文件里面 - String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath,argument.getModelId())); + String localModelsPath = replaceDuplicateSeparator(String.join(File.separator,modelsPath)); List cmds = ListUtil.toList(shellEnv, getShellPath(RasaConstant.TRAIN_SHELL),config,localDataPath,domain,localModelsPath); @@ -84,8 +108,10 @@ public class RasaCmdServiceImpl implements RasaCmdService { cmds.set(1,null); rasaModelInfo.setTrainCmd(cmds); rasaModelInfo.setTrainLog(outMessageString); + rasaModelInfo.setModelId("1"); rasaModeService.saveOrUpdateByModelId(rasaModelInfo); + log.info("trainExec:end train rasa model ...."); return outMessageString; } @@ -94,7 +120,7 @@ public class RasaCmdServiceImpl implements RasaCmdService { @Override public String runExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { - + log.info("runExec:start runExec rasa model ....args:{}",JSONUtil.toJsonStr(argument)); // 1. 查找可用端口 int port = PortUtil.findUnusedPort(5050, 100000,rasaModeService.listActivePort()); log.info("runExec findUnusedPort is : {}",port); @@ -104,7 +130,7 @@ public class RasaCmdServiceImpl implements RasaCmdService { // aaa1111.tar.gz这个,前面的文件名应该是--fixed-model-name指定的,.tar.gz是文件后缀,代码拼接 String fixedModePath; - String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath, argument.getModelId())); + String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath)); if (StrUtil.isEmpty(argument.getFixedModelName())){ fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz")); }else { @@ -142,6 +168,7 @@ public class RasaCmdServiceImpl implements RasaCmdService { rasaModelInfo.setRunLog(outMessageString); rasaModeService.saveOrUpdateByModelId(rasaModelInfo); + log.info("runExec:runExec end ...."); return outMessageString; } @@ -176,56 +203,63 @@ public class RasaCmdServiceImpl implements RasaCmdService { + @Override + public boolean deployRasa() throws Exception { - private boolean trainIsSuccess(List messageList){ + // 1.生成rasa模型语料文件 + Map questionAnswerDTOMap = generateRasaYml(String.join(File.separator, rasaFilePath)); - return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE); - } + // 2.训练模型 + trainExec(new RasaCmdArgumentVo()); + //3.运行模型 + RasaCmdArgumentVo rasaCmdArgumentVo = new RasaCmdArgumentVo(); + rasaCmdArgumentVo.setModelId("1"); + runExec(rasaCmdArgumentVo); - private boolean runIsSuccess(List messageList){ + // 更新text2vec数据信息 + List text2vecDataVoList = questionAnswerDTOMap.entrySet().stream() + .flatMap(entry -> entry.getValue().getQuestionList().stream() + .map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList()); + text2vecService.updateDataset(text2vecDataVoList); - return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE); + return true; } - private boolean containKey(List messageList,String keyWord){ + @Override + public Map generateRasaYml(String path) { - if (CollectionUtil.isEmpty(messageList)){ - return false; - } - if (StrUtil.isEmpty(keyWord)){ - return false; - } - return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); - } + log.info("generateRasaYml:start generateRasaYml ...."); - private String replaceDuplicateSeparator(String path){ + // 默认问答MAP + List ruleList = new ArrayList<>(); - if (StrUtil.isEmpty(path)){ - return path; - } + Map ymalFileMap = new HashMap<>(); + Map intentCodeAndIdMap = getIntentCodeAndIdMap(); - return path.replace(File.separator + File.separator, File.separator); - } + // 开始生成各种yaml文件 + Pair nulFilePair = generateNlu(intentCodeAndIdMap); + ymalFileMap.put(nulFilePair.getKey(),nulFilePair.getValue()); + Pair domainFilePair = generateDomain(intentCodeAndIdMap, ruleList); + ymalFileMap.put(domainFilePair.getKey(),domainFilePair.getValue()); + Pair ruleFilePair = generateRule(ruleList); + ymalFileMap.put(ruleFilePair.getKey(),ruleFilePair.getValue()); + + // 把文件复制到指定位置 + for (Map.Entry fileEntry : ymalFileMap.entrySet()) { + try { + FileUtil.copy(fileEntry.getValue(), new File(StrUtil.join(File.separator, path,fileEntry.getKey())), true); + }finally { + FileUtil.del(fileEntry.getValue()); + } - private String listLastFilePath(String path, FileFilter filter){ - File file = listLastFile(path, filter); - if (null == file){ - return null; } - return file.getPath(); + log.info("generateRasaYml:end generateRasaYml ...."); + return intentCodeAndIdMap; } - private File listLastFile(String path,FileFilter filter){ - File file = new File(path); - File[] files = file.listFiles(filter); - if (null == files){ - return null; - } - return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null); - } public String getShellPath(String shell){ @@ -266,6 +300,172 @@ public class RasaCmdServiceImpl implements RasaCmdService { throw new RuntimeException(e); } } + private Pair generateNlu(Map intentCodeAndIdMap) { + // 首先生成根据意图查找到nlu文件 + List nluList = intentCodeAndIdMap.entrySet() + .stream().map(entry -> + new NluYmlTemplate.Nlu(entry.getKey(), entry.getValue().getQuestionList())) + .collect(Collectors.toList()); + + NluYmlTemplate nluYmlTemplate = new NluYmlTemplate(); + nluYmlTemplate.setNlu(nluList); + // 生成后生成yml文件 + return createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml"); + } + + public Map getIntentCodeAndIdMap() { + Map intentCodeAndIdMap = new HashMap<>(); + // 默认意图 + List askTemplateQuestionLibraryList = askTemplateQuestionLibraryService.lambdaQuery().list(); + // 生成默认意图的nlu + for (AskTemplateQuestionLibrary questionLibrary : askTemplateQuestionLibraryList) { + // 开始生成 + // 拼接格式:code_id(防止重复) + String intentCode = questionLibrary.getCode() + "_" + questionLibrary.getId(); + intentCodeAndIdMap.put(intentCode, new QuestionAnswerDTO(questionLibrary.getQuestion(), CollUtil.newArrayList( questionLibrary.getId()), questionLibrary.getDescription())); + } + + // 这里处理呼出的问题(code和问题不能为空) + List physicalToolList = configPhysicalToolService.lambdaQuery() + .isNotNull(ConfigPhysicalTool::getCode) + .isNotNull(ConfigPhysicalTool::getCallOutQuestion).list(); + + for (ConfigPhysicalTool tool : physicalToolList) { + // 把呼出的问题全部加进去 + String toolIntent = "tool_" + tool.getCode(); + // answer格式为:---tool---工具ID + intentCodeAndIdMap.put(toolIntent, + new QuestionAnswerDTO(tool.getCallOutQuestion(), + CollUtil.newArrayList("tool_" + tool.getId()), "tool-" + tool.getToolName())); + } + + // 生成呼出的辅助检查 + List ancillaryItemList = configAncillaryItemService.lambdaQuery() + .isNotNull(ConfigAncillaryItem::getCode) + .isNotNull(ConfigAncillaryItem::getCallOutQuestion).list(); + + for (ConfigAncillaryItem ancillary : ancillaryItemList) { + // 把辅助问诊的问题全部加进去 + String itemIntent = "ancillary_" + ancillary.getCode(); + // answer格式为:---ancillary---工具ID + intentCodeAndIdMap.put(itemIntent, + new QuestionAnswerDTO(ancillary.getCallOutQuestion(), + CollUtil.newArrayList("ancillary_" + ancillary.getId()), "呼出-ancillary-" + ancillary.getItemName())); + } + return intentCodeAndIdMap; + } + + + public Pair generateDomain(Map questionCodeAndIdMap, List ruleList) { + LinkedHashMap> responses = new LinkedHashMap<>(); + for (Map.Entry entry : questionCodeAndIdMap.entrySet()) { + String intentCode = entry.getKey(); + QuestionAnswerDTO value = entry.getValue(); + String utter = "utter_" + intentCode; + responses.put(utter, CollUtil.newArrayList(value.getAnswerList())); + ruleList.add(new RuleYmlTemplate.Rule(value.getDesc(), intentCode, utter)); + } + + + DomainYmlTemplate domainYmlTemplate = new DomainYmlTemplate(); + // 意图 + List intentList = new ArrayList<>(questionCodeAndIdMap.keySet()); + domainYmlTemplate.setIntents(intentList); + // 回复 + domainYmlTemplate.setResponses(responses); + // action + List actionList = new ArrayList<>(responses.keySet()); + domainYmlTemplate.setActions(actionList); + // 生成yml文件 + return createYmlFile(DomainYmlTemplate.class, "domain.ftl", domainYmlTemplate, "domain.yml"); + } + + /** + * 生成rule + */ + public Pair generateRule(List ruleList) { + RuleYmlTemplate ruleYmlTemplate = new RuleYmlTemplate(); + ruleYmlTemplate.setRules(ruleList); + // 生成yml文件 + return createYmlFile(RuleYmlTemplate.class, "rules.ftl", ruleYmlTemplate, "rules.yml"); + } + + private Pair createYmlFile(Class clazz, String ftlName, Object data, String ymlName) { + try { + // 这个版本和maven依赖的版本一致 + Configuration configuration = new Configuration(Configuration.VERSION_2_3_31); + configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录 + // 获取模板 + Template template = configuration.getTemplate(ftlName); + File tempFile = FileUtil.createTempFile(".yml", true); + // 创建输出文件 + try (PrintWriter out = new PrintWriter(tempFile);) { + // 填充并生成输出 + template.process(data, out); + } catch (Exception e) { + log.error("文件生成失败"); + } + return Pair.of(ymlName,tempFile); + } catch (Exception e) { + log.error("导出模板失败", e); + throw new RuntimeException("文件生成失败", e); + } + + } + + + + private boolean trainIsSuccess(List messageList){ + + return containKey(messageList,RasaConstant.TRAN_SUCCESS_MESSAGE); + } + + + private boolean runIsSuccess(List messageList){ + + return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE); + } + + private boolean containKey(List messageList,String keyWord){ + + if (CollectionUtil.isEmpty(messageList)){ + return false; + } + if (StrUtil.isEmpty(keyWord)){ + return false; + } + return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); + } + + private String replaceDuplicateSeparator(String path){ + + if (StrUtil.isEmpty(path)){ + return path; + } + + return path.replace(File.separator + File.separator, File.separator); + } + + + + private String listLastFilePath(String path, FileFilter filter){ + File file = listLastFile(path, filter); + if (null == file){ + return null; + } + return file.getPath(); + } + private File listLastFile(String path,FileFilter filter){ + File file = new File(path); + File[] files = file.listFiles(filter); + if (null == files){ + return null; + } + + return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null); + } + + } diff --git a/virtual-patient-rasa/src/main/resources/templates/config.ftl b/virtual-patient-rasa/src/main/resources/templates/config.ftl new file mode 100644 index 00000000..d87a0641 --- /dev/null +++ b/virtual-patient-rasa/src/main/resources/templates/config.ftl @@ -0,0 +1,28 @@ +recipe: default.v1 +language: zh + +pipeline: + - name: JiebaTokenizer + - name: LanguageModelFeaturizer + model_name: bert + model_weights: bert-base-chinese + - name: RegexFeaturizer + - name: DIETClassifier + epochs: 100 + learning_rate: 0.001 + tensorboard_log_directory: ./log + - name: ResponseSelector + epochs: 100 + learning_rate: 0.001 + - name: FallbackClassifier + threshold: 0.87 + ambiguity_threshold: 0.1 + - name: EntitySynonymMapper + +policies: + - name: MemoizationPolicy + - name: TEDPolicy + - name: RulePolicy + core_fallback_threshold: 0.87 + core_fallback_action_name: "action_default_fallback" + enable_fallback_prediction: True diff --git a/virtual-patient-rasa/src/main/resources/templates/domain.ftl b/virtual-patient-rasa/src/main/resources/templates/domain.ftl new file mode 100644 index 00000000..8ee679bb --- /dev/null +++ b/virtual-patient-rasa/src/main/resources/templates/domain.ftl @@ -0,0 +1,23 @@ +version: "3.1" + +intents: +<#list intents as intent> + - ${intent} + + +responses: +<#list responses?keys as response> + ${response}: + <#list responses[response] as item> + - text: "${item}" + + + +actions: +<#list actions as action> + - ${action} + + +session_config: + session_expiration_time: 60 + carry_over_slots_to_new_session: true diff --git a/virtual-patient-rasa/src/main/resources/templates/nlu.ftl b/virtual-patient-rasa/src/main/resources/templates/nlu.ftl new file mode 100644 index 00000000..70097b65 --- /dev/null +++ b/virtual-patient-rasa/src/main/resources/templates/nlu.ftl @@ -0,0 +1,10 @@ +version: "3.1" + +nlu: +<#list nlu as item> + - intent: ${item.intent} + examples: | + <#list item.examples as example> + - ${example} + + diff --git a/virtual-patient-rasa/src/main/resources/templates/rules.ftl b/virtual-patient-rasa/src/main/resources/templates/rules.ftl new file mode 100644 index 00000000..d2e2c0fd --- /dev/null +++ b/virtual-patient-rasa/src/main/resources/templates/rules.ftl @@ -0,0 +1,12 @@ +version: "3.1" + +rules: + +<#list rules as item> + - rule: ${item.rule} + steps: + <#list item.steps as ss> + - intent: ${ss.intent} + - action: ${ss.action} + + \ No newline at end of file diff --git a/virtual-patient-rasa/src/test/java/com/supervision/rasa/VirtualPatientRasaApplicationTests.java b/virtual-patient-rasa/src/test/java/com/supervision/rasa/VirtualPatientRasaApplicationTests.java index a6b4f6f6..36720faa 100644 --- a/virtual-patient-rasa/src/test/java/com/supervision/rasa/VirtualPatientRasaApplicationTests.java +++ b/virtual-patient-rasa/src/test/java/com/supervision/rasa/VirtualPatientRasaApplicationTests.java @@ -1,13 +1,36 @@ package com.supervision.rasa; +import com.supervision.rasa.pojo.dto.QuestionAnswerDTO; +import com.supervision.rasa.pojo.dto.Text2vecDataVo; +import com.supervision.rasa.service.RasaCmdService; +import com.supervision.rasa.service.Text2vecService; import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + @SpringBootTest class VirtualPatientRasaApplicationTests { + @Autowired + private RasaCmdService rasaCmdService; + + @Autowired + private Text2vecService text2vecService; @Test void contextLoads() { + /*Map questionAnswerDTOMap = rasaCmdService.generateRasaYml("F:\\tmp\\rasa"); + System.out.println(questionAnswerDTOMap);*/ + + Map questionAnswerDTOMap = rasaCmdService.generateRasaYml(String.join(File.separator, "F:\\tmp\\rasa")); + List text2vecDataVoList = questionAnswerDTOMap.entrySet().stream() + .flatMap(entry -> entry.getValue().getQuestionList().stream() + .map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList()); + text2vecService.updateDataset(text2vecDataVoList); } }