From 28cb30cbc23da554b1f221761add9329c01104a8 Mon Sep 17 00:00:00 2001 From: xueqingkun <xueqingkun@126.com> Date: Tue, 31 Oct 2023 16:01:42 +0800 Subject: [PATCH] =?UTF-8?q?rasa=EF=BC=9A=E9=9B=86=E6=88=90=E4=BC=9A?= =?UTF-8?q?=E8=AF=9D=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/service/RasaModeService.java | 4 + .../service/impl/RasaModeServiceImpl.java | 17 ++++ .../rasa/controller/RasaCmdController.java | 21 ++-- .../rasa/controller/RasaFileController.java | 4 +- .../rasa/controller/RasaTalkController.java | 18 ++-- ...{RasaArgument.java => RasaArgumentVo.java} | 2 +- .../supervision/rasa/pojo/vo/RasaTalkVo.java | 22 +++++ .../rasa/service/RasaCmdService.java | 8 +- .../rasa/service/RasaTalkService.java | 4 +- .../rasa/service/impl/RasaCmdServiceImpl.java | 97 ++++++++++++++----- .../service/impl/RasaTalkServiceImpl.java | 33 ++++++- .../src/main/resources/application.yml | 2 +- 12 files changed, 182 insertions(+), 50 deletions(-) rename virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/{RasaArgument.java => RasaArgumentVo.java} (92%) create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java diff --git a/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java b/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java index 71dda8a2..fee7395d 100644 --- a/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java +++ b/virtual-patient-model/src/main/java/com/supervision/service/RasaModeService.java @@ -2,8 +2,12 @@ package com.supervision.service; import com.baomidou.mybatisplus.extension.service.IService; import com.supervision.model.RasaModelInfo; +import sun.reflect.generics.tree.VoidDescriptor; public interface RasaModeService extends IService<RasaModelInfo> { RasaModelInfo queryByModelId(String modelId); + + + RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo); } diff --git a/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java b/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java index 20d2e70a..cf94a66b 100644 --- a/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java +++ b/virtual-patient-model/src/main/java/com/supervision/service/impl/RasaModeServiceImpl.java @@ -1,7 +1,9 @@ package com.supervision.service.impl; +import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.exception.BusinessException; import com.supervision.mapper.RasaModelInfoMapper; import com.supervision.model.RasaModelInfo; import com.supervision.service.RasaModeService; @@ -18,4 +20,19 @@ public class RasaModeServiceImpl extends ServiceImpl<RasaModelInfoMapper,RasaMod return getOne(queryWrapper); } + + @Override + public RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo) { + + if (StrUtil.isEmpty(rasaModelInfo.getModelId())){ + throw new BusinessException("modelId is not allow empty..."); + } + RasaModelInfo dbModelInfo = this.queryByModelId(rasaModelInfo.getModelId()); + if (null != dbModelInfo && StrUtil.isNotEmpty(dbModelInfo.getId())){ + rasaModelInfo.setId(dbModelInfo.getId()); + super.updateById(rasaModelInfo); + } + super.save(rasaModelInfo); + return rasaModelInfo; + } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java index 3743290e..5bdf2db7 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaCmdController.java @@ -1,13 +1,12 @@ package com.supervision.rasa.controller; import cn.hutool.core.util.StrUtil; -import com.supervision.rasa.pojo.vo.RasaArgument; +import com.supervision.rasa.pojo.vo.RasaArgumentVo; import com.supervision.rasa.service.RasaCmdService; import com.supervision.exception.BusinessException; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import lombok.RequiredArgsConstructor; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import java.io.*; @@ -19,12 +18,11 @@ import java.util.concurrent.*; @RequiredArgsConstructor public class RasaCmdController { - @Autowired - private RasaCmdService rasaCmdService; + private final RasaCmdService rasaCmdService; @ApiOperation("执行训练shell命令") @PostMapping("/trainExec") - public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { + public String trainExec(@RequestBody RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { if (StrUtil.isEmpty(argument.getFixedModelName())){ throw new BusinessException("fixedModelName参数不能为空!"); @@ -40,7 +38,7 @@ public class RasaCmdController { @ApiOperation("执行启动shell命令") @PostMapping("/runExec") - public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { + public String runExec(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { if (StrUtil.isEmpty(argument.getFixedModelName())){ throw new BusinessException("fixedModelName参数不能为空!"); @@ -57,4 +55,15 @@ public class RasaCmdController { } + @ApiOperation("执行启动shell命令") + @PostMapping("/test") + public String test(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { + + + rasaCmdService.test(); + + return " dd"; + + } + } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaFileController.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaFileController.java index eafcf0f3..bd51f47a 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaFileController.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaFileController.java @@ -4,6 +4,7 @@ package com.supervision.rasa.controller; import com.supervision.rasa.service.RasaFileService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; +import jdk.internal.org.objectweb.asm.tree.FieldInsnNode; import lombok.RequiredArgsConstructor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -17,8 +18,7 @@ import java.io.IOException; @RequiredArgsConstructor public class RasaFileController { - @Autowired - private RasaFileService rasaFileService; + private final RasaFileService rasaFileService; @ApiOperation("接受并保存rasa文件") @PostMapping("/saveRasaFile") diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaTalkController.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaTalkController.java index 59f33fef..4f4f236e 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaTalkController.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/RasaTalkController.java @@ -1,30 +1,30 @@ package com.supervision.rasa.controller; +import com.supervision.rasa.pojo.vo.RasaTalkVo; import com.supervision.rasa.service.RasaTalkService; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import lombok.RequiredArgsConstructor; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; import java.util.List; -@Api(tags = "rasa文件保存") +@Api(tags = "ras对话服务") @RestController -@RequestMapping("rasaFile") +@RequestMapping("rasa") @RequiredArgsConstructor public class RasaTalkController { - @Autowired - private RasaTalkService rasaTalkService; + private final RasaTalkService rasaTalkService; @ApiOperation("rasa对话") - @GetMapping("talkRasa") - public List<String> talkRasa(String question, String sessionId){ + @PostMapping("talkRasa") + public List<String> talkRasa(@RequestBody RasaTalkVo rasaTalkVo){ - return null; + return rasaTalkService.talkRasa(rasaTalkVo); } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java similarity index 92% rename from virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java rename to virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java index 504811e2..f45f8827 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgument.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaArgumentVo.java @@ -3,7 +3,7 @@ package com.supervision.rasa.pojo.vo; import lombok.Data; @Data -public class RasaArgument { +public class RasaArgumentVo { private String config; private String data; diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java new file mode 100644 index 00000000..0b20ebdc --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/vo/RasaTalkVo.java @@ -0,0 +1,22 @@ +package com.supervision.rasa.pojo.vo; + +import lombok.Data; + +@Data +public class RasaTalkVo { + + /** + * 问题 + */ + private String question; + /** + * 会话标识 + */ + private String sessionId; + + /** + * 模型id + */ + private String modelId; + +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java index 908132fa..fca00ca2 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaCmdService.java @@ -1,6 +1,6 @@ package com.supervision.rasa.service; -import com.supervision.rasa.pojo.vo.RasaArgument; +import com.supervision.rasa.pojo.vo.RasaArgumentVo; import java.io.IOException; import java.util.concurrent.ExecutionException; @@ -8,8 +8,10 @@ import java.util.concurrent.TimeoutException; public interface RasaCmdService { - String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; + String trainExec(RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; - String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException; + String runExec( RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException; + + void test(); } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java index 2630f7fa..7386e8e2 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/RasaTalkService.java @@ -1,9 +1,11 @@ package com.supervision.rasa.service; +import com.supervision.rasa.pojo.vo.RasaTalkVo; + import java.util.List; public interface RasaTalkService { - List<String> talkRasa(String question, String sessionId) ; + List<String> talkRasa(RasaTalkVo rasaTalkVo) ; } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java index 0b9ccb45..a202b365 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaCmdServiceImpl.java @@ -3,14 +3,17 @@ package com.supervision.rasa.service.impl; import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.ListUtil; import cn.hutool.core.util.StrUtil; +import com.supervision.model.RasaModelInfo; import com.supervision.rasa.config.ThreadPoolExecutorConfig; -import com.supervision.rasa.pojo.vo.RasaArgument; +import com.supervision.rasa.pojo.vo.RasaArgumentVo; import com.supervision.rasa.service.RasaCmdService; +import com.supervision.rasa.util.PortUtil; import com.supervision.service.RasaModeService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.io.*; import java.util.ArrayList; @@ -51,8 +54,12 @@ public class RasaCmdServiceImpl implements RasaCmdService { private final RasaModeService rasaModeService; + private final String TRAN_SUCCESS_MESSAGE = "Your Rasa model is trained and saved at"; + private final String RUN_SUCCESS_MESSAGE = "Rasa server is up and running"; + @Override - public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { + @Transactional + public String trainExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { String domain = dataPath+"domain.yml"; List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath); @@ -62,41 +69,60 @@ public class RasaCmdServiceImpl implements RasaCmdService { List<String> outMessage = execCmd(cmds, s -> false, 300); + //保存 模型信息 + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setModelId(argument.getModelId()); + rasaModelInfo.setTranStatus(trainIsSuccess(outMessage)?1:0); + rasaModelInfo.setServerStatus(-1); + rasaModeService.saveOrUpdateByModelId(rasaModelInfo); - boolean isSuccess = trainIsSuccess(outMessage); return String.join("\r\n",outMessage); } - private boolean trainIsSuccess(List<String> messageList){ - - if (CollectionUtil.isEmpty(messageList)){ - return false; - } - String keyWord = "Your Rasa model is trained and saved at"; - if (StrUtil.isEmpty(keyWord)){ - return false; - } - return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); - } - - private boolean containKey(List<String> messageList,String key){ - return true; - } - - @Override - public String runExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { - + public String runExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException { + + + // 1. 查找模型信息是否为空 + RasaModelInfo dbRasaModelInfo = rasaModeService.queryByModelId(argument.getModelId()); + int unusedPort; + if (null == dbRasaModelInfo || null == dbRasaModelInfo.getPort()){ + unusedPort = PortUtil.findUnusedPort(5050, 100000); + log.info("runExec findUnusedPort is : {}",unusedPort); + argument.setPort(String.valueOf(unusedPort)); + }else { + argument.setPort(String.valueOf(dbRasaModelInfo.getPort())); + // todo:杀掉该端口对应的进程 + } + // 2. 运行模型 String mPath = modelsPath+argument.getFixedModelName()+".tar.gz"; List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort()); log.info("runExec cmd : {}",StrUtil.join(" ",cmds)); - return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90)); + + List<String> outMessageList = execCmd(cmds, s -> StrUtil.isNotBlank(s) && s.contains(RUN_SUCCESS_MESSAGE), 300); + + // 3. 更新模型信息 + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setModelId(argument.getModelId()); + rasaModelInfo.setPort(0); + rasaModelInfo.setServerStatus(runIsSuccess(outMessageList)?1:0); + rasaModeService.saveOrUpdateByModelId(rasaModelInfo); + + return String.join("\r\n",outMessageList); } + @Override + public void test() { + RasaModelInfo rasaModelInfo = new RasaModelInfo(); + rasaModelInfo.setId("1"); + rasaModelInfo.setModelId("1"); + rasaModelInfo.setTranStatus(1); + rasaModeService.saveOrUpdate(rasaModelInfo); + } private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException { @@ -127,4 +153,29 @@ public class RasaCmdServiceImpl implements RasaCmdService { return future.get(timeOut, TimeUnit.SECONDS); } + + + + private boolean trainIsSuccess(List<String> messageList){ + + return containKey(messageList,TRAN_SUCCESS_MESSAGE); + } + + + private boolean runIsSuccess(List<String> messageList){ + + return containKey(messageList,RUN_SUCCESS_MESSAGE); + } + + private boolean containKey(List<String> messageList,String keyWord){ + + if (CollectionUtil.isEmpty(messageList)){ + return false; + } + if (StrUtil.isEmpty(keyWord)){ + return false; + } + return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord)); + } + } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java index b8c4624d..384b5638 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java @@ -1,16 +1,23 @@ package com.supervision.rasa.service.impl; +import cn.hutool.core.text.CharSequenceUtil; +import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONUtil; +import com.supervision.exception.BusinessException; +import com.supervision.model.RasaModelInfo; import com.supervision.rasa.pojo.dto.RasaReqDTO; import com.supervision.rasa.pojo.dto.RasaResDTO; +import com.supervision.rasa.pojo.vo.RasaTalkVo; import com.supervision.rasa.service.RasaTalkService; +import com.supervision.service.RasaModeService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.util.List; +import java.util.function.Predicate; import java.util.stream.Collectors; @Service @Slf4j @@ -19,14 +26,32 @@ public class RasaTalkServiceImpl implements RasaTalkService { @Value("${rasa.url}") private String rasaUrl; + + private final RasaModeService rasaModeService; @Override - public List<String> talkRasa(String question, String sessionId) { + public List<String> talkRasa(RasaTalkVo rasaTalkVo) { + + if (StrUtil.isEmpty(rasaTalkVo.getModelId())){ + throw new BusinessException("modelId is not allow empty "); + } + RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId(rasaTalkVo.getModelId()); + if (null == rasaModelInfo){ + throw new BusinessException(" not find model info , check if modelId is available "); + }else if (null == rasaModelInfo.getPort()){ + throw new BusinessException(" rasa model port is empty , check if rasa is start "); + } RasaReqDTO rasaReqDTO = new RasaReqDTO(); - rasaReqDTO.setSender(sessionId); - rasaReqDTO.setMessage(question); - String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO)); + rasaReqDTO.setSender(rasaTalkVo.getSessionId()); + rasaReqDTO.setMessage(rasaTalkVo.getQuestion()); + String post = HttpUtil.post(getRasaUrl(rasaModelInfo.getPort()), JSONUtil.toJsonStr(rasaReqDTO)); List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class); + return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); } + + private String getRasaUrl(int port){ + + return StrUtil.format(rasaUrl, port); + } } diff --git a/virtual-patient-rasa/src/main/resources/application.yml b/virtual-patient-rasa/src/main/resources/application.yml index dcbc4595..733a2604 100644 --- a/virtual-patient-rasa/src/main/resources/application.yml +++ b/virtual-patient-rasa/src/main/resources/application.yml @@ -18,7 +18,7 @@ rasa: config: /rasa/config-local.yml # 启动rasa需要的配置文件,在配置文件中配置 train-shell: /home/rasa_manage/train.sh run-shell: /home/rasa_manage/run.sh - url: 192.168.10.137:5005/webhooks/rest/webhook + url: 192.168.10.137:{}/webhooks/rest/webhook spring: profiles: active: dev