rasa:集成会话接口

dev_v1.0.1
xueqingkun 2 years ago
parent 05aca5c443
commit 28cb30cbc2

@ -2,8 +2,12 @@ package com.supervision.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.model.RasaModelInfo;
import sun.reflect.generics.tree.VoidDescriptor;
public interface RasaModeService extends IService<RasaModelInfo> {
RasaModelInfo queryByModelId(String modelId);
RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo);
}

@ -1,7 +1,9 @@
package com.supervision.service.impl;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.exception.BusinessException;
import com.supervision.mapper.RasaModelInfoMapper;
import com.supervision.model.RasaModelInfo;
import com.supervision.service.RasaModeService;
@ -18,4 +20,19 @@ public class RasaModeServiceImpl extends ServiceImpl<RasaModelInfoMapper,RasaMod
return getOne(queryWrapper);
}
@Override
public RasaModelInfo saveOrUpdateByModelId(RasaModelInfo rasaModelInfo) {
if (StrUtil.isEmpty(rasaModelInfo.getModelId())){
throw new BusinessException("modelId is not allow empty...");
}
RasaModelInfo dbModelInfo = this.queryByModelId(rasaModelInfo.getModelId());
if (null != dbModelInfo && StrUtil.isNotEmpty(dbModelInfo.getId())){
rasaModelInfo.setId(dbModelInfo.getId());
super.updateById(rasaModelInfo);
}
super.save(rasaModelInfo);
return rasaModelInfo;
}
}

@ -1,13 +1,12 @@
package com.supervision.rasa.controller;
import cn.hutool.core.util.StrUtil;
import com.supervision.rasa.pojo.vo.RasaArgument;
import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import com.supervision.rasa.service.RasaCmdService;
import com.supervision.exception.BusinessException;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.io.*;
@ -19,12 +18,11 @@ import java.util.concurrent.*;
@RequiredArgsConstructor
public class RasaCmdController {
@Autowired
private RasaCmdService rasaCmdService;
private final RasaCmdService rasaCmdService;
@ApiOperation("执行训练shell命令")
@PostMapping("/trainExec")
public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
public String trainExec(@RequestBody RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空");
@ -40,7 +38,7 @@ public class RasaCmdController {
@ApiOperation("执行启动shell命令")
@PostMapping("/runExec")
public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
public String runExec(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空");
@ -57,4 +55,15 @@ public class RasaCmdController {
}
@ApiOperation("执行启动shell命令")
@PostMapping("/test")
public String test(@RequestBody RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
rasaCmdService.test();
return " dd";
}
}

@ -4,6 +4,7 @@ package com.supervision.rasa.controller;
import com.supervision.rasa.service.RasaFileService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import jdk.internal.org.objectweb.asm.tree.FieldInsnNode;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@ -17,8 +18,7 @@ import java.io.IOException;
@RequiredArgsConstructor
public class RasaFileController {
@Autowired
private RasaFileService rasaFileService;
private final RasaFileService rasaFileService;
@ApiOperation("接受并保存rasa文件")
@PostMapping("/saveRasaFile")

@ -1,30 +1,30 @@
package com.supervision.rasa.controller;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import com.supervision.rasa.service.RasaTalkService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Api(tags = "rasa文件保存")
@Api(tags = "ras对话服务")
@RestController
@RequestMapping("rasaFile")
@RequestMapping("rasa")
@RequiredArgsConstructor
public class RasaTalkController {
@Autowired
private RasaTalkService rasaTalkService;
private final RasaTalkService rasaTalkService;
@ApiOperation("rasa对话")
@GetMapping("talkRasa")
public List<String> talkRasa(String question, String sessionId){
@PostMapping("talkRasa")
public List<String> talkRasa(@RequestBody RasaTalkVo rasaTalkVo){
return null;
return rasaTalkService.talkRasa(rasaTalkVo);
}
}

@ -3,7 +3,7 @@ package com.supervision.rasa.pojo.vo;
import lombok.Data;
@Data
public class RasaArgument {
public class RasaArgumentVo {
private String config;
private String data;

@ -0,0 +1,22 @@
package com.supervision.rasa.pojo.vo;
import lombok.Data;
@Data
public class RasaTalkVo {
/**
*
*/
private String question;
/**
*
*/
private String sessionId;
/**
* id
*/
private String modelId;
}

@ -1,6 +1,6 @@
package com.supervision.rasa.service;
import com.supervision.rasa.pojo.vo.RasaArgument;
import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
@ -8,8 +8,10 @@ import java.util.concurrent.TimeoutException;
public interface RasaCmdService {
String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
String trainExec(RasaArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException;
String runExec( RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException;
void test();
}

@ -1,9 +1,11 @@
package com.supervision.rasa.service;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import java.util.List;
public interface RasaTalkService {
List<String> talkRasa(String question, String sessionId) ;
List<String> talkRasa(RasaTalkVo rasaTalkVo) ;
}

@ -3,14 +3,17 @@ package com.supervision.rasa.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.config.ThreadPoolExecutorConfig;
import com.supervision.rasa.pojo.vo.RasaArgument;
import com.supervision.rasa.pojo.vo.RasaArgumentVo;
import com.supervision.rasa.service.RasaCmdService;
import com.supervision.rasa.util.PortUtil;
import com.supervision.service.RasaModeService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.*;
import java.util.ArrayList;
@ -51,8 +54,12 @@ public class RasaCmdServiceImpl implements RasaCmdService {
private final RasaModeService rasaModeService;
private final String TRAN_SUCCESS_MESSAGE = "Your Rasa model is trained and saved at";
private final String RUN_SUCCESS_MESSAGE = "Rasa server is up and running";
@Override
public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
@Transactional
public String trainExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
String domain = dataPath+"domain.yml";
List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath);
@ -62,41 +69,60 @@ public class RasaCmdServiceImpl implements RasaCmdService {
List<String> outMessage = execCmd(cmds, s -> false, 300);
//保存 模型信息
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setModelId(argument.getModelId());
rasaModelInfo.setTranStatus(trainIsSuccess(outMessage)?1:0);
rasaModelInfo.setServerStatus(-1);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
boolean isSuccess = trainIsSuccess(outMessage);
return String.join("\r\n",outMessage);
}
private boolean trainIsSuccess(List<String> messageList){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
String keyWord = "Your Rasa model is trained and saved at";
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
private boolean containKey(List<String> messageList,String key){
return true;
}
@Override
public String runExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
public String runExec(RasaArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
// 1. 查找模型信息是否为空
RasaModelInfo dbRasaModelInfo = rasaModeService.queryByModelId(argument.getModelId());
int unusedPort;
if (null == dbRasaModelInfo || null == dbRasaModelInfo.getPort()){
unusedPort = PortUtil.findUnusedPort(5050, 100000);
log.info("runExec findUnusedPort is : {}",unusedPort);
argument.setPort(String.valueOf(unusedPort));
}else {
argument.setPort(String.valueOf(dbRasaModelInfo.getPort()));
// todo:杀掉该端口对应的进程
}
// 2. 运行模型
String mPath = modelsPath+argument.getFixedModelName()+".tar.gz";
List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort());
log.info("runExec cmd : {}",StrUtil.join(" ",cmds));
return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90));
List<String> outMessageList = execCmd(cmds, s -> StrUtil.isNotBlank(s) && s.contains(RUN_SUCCESS_MESSAGE), 300);
// 3. 更新模型信息
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setModelId(argument.getModelId());
rasaModelInfo.setPort(0);
rasaModelInfo.setServerStatus(runIsSuccess(outMessageList)?1:0);
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
return String.join("\r\n",outMessageList);
}
@Override
public void test() {
RasaModelInfo rasaModelInfo = new RasaModelInfo();
rasaModelInfo.setId("1");
rasaModelInfo.setModelId("1");
rasaModelInfo.setTranStatus(1);
rasaModeService.saveOrUpdate(rasaModelInfo);
}
private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException {
@ -127,4 +153,29 @@ public class RasaCmdServiceImpl implements RasaCmdService {
return future.get(timeOut, TimeUnit.SECONDS);
}
private boolean trainIsSuccess(List<String> messageList){
return containKey(messageList,TRAN_SUCCESS_MESSAGE);
}
private boolean runIsSuccess(List<String> messageList){
return containKey(messageList,RUN_SUCCESS_MESSAGE);
}
private boolean containKey(List<String> messageList,String keyWord){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
}

@ -1,16 +1,23 @@
package com.supervision.rasa.service.impl;
import cn.hutool.core.text.CharSequenceUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.pojo.dto.RasaReqDTO;
import com.supervision.rasa.pojo.dto.RasaResDTO;
import com.supervision.rasa.pojo.vo.RasaTalkVo;
import com.supervision.rasa.service.RasaTalkService;
import com.supervision.service.RasaModeService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@Service
@Slf4j
@ -19,14 +26,32 @@ public class RasaTalkServiceImpl implements RasaTalkService {
@Value("${rasa.url}")
private String rasaUrl;
private final RasaModeService rasaModeService;
@Override
public List<String> talkRasa(String question, String sessionId) {
public List<String> talkRasa(RasaTalkVo rasaTalkVo) {
if (StrUtil.isEmpty(rasaTalkVo.getModelId())){
throw new BusinessException("modelId is not allow empty ");
}
RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId(rasaTalkVo.getModelId());
if (null == rasaModelInfo){
throw new BusinessException(" not find model info , check if modelId is available ");
}else if (null == rasaModelInfo.getPort()){
throw new BusinessException(" rasa model port is empty , check if rasa is start ");
}
RasaReqDTO rasaReqDTO = new RasaReqDTO();
rasaReqDTO.setSender(sessionId);
rasaReqDTO.setMessage(question);
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO));
rasaReqDTO.setSender(rasaTalkVo.getSessionId());
rasaReqDTO.setMessage(rasaTalkVo.getQuestion());
String post = HttpUtil.post(getRasaUrl(rasaModelInfo.getPort()), JSONUtil.toJsonStr(rasaReqDTO));
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
}
private String getRasaUrl(int port){
return StrUtil.format(rasaUrl, port);
}
}

@ -18,7 +18,7 @@ rasa:
config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置
train-shell: /home/rasa_manage/train.sh
run-shell: /home/rasa_manage/run.sh
url: 192.168.10.137:5005/webhooks/rest/webhook
url: 192.168.10.137:{}/webhooks/rest/webhook
spring:
profiles:
active: dev

Loading…
Cancel
Save