rasa训练

dev_v1.0.1
liu 2 years ago
parent c38c926a5b
commit 5d9063603f

@ -1,18 +0,0 @@
package com.supervision.mapper;
import com.supervision.model.Lock;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
/**
* @author flevance
* @description lock(,redis,使)Mapper
* @createDate 2023-11-01 13:11:34
* @Entity com.supervision.model.Lock
*/
public interface LockMapper extends BaseMapper<Lock> {
}

@ -31,10 +31,10 @@ public class AskDiseaseQuestionAnswer extends Model<AskDiseaseQuestionAnswer> im
private String id;
/**
* ID
* ID
*/
@ApiModelProperty("病ID")
private String diseaseId;
@ApiModelProperty("ID")
private String patientId;
/**
* ID(not null,template_questioncode,desc,question)

@ -1,45 +0,0 @@
package com.supervision.model;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import java.io.Serializable;
import lombok.Data;
/**
* ,redis,使
* @TableName lock
*/
@TableName(value ="lock")
@Data
public class Lock implements Serializable {
/**
*
*/
@TableId
private String id;
/**
*
*/
private String lockCode;
/**
*
*/
private Long timestamp;
/**
*
*/
private Long expireMs;
/**
* 线ID,
*/
private String threadId;
@TableField(exist = false)
private static final long serialVersionUID = 1L;
}

@ -1,13 +0,0 @@
package com.supervision.service;
import com.supervision.model.Lock;
import com.baomidou.mybatisplus.extension.service.IService;
/**
* @author flevance
* @description lock(,redis,使)Service
* @createDate 2023-11-01 13:11:34
*/
public interface LockService extends IService<Lock> {
}

@ -1,22 +0,0 @@
package com.supervision.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.model.Lock;
import com.supervision.service.LockService;
import com.supervision.mapper.LockMapper;
import org.springframework.stereotype.Service;
/**
* @author flevance
* @description lock(,redis,使)Service
* @createDate 2023-11-01 13:11:34
*/
@Service
public class LockServiceImpl extends ServiceImpl<LockMapper, Lock>
implements LockService{
}

@ -6,7 +6,7 @@
<resultMap id="BaseResultMap" type="com.supervision.model.AskDiseaseQuestionAnswer">
<id property="id" column="id" jdbcType="VARCHAR"/>
<result property="diseaseId" column="disease_id" jdbcType="VARCHAR"/>
<result property="patientId" column="patient_id" jdbcType="VARCHAR"/>
<result property="templateQuestionId" column="template_question_id" jdbcType="VARCHAR"/>
<result property="code" column="code" jdbcType="VARCHAR"/>
<result property="description" column="description" jdbcType="VARCHAR"/>
@ -19,7 +19,7 @@
</resultMap>
<sql id="Base_Column_List">
id,disease_id,template_question_id,
id,patient_id,template_question_id,
code,description,question,
answer,create_user_id,create_time,
update_user_id,update_time

@ -1,19 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.mapper.LockMapper">
<resultMap id="BaseResultMap" type="com.supervision.model.Lock">
<id property="id" column="id" jdbcType="VARCHAR"/>
<result property="lockCode" column="lock_code" jdbcType="VARCHAR"/>
<result property="timestamp" column="timestamp" jdbcType="INTEGER"/>
<result property="expireMs" column="expire_ms" jdbcType="INTEGER"/>
<result property="threadId" column="thread_id" jdbcType="VARCHAR"/>
</resultMap>
<sql id="Base_Column_List">
id,lock_code,timestamp,
expire_ms,thread_id
</sql>
</mapper>

@ -11,6 +11,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
@SpringBootApplication
@EnableScheduling
@MapperScan(basePackages = {"com.supervision.**.mapper"})
// 排除JWT权限校验
@ComponentScan(basePackages = {"com.supervision"},excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {WebConfig.class}))
public class VirtualPatientRasaApplication {

@ -2,6 +2,7 @@ package com.supervision.rasa.controller;
import cn.hutool.core.util.StrUtil;
import com.supervision.exception.BusinessException;
import com.supervision.rasa.service.RasaFileService;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
@ -28,14 +29,13 @@ public class RasaFileController {
public String saveRasaFile(@RequestParam("file") MultipartFile file, @RequestParam("modelId") String modelId) throws IOException {
if (file == null || file.isEmpty()) {
return "file is empty";
throw new BusinessException("file is empty");
}
if (StrUtil.isEmpty(modelId)){
return "modelId is empty";
throw new BusinessException("modelId is empty");
}
rasaFileService.saveRasaFile(file,modelId);
return "succss";
return "success";
}
}

@ -8,10 +8,7 @@ import com.supervision.vo.ask.DiagnosisPrimaryVO;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@ -36,7 +33,7 @@ public class AskPrimaryController {
}
@ApiOperation("保存初步诊断")
@GetMapping("savePrimary")
@PostMapping("savePrimary")
public void savePrimary(@RequestBody DiagnosisPrimary reqVO){
askPrimaryService.savePrimary(reqVO);
}

@ -1,7 +1,12 @@
package com.supervision.controller;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.supervision.domain.GlobalResult;
import com.supervision.service.RasaService;
import io.swagger.annotations.ApiModelProperty;
import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@ -13,9 +18,22 @@ public class RasaController {
private final RasaService rasaService;
@ApiOperation("生成rasa的yml文件")
@GetMapping("generateRasaYml")
public void generateRasaYml(String diseaseId){
rasaService.generateRasaYml(diseaseId);
public GlobalResult<String> generateRasaYml(String patientId) {
return rasaService.generateRasaYml(patientId);
}
@ApiOperation("训练rasa")
@GetMapping("trainRasa")
public GlobalResult<String> trainRasa(String patientId) throws JsonProcessingException {
return rasaService.trainRasa(patientId);
}
@ApiOperation("运行Rasa")
@GetMapping("runRasa")
public GlobalResult<String> runRasa(String patientId) throws JsonProcessingException {
return rasaService.runRasa(patientId);
}

@ -118,7 +118,7 @@ public class TestController {
AskDiseaseQuestionAnswer askDiseaseQuestionAnswer = new AskDiseaseQuestionAnswer();
askDiseaseQuestionAnswer.setDiseaseId("1");
askDiseaseQuestionAnswer.setPatientId("1");
askDiseaseQuestionAnswer.setTemplateQuestionId(templateQuestion.getId());
askDiseaseQuestionAnswer.setAnswer(ListUtil.of(answer));
askDiseaseQuestionAnswer.insert();

@ -1,6 +1,13 @@
package com.supervision.service;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.supervision.domain.GlobalResult;
public interface RasaService {
void generateRasaYml(String diseaseId);
GlobalResult<String> generateRasaYml(String diseaseId);
GlobalResult<String> trainRasa(String patientId) throws JsonProcessingException;
GlobalResult<String> runRasa(String patientId) throws JsonProcessingException;
}

@ -4,9 +4,18 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpRequest;
import cn.hutool.http.HttpResponse;
import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.supervision.domain.GlobalResult;
import com.supervision.exception.BusinessException;
import com.supervision.model.*;
import com.supervision.pojo.paddlespeech.res.PaddleSpeechResDTO;
import com.supervision.pojo.paddlespeech.res.TtsResultDTO;
import com.supervision.pojo.rasa.train.DomainYmlTemplate;
import com.supervision.pojo.rasa.train.NluYmlTemplate;
import com.supervision.pojo.rasa.train.QuestionAnswerDTO;
@ -16,6 +25,7 @@ import freemarker.template.Configuration;
import freemarker.template.Template;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.io.ByteArrayInputStream;
@ -43,8 +53,19 @@ public class RasaServiceImpl implements RasaService {
private final ConfigAncillaryItemService configAncillaryItemService;
private static final ObjectMapper objectMapper = new ObjectMapper();
@Value("${rasa.base-url}${rasa.saveRasaFile}")
private String saveRasaFileUrl;
@Value("${rasa.base-url}${rasa.train}")
private String trainRasaUrl;
@Value("${rasa.base-url}${rasa.run}")
private String runRasaUrl;
@Override
public void generateRasaYml(String diseaseId) {
public GlobalResult<String> generateRasaYml(String patientId) {
Map<String, File> ymalFileMap = new HashMap<>();
// 默认问答MAP
@ -53,29 +74,42 @@ public class RasaServiceImpl implements RasaService {
List<RuleYmlTemplate.Rule> ruleList = new ArrayList<>();
// 开始生成各种yaml文件
generateNlu(diseaseId, questionCodeAndIdMap, ymalFileMap);
generateNlu(patientId, questionCodeAndIdMap, ymalFileMap);
generateDomain(questionCodeAndIdMap, ruleList, ymalFileMap);
generateRule(ruleList, ymalFileMap);
// 生成压缩文件
List<File> tempFile = new ArrayList<>();
File tempZipFile = FileUtil.createTempFile(".zip", true);
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try (ZipOutputStream zipOutputStream = new ZipOutputStream(bos)) {
for (Map.Entry<String, File> fileEntry : ymalFileMap.entrySet()) {
zipOutputStream.putNextEntry(new ZipEntry(fileEntry.getKey()));
IoUtil.copy(FileUtil.getInputStream(fileEntry.getValue()), zipOutputStream);
zipOutputStream.closeEntry();
tempFile.add(fileEntry.getValue());
}
zipOutputStream.finish();
// 调用接口传文件
HttpRequest request = HttpRequest.post(saveRasaFileUrl);
IoUtil.copy(new ByteArrayInputStream(bos.toByteArray()), FileUtil.getOutputStream(tempZipFile));
request.form("file", tempZipFile);
request.form("modelId", patientId);
HttpResponse response = request.execute();
String responseBody = response.body();
log.info(responseBody);
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
});
} catch (Exception e) {
log.error("生成ZIP文件失败", e);
throw new BusinessException("生成ZIP文件失败");
} finally {
// 最后把临时文件删除
tempFile.forEach(FileUtil::del);
FileUtil.del(tempZipFile);
}
// TODO 这是压缩文件的字节流,这里需要把自己流调用Python程序
byte[] byteArray = bos.toByteArray();
File file = new File("rasa.zip");
IoUtil.copy(new ByteArrayInputStream(byteArray), FileUtil.getOutputStream(file));
}
private void generateNlu(String diseaseId,
private void generateNlu(String patientId,
Map<String, QuestionAnswerDTO> intentCodeAndIdMap,
Map<String, File> ymalFileMap) {
// 首先生成根据意图查找到nlu文件
@ -98,7 +132,7 @@ public class RasaServiceImpl implements RasaService {
}
// 然后处理该疾病对应的意图
List<AskDiseaseQuestionAnswer> diseaseQuestionAnswerList = askDiseaseQuestionAnswerService.lambdaQuery()
.eq(AskDiseaseQuestionAnswer::getDiseaseId, diseaseId).list();
.eq(AskDiseaseQuestionAnswer::getPatientId, patientId).list();
// 使用通用模板的
Map<String, AskTemplateQuestion> templateQuestionMap = new HashMap<>();
// 根据默认意图找到所有的问题
@ -215,7 +249,7 @@ public class RasaServiceImpl implements RasaService {
configuration.setClassForTemplateLoading(clazz, "/templates"); // 模板文件的所在目录
// 获取模板
Template template = configuration.getTemplate(ftlName);
File tempFile = FileUtil.createTempFile();
File tempFile = FileUtil.createTempFile(".yml", true);
// 创建输出文件
try (PrintWriter out = new PrintWriter(tempFile);) {
// 填充并生成输出
@ -230,4 +264,22 @@ public class RasaServiceImpl implements RasaService {
}
@Override
public GlobalResult<String> trainRasa(String patientId) throws JsonProcessingException {
Map<String, Object> param = new HashMap<>();
param.put("modelId", patientId);
String responseBody = HttpUtil.post(trainRasaUrl, param);
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
});
}
@Override
public GlobalResult<String> runRasa(String patientId) throws JsonProcessingException {
Map<String, Object> param = new HashMap<>();
param.put("modelId", patientId);
String responseBody = HttpUtil.post(runRasaUrl, param);
return objectMapper.readValue(responseBody, new TypeReference<GlobalResult<String>>() {
});
}
}

@ -16,14 +16,15 @@ import java.util.stream.Collectors;
@Slf4j
public class RasaUtil {
private static final String RASA_URL = SpringBeanUtil.getBean(Environment.class).getProperty("rasa.url");
private static final Environment environment = SpringBeanUtil.getBean(Environment.class);
private static final String RASA_TALK_URL = environment.getProperty("rasa.base-url") + environment.getProperty("rasa.talk");
public static String talkRasa(String question, String sessionId, String patientId) {
RasaTalkVo rasaTalkVo = new RasaTalkVo();
rasaTalkVo.setQuestion(question);
rasaTalkVo.setSessionId(sessionId);
rasaTalkVo.setModelId(patientId);
String post = HttpUtil.post(RASA_URL, JSONUtil.toJsonStr(rasaTalkVo));
String post = HttpUtil.post(RASA_TALK_URL, JSONUtil.toJsonStr(rasaTalkVo));
List<String> list = JSONUtil.toList(post, String.class);
log.info("调用rasa对话返回结果:{}",post);
if (CollUtil.isEmpty(list)){

@ -62,7 +62,11 @@ paddle-speech:
tts: http://192.168.10.137:8090/paddlespeech/tts
asr: http://192.168.10.137:8090/paddlespeech/asr
rasa:
url: http://192.168.10.137:8890/rasa/talkRasa
base-url: http://192.168.10.137:8890/
talk: rasa/talkRasa
saveRasaFile: rasaFile/saveRasaFile
train: rasaCmd/trainExec
run: rasaCmd/runExec
human:
base-url: https://digital-human.jd.com
room-id: /getRoomId

Loading…
Cancel
Save