diff --git a/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java b/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java index 71dda8a2..fee7395d 100644 --- a/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java +++ b/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java @@ -2,8 +2,12 @@ package com.supervision.service; import com.baomidou.mybatisplus.extension.service.IService; import com.supervision.model.RasaModelInfo; +import sun.reflect.generics.tree.VoidDescriptor; public interface RasaModeService extends IService { RasaModelInfo queryByModelId(String modelId); + + + RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo); } diff --git a/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java b/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java index 20d2e70a..cf94a66b 100644 --- a/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java +++ b/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java @@ -1,7 +1,9 @@ package com.supervision.service.impl; +import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.exception.BusinessException; import com.supervision.mapper.RasaModelInfoMapper; import com.supervision.model.RasaModelInfo; import com.supervision.service.RasaModeService; @@ -18,4 +20,19 @@ public class RasaModeServiceImpl extends ServiceImpl talkRasa(String question, String sessionId){ + @PostMapping("talkRasa") + public List talkRasa(@RequestBody RasaTalkVo rasaTalkVo){ - return null; + return rasaTalkService.talkRasa(rasaTalkVo); } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java similarity index 92% rename from virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java rename to virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java index 504811e2..f45f8827 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java @@ -3,7 +3,7 @@ package com.supervision.rasa.pojo.vo; import lombok.Data; @Data -public class RasaArgument { +public class RasaArgumentVo { private String config; private String data; diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java new file mode 100644 index 00000000..0b20ebdc --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java @@ -0,0 +1,22 @@ +package com.supervision.rasa.pojo.vo; + +import lombok.Data; + +@Data +public class RasaTalkVo { + + /** + * 问题 + */ + private String question; + /** + * 会话标识 + */ + private String sessionId; + + /** + * 模型id + */ + private String modelId; + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java index 908132fa..fca00ca2 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java @@ -1,6 +1,6 @@ package com.supervision.rasa.service; -import com.supervision.rasa.pojo.vo.RasaArgument; +import com.supervision.rasa.pojo.vo.RasaArgumentVo; import java.io.IOException; import java.util.concurrent.ExecutionException; @@ -8,8 +8,10 @@ import java.util.concurrent.TimeoutException; public interface RasaCmdService { - String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; + String trainExec(RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; - String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException; + String runExec( RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException; + + void test(); } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java index 2630f7fa..7386e8e2 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java @@ -1,9 +1,11 @@ package com.supervision.rasa.service; +import com.supervision.rasa.pojo.vo.RasaTalkVo; + import java.util.List; public interface RasaTalkService { - List talkRasa(String question, String sessionId) ; + List talkRasa(RasaTalkVo rasaTalkVo) ; } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java index 0b9ccb45..a202b365 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java @@ -3,14 +3,17 @@ package com.supervision.rasa.service.impl; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.ListUtil; import cn.hutool.core.util.StrUtil; +import com.supervision.model.RasaModelInfo; import com.supervision.rasa.config.ThreadPoolExecutorConfig; -import com.supervision.rasa.pojo.vo.RasaArgument; +import com.supervision.rasa.pojo.vo.RasaArgumentVo; import com.supervision.rasa.service.RasaCmdService; +import com.supervision.rasa.util.PortUtil; import com.supervision.service.RasaModeService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.io.*; import java.util.ArrayList; @@ -51,8 +54,12 @@ public class RasaCmdServiceImpl implements RasaCmdService { private final RasaModeService rasaModeService; + private final String TRAN_SUCCESS_MESSAGE = "Your Rasa model is trained and saved at"; + private final String RUN_SUCCESS_MESSAGE = "Rasa server is up and running"; + @Override - public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { + @Transactional + public String trainExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { String domain = dataPath+"domain.yml"; List cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath); @@ -62,41 +69,60 @@ public class RasaCmdServiceImpl implements RasaCmdService { List outMessage = execCmd(cmds, s -> false, 300); + //保存 模型信息 + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setModelId(argument.getModelId()); + rasaModelInfo.setTranStatus(trainIsSuccess(outMessage)?1:0); + rasaModelInfo.setServerStatus(-1); + rasaModeService.saveOrUpdateByModelId(rasaModelInfo); - boolean isSuccess = trainIsSuccess(outMessage); return String.join("\r\n",outMessage); } - private boolean trainIsSuccess(List messageList){ - - if (CollectionUtil.isEmpty(messageList)){ - return false; - } - String keyWord = "Your Rasa model is trained and saved at"; - if (StrUtil.isEmpty(keyWord)){ - return false; - } - return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); - } - - private boolean containKey(List messageList,String key){ - return true; - } - - @Override - public String runExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { - + public String runExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { + + + // 1. 查找模型信息是否为空 + RasaModelInfo dbRasaModelInfo = rasaModeService.queryByModelId(argument.getModelId()); + int unusedPort; + if (null == dbRasaModelInfo || null == dbRasaModelInfo.getPort()){ + unusedPort = PortUtil.findUnusedPort(5050, 100000); + log.info("runExec findUnusedPort is : {}",unusedPort); + argument.setPort(String.valueOf(unusedPort)); + }else { + argument.setPort(String.valueOf(dbRasaModelInfo.getPort())); + // todo:杀掉该端口对应的进程 + } + // 2. 运行模型 String mPath = modelsPath+argument.getFixedModelName()+".tar.gz"; List cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort()); log.info("runExec cmd : {}",StrUtil.join(" ",cmds)); - return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90)); + + List outMessageList = execCmd(cmds, s -> StrUtil.isNotBlank(s) && s.contains(RUN_SUCCESS_MESSAGE), 300); + + // 3. 更新模型信息 + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setModelId(argument.getModelId()); + rasaModelInfo.setPort(0); + rasaModelInfo.setServerStatus(runIsSuccess(outMessageList)?1:0); + rasaModeService.saveOrUpdateByModelId(rasaModelInfo); + + return String.join("\r\n",outMessageList); } + @Override + public void test() { + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setId("1"); + rasaModelInfo.setModelId("1"); + rasaModelInfo.setTranStatus(1); + rasaModeService.saveOrUpdate(rasaModelInfo); + } private List execCmd(List cmds, Predicate endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException { @@ -127,4 +153,29 @@ public class RasaCmdServiceImpl implements RasaCmdService { return future.get(timeOut, TimeUnit.SECONDS); } + + + + private boolean trainIsSuccess(List messageList){ + + return containKey(messageList,TRAN_SUCCESS_MESSAGE); + } + + + private boolean runIsSuccess(List messageList){ + + return containKey(messageList,RUN_SUCCESS_MESSAGE); + } + + private boolean containKey(List messageList,String keyWord){ + + if (CollectionUtil.isEmpty(messageList)){ + return false; + } + if (StrUtil.isEmpty(keyWord)){ + return false; + } + return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); + } + } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java index b8c4624d..384b5638 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java @@ -1,16 +1,23 @@ package com.supervision.rasa.service.impl; +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONUtil; +import com.supervision.exception.BusinessException; +import com.supervision.model.RasaModelInfo; import com.supervision.rasa.pojo.dto.RasaReqDTO; import com.supervision.rasa.pojo.dto.RasaResDTO; +import com.supervision.rasa.pojo.vo.RasaTalkVo; import com.supervision.rasa.service.RasaTalkService; +import com.supervision.service.RasaModeService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.util.List; +import java.util.function.Predicate; import java.util.stream.Collectors; @Service @Slf4j @@ -19,14 +26,32 @@ public class RasaTalkServiceImpl implements RasaTalkService { @Value("${rasa.url}") private String rasaUrl; + + private final RasaModeService rasaModeService; @Override - public List talkRasa(String question, String sessionId) { + public List talkRasa(RasaTalkVo rasaTalkVo) { + + if (StrUtil.isEmpty(rasaTalkVo.getModelId())){ + throw new BusinessException("modelId is not allow empty "); + } + RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId(rasaTalkVo.getModelId()); + if (null == rasaModelInfo){ + throw new BusinessException(" not find model info , check if modelId is available "); + }else if (null == rasaModelInfo.getPort()){ + throw new BusinessException(" rasa model port is empty , check if rasa is start "); + } RasaReqDTO rasaReqDTO = new RasaReqDTO(); - rasaReqDTO.setSender(sessionId); - rasaReqDTO.setMessage(question); - String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO)); + rasaReqDTO.setSender(rasaTalkVo.getSessionId()); + rasaReqDTO.setMessage(rasaTalkVo.getQuestion()); + String post = HttpUtil.post(getRasaUrl(rasaModelInfo.getPort()), JSONUtil.toJsonStr(rasaReqDTO)); List list = JSONUtil.toList(post, RasaResDTO.class); + return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); } + + private String getRasaUrl(int port){ + + return StrUtil.format(rasaUrl, port); + } } diff --git a/virtual-patient-rasa/src/main/resources/application.yml b/virtual-patient-rasa/src/main/resources/application.yml index dcbc4595..733a2604 100644 --- a/virtual-patient-rasa/src/main/resources/application.yml +++ b/virtual-patient-rasa/src/main/resources/application.yml @@ -18,7 +18,7 @@ rasa: config: /rasa/config-local.yml # 启动rasa需要的配置文件,在配置文件中配置 train-shell: /home/rasa_manage/train.sh run-shell: /home/rasa_manage/run.sh - url: 192.168.10.137:5005/webhooks/rest/webhook + url: 192.168.10.137:{}/webhooks/rest/webhook spring: profiles: active: dev