rasa 草稿

dev_v1.0.1
xueqingkun 2 years ago
parent 7245b593b2
commit 32c64a742d

@ -0,0 +1,11 @@
package com.supervision.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.supervision.model.Process;
import com.supervision.model.RasaModelInfo;
/**
*
*/
public interface RasaModelInfoMapper extends BaseMapper<RasaModelInfo> {
}

@ -0,0 +1,83 @@
package com.supervision.model;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.io.Serializable;
import java.time.LocalDateTime;
/**
*
* @TableName vp_rasa_model_info
*/
@TableName(value ="vp_rasa_model_info")
@Data
@ApiModel
public class RasaModelInfo implements Serializable {
/**
*
*/
@TableId
private String id;
/**
* ID vp_disease
*/
@ApiModelProperty("模型ID")
private String modelId;
/**
*
*/
@ApiModelProperty("中文注释")
private String description;
/**
*
*/
@ApiModelProperty("模型对应的端口号")
private Integer port;
/**
* 0: 1:
*/
@ApiModelProperty("训练状态")
private Integer tranStatus;
/**
* -1: 0: 1:
*/
@ApiModelProperty("启动状态")
private Integer serverStatus;
/**
* ID
*/
private String createUserId;
/**
*
*/
private LocalDateTime createTime;
/**
*
*/
private String updateUserId;
/**
*
*/
private LocalDateTime updateTime;
@TableField(exist = false)
private static final long serialVersionUID = 1L;
}

@ -0,0 +1,9 @@
package com.supervision.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.model.RasaModelInfo;
public interface RasaModeService extends IService<RasaModelInfo> {
RasaModelInfo queryByModelId(String modelId);
}

@ -0,0 +1,21 @@
package com.supervision.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.mapper.RasaModelInfoMapper;
import com.supervision.model.RasaModelInfo;
import com.supervision.service.RasaModeService;
import org.springframework.stereotype.Service;
@Service
public class RasaModeServiceImpl extends ServiceImpl<RasaModelInfoMapper,RasaModelInfo> implements RasaModeService {
@Override
public RasaModelInfo queryByModelId(String modelId){
LambdaQueryWrapper<RasaModelInfo> queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(RasaModelInfo::getModelId, modelId).last("LIMIT 1");
return getOne(queryWrapper);
}
}

@ -0,0 +1,26 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.mapper.RasaModelInfoMapper">
<resultMap id="BaseResultMap" type="com.supervision.model.RasaModelInfo">
<id property="id" column="id" jdbcType="VARCHAR"/>
<result property="modelId" column="model_id" jdbcType="VARCHAR"/>
<result property="description" column="description" jdbcType="VARCHAR"/>
<result property="port" column="port" jdbcType="INTEGER"/>
<result property="tranStatus" column="tran_status" jdbcType="INTEGER"/>
<result property="serverStatus" column="server_status" jdbcType="INTEGER"/>
<result property="createUserId" column="create_user_id" jdbcType="VARCHAR"/>
<result property="createTime" column="create_time" jdbcType="TIMESTAMP"/>
<result property="updateUserId" column="update_user_id" jdbcType="VARCHAR"/>
<result property="updateTime" column="update_time" jdbcType="TIMESTAMP"/>
</resultMap>
<sql id="Base_Column_List">
id,model_id,description,
port,tran_status,server_status,create_user_id,create_time,
update_user_id,update_time
</sql>
</mapper>

@ -27,6 +27,12 @@
<version>${project.version}</version> <version>${project.version}</version>
</dependency> </dependency>
<dependency>
<groupId>com.supervision</groupId>
<artifactId>virtual-patient-model</artifactId>
<version>${project.version}</version>
</dependency>
<dependency> <dependency>
<groupId>org.springframework.boot</groupId> <groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId> <artifactId>spring-boot-starter-web</artifactId>

@ -1,13 +1,15 @@
package com.superversion.rasa; package com.superversion.rasa;
import com.supervision.config.WebConfig; import com.supervision.config.WebConfig;
import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.FilterType; import org.springframework.context.annotation.FilterType;
@SpringBootApplication @SpringBootApplication
@ComponentScan(basePackages = {"com.superversion.rasa"},excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {WebConfig.class})) @MapperScan(basePackages = {"com.supervision.**.mapper"})
@ComponentScan(basePackages = {"com.superversion"},excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {WebConfig.class}))
public class VirtualPatientRasaApplication { public class VirtualPatientRasaApplication {
public static void main(String[] args) { public static void main(String[] args) {

@ -11,6 +11,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.io.*; import java.io.*;
import java.net.ServerSocket;
import java.util.concurrent.*; import java.util.concurrent.*;
@Api(tags = "rasa文件保存") @Api(tags = "rasa文件保存")
@ -26,6 +27,14 @@ public class RasaCmdController {
@PostMapping("/trainExec") @PostMapping("/trainExec")
public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException { public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空");
}
if (StrUtil.isEmpty(argument.getModelId())){
throw new BusinessException("modelId参数不能为空! ");
}
return rasaCmdService.trainExec(argument); return rasaCmdService.trainExec(argument);
} }
@ -34,6 +43,12 @@ public class RasaCmdController {
@PostMapping("/runExec") @PostMapping("/runExec")
public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
if (StrUtil.isEmpty(argument.getFixedModelName())){
throw new BusinessException("fixedModelName参数不能为空");
}
if (StrUtil.isEmpty(argument.getModelId())){
throw new BusinessException("modelId参数不能为空! ");
}
String outString = rasaCmdService.runExec(argument); String outString = rasaCmdService.runExec(argument);
if (StrUtil.isEmptyIfStr(outString) || !outString.contains("Rasa server is up and running")){ if (StrUtil.isEmptyIfStr(outString) || !outString.contains("Rasa server is up and running")){
throw new BusinessException("任务执行异常。详细日志:"+outString); throw new BusinessException("任务执行异常。详细日志:"+outString);
@ -42,5 +57,13 @@ public class RasaCmdController {
} }
@ApiOperation("执行启动shell命令")
@PostMapping("/test")
public String test(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
rasaCmdService.test();
return "outString";
}
} }

@ -13,5 +13,6 @@ public class RasaArgument {
private String enableApi;//enable-api private String enableApi;//enable-api
private String endpoints; private String endpoints;
private String port; private String port;
private String modelId;
} }

@ -11,4 +11,6 @@ public interface RasaCmdService {
String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException; String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException; String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException;
void test();
} }

@ -1,12 +1,20 @@
package com.superversion.rasa.service.impl; package com.superversion.rasa.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.collection.ListUtil; import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import com.superversion.rasa.config.ThreadPoolExecutorConfig; import com.superversion.rasa.config.ThreadPoolExecutorConfig;
import com.superversion.rasa.pojo.vo.RasaArgument; import com.superversion.rasa.pojo.vo.RasaArgument;
import com.superversion.rasa.service.RasaCmdService; import com.superversion.rasa.service.RasaCmdService;
import com.supervision.model.RasaModelInfo;
import com.supervision.model.User;
import com.supervision.service.RasaModeService;
import com.supervision.service.UserService;
import com.supervision.service.impl.UserServiceImpl;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.omg.CORBA.TRANSACTION_MODE;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -47,6 +55,8 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Value("${rasa.shell.work:/home/rasa_manage/}") @Value("${rasa.shell.work:/home/rasa_manage/}")
private String shellWork; private String shellWork;
private final UserService userService;
@Override @Override
public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException { public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
@ -55,9 +65,27 @@ public class RasaCmdServiceImpl implements RasaCmdService {
cmds.add(argument.getFixedModelName()); cmds.add(argument.getFixedModelName());
log.info("trainExec cmd : {}",StrUtil.join(" ",cmds)); log.info("trainExec cmd : {}",StrUtil.join(" ",cmds));
return String.join("\r\n",execCmd(cmds,s->false,90));
List<String> outMessage = execCmd(cmds, s -> false, 300);
boolean isSuccess = trainIsSuccess(outMessage);
return String.join("\r\n",outMessage);
} }
private boolean trainIsSuccess(List<String> messageList){
if (CollectionUtil.isEmpty(messageList)){
return false;
}
String keyWord = "Your Rasa model is trained and saved at";
if (StrUtil.isEmpty(keyWord)){
return false;
}
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
}
@Override @Override
@ -71,6 +99,13 @@ public class RasaCmdServiceImpl implements RasaCmdService {
return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90)); return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90));
} }
@Override
public void test() {
List<User> list = userService.list();
//RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId("1");
System.out.println(list);
}
private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException { private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException {

@ -0,0 +1,40 @@
package com.superversion.rasa.util;
import jdk.nashorn.internal.runtime.regexp.JoniRegExp;
import lombok.extern.slf4j.Slf4j;
import java.io.IOException;
import java.net.Socket;
import java.util.TreeMap;
import java.util.function.Predicate;
@Slf4j
public class PortUtil {
public static boolean portIsActive(int port){
try {
Socket socket = new Socket("localhost", port);
socket.close();
return true;
} catch (IOException e) {
log.info("portIsActive: port:{} connect error",port);
}
return false;
}
public static int findUnusedPort(int minPort,int maxPort){
if (maxPort < minPort){
return -1;
}
for (int port = minPort; port < maxPort; port++) {
if (!portIsActive(port)){
return port;
}
}
return -1;
}
}

@ -7,10 +7,10 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException; import com.supervision.exception.BusinessException;
import com.supervision.model.*; import com.supervision.model.*;
import com.supervision.rasa.dto.train.DomainYmlTemplate; import com.supervision.pojo.rasa.train.DomainYmlTemplate;
import com.supervision.rasa.dto.train.QuestionAnswerDTO; import com.supervision.pojo.rasa.train.NluYmlTemplate;
import com.supervision.rasa.dto.train.NluYmlTemplate; import com.supervision.pojo.rasa.train.QuestionAnswerDTO;
import com.supervision.rasa.dto.train.RuleYmlTemplate; import com.supervision.pojo.rasa.train.RuleYmlTemplate;
import com.supervision.service.*; import com.supervision.service.*;
import freemarker.template.Configuration; import freemarker.template.Configuration;
import freemarker.template.Template; import freemarker.template.Template;

Loading…
Cancel
Save