rasa 训练和启动接口开发

dev_v1.0.1
xueqingkun 2 years ago
parent 362d33d93d
commit 76170c8a41

@ -0,0 +1,28 @@
package com.superversion.rasa.config;
import java.util.concurrent.*;
public class ThreadPoolExecutorConfig {
private volatile static ThreadPoolExecutor instance = null;
private ThreadPoolExecutorConfig(){}
public static ThreadPoolExecutor getInstance() {
if (instance == null) {
synchronized (ThreadPoolExecutorConfig.class) { // 加锁
if (instance == null) {
int corePoolSize = 5;
int maximumPoolSize = 10;
long keepAliveTime = 100;
BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<>(20);
RejectedExecutionHandler rejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy();
instance = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, TimeUnit.SECONDS, workQueue, rejectedExecutionHandler);
}
}
}
return instance;
}
}

@ -1,20 +1,15 @@
package com.superversion.rasa.controller; package com.superversion.rasa.controller;
import cn.hutool.core.io.FileUtil; import com.superversion.rasa.pojo.vo.RasaArgument;
import cn.hutool.core.util.ZipUtil; import com.superversion.rasa.service.RasaCmdService;
import com.superversion.rasa.service.RasaFileService;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import java.io.*; import java.io.*;
import java.nio.charset.Charset; import java.util.concurrent.*;
@Api(tags = "rasa文件保存") @Api(tags = "rasa文件保存")
@RestController @RestController
@ -23,14 +18,21 @@ import java.nio.charset.Charset;
public class RasaCmdController { public class RasaCmdController {
@Autowired @Autowired
private RasaFileService rasaFileService; private RasaCmdService rasaCmdService;
@ApiOperation("接受并保存rasa文件") @ApiOperation("执行训练shell命令")
@PostMapping("/exec") @PostMapping("/trainExec")
public String cmdExec(@RequestParam("file") MultipartFile file){ public String trainExec(@RequestBody RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
return rasaCmdService.trainExec(argument);
return "ok"; }
@ApiOperation("执行启动shell命令")
@PostMapping("/runExec")
public String runExec(@RequestBody RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
return rasaCmdService.runExec(argument);
} }

@ -0,0 +1,17 @@
package com.superversion.rasa.pojo.vo;
import lombok.Data;
@Data
public class RasaArgument {
private String config;
private String data;
private String domain;
private String out;
private String fixedModelName;//fixed-model-name
private String enableApi;//enable-api
private String endpoints;
private String port;
}

@ -1,4 +1,14 @@
package com.superversion.rasa.service; package com.superversion.rasa.service;
import com.superversion.rasa.pojo.vo.RasaArgument;
import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
public interface RasaCmdService { public interface RasaCmdService {
String trainExec(RasaArgument argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
String runExec( RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException;
} }

@ -1,33 +1,106 @@
package com.superversion.rasa.service.impl; package com.superversion.rasa.service.impl;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.util.StrUtil;
import com.superversion.rasa.config.ThreadPoolExecutorConfig;
import com.superversion.rasa.pojo.vo.RasaArgument;
import com.superversion.rasa.service.RasaCmdService; import com.superversion.rasa.service.RasaCmdService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.io.BufferedReader; import java.io.*;
import java.io.IOException; import java.util.ArrayList;
import java.io.InputStreamReader; import java.util.List;
import java.nio.charset.Charset; import java.util.concurrent.*;
import java.util.function.Predicate;
@Service
@Slf4j
@RequiredArgsConstructor
public class RasaCmdServiceImpl implements RasaCmdService { public class RasaCmdServiceImpl implements RasaCmdService {
public static void main(String[] args) throws IOException, InterruptedException { @Value("${rasa.models-path}")
Runtime runtime = Runtime.getRuntime(); private String modelsPath;
Process process = runtime.exec(" cmd /c ls ");
BufferedReader br = new BufferedReader(new InputStreamReader(process.getInputStream(), Charset.forName("GBK")));
String lineMes;
while ((lineMes = br.readLine()) != null){
System.out.println(lineMes);// 打印输出信息
}
@Value("${rasa.endpoints}")
private String endpoints;
@Value("${rasa.config}")
private String config;
//检查命令是否执行失败。 @Value("${rasa.data-path}")
if (process.waitFor() != 0) { private String dataPath;
if (process.exitValue() == 1)//0表示正常结束1非正常结束
System.err.println("命令执行失败!");
}
br.close();
@Value("${rasa.shell-env:/bin/bash}")
private String shellEnv;
@Value("${rasa.train-shell}")
private String trainShell;
@Value("${rasa.run-shell}")
private String runShell;
@Value("${rasa.shell.work:/home/rasa_manage/}")
private String shellWork;
@Override
public String trainExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
String domain = dataPath+"domain.yml";
List<String> cmds = ListUtil.toList(shellEnv, trainShell,config,dataPath,domain,modelsPath);
cmds.add(argument.getFixedModelName());
log.info("trainExec cmd : {}",StrUtil.join(" ",cmds));
System.out.println("trainExec cmd sout:"+StrUtil.join(" ",cmds));
return String.join("\r\n",execCmd(cmds,s->false,90));
}
@Override
public String runExec(RasaArgument argument) throws ExecutionException, InterruptedException, TimeoutException {
String mPath = modelsPath+argument.getFixedModelName()+".tar.gz";
List<String> cmds = ListUtil.toList(shellEnv, runShell,mPath,endpoints,argument.getPort());
log.info("runExec cmd : {}",StrUtil.join(" ",cmds));
System.out.println("runExec cmd sout:"+StrUtil.join(" ",cmds));
return String.join("\r\n",execCmd(cmds,s-> StrUtil.isNotBlank(s)&& s.contains("Rasa server is up and running"),90));
}
private List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException {
ProcessBuilder processBuilder = new ProcessBuilder(cmds);
processBuilder.directory(new File(shellWork)); // 设置工作目录
processBuilder.redirectErrorStream(true); // 合并标准输出和错误输出
ThreadPoolExecutor instance = ThreadPoolExecutorConfig.getInstance();
Future<List<String>> future = instance.submit(() -> {
Process process = processBuilder.start(); // 启动进程
InputStream inputStream = process.getInputStream(); // 获取进程的输出流
Reader reader = new InputStreamReader(inputStream, "UTF-8");
BufferedReader bufferedReader = new BufferedReader(reader);
List<String> outString = new ArrayList<>();
String resultLines = bufferedReader.readLine();
while(resultLines != null) {
resultLines = bufferedReader.readLine(); // 读取下一行
log.info("resultLines:{}",resultLines);
if (endPredicate.test(resultLines)){
break;
}
outString.add(resultLines);
}
bufferedReader.close();
return outString;
});
return future.get(timeOut, TimeUnit.SECONDS);
} }
} }

@ -18,12 +18,13 @@ import java.io.IOException;
public class RasaFileServiceImpl implements RasaFileService { public class RasaFileServiceImpl implements RasaFileService {
@Value("${rasa.file-path:/home/rasa}") @Value("${rasa.data-path:/home/rasa/}")
private String rasaFilePath; private String rasaFilePath;
@Value("${rasa.file-name:rasa.zip}") @Value("${rasa.file-name:rasa.zip}")
private String rasaFileName; private String rasaFileName;
@Override @Override
public void saveRasaFile(MultipartFile file) throws IOException { public void saveRasaFile(MultipartFile file) throws IOException {
@ -35,7 +36,7 @@ public class RasaFileServiceImpl implements RasaFileService {
String suffix = "_back"; String suffix = "_back";
String rasaFullPath = String.join(File.separator, rasaFilePath, rasaFileName); String rasaFullPath = String.join(File.separator, rasaFilePath, rasaFileName);
String rasaBackFullPath = String.join(File.separator, rasaFilePath, rasaFileName + suffix); String rasaBackFullPath = rasaFilePath+rasaFileName + suffix;
//1.检查路径下是否存在文件 //1.检查路径下是否存在文件
File oldFile = new File(rasaFullPath); File oldFile = new File(rasaFullPath);

@ -4,4 +4,9 @@ server:
context-path: / context-path: /
rasa: rasa:
file-path: F:\tmp data-path: /rasa/v3_jiazhuangxian/ # 文件解压后存放位置
models-path: /rasa/models/
endpoints: /rasa/endpoints.yml # 启动的配置项,应该是写在配置文件里面
config: /rasa/config-local.yml # 启动rasa需要的配置文件在配置文件中配置
train-shell: /home/rasa_manage/train.sh
run-shell: /home/rasa_manage/run.sh
Loading…
Cancel
Save