rasa : 对话接口添加Text2vecService容错

dev_2.1.0
xueqingkun 1 year ago
parent 229fe1a379
commit 10acfb65c0

@ -11,4 +11,16 @@ public class Text2vecMatchesReq {
@ApiModelProperty("相似度阈值") @ApiModelProperty("相似度阈值")
private Double threshold; private Double threshold;
public Text2vecMatchesReq() {
}
public Text2vecMatchesReq(String querySentence) {
this.querySentence = querySentence;
}
public Text2vecMatchesReq(String querySentence, Double threshold) {
this.querySentence = querySentence;
this.threshold = threshold;
}
} }

@ -13,6 +13,7 @@ import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j; import lombok.extern.log4j.Log4j;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
@ -23,7 +24,8 @@ import java.util.Objects;
@RequiredArgsConstructor @RequiredArgsConstructor
public class Text2vecServiceImpl implements Text2vecService { public class Text2vecServiceImpl implements Text2vecService {
private final String TEXT2VEC_SERVICE_DOMAIN = "http://127.0.0.1:5000/"; @Value("${text2vec.service.domain}")
private String TEXT2VEC_SERVICE_DOMAIN;
private final String UPDATE_DATASET_PATH = "update_dataset"; private final String UPDATE_DATASET_PATH = "update_dataset";
private final String MATCHES_PATH = "matches"; private final String MATCHES_PATH = "matches";
@ -57,7 +59,10 @@ public class Text2vecServiceImpl implements Text2vecService {
String body = HttpUtil.post(url, JSONUtil.toJsonStr(text2vecMatchesReq)); String body = HttpUtil.post(url, JSONUtil.toJsonStr(text2vecMatchesReq));
log.info("updateDataset: res is :{}",body); log.info("updateDataset: res is :{}",body);
JSONObject jsonBody = JSONUtil.parseObj(body);
return JSONUtil.toList(JSONUtil.parseArray(body), Text2vecMatchesRes.class); Assert.isTrue("success".equals(jsonBody.get("status")),"查询失败");
return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class);
} }
} }

@ -1,12 +1,15 @@
package com.supervision.rasa.service.impl; package com.supervision.rasa.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.RasaModelInfo; import com.supervision.model.RasaModelInfo;
import com.supervision.rasa.pojo.dto.RasaReqDTO; import com.supervision.rasa.pojo.dto.RasaReqDTO;
import com.supervision.rasa.pojo.dto.RasaResDTO; import com.supervision.rasa.pojo.dto.RasaResDTO;
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
import com.supervision.rasa.service.Text2vecService;
import com.supervision.vo.rasa.RasaTalkVo; import com.supervision.vo.rasa.RasaTalkVo;
import com.supervision.rasa.service.RasaTalkService; import com.supervision.rasa.service.RasaTalkService;
import com.supervision.service.RasaModeService; import com.supervision.service.RasaModeService;
@ -27,6 +30,8 @@ public class RasaTalkServiceImpl implements RasaTalkService {
private String rasaUrl; private String rasaUrl;
private final RasaModeService rasaModeService; private final RasaModeService rasaModeService;
private final Text2vecService text2vecService;
@Override @Override
public List<String> talkRasa(RasaTalkVo rasaTalkVo) { public List<String> talkRasa(RasaTalkVo rasaTalkVo) {
@ -44,8 +49,12 @@ public class RasaTalkServiceImpl implements RasaTalkService {
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class); List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
if (CollUtil.isNotEmpty(list)){
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
} }
return text2vecService.matches(new Text2vecMatchesReq(rasaTalkVo.getQuestion()))
.stream().map(Text2vecMatchesRes::getSentence).collect(Collectors.toList());
}
private String getRasaUrl(int port){ private String getRasaUrl(int port){

Loading…
Cancel
Save