From a2dbd7a68d3783cb830b9074e3f50070df47b0bc Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Wed, 10 Jan 2024 17:00:34 +0800 Subject: [PATCH] =?UTF-8?q?rasa=20:=20=E6=B7=BB=E5=8A=A0=20text2vec?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rasa/controller/Text2vecController.java | 37 +++++++++++ .../rasa/pojo/dto/Text2vecDataVo.java | 14 +++++ .../rasa/pojo/dto/Text2vecMatchesReq.java | 14 +++++ .../rasa/pojo/dto/Text2vecMatchesRes.java | 17 +++++ .../rasa/service/Text2vecService.java | 24 +++++++ .../rasa/service/Text2vecServiceImpl.java | 63 +++++++++++++++++++ 6 files changed, 169 insertions(+) create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/Text2vecController.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesRes.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java create mode 100644 virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/Text2vecController.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/Text2vecController.java new file mode 100644 index 00000000..3ae58ca9 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/controller/Text2vecController.java @@ -0,0 +1,37 @@ +package com.supervision.rasa.controller; + +import com.supervision.rasa.pojo.dto.Text2vecDataVo; +import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; +import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; +import com.supervision.rasa.service.Text2vecService; +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.util.List; + +@Api(tags = "text2vec服务") +@RestController +@RequestMapping("text2vec") +@RequiredArgsConstructor +public class Text2vecController { + + private final Text2vecService text2vecService; + @ApiOperation("更新数据库") + @PostMapping("updateDataset") + public boolean talkRasa(@RequestBody List text2vecDataVoList){ + + return text2vecService.updateDataset(text2vecDataVoList); + } + + @ApiOperation("获取匹配项") + @PostMapping("matches") + public List matches(@RequestBody Text2vecMatchesReq text2vecMatchesReq){ + + return text2vecService.matches(text2vecMatchesReq); + } +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java new file mode 100644 index 00000000..15b748da --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecDataVo.java @@ -0,0 +1,14 @@ +package com.supervision.rasa.pojo.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +@Data +public class Text2vecDataVo { + + @ApiModelProperty("数据id") + private String id; + + @ApiModelProperty("问题") + private String question; +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java new file mode 100644 index 00000000..63bed238 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesReq.java @@ -0,0 +1,14 @@ +package com.supervision.rasa.pojo.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +@Data +public class Text2vecMatchesReq { + + @ApiModelProperty("需要被匹配的语句") + private String querySentence; + + @ApiModelProperty("相似度阈值") + private Double threshold; +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesRes.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesRes.java new file mode 100644 index 00000000..76d423da --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/pojo/dto/Text2vecMatchesRes.java @@ -0,0 +1,17 @@ +package com.supervision.rasa.pojo.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +@Data +public class Text2vecMatchesRes { + + @ApiModelProperty("id") + private String id; + + @ApiModelProperty("句子") + private String sentence; + + @ApiModelProperty("相似度") + private String similarity; +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java new file mode 100644 index 00000000..9d064b67 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecService.java @@ -0,0 +1,24 @@ +package com.supervision.rasa.service; + +import com.supervision.rasa.pojo.dto.Text2vecDataVo; +import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; +import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; + +import java.util.List; + +public interface Text2vecService { + + /** + * 更新数据 + * @param text2vecDataVoList 数据集合 + * @return 是否更新成功 + */ + boolean updateDataset(List text2vecDataVoList); + + /** + * 语句匹配 + * @param text2vecMatchesReq + * @return + */ + List matches(Text2vecMatchesReq text2vecMatchesReq); +} diff --git a/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java new file mode 100644 index 00000000..72490107 --- /dev/null +++ b/virtual-patient-rasa/src/main/java/com/supervision/rasa/service/Text2vecServiceImpl.java @@ -0,0 +1,63 @@ +package com.supervision.rasa.service; + +import cn.hutool.core.lang.Assert; +import cn.hutool.core.util.StrUtil; +import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSON; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; +import com.supervision.rasa.pojo.dto.Text2vecDataVo; +import com.supervision.rasa.pojo.dto.Text2vecMatchesReq; +import com.supervision.rasa.pojo.dto.Text2vecMatchesRes; +import lombok.RequiredArgsConstructor; +import lombok.extern.log4j.Log4j; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Objects; + +@Slf4j +@Service +@RequiredArgsConstructor +public class Text2vecServiceImpl implements Text2vecService { + + private final String TEXT2VEC_SERVICE_DOMAIN = "http://127.0.0.1:5000/"; + + private final String UPDATE_DATASET_PATH = "update_dataset"; + private final String MATCHES_PATH = "matches"; + private final String GET_ALL_SIMILARITIES_PATH = "get_all_similarities"; + @Override + public boolean updateDataset(List text2vecDataVoList) { + + Assert.notEmpty(text2vecDataVoList, "数据不能为空"); + text2vecDataVoList.forEach(vo->{ + Assert.notEmpty(vo.getId(), "id不能为空"); + Assert.notEmpty(vo.getQuestion(), "question不能为空"); + }); + + String url = TEXT2VEC_SERVICE_DOMAIN + UPDATE_DATASET_PATH; + log.info("updateDataset: url is : {}",url); + + String body = HttpUtil.post(url, JSONUtil.toJsonStr(text2vecDataVoList)); + log.info("updateDataset: res is :{}",body); + + return "success".equals(JSONUtil.parseObj(body).get("status")); + } + + @Override + public List matches(Text2vecMatchesReq text2vecMatchesReq) { + + Assert.notEmpty(text2vecMatchesReq.getQuerySentence(), "querySentence不能为空"); + + String path = Objects.isNull(text2vecMatchesReq.getThreshold()) ? MATCHES_PATH : GET_ALL_SIMILARITIES_PATH; + String url = TEXT2VEC_SERVICE_DOMAIN + path; + log.info("matches: url is : {}",url); + + String body = HttpUtil.post(url, JSONUtil.toJsonStr(text2vecMatchesReq)); + log.info("updateDataset: res is :{}",body); + + return JSONUtil.toList(JSONUtil.parseArray(body), Text2vecMatchesRes.class); + } +}