rasa 优化训练和启动接口
parent
3b2d6c3adb
commit
7245b593b2
@ -1,102 +0,0 @@
|
|||||||
package com.superversion.rasa.config;
|
|
||||||
|
|
||||||
import cn.hutool.json.JSONUtil;
|
|
||||||
import com.superversion.rasa.domian.GlobalResult;
|
|
||||||
import com.superversion.rasa.exception.BusinessException;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.core.MethodParameter;
|
|
||||||
import org.springframework.http.HttpStatus;
|
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.http.converter.HttpMessageConverter;
|
|
||||||
import org.springframework.http.server.ServerHttpRequest;
|
|
||||||
import org.springframework.http.server.ServerHttpResponse;
|
|
||||||
import org.springframework.lang.Nullable;
|
|
||||||
import org.springframework.validation.BindException;
|
|
||||||
import org.springframework.web.bind.MethodArgumentNotValidException;
|
|
||||||
import org.springframework.web.bind.annotation.ExceptionHandler;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
import org.springframework.web.bind.annotation.RestControllerAdvice;
|
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice;
|
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 统一返回
|
|
||||||
*
|
|
||||||
* @author wb
|
|
||||||
* @date 2022/3/10 13:24
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
@RestControllerAdvice(annotations = RestController.class, basePackages = {"com.**.controller"})
|
|
||||||
public class ResponseConfig implements ResponseBodyAdvice<Object> {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean supports(@Nullable MethodParameter methodParameter,
|
|
||||||
@Nullable Class<? extends HttpMessageConverter<?>> aClass) {
|
|
||||||
assert methodParameter != null;
|
|
||||||
return !methodParameter.getDeclaringClass().getName().contains("swagger");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object beforeBodyWrite(Object o, @Nullable MethodParameter methodParameter, @Nullable MediaType mediaType,
|
|
||||||
@Nullable Class<? extends HttpMessageConverter<?>> aClass, @Nullable ServerHttpRequest serverHttpRequest,
|
|
||||||
@Nullable ServerHttpResponse serverHttpResponse) {
|
|
||||||
|
|
||||||
|
|
||||||
if (Objects.isNull(o)) {
|
|
||||||
return JSONUtil.toJsonStr(GlobalResult.ok(null, "success"));
|
|
||||||
}
|
|
||||||
if (o instanceof GlobalResult) {
|
|
||||||
return o;
|
|
||||||
}
|
|
||||||
// 对于String类型的返回值需要进行特殊处理
|
|
||||||
if (o instanceof String) {
|
|
||||||
return JSONUtil.toJsonStr(GlobalResult.ok(o, "success"));
|
|
||||||
}
|
|
||||||
return GlobalResult.ok(o, "success");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 业务异常处理
|
|
||||||
*
|
|
||||||
* @param exception 业务异常
|
|
||||||
* @return 通用返回值
|
|
||||||
*/
|
|
||||||
@ExceptionHandler(BusinessException.class)
|
|
||||||
public GlobalResult<?> businessExceptionResponse(BusinessException exception) {
|
|
||||||
log.error(exception.getMessage(), exception);
|
|
||||||
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), exception.getMessage(), "业务异常");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 参数验证异常处理
|
|
||||||
*
|
|
||||||
* @param exception 参数验证异常
|
|
||||||
* @return 通用返回值
|
|
||||||
*/
|
|
||||||
@ExceptionHandler({MethodArgumentNotValidException.class, BindException.class})
|
|
||||||
public GlobalResult<?> validationExceptionResponse(MethodArgumentNotValidException exception) {
|
|
||||||
log.error(exception.getMessage(), exception);
|
|
||||||
// 格式化错误信息
|
|
||||||
String errorMsg = exception.getBindingResult().getFieldErrors().stream()
|
|
||||||
.map(e -> e.getField() + ":" + e.getDefaultMessage()).collect(Collectors.joining("、"));
|
|
||||||
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), "参数验证异常", errorMsg);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 未知异常处理
|
|
||||||
*
|
|
||||||
* @param exception 未知异常
|
|
||||||
* @return 通用返回值
|
|
||||||
*/
|
|
||||||
@ExceptionHandler(Exception.class)
|
|
||||||
public GlobalResult<?> validationExceptionResponse(Exception exception) {
|
|
||||||
log.error(exception.getMessage(), exception);
|
|
||||||
return GlobalResult.error(HttpStatus.INTERNAL_SERVER_ERROR.value(), "未知错误", exception.getMessage());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -0,0 +1,30 @@
|
|||||||
|
package com.superversion.rasa.controller;
|
||||||
|
|
||||||
|
import com.superversion.rasa.service.RasaTalkService;
|
||||||
|
import io.swagger.annotations.Api;
|
||||||
|
import io.swagger.annotations.ApiOperation;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Api(tags = "rasa文件保存")
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("rasaFile")
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class RasaTalkController {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private RasaTalkService rasaTalkService;
|
||||||
|
|
||||||
|
@ApiOperation("rasa对话")
|
||||||
|
@GetMapping("talkRasa")
|
||||||
|
public List<String> talkRasa(String question, String sessionId){
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1,49 +0,0 @@
|
|||||||
package com.superversion.rasa.domian;
|
|
||||||
|
|
||||||
import io.swagger.annotations.ApiModel;
|
|
||||||
import io.swagger.annotations.ApiModelProperty;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.springframework.http.HttpStatus;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@ApiModel
|
|
||||||
public class GlobalResult<T> {
|
|
||||||
|
|
||||||
private int code = 200;
|
|
||||||
|
|
||||||
private String msg = "success";
|
|
||||||
|
|
||||||
@ApiModelProperty
|
|
||||||
private T data;
|
|
||||||
|
|
||||||
|
|
||||||
public static <T> GlobalResult<T> ok() {
|
|
||||||
return ok(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> GlobalResult<T> ok(T data) {
|
|
||||||
GlobalResult<T> globalResult = new GlobalResult<>();
|
|
||||||
globalResult.setData(data);
|
|
||||||
return globalResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> GlobalResult<T> ok(T data, String message) {
|
|
||||||
GlobalResult<T> globalResult = new GlobalResult<>();
|
|
||||||
globalResult.setMsg(message);
|
|
||||||
globalResult.setData(data);
|
|
||||||
return globalResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> GlobalResult<T> error(String msg) {
|
|
||||||
return error(HttpStatus.INTERNAL_SERVER_ERROR.value(), null, msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static <T> GlobalResult<T> error(int code, T data, String msg) {
|
|
||||||
GlobalResult<T> globalResult = new GlobalResult<>();
|
|
||||||
globalResult.setCode(code);
|
|
||||||
globalResult.setData(data);
|
|
||||||
globalResult.setMsg(msg);
|
|
||||||
return globalResult;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,76 +0,0 @@
|
|||||||
/*
|
|
||||||
* 文 件 名: CustomException
|
|
||||||
* 版 权:
|
|
||||||
* 描 述: <描述>
|
|
||||||
* 修 改 人: RedName
|
|
||||||
* 修改时间: 2022/8/5
|
|
||||||
* 跟踪单号: <跟踪单号>
|
|
||||||
* 修改单号: <修改单号>
|
|
||||||
* 修改内容: <修改内容>
|
|
||||||
*/
|
|
||||||
package com.superversion.rasa.exception;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.http.HttpStatus;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* <功能详细描述>
|
|
||||||
* 自定义异常
|
|
||||||
*
|
|
||||||
* @author ljt
|
|
||||||
* @version [版本号, 2022/8/5]
|
|
||||||
* @see [相关类/方法]
|
|
||||||
* @since [产品/模块版本]
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class BusinessException extends RuntimeException {
|
|
||||||
/**
|
|
||||||
* 异常编码
|
|
||||||
*/
|
|
||||||
private final Integer code;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 异常信息
|
|
||||||
*/
|
|
||||||
private final String message;
|
|
||||||
|
|
||||||
public BusinessException(Throwable cause) {
|
|
||||||
super(cause);
|
|
||||||
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
|
|
||||||
this.message = null;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public BusinessException(Throwable cause, String message) {
|
|
||||||
super(cause);
|
|
||||||
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
|
|
||||||
this.message = message;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public BusinessException(String message) {
|
|
||||||
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
|
|
||||||
this.message = message;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BusinessException(String message, Integer code) {
|
|
||||||
this.message = message;
|
|
||||||
this.code = code;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BusinessException(String message, Throwable e) {
|
|
||||||
super(message, e);
|
|
||||||
log.error(message, e);
|
|
||||||
this.code = HttpStatus.INTERNAL_SERVER_ERROR.value();
|
|
||||||
this.message = message;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMessage() {
|
|
||||||
return message;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Integer getCode() {
|
|
||||||
return code;
|
|
||||||
}
|
|
||||||
}
|
|
@ -0,0 +1,13 @@
|
|||||||
|
package com.superversion.rasa.pojo.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class RasaReqDTO {
|
||||||
|
|
||||||
|
private String sender;
|
||||||
|
|
||||||
|
private String message;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,11 @@
|
|||||||
|
package com.superversion.rasa.pojo.dto;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class RasaResDTO {
|
||||||
|
|
||||||
|
private String recipient_id;
|
||||||
|
|
||||||
|
private String text;
|
||||||
|
}
|
@ -0,0 +1,9 @@
|
|||||||
|
package com.superversion.rasa.service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface RasaTalkService {
|
||||||
|
|
||||||
|
|
||||||
|
List<String> talkRasa(String question, String sessionId) ;
|
||||||
|
}
|
@ -0,0 +1,32 @@
|
|||||||
|
package com.superversion.rasa.service.impl;
|
||||||
|
|
||||||
|
import cn.hutool.http.HttpUtil;
|
||||||
|
import cn.hutool.json.JSONUtil;
|
||||||
|
import com.superversion.rasa.pojo.dto.RasaReqDTO;
|
||||||
|
import com.superversion.rasa.pojo.dto.RasaResDTO;
|
||||||
|
import com.superversion.rasa.service.RasaTalkService;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class RasaTalkServiceImpl implements RasaTalkService {
|
||||||
|
|
||||||
|
@Value("${rasa.url}")
|
||||||
|
private String rasaUrl;
|
||||||
|
@Override
|
||||||
|
public List<String> talkRasa(String question, String sessionId) {
|
||||||
|
|
||||||
|
RasaReqDTO rasaReqDTO = new RasaReqDTO();
|
||||||
|
rasaReqDTO.setSender(sessionId);
|
||||||
|
rasaReqDTO.setMessage(question);
|
||||||
|
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO));
|
||||||
|
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
|
||||||
|
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue