From 10acfb65c0d4ecd30562dac4c7fbd43a91fa709b Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Wed, 10 Jan 2024 17:52:32 +0800 Subject: [PATCH] =?UTF-8?q?rasa=20:=20=E5=AF=B9=E8=AF=9D=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0Text2vecService=E5=AE=B9=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rasa/pojo/dto/Text2vecMatchesReq.java | 12 ++++++++++++ .../rasa/service/Text2vecServiceImpl.java | 9 +++++++-- .../rasa/service/impl/RasaTalkServiceImpl.java | 13 +++++++++++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java index 63bed238..8292e709 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java @@ -11,4 +11,16 @@ public class Text2vecMatchesReq { @ApiModelProperty("相似度阈值") private Double threshold; + + public Text2vecMatchesReq() { + } + + public Text2vecMatchesReq(String querySentence) { + this.querySentence = querySentence; + } + + public Text2vecMatchesReq(String querySentence, Double threshold) { + this.querySentence = querySentence; + this.threshold = threshold; + } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java index 72490107..b20b6c4f 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java @@ -13,6 +13,7 @@ import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.util.List; @@ -23,7 +24,8 @@ import java.util.Objects; @RequiredArgsConstructor public class Text2vecServiceImpl implements Text2vecService { - private final String TEXT2VEC_SERVICE_DOMAIN = "http://127.0.0.1:5000/"; + @Value("${text2vec.service.domain}") + private String TEXT2VEC_SERVICE_DOMAIN; private final String UPDATE_DATASET_PATH = "update_dataset"; private final String MATCHES_PATH = "matches"; @@ -57,7 +59,10 @@ public class Text2vecServiceImpl implements Text2vecService { String body = HttpUtil.post(url, JSONUtil.toJsonStr(text2vecMatchesReq)); log.info("updateDataset: res is :{}",body); + JSONObject jsonBody = JSONUtil.parseObj(body); - return JSONUtil.toList(JSONUtil.parseArray(body), Text2vecMatchesRes.class); + Assert.isTrue("success".equals(jsonBody.get("status")),"查询失败"); + + return JSONUtil.toList(JSONUtil.parseArray(jsonBody.get("results")), Text2vecMatchesRes.class); } } diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java index b06a7890..ea95bd0e 100644 --- a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/impl/RasaTalkServiceImpl.java @@ -1,12 +1,15 @@ package com.supervision.rasa.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONUtil; -import com.supervision.exception.BusinessException; import com.supervision.model.RasaModelInfo; import com.supervision.rasa.pojo.dto.RasaReqDTO; import com.supervision.rasa.pojo.dto.RasaResDTO; +import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; +import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; +import com.supervision.rasa.service.Text2vecService; import com.supervision.vo.rasa.RasaTalkVo; import com.supervision.rasa.service.RasaTalkService; import com.supervision.service.RasaModeService; @@ -27,6 +30,8 @@ public class RasaTalkServiceImpl implements RasaTalkService { private String rasaUrl; private final RasaModeService rasaModeService; + + private final Text2vecService text2vecService; @Override public List talkRasa(RasaTalkVo rasaTalkVo) { @@ -44,7 +49,11 @@ public class RasaTalkServiceImpl implements RasaTalkService { List list = JSONUtil.toList(post, RasaResDTO.class); - return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); + if (CollUtil.isNotEmpty(list)){ + return list.stream().map(RasaResDTO::getText).collect(Collectors.toList()); + } + return text2vecService.matches(new Text2vecMatchesReq(rasaTalkVo.getQuestion())) + .stream().map(Text2vecMatchesRes::getSentence).collect(Collectors.toList()); } private String getRasaUrl(int port){