rasa 优化训练和启动接口

dev_v1.0.1
xueqingkun 2 years ago
parent 3b2d6c3adb
commit 7245b593b2

@ -11,7 +11,6 @@
<groupId>com.superversion</groupId>
<artifactId>virtual-patient-rasa</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>virtual-patient-rasa</name>
<description>virtual-patient-rasa</description>
<packaging>jar</packaging>
@ -21,6 +20,13 @@
</properties>
<dependencies>
<dependency>
<groupId>com.supervision</groupId>
<artifactId>virtual-patient-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>

@ -1,9 +1,13 @@
package com.superversion.rasa;
import com.supervision.config.WebConfig;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.FilterType;
@SpringBootApplication
@ComponentScan(basePackages = {"com.superversion.rasa"},excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {WebConfig.class}))
public class VirtualPatientRasaApplication {
public static void main(String[] args) {

@ -1,102 +0,0 @@
package com.superversion.rasa.config;
import cn.hutool.json.JSONUtil;
import com.superversion.rasa.domian.GlobalResult;
import com.superversion.rasa.exception.BusinessException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.MethodParameter;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.validation.BindException;
import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice;
import java.util.Objects;
import java.util.stream.Collectors;
/**
*
*
* @author wb
* @date 2022/3/10 13:24
*/
@Slf4j
@RestControllerAdvice(annotations = RestController.class, basePackages = {"com.**.controller"})
public class ResponseConfig implements ResponseBodyAdvice<Object> {
@Override
public boolean supports(@Nullable MethodParameter methodParameter,
@Nullable Class<? extends HttpMessageConverter<?>> aClass) {
assert methodParameter != null;
return !methodParameter.getDeclaringClass().getName().contains("swagger");
}
@Override
public Object beforeBodyWrite(Object o, @Nullable MethodParameter methodParameter, @Nullable MediaType mediaType,
@Nullable Class<? extends HttpMessageConverter<?>> aClass, @Nullable ServerHttpRequest serverHttpRequest,
@Nullable ServerHttpResponse serverHttpResponse) {
if (Objects.isNull(o)) {
return JSONUtil.toJsonStr(GlobalResult.ok(null, "success"));
}
if (o instanceof GlobalResult) {
return o;
}
// 对于String类型的返回值需要进行特殊处理
if (o instanceof String) {
return JSONUtil.toJsonStr(GlobalResult.ok(o, "success"));
}
return GlobalResult.ok(o, "success");
}
/**
*
*
* @param exception
* @return
*/
@ExceptionHandler(BusinessException.class)
public GlobalResult<?> businessExceptionResponse(BusinessException exception) {
log.error(exception.getMessage(), exception);
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), exception.getMessage(), "业务异常");
}
/**
*
*
* @param exception
* @return
*/
@ExceptionHandler({MethodArgumentNotValidException.class, BindException.class})
public GlobalResult<?> validationExceptionResponse(MethodArgumentNotValidException exception) {
log.error(exception.getMessage(), exception);
// 格式化错误信息
String errorMsg = exception.getBindingResult().getFieldErrors().stream()
.map(e -> e.getField() + ":" + e.getDefaultMessage()).collect(Collectors.joining("、"));
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), "参数验证异常", errorMsg);
}
/**
*
*
* @param exception
* @return
*/
@ExceptionHandler(Exception.class)
public GlobalResult<?> validationExceptionResponse(Exception exception) {
log.error(exception.getMessage(), exception);
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), "未知错误", exception.getMessage());
}
}

@ -1,9 +1,9 @@
package com.superversion.rasa.controller;
import cn.hutool.core.util.StrUtil;
import com.superversion.rasa.exception.BusinessException;
import com.superversion.rasa.pojo.vo.RasaArgument;
import com.superversion.rasa.service.RasaCmdService;
import com.supervision.exception.BusinessException;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;

@ -0,0 +1,30 @@
package com.superversion.rasa.controller;
import com.superversion.rasa.service.RasaTalkService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@Api(tags = "rasa文件保存")
@RestController
@RequestMapping("rasaFile")
@RequiredArgsConstructor
public class RasaTalkController {
@Autowired
private RasaTalkService rasaTalkService;
@ApiOperation("rasa对话")
@GetMapping("talkRasa")
public List<String> talkRasa(String question, String sessionId){
return null;
}
}

@ -1,49 +0,0 @@
package com.superversion.rasa.domian;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import org.springframework.http.HttpStatus;
@Data
@ApiModel
public class GlobalResult<T> {
private int code = 200;
private String msg = "success";
@ApiModelProperty
private T data;
public static <T> GlobalResult<T> ok() {
return ok(null);
}
public static <T> GlobalResult<T> ok(T data) {
GlobalResult<T> globalResult = new GlobalResult<>();
globalResult.setData(data);
return globalResult;
}
public static <T> GlobalResult<T> ok(T data, String message) {
GlobalResult<T> globalResult = new GlobalResult<>();
globalResult.setMsg(message);
globalResult.setData(data);
return globalResult;
}
public static <T> GlobalResult<T> error(String msg) {
return error(HttpStatus.INTERNAL_SERVER_ERROR.value(), null, msg);
}
public static <T> GlobalResult<T> error(int code, T data, String msg) {
GlobalResult<T> globalResult = new GlobalResult<>();
globalResult.setCode(code);
globalResult.setData(data);
globalResult.setMsg(msg);
return globalResult;
}
}

@ -1,76 +0,0 @@
/*
* : CustomException
* :
* : <>
* : RedName
* : 2022/8/5
* : <>
* : <>
* : <>
*/
package com.superversion.rasa.exception;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
/**
* <>
*
*
* @author ljt
* @version [, 2022/8/5]
* @see [/]
* @since [/]
*/
@Slf4j
public class BusinessException extends RuntimeException {
/**
*
*/
private final Integer code;
/**
*
*/
private final String message;
public BusinessException(Throwable cause) {
super(cause);
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
this.message = null;
}
public BusinessException(Throwable cause, String message) {
super(cause);
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
this.message = message;
}
public BusinessException(String message) {
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
this.message = message;
}
public BusinessException(String message, Integer code) {
this.message = message;
this.code = code;
}
public BusinessException(String message, Throwable e) {
super(message, e);
log.error(message, e);
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
this.message = message;
}
@Override
public String getMessage() {
return message;
}
public Integer getCode() {
return code;
}
}

@ -0,0 +1,13 @@
package com.superversion.rasa.pojo.dto;
import lombok.Data;
@Data
public class RasaReqDTO {
private String sender;
private String message;
}

@ -0,0 +1,11 @@
package com.superversion.rasa.pojo.dto;
import lombok.Data;
@Data
public class RasaResDTO {
private String recipient_id;
private String text;
}

@ -0,0 +1,9 @@
package com.superversion.rasa.service;
import java.util.List;
public interface RasaTalkService {
List<String> talkRasa(String question, String sessionId) ;
}

@ -75,19 +75,19 @@ public class RasaCmdServiceImpl implements RasaCmdService {
private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException {
ProcessBuilder processBuilder = new ProcessBuilder(cmds);
processBuilder.directory(new File(shellWork)); // 设置工作目录
processBuilder.redirectErrorStream(true); // 合并标准输出和错误输出
processBuilder.directory(new File(shellWork));
processBuilder.redirectErrorStream(true);
ThreadPoolExecutor instance = ThreadPoolExecutorConfig.getInstance();
Future<List<String>> future = instance.submit(() -> {
Process process = processBuilder.start(); // 启动进程
InputStream inputStream = process.getInputStream(); // 获取进程的输出流
Process process = processBuilder.start();
InputStream inputStream = process.getInputStream();
Reader reader = new InputStreamReader(inputStream, "UTF-8");
BufferedReader bufferedReader = new BufferedReader(reader);
List<String> outString = new ArrayList<>();
String resultLines = bufferedReader.readLine();
while( resultLines != null) {
resultLines = bufferedReader.readLine(); // 读取下一行
resultLines = bufferedReader.readLine();
log.info("resultLines:{}",resultLines);
outString.add(resultLines);
if (endPredicate.test(resultLines)){

@ -0,0 +1,32 @@
package com.superversion.rasa.service.impl;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import com.superversion.rasa.pojo.dto.RasaReqDTO;
import com.superversion.rasa.pojo.dto.RasaResDTO;
import com.superversion.rasa.service.RasaTalkService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.stream.Collectors;
@Service
@Slf4j
@RequiredArgsConstructor
public class RasaTalkServiceImpl implements RasaTalkService {
@Value("${rasa.url}")
private String rasaUrl;
@Override
public List<String> talkRasa(String question, String sessionId) {
RasaReqDTO rasaReqDTO = new RasaReqDTO();
rasaReqDTO.setSender(sessionId);
rasaReqDTO.setMessage(question);
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO));
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
}
}

@ -1,7 +1,15 @@
server:
port: 8082
port: 8890
servlet:
context-path: /
undertow:
# HTTP post内容的最大大小。当值为-1时默认值为大小是无限的
max-http-post-size: -1
# 以下的配置会影响buffer,这些buffer会用于服务器连接的IO操作,有点类似netty的池化内存管理
# 每块buffer的空间大小,越小的空间被利用越充分
buffer-size: 512
# 是否分配的直接内存
direct-buffers: true
rasa:
data-path: /rasa/v3_jiazhuangxian/ # 文件解压后存放位置
@ -9,4 +17,49 @@ rasa:
endpoints: /rasa/endpoints.yml # 启动的配置项,应该是写在配置文件里面
config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置
train-shell: /home/rasa_manage/train.sh
run-shell: /home/rasa_manage/run.sh
run-shell: /home/rasa_manage/run.sh
url: 192.168.10.137:5005/webhooks/rest/webhook
spring:
profiles:
active: dev
application:
name: virtual-patient
servlet:
multipart:
max-file-size: 100MB
max-request-size: 100MB
##数据源配置
datasource:
type: com.alibaba.druid.pool.DruidDataSource
druid:
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://192.168.10.138:3306/virtual_patient?useUnicode=true&characterEncoding=utf-8&useSSL=true&nullCatalogMeansCurrent=true&serverTimezone=GMT%2B8
username: root
password: '123456'
initial-size: 5 # 初始化大小
min-idle: 10 # 最小连接数
max-active: 20 # 最大连接数
max-wait: 60000 # 获取连接时的最大等待时间
min-evictable-idle-time-millis: 300000 # 一个连接在池中最小生存的时间,单位是毫秒
time-between-eviction-runs-millis: 60000 # 多久才进行一次检测需要关闭的空闲连接,单位是毫秒
filters: stat,wall # 配置扩展插件stat-监控统计log4j-日志wall-防火墙防止SQL注入去掉后监控界面的sql无法统计
validation-query: SELECT 1 # 检测连接是否有效的 SQL语句为空时以下三个配置均无效
test-on-borrow: true # 申请连接时执行validationQuery检测连接是否有效默认true开启后会降低性能
test-on-return: true # 归还连接时执行validationQuery检测连接是否有效默认false开启后会降低性能
test-while-idle: true # 申请连接时如果空闲时间大于timeBetweenEvictionRunsMillis执行validationQuery检测连接是否有效默认false建议开启不影响性能
stat-view-servlet:
enabled: false # 是否开启 StatViewServlet
filter:
stat:
enabled: true # 是否开启 FilterStat默认true
log-slow-sql: true # 是否开启 慢SQL 记录默认false
slow-sql-millis: 5000 # 慢 SQL 的标准,默认 3000单位毫秒
merge-sql: false # 合并多个连接池的监控数据默认false
mybatis-plus:
mapper-locations: classpath*:mapper/**/*.xml
configuration:
log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
Loading…
Cancel
Save