From 82e1578c1ebfdd74d40ee3a36ff5cbf51744027f Mon Sep 17 00:00:00 2001 From: liu Date: Wed, 24 Apr 2024 15:16:49 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/controller/AskController.java | 6 +-- .../com/supervision/service/AskService.java | 2 +- .../service/impl/AskServiceImpl.java | 45 ++++++++++++------- nGQL/demo_data.ngql | 2 +- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/kbqa-graph/src/main/java/com/supervision/controller/AskController.java b/kbqa-graph/src/main/java/com/supervision/controller/AskController.java index ff7b692..a73ba37 100644 --- a/kbqa-graph/src/main/java/com/supervision/controller/AskController.java +++ b/kbqa-graph/src/main/java/com/supervision/controller/AskController.java @@ -65,10 +65,10 @@ public class AskController { } - @ApiOperation("多轮对话中用户手动填写参数,可能直接返回结果") + @ApiOperation("多轮对话中用户手动填写参数") @PostMapping("saveUserParam") - public RoundTalkResVO saveUserParam(@RequestBody RoleSetReqVO paramReqVO) { - return askService.saveUserParam(paramReqVO); + public void saveUserParam(@RequestBody RoleSetReqVO paramReqVO) { + askService.saveUserParam(paramReqVO); } @ApiOperation("查询多轮对话中用户需要填写的参数") diff --git a/kbqa-graph/src/main/java/com/supervision/service/AskService.java b/kbqa-graph/src/main/java/com/supervision/service/AskService.java index 2340e59..62c5c2f 100644 --- a/kbqa-graph/src/main/java/com/supervision/service/AskService.java +++ b/kbqa-graph/src/main/java/com/supervision/service/AskService.java @@ -11,7 +11,7 @@ public interface AskService { SingleTalkResVO singleTalk(SingleTalkReqVO singleTalkReqVO); - RoundTalkResVO saveUserParam(RoleSetReqVO paramReqVO); + void saveUserParam(RoleSetReqVO paramReqVO); RoleSetResVO queryUserNeedParam(String sessionId); } diff --git a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java index 9957e6b..4a82104 100644 --- a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -107,6 +107,7 @@ public class AskServiceImpl implements AskService { // 如果不等于1,说明可能有不确定的节点,这时就要开始找节点 Map waitMatchItemLeafMap = allMatchLeafNode.stream().collect(Collectors.toMap(ItemLeaf::getVid, Function.identity(), (k1, k2) -> k1)); + // 所有的实体类型以及出现次数计数 Map entityCountMap = new HashMap<>(); // 用来存放所有的节点以及节点路径(key是节点ID,value是节点路径) @@ -134,8 +135,11 @@ public class AskServiceImpl implements AskService { sessionParamDTO.setConditionPathMap(conditionPathMap); sessionParamDTO.setEntityCountMap(entityCountMap); redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); + } } + // 多轮问答支持设置角色,如果设置角色了,在这里应该先根据角色来排除一遍路径 + filterUserRoleSet(sessionParamDTO); // 如果判断过实体,这时就要判断是否已经确认了节点,如果没有确认,在这里进行确认 match: if (ObjectUtil.isEmpty(sessionParamDTO.getMatchItemLeaf())) { @@ -147,7 +151,7 @@ public class AskServiceImpl implements AskService { if (sessionParamDTO.getWaitMatchItemLeafMap().size() != 1) { if (ObjectUtil.isNotEmpty(sessionParamDTO.getCurrentEntity())) { - // 如果当前对话实体不为空,说明当前问答就是上一个问题的回复,这个时候,就去GPT中进行匹配并排除路径 + // 如果当前对话实体不为空,说明当前问答就是上一个问题的回复,这个时候,就去大模型中进行匹配并排除路径 filterPath(sessionParamDTO, roundTalkReqVO.getUserTalk()); // 如果排除后只剩一个了,这时跳出多轮问答 if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) { @@ -361,16 +365,15 @@ public class AskServiceImpl implements AskService { return SingleTalkResVO.builder().answerText(answer).build(); } - @Override - public RoundTalkResVO saveUserParam(RoleSetReqVO paramReqVO) { - // 缓存到Redis中 - if (CollUtil.isNotEmpty(paramReqVO.getParamMap()) && StrUtil.isNotBlank(paramReqVO.getSessionId())) { - redisTemplate.opsForValue().set(USER_PARAM + paramReqVO.getSessionId(), paramReqVO.getParamMap()); - // 这里获取session进行筛选,将不匹配的路径移除掉 - Object cache = redisTemplate.opsForValue().get(SESSION_PARAM + paramReqVO.getSessionId()); - Optional.ofNullable(cache).orElseThrow(() -> new BusinessException("未找到的会话ID")); - SessionParamDTO sessionParamDTO = BeanUtil.toBean(cache, SessionParamDTO.class); - for (Map.Entry entry : paramReqVO.getParamMap().entrySet()) { + /** + * 对用户的角色设置路径进行排除 + */ + private void filterUserRoleSet(SessionParamDTO sessionParamDTO) { + // 再获取用户填写的缓存 + Object userParamObject = redisTemplate.opsForValue().get(USER_PARAM + sessionParamDTO.getSessionId()); + if (ObjectUtil.isNotEmpty(userParamObject)) { + Map stringObjectMap = BeanUtil.beanToMap(userParamObject, false, true); + for (Map.Entry entry : stringObjectMap.entrySet()) { String key = entry.getKey(); // 去枚举里面去找 RetireRoleEnum roleEnum = Arrays.stream(RetireRoleEnum.values()).filter(e -> e.getCode().equals(key)).findFirst().orElseThrow(() -> new BusinessException("未找到的参数")); @@ -379,19 +382,27 @@ public class AskServiceImpl implements AskService { Map entityCountMap = sessionParamDTO.getEntityCountMap(); // 如果包含,就去尝试排除路径 if (entityCountMap.containsKey(roleEnum.getEntityEnum().getEntityType())) { - filterPath(sessionParamDTO, entry.getValue()); + filterPath(sessionParamDTO, String.valueOf(entry.getValue())); } } } - // 在这里,判断是不是只生一个匹配的了,如果只剩一个匹配的了,就直接回复了,不需要再问了 - // 如果排除后只剩一个了,这时跳出多轮问答 + // 排除之后,再看看是不是确定了 if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) { sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next()); - redisTemplate.opsForValue().set(SESSION_PARAM + paramReqVO.getSessionId(), sessionParamDTO); - return afterMatchReturnAnswer(sessionParamDTO); } + // 缓存到Redis中 + redisTemplate.opsForValue().set(SESSION_PARAM + sessionParamDTO.getSessionId(), sessionParamDTO); + } + + } + + @Override + public void saveUserParam(RoleSetReqVO paramReqVO) { + // 缓存到Redis中 + if (CollUtil.isNotEmpty(paramReqVO.getParamMap()) && StrUtil.isNotBlank(paramReqVO.getSessionId())) { + redisTemplate.opsForValue().set(USER_PARAM + paramReqVO.getSessionId(), paramReqVO.getParamMap()); + } - return null; } @Override diff --git a/nGQL/demo_data.ngql b/nGQL/demo_data.ngql index b1b7cd8..695efb2 100644 --- a/nGQL/demo_data.ngql +++ b/nGQL/demo_data.ngql @@ -38,7 +38,7 @@ insert edge `process_condition_edge`() values "1-1-2"->"1-1-2-1":(); insert vertex `condition` ( `condition`, `entity_type`) values "1-1-2-1-1":("城乡居民退休", "退休类型"); insert vertex `condition` ( `condition`, `entity_type`) values "1-1-2-1-1-1":("港澳台和外籍人员", "户口所在地"); insert vertex `condition` ( `condition`, `entity_type`) values "1-1-2-1-1-1-1":("年满60岁", "年龄"); -insert vertex `condition` ( `condition`, `entity_type`) values "1-1-2-1-1-1-1-1":("非居民","缴费年限"); +insert vertex `condition` ( `condition`, `entity_type`) values "1-1-2-1-1-1-1-1":("实际缴费年限+视同缴费年限满15年","缴费年限"); # 建立连接关系,非深圳城乡居民退休条件下的所有办理条件 insert edge `condition_edge`() values "1-1-2-1"->"1-1-2-1-1":(),"1-1-2-1-1"->"1-1-2-1-1-1":(),"1-1-2-1-1-1"->"1-1-2-1-1-1-1":(),"1-1-2-1-1-1-1"->"1-1-2-1-1-1-1-1":();