From 304bb50f899965793f8681c1e38aec67f06f8b97 Mon Sep 17 00:00:00 2001 From: liu Date: Thu, 25 Apr 2024 16:32:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/enums/EntityQuestionEnum.java | 2 + .../handler/gpt/ItemExtractHandler.java | 2 +- .../service/impl/AskServiceImpl.java | 64 +++++++++++++------ 3 files changed, 46 insertions(+), 22 deletions(-) diff --git a/kbqa-graph/src/main/java/com/supervision/enums/EntityQuestionEnum.java b/kbqa-graph/src/main/java/com/supervision/enums/EntityQuestionEnum.java index b3022bd..a242948 100644 --- a/kbqa-graph/src/main/java/com/supervision/enums/EntityQuestionEnum.java +++ b/kbqa-graph/src/main/java/com/supervision/enums/EntityQuestionEnum.java @@ -35,4 +35,6 @@ public enum EntityQuestionEnum { } return null; } + + } diff --git a/kbqa-graph/src/main/java/com/supervision/handler/gpt/ItemExtractHandler.java b/kbqa-graph/src/main/java/com/supervision/handler/gpt/ItemExtractHandler.java index 1bd3027..48f7be0 100644 --- a/kbqa-graph/src/main/java/com/supervision/handler/gpt/ItemExtractHandler.java +++ b/kbqa-graph/src/main/java/com/supervision/handler/gpt/ItemExtractHandler.java @@ -42,7 +42,7 @@ public class ItemExtractHandler { "输入:我是南京市户口,可以在深圳办理退休吗?输出:省外户口企业职工退休\n" + "输入:我是澳门人,在深圳好多年了,可以根据城乡居民来办理退休吗?输出:港澳台和外籍人员城乡居民退休\n" + "输入:今天中午吃什么?输出:无关问题\n" + - "请你学习上面示例中的输入和输出的提取方式。现在我给你一句话,请给我输出:\n" + + "请你学习上面示例中的输入和输出的提取方式。现在我给你一句话,请给我输出。\n" + "输入:{}\n" + "输出:"; List messageList = new ArrayList<>(); diff --git a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java index f670b16..93cdd2a 100644 --- a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -81,12 +81,18 @@ public class AskServiceImpl implements AskService { sessionParamDTO = new SessionParamDTO(); sessionParamDTO.setOriginalQuestion(roundTalkReqVO.getUserTalk()); sessionParamDTO.setSessionId(sessionId); + sessionParamDTO.setRoleSetId(roundTalkReqVO.getRoleSetId()); sessionParamDTO.setAlreadyMatchEntitySet(new HashSet<>()); sessionParamDTO.setTalkRecord(new HashMap<>()); redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); } else { sessionParamDTO = BeanUtil.toBean(cache, SessionParamDTO.class); } + // 如果前端传了角色设置ID,就更新缓存 + if (StrUtil.isNotBlank(roundTalkReqVO.getRoleSetId())) { + sessionParamDTO.setRoleSetId(roundTalkReqVO.getRoleSetId()); + redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); + } // 判断意图是否为空,如果意图为空,进行识别意图 if (StrUtil.isBlank(sessionParamDTO.getIntent())) { String intent = identifyIntentHandler.identifyIntentExample(roundTalkReqVO.getUserTalk()); @@ -157,7 +163,7 @@ public class AskServiceImpl implements AskService { if (ObjectUtil.isNotEmpty(sessionParamDTO.getCurrentEntity())) { // 如果当前对话实体不为空,说明当前问答就是上一个问题的回复,这个时候,就去大模型中进行匹配并排除路径 - filterPath(sessionParamDTO, roundTalkReqVO.getUserTalk()); + filterPath(sessionParamDTO, sessionParamDTO.getCurrentEntity().getCurrentEntityType(), sessionParamDTO.getCurrentEntity().getCurrentQuestion(), roundTalkReqVO.getUserTalk()); // 如果排除后只剩一个了,这时跳出多轮问答 if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) { sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next()); @@ -216,24 +222,31 @@ public class AskServiceImpl implements AskService { return RoundTalkResVO.builder().sessionId(sessionParamDTO.getSessionId()).answerText(answer).build(); } - private void filterPath(SessionParamDTO sessionParamDTO, String userTalk) { + /** + * 根据用户回答进行答案过滤(建立在当前节点存在的情况) + * + * @param sessionParamDTO dto + * @param currentQuestion 当前的回答 + * @param userTalk 用户的回答 + */ + private void filterPath(SessionParamDTO sessionParamDTO, String currentEntityType, String currentQuestion, String userTalk) { // 遍历所有的path,找到回答 - Set possibleAnswerSet = findPossibleAnswerSet(sessionParamDTO); - Set judgeResultSet = conditionJudgeHandler.conditionJudgeAll(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), + Set possibleAnswerSet = findPossibleAnswerSet(sessionParamDTO, currentEntityType); + Set judgeResultSet = conditionJudgeHandler.conditionJudgeAll(currentQuestion, possibleAnswerSet, userTalk); // 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径 - pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO); + pathFilterByJudgeResult(currentEntityType, judgeResultSet, sessionParamDTO); filterNotMatchNode(sessionParamDTO); // 加到已匹配的实体类型,下次不再匹配 - sessionParamDTO.getAlreadyMatchEntitySet().add(sessionParamDTO.getCurrentEntity().getCurrentEntityType()); + sessionParamDTO.getAlreadyMatchEntitySet().add(currentEntityType); // 保存用户的对话记录 - sessionParamDTO.getTalkRecord().put(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), userTalk); + sessionParamDTO.getTalkRecord().put(currentEntityType, userTalk); // 缓存到Redis中 redisTemplate.opsForValue().set(SESSION_PARAM + sessionParamDTO.getSessionId(), sessionParamDTO); } - private Set findPossibleAnswerSet(SessionParamDTO sessionParamDTO) { + private Set findPossibleAnswerSet(SessionParamDTO sessionParamDTO, String currentEntityType) { Map>> conditionPathMap = sessionParamDTO.getConditionPathMap(); Set possibleAnswerSet = new HashSet<>(); for (Map.Entry>> entry : conditionPathMap.entrySet()) { @@ -243,7 +256,7 @@ public class AskServiceImpl implements AskService { // 遍历所有的条件,找到回答 for (Condition condition : conditions) { // 如果当前对话实体和条件中的实体类型相同,就添加到possibleAnswerSet中 - if (sessionParamDTO.getCurrentEntity().getCurrentEntityType().equals(condition.getEntityType())) { + if (currentEntityType.equals(condition.getEntityType())) { possibleAnswerSet.add(condition.getCondition()); } } @@ -377,24 +390,33 @@ public class AskServiceImpl implements AskService { // 再获取用户填写的缓存 Object userParamObject = redisTemplate.opsForValue().get(USER_PARAM + sessionParamDTO.getRoleSetId()); if (ObjectUtil.isNotEmpty(userParamObject)) { + log.info("用户填写了角色设置,先对用户填写的角色设置进行过滤"); List list = JSONUtil.toList(JSONUtil.toJsonStr(userParamObject), RoleSetNode.class); for (RoleSetNode roleSetNode : list) { // 根据编码找到枚举值 - RetireRoleEnum retireRoleEnum = RetireRoleEnum.valueOf(roleSetNode.getItemEn()); - if (ObjectUtils.isNotEmpty(retireRoleEnum)) { - - Map entityCountMap = sessionParamDTO.getEntityCountMap(); - - // 如果包含,就去尝试排除路径 - if (entityCountMap.containsKey(retireRoleEnum.getZhName()) && ObjectUtil.isNotEmpty(roleSetNode.getValueNum())) { - List> answerList = retireRoleEnum.getAnswerList(); - for (Pair pair : answerList) { - // 如果枚举的key和用户填写的value相等,就排除 - if (pair.getKey().equals(roleSetNode.getValueNum())) { - filterPath(sessionParamDTO, pair.getValue()); + try { + RetireRoleEnum retireRoleEnum = RetireRoleEnum.valueOf(roleSetNode.getItemEn()); + if (ObjectUtils.isNotEmpty(retireRoleEnum)) { + + Map entityCountMap = sessionParamDTO.getEntityCountMap(); + + // 如果包含,就去尝试排除路径 + if (entityCountMap.containsKey(retireRoleEnum.getZhName()) && ObjectUtil.isNotEmpty(roleSetNode.getValueNum())) { + List> answerList = retireRoleEnum.getAnswerList(); + for (Pair pair : answerList) { + // 如果枚举的key和用户填写的value相等,就排除 + if (pair.getKey().equals(roleSetNode.getValueNum())) { + // 获取对应的枚举 + Optional any = Arrays.stream(EntityQuestionEnum.values()).filter(e -> e.getEntityType().equals(retireRoleEnum.getZhName())).findAny(); + // 然后获取结果 + any.ifPresent(entityQuestionEnum -> filterPath(sessionParamDTO, entityQuestionEnum.getEntityType(), entityQuestionEnum.getQuestion(), pair.getValue())); + + } } } } + } catch (IllegalArgumentException e) { + log.info("未找到的枚举,跳过"); } } // 排除之后,再看看是不是确定了