rasa:优化代码

dev_v1.0.1
xueqingkun 2 years ago
parent 3c5c235311
commit 3e6344f93b

@ -1,6 +1,7 @@
package com.supervision.rasa.controller;
import cn.hutool.core.util.StrUtil;
import com.supervision.rasa.service.RasaFileService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
@ -20,14 +21,17 @@ public class RasaFileController {
@ApiOperation("接受并保存rasa文件")
@PostMapping("/saveRasaFile")
public String saveRasaFile(@RequestParam("file") MultipartFile file) throws IOException {
public String saveRasaFile(@RequestParam("file") MultipartFile file, @RequestParam("modelId") String modelId) throws IOException {
if (file == null || file.isEmpty()) {
return "file is empty";
}
if (StrUtil.isEmpty(modelId)){
return "modelId is empty";
}
rasaFileService.saveRasaFile(file);
return "ok";
rasaFileService.saveRasaFile(file,modelId);
return "succss";
}
}

@ -6,5 +6,5 @@ import java.io.IOException;
public interface RasaFileService {
void saveRasaFile(MultipartFile file) throws IOException;
void saveRasaFile(MultipartFile file,String modelId) throws IOException;
}

@ -60,8 +60,13 @@ public class RasaCmdServiceImpl implements RasaCmdService {
@Transactional
public String trainExec(RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
String domain = dataPath+"domain.yml";
List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath);
// /rasa/v3_jiazhuangxian/domain.yml domain的路径应该是从zip文件中加压出来的文件的路径后面拼上/domain.yml
String domain = String.join(File.separator,dataPath,argument.getModelId(),"domain.yml");
// /rasa/v3_jiazhuangxian/ yml文件的路径应该是从zip文件中加压出来的文件的路径,在配置文件中配置
String localDataPath = String.join(File.separator,dataPath,argument.getModelId());
// /rasa/models 生成出来的模型的存放路径,也写在配置文件里面
String localModelsPath = String.join(File.separator,modelsPath,argument.getModelId());
List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,localDataPath,domain,localModelsPath);
cmds.add(argument.getFixedModelName());
@ -90,8 +95,10 @@ public class RasaCmdServiceImpl implements RasaCmdService {
log.info("runExec findUnusedPort is : {}",port);
// 2. 运行模型
String mPath = modelsPath+argument.getFixedModelName()+".tar.gz";
List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,String.valueOf(port));
// 2.1 -m 参数对应的值 /rasa/models/aaa1111.tar.gz 指run的模型的路径/rasa/models应该来自于配置文件和训练时的--out是同一配置项
// aaa1111.tar.gz这个前面的文件名应该是--fixed-model-name指定的.tar.gz是文件后缀代码拼接
String localPath = String.join(File.separator,modelsPath,argument.getModelId(),argument.getFixedModelName(),".tar.gz");
List<String> cmds = ListUtil.toList(shellEnv, runShell,localPath,endpoints,String.valueOf(port));
log.info("runExec cmd : {}",StrUtil.join(" ",cmds));
@ -102,7 +109,7 @@ public class RasaCmdServiceImpl implements RasaCmdService {
rasaModelInfo.setModelId(argument.getModelId());
rasaModelInfo.setPort(port);
rasaModelInfo.setServerStatus(runIsSuccess(outMessageList)?1:0);
rasaModelInfo.setCmd(ListUtil.toList(shellEnv, runShell,mPath,endpoints));
rasaModelInfo.setCmd(ListUtil.toList(shellEnv, runShell,localPath,endpoints));
rasaModeService.saveOrUpdateByModelId(rasaModelInfo);
return String.join("\r\n",outMessageList);

@ -18,7 +18,7 @@ import java.io.IOException;
public class RasaFileServiceImpl implements RasaFileService {
@Value("${rasa.data-path:/home/rasa/}")
@Value("${rasa.data-path:/home/rasa/model_resource/}")
private String rasaFilePath;
@Value("${rasa.file-name:rasa.zip}")
@ -26,18 +26,19 @@ public class RasaFileServiceImpl implements RasaFileService {
@Override
public void saveRasaFile(MultipartFile file) throws IOException {
public void saveRasaFile(MultipartFile file,String modelId) throws IOException {
String suffix = "_back";
String rasaFullPath = String.join(File.separator, rasaFilePath,modelId, rasaFileName);
String rasaBackFullPath = rasaFullPath + suffix;
//初始化目录
File dir = new File(rasaFilePath);
File dir = new File(String.join(File.separator, rasaFilePath,modelId));
if (!dir.exists()){
FileUtil.mkdir(dir);
}
String suffix = "_back";
String rasaFullPath = String.join(File.separator, rasaFilePath, rasaFileName);
String rasaBackFullPath = rasaFilePath+rasaFileName + suffix;
//1.检查路径下是否存在文件
File oldFile = new File(rasaFullPath);
if (oldFile.exists()){
@ -51,7 +52,7 @@ public class RasaFileServiceImpl implements RasaFileService {
file.transferTo(new File(rasaFullPath));
//3.解压文件
ZipUtil.unzip(rasaFullPath);
ZipUtil.unzip(rasaFullPath,String.join(File.separator, rasaFilePath,modelId));
//4.删除备份文件
FileUtil.del(rasaBackFullPath);

@ -44,7 +44,14 @@ public class RasaTalkServiceImpl implements RasaTalkService {
RasaReqDTO rasaReqDTO = new RasaReqDTO();
rasaReqDTO.setSender(rasaTalkVo.getSessionId());
rasaReqDTO.setMessage(rasaTalkVo.getQuestion());
String post = HttpUtil.post(getRasaUrl(rasaModelInfo.getPort()), JSONUtil.toJsonStr(rasaReqDTO));
String rasaUrl = getRasaUrl(rasaModelInfo.getPort());
log.info("talkRasa: url is: {}",rasaUrl);
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO));
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());

@ -12,7 +12,8 @@ server:
direct-buffers: true
rasa:
data-path: /rasa/v3_jiazhuangxian/ # 文件解压后存放位置
#data-path: /home/rasa/model_resource/ # 文件解压后存放位置
data-path: F:\tmp\model_resource\ # 文件解压后存放位置
models-path: /rasa/models/
endpoints: /rasa/endpoints.yml # 启动的配置项,应该是写在配置文件里面
config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置

Loading…
Cancel
Save