diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaRunParam.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaRunParam.java new file mode 100644 index 00000000..0c77ad9a --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaRunParam.java @@ -0,0 +1,72 @@ +package com.supervision.rasa.pojo.dto; + +import cn.hutool.core.collection.CollUtil; +import lombok.Data; + +import java.util.List; + +/** + * rasa 启动参数 + */ +@Data +public class RasaRunParam { + + /** + * bash 路径 + */ + private String bashPath; + + /** + * rasa 启动脚本路径 + */ + private String shellPath; + + /** + * rasa 模型路径 + */ + private String rasaModelPath; + + /** + * rasa 配置文件位置 + */ + private String endpointsPath; + + /** + * rasa 服务端口 + */ + private String port; + + + /** + * 通过list构建RasaRunParam对象 + * @param args bashPath = args[0], shellPath = args[1], rasaModelPath = args[2], endpointsPath = args[3], port = args[4] + * @return + */ + public static RasaRunParam build(List args) { + RasaRunParam rasaRunParam = new RasaRunParam(); + if (CollUtil.isEmpty(args)){ + return rasaRunParam; + } + + rasaRunParam.setBashPath(args.get(0)); + if (args.size()>1){ + rasaRunParam.setShellPath(args.get(1)); + } + if (args.size()>2){ + rasaRunParam.setRasaModelPath(args.get(2)); + } + if (args.size()>3){ + rasaRunParam.setEndpointsPath(args.get(3)); + } + if (args.size()>4){ + rasaRunParam.setPort(args.get(4)); + } + return rasaRunParam; + + } + + + public List toList(){ + return CollUtil.newArrayList(bashPath,shellPath,rasaModelPath,endpointsPath,port); + } +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaTrainParam.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaTrainParam.java new file mode 100644 index 00000000..289606f8 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/RasaTrainParam.java @@ -0,0 +1,86 @@ +package com.supervision.rasa.pojo.dto; + +import cn.hutool.core.collection.CollUtil; +import lombok.Data; + +import java.util.List; + +/** + * rasa 训练参数 + */ +@Data +public class RasaTrainParam { + + /** + * bash路径 + */ + private String bashPath; + + /** + * rasa 训练脚本路径 + */ + private String shellPath; + + /** + * rasa 训练配置文件路径 + */ + private String configPath; + + /** + * rasa 训练数据路径 (rules.yml nlu.yml) + */ + private String localDataPath; + + /** + * rasa domain.yml 存放路径 + */ + private String domainPath; + + /** + * rasa 训练出的模型存放路径 + */ + private String localModelsPath; + + /** + * 训练出的模型名称 + */ + private String fixedModelName; + + /** + * 通过list构建RasaTrainParam对象 + * @param args 参数列表 bashPath = args[0] shellPath = args[1] configPath = args[2] + * localDataPath = args[3] domainPath = args[4] localModelsPath = args[5] fixedModelName = args[6] + * @return + */ + public static RasaTrainParam build(List args) { + RasaTrainParam rasaTrainParam = new RasaTrainParam(); + if (CollUtil.isEmpty(args)){ + return rasaTrainParam; + } + rasaTrainParam.bashPath = args.get(0); + if (args.size() > 1){ + rasaTrainParam.shellPath = args.get(1); + } + if (args.size() > 2){ + rasaTrainParam.configPath = args.get(2); + } + if (args.size() > 3){ + rasaTrainParam.localDataPath = args.get(3); + } + if (args.size() > 4){ + rasaTrainParam.domainPath = args.get(4); + } + if (args.size() > 5){ + rasaTrainParam.localModelsPath = args.get(5); + } + if (args.size() > 6){ + rasaTrainParam.fixedModelName = args.get(6); + } + return rasaTrainParam; + } + + public List toList() { + return CollUtil.newArrayList(bashPath, shellPath, configPath, localDataPath, domainPath, localModelsPath, fixedModelName); + } + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java index 188c0ba5..6111da14 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaModelManager.java @@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.util.StrUtil; import com.supervision.model.RasaModelInfo; import com.supervision.rasa.constant.RasaConstant; +import com.supervision.rasa.pojo.dto.RasaRunParam; import com.supervision.rasa.util.PortUtil; import com.supervision.service.RasaModeService; import lombok.RequiredArgsConstructor; @@ -42,14 +43,16 @@ public class RasaModelManager { for (RasaModelInfo rasaModelInfo : activeRasaList) { if (!PortUtil.portIsActive(rasaModelInfo.getPort())){ try { - List runCmd = rasaModelInfo.getRunCmd(); - runCmd.add(String.valueOf(rasaModelInfo.getPort())); - runCmd.set(1,rasaCmdService.getShellPath(RasaConstant.RUN_SHELL)); - List outMessageList = rasaCmdService.execCmd(rasaModelInfo.getRunCmd(), + RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd()); + rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort())); + rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL)); + List outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(), s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300); - rasaModelInfo.setTrainLog(String.join("\r\n",outMessageList)); + rasaModelInfo.setRunLog(String.join("\r\n",outMessageList)); + rasaModelInfo.setRunCmd(rasaRunParam.toList()); rasaModeService.updateById(rasaModelInfo); + if (!runIsSuccess(outMessageList)){ log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}",rasaModelInfo.getPort(),String.join("\r\n",outMessageList)); }