rasa:集成会话接口

dev_v1.0.1
xueqingkun 2 years ago
parent 05aca5c443
commit 28cb30cbc2

@ -2,8 +2,12 @@ package com.supervision.service;
import com.baomidou.mybatisplus.extension.service.IService; import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.model.RasaModelInfo; import com.supervision.model.RasaModelInfo;
import sun.reflect.generics.tree.VoidDescriptor;
public interface RasaModeService extends IService<RasaModelInfo> { public interface RasaModeService extends IService<RasaModelInfo> {
RasaModelInfo queryByModelId(String modelId); RasaModelInfo queryByModelId(String modelId);
RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo);
} }

@ -1,7 +1,9 @@
package com.supervision.service.impl; package com.supervision.service.impl;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.exception.BusinessException;
import com.supervision.mapper.RasaModelInfoMapper; import com.supervision.mapper.RasaModelInfoMapper;
import com.supervision.model.RasaModelInfo; import com.supervision.model.RasaModelInfo;
import com.supervision.service.RasaModeService; import com.supervision.service.RasaModeService;
@ -18,4 +20,19 @@ public class RasaModeServiceImpl extends ServiceImpl<RasaModelInfoMapper,RasaMod
return getOne(queryWrapper); return getOne(queryWrapper);
} }
@Override
public RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo) {
if (StrUtil.isEmpty(rasaModelInfo.getModelId())){
throw new BusinessException("modelId is not allow empty...");
}
RasaModelInfo dbModelInfo = this.queryByModelId(rasaModelInfo.getModelId());
if (null != dbModelInfo && StrUtil.isNotEmpty(dbModelInfo.getId())){
rasaModelInfo.setId(dbModelInfo.getId());
super.updateById(rasaModelInfo);
}
super.save(rasaModelInfo);
return rasaModelInfo;
}
} }

@ -1,13 +1,12 @@
package com.supervision.rasa.controller; package com.supervision.rasa.controller;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.supervision.rasa.pojo.vo.RasaArgument; import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import com.supervision.rasa.service.RasaCmdService; import com.supervision.rasa.service.RasaCmdService;
import com.supervision.exception.BusinessException; import com.supervision.exception.BusinessException;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.io.*; import java.io.*;
@ -19,12 +18,11 @@ import java.util.concurrent.*;
@RequiredArgsConstructor @RequiredArgsConstructor
public class RasaCmdController { public class RasaCmdController {
@Autowired private final RasaCmdService rasaCmdService;
private RasaCmdService rasaCmdService;
@ApiOperation("执行训练shell命令") @ApiOperation("执行训练shell命令")
@PostMapping("/trainExec") @PostMapping("/trainExec")
public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { public String trainExec(@RequestBody RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){ if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空"); throw new BusinessException("fixedModelName参数不能为空");
@ -40,7 +38,7 @@ public class RasaCmdController {
@ApiOperation("执行启动shell命令") @ApiOperation("执行启动shell命令")
@PostMapping("/runExec") @PostMapping("/runExec")
public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { public String runExec(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){ if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空"); throw new BusinessException("fixedModelName参数不能为空");
@ -57,4 +55,15 @@ public class RasaCmdController {
} }
@ApiOperation("执行启动shell命令")
@PostMapping("/test")
public String test(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
rasaCmdService.test();
return " dd";
}
} }

@ -4,6 +4,7 @@ package com.supervision.rasa.controller;
import com.supervision.rasa.service.RasaFileService; import com.supervision.rasa.service.RasaFileService;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import jdk.internal.org.objectweb.asm.tree.FieldInsnNode;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@ -17,8 +18,7 @@ import java.io.IOException;
@RequiredArgsConstructor @RequiredArgsConstructor
public class RasaFileController { public class RasaFileController {
@Autowired private final RasaFileService rasaFileService;
private RasaFileService rasaFileService;
@ApiOperation("接受并保存rasa文件") @ApiOperation("接受并保存rasa文件")
@PostMapping("/saveRasaFile") @PostMapping("/saveRasaFile")

@ -1,30 +1,30 @@
package com.supervision.rasa.controller; package com.supervision.rasa.controller;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import com.supervision.rasa.service.RasaTalkService; import com.supervision.rasa.service.RasaTalkService;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import java.util.List; import java.util.List;
@Api(tags = "rasa文件保存") @Api(tags = "ras对话服务")
@RestController @RestController
@RequestMapping("rasaFile") @RequestMapping("rasa")
@RequiredArgsConstructor @RequiredArgsConstructor
public class RasaTalkController { public class RasaTalkController {
@Autowired private final RasaTalkService rasaTalkService;
private RasaTalkService rasaTalkService;
@ApiOperation("rasa对话") @ApiOperation("rasa对话")
@GetMapping("talkRasa") @PostMapping("talkRasa")
public List<String> talkRasa(String question, String sessionId){ public List<String> talkRasa(@RequestBody RasaTalkVo rasaTalkVo){
return null; return rasaTalkService.talkRasa(rasaTalkVo);
} }
} }

@ -3,7 +3,7 @@ package com.supervision.rasa.pojo.vo;
import lombok.Data; import lombok.Data;
@Data @Data
public class RasaArgument { public class RasaArgumentVo {
private String config; private String config;
private String data; private String data;

@ -0,0 +1,22 @@
package com.supervision.rasa.pojo.vo;
import lombok.Data;
@Data
public class RasaTalkVo {
/**
*
*/
private String question;
/**
*
*/
private String sessionId;
/**
* id
*/
private String modelId;
}

@ -1,6 +1,6 @@
package com.supervision.rasa.service; package com.supervision.rasa.service;
import com.supervision.rasa.pojo.vo.RasaArgument; import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import java.io.IOException; import java.io.IOException;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
@ -8,8 +8,10 @@ import java.util.concurrent.TimeoutException;
public interface RasaCmdService { public interface RasaCmdService {
String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; String trainExec(RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException; String runExec( RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException;
void test();
} }

@ -1,9 +1,11 @@
package com.supervision.rasa.service; package com.supervision.rasa.service;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import java.util.List; import java.util.List;
public interface RasaTalkService { public interface RasaTalkService {
List<String> talkRasa(String question, String sessionId) ; List<String> talkRasa(RasaTalkVo rasaTalkVo) ;
} }

@ -3,14 +3,17 @@ package com.supervision.rasa.service.impl;
import cn.hutool.core.collection.CollectionUtil; import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.ListUtil; import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.config.ThreadPoolExecutorConfig; import com.supervision.rasa.config.ThreadPoolExecutorConfig;
import com.supervision.rasa.pojo.vo.RasaArgument; import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import com.supervision.rasa.service.RasaCmdService; import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.util.PortUtil;
import com.supervision.service.RasaModeService; import com.supervision.service.RasaModeService;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.*; import java.io.*;
import java.util.ArrayList; import java.util.ArrayList;
@ -51,8 +54,12 @@ public class RasaCmdServiceImpl implements RasaCmdService {
private final RasaModeService rasaModeService; private final RasaModeService rasaModeService;
private final String TRAN_SUCCESS_MESSAGE = "Your Rasa model is trained and saved at";
private final String RUN_SUCCESS_MESSAGE = "Rasa server is up and running";
@Override @Override
public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { @Transactional
public String trainExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
String domain = dataPath+"domain.yml"; String domain = dataPath+"domain.yml";
List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath); List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath);
@ -62,41 +69,60 @@ public class RasaCmdServiceImpl implements RasaCmdService {
List<String> outMessage = execCmd(cmds, s -> false, 300); List<String> outMessage = execCmd(cmds, s -> false, 300);
//保存 模型信息
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setModelId(argument.getModelId());
rasaModelInfo.setTranStatus(trainIsSuccess(outMessage)?1:0);
rasaModelInfo.setServerStatus(-1);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
boolean isSuccess = trainIsSuccess(outMessage);
return String.join("\r\n",outMessage); return String.join("\r\n",outMessage);
} }
private boolean trainIsSuccess(List<String> messageList){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
String keyWord = "Your Rasa model is trained and saved at";
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
private boolean containKey(List<String> messageList,String key){
return true;
}
@Override @Override
public String runExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { public String runExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
// 1. 查找模型信息是否为空
RasaModelInfo dbRasaModelInfo = rasaModeService.queryByModelId(argument.getModelId());
int unusedPort;
if (null == dbRasaModelInfo || null == dbRasaModelInfo.getPort()){
unusedPort = PortUtil.findUnusedPort(5050, 100000);
log.info("runExec findUnusedPort is : {}",unusedPort);
argument.setPort(String.valueOf(unusedPort));
}else {
argument.setPort(String.valueOf(dbRasaModelInfo.getPort()));
// todo:杀掉该端口对应的进程
}
// 2. 运行模型
String mPath = modelsPath+argument.getFixedModelName()+".tar.gz"; String mPath = modelsPath+argument.getFixedModelName()+".tar.gz";
List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort()); List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort());
log.info("runExec cmd : {}",StrUtil.join(" ",cmds)); log.info("runExec cmd : {}",StrUtil.join(" ",cmds));
return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90));
List<String> outMessageList = execCmd(cmds, s -> StrUtil.isNotBlank(s) && s.contains(RUN_SUCCESS_MESSAGE), 300);
// 3. 更新模型信息
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setModelId(argument.getModelId());
rasaModelInfo.setPort(0);
rasaModelInfo.setServerStatus(runIsSuccess(outMessageList)?1:0);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
return String.join("\r\n",outMessageList);
} }
@Override
public void test() {
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setId("1");
rasaModelInfo.setModelId("1");
rasaModelInfo.setTranStatus(1);
rasaModeService.saveOrUpdate(rasaModelInfo);
}
private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException { private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException {
@ -127,4 +153,29 @@ public class RasaCmdServiceImpl implements RasaCmdService {
return future.get(timeOut, TimeUnit.SECONDS); return future.get(timeOut, TimeUnit.SECONDS);
} }
private boolean trainIsSuccess(List<String> messageList){
return containKey(messageList,TRAN_SUCCESS_MESSAGE);
}
private boolean runIsSuccess(List<String> messageList){
return containKey(messageList,RUN_SUCCESS_MESSAGE);
}
private boolean containKey(List<String> messageList,String keyWord){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
} }

@ -1,16 +1,23 @@
package com.supervision.rasa.service.impl; package com.supervision.rasa.service.impl;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.pojo.dto.RasaReqDTO; import com.supervision.rasa.pojo.dto.RasaReqDTO;
import com.supervision.rasa.pojo.dto.RasaResDTO; import com.supervision.rasa.pojo.dto.RasaResDTO;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import com.supervision.rasa.service.RasaTalkService; import com.supervision.rasa.service.RasaTalkService;
import com.supervision.service.RasaModeService;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@Service @Service
@Slf4j @Slf4j
@ -19,14 +26,32 @@ public class RasaTalkServiceImpl implements RasaTalkService {
@Value("${rasa.url}") @Value("${rasa.url}")
private String rasaUrl; private String rasaUrl;
private final RasaModeService rasaModeService;
@Override @Override
public List<String> talkRasa(String question, String sessionId) { public List<String> talkRasa(RasaTalkVo rasaTalkVo) {
if (StrUtil.isEmpty(rasaTalkVo.getModelId())){
throw new BusinessException("modelId is not allow empty ");
}
RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId(rasaTalkVo.getModelId());
if (null == rasaModelInfo){
throw new BusinessException(" not find model info , check if modelId is available ");
}else if (null == rasaModelInfo.getPort()){
throw new BusinessException(" rasa model port is empty , check if rasa is start ");
}
RasaReqDTO rasaReqDTO = new RasaReqDTO(); RasaReqDTO rasaReqDTO = new RasaReqDTO();
rasaReqDTO.setSender(sessionId); rasaReqDTO.setSender(rasaTalkVo.getSessionId());
rasaReqDTO.setMessage(question); rasaReqDTO.setMessage(rasaTalkVo.getQuestion());
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO)); String post = HttpUtil.post(getRasaUrl(rasaModelInfo.getPort()), JSONUtil.toJsonStr(rasaReqDTO));
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class); List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
} }
private String getRasaUrl(int port){
return StrUtil.format(rasaUrl, port);
}
} }

@ -18,7 +18,7 @@ rasa:
config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置 config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置
train-shell: /home/rasa_manage/train.sh train-shell: /home/rasa_manage/train.sh
run-shell: /home/rasa_manage/run.sh run-shell: /home/rasa_manage/run.sh
url: 192.168.10.137:5005/webhooks/rest/webhook url: 192.168.10.137:{}/webhooks/rest/webhook
spring: spring:
profiles: profiles:
active: dev active: dev

Loading…
Cancel
Save