From e7f6bd140762fde061ff95ee4924dfce4e44f3e9 Mon Sep 17 00:00:00 2001 From: liu Date: Tue, 9 Apr 2024 12:06:48 +0800 Subject: [PATCH] =?UTF-8?q?KBQA=E4=BB=A3=E7=A0=81=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../handler/gpt/ConditionJudgeHandler.java | 10 ++- .../service/impl/AskServiceImpl.java | 68 ++++++++++++++----- 2 files changed, 57 insertions(+), 21 deletions(-) diff --git a/kbqa-graph/src/main/java/com/supervision/handler/gpt/ConditionJudgeHandler.java b/kbqa-graph/src/main/java/com/supervision/handler/gpt/ConditionJudgeHandler.java index abe1a89..3275cbc 100644 --- a/kbqa-graph/src/main/java/com/supervision/handler/gpt/ConditionJudgeHandler.java +++ b/kbqa-graph/src/main/java/com/supervision/handler/gpt/ConditionJudgeHandler.java @@ -2,6 +2,7 @@ package com.supervision.handler.gpt; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.supervision.ai.AiUtil; import com.supervision.ai.dto.MessageDTO; import lombok.extern.slf4j.Slf4j; @@ -20,13 +21,16 @@ public class ConditionJudgeHandler { public String conditionJudge(String question, Collection candidateAnswerList, String userAnswer) { List messageList = new ArrayList<>(); - messageList.add(new MessageDTO("system", "你是一个条件判断模型且精通政务事项,我现在给你一个问题,给你候选答案,请你根据用户的实际回答,从候选答案中给我选择一个对应的候选答案.除了候选答案,什么其他的都不要说.")); + messageList.add(new MessageDTO("system", "你是一个条件判断模型且精通政务事项,我现在给你一个问题,给你候选答案列表,请你根据用户的实际回答,从候选答案列表中给我选择对应的候选答案.除了候选答案,什么其他的都不要说.")); messageList.add(new MessageDTO("assistant", "好的")); messageList.add(new MessageDTO("user", StrUtil.format("问题:[{}]", question))); messageList.add(new MessageDTO("assistant", "继续")); - messageList.add(new MessageDTO("user", StrUtil.format("候选答案:[{};未找到匹配答案]", StrUtil.join(";", candidateAnswerList)))); + messageList.add(new MessageDTO("user", StrUtil.format("候选答案列表:[{};未找到匹配答案]", StrUtil.join(";", candidateAnswerList)))); messageList.add(new MessageDTO("assistant", "继续")); messageList.add(new MessageDTO("user", StrUtil.format("用户答案:[{}],现在请给我匹配的候选答案,其他什么都不要说.如果有多个候选答案,用;号分割", userAnswer))); - return AiUtil.chatByMessage(messageList); + log.info("conditionJudge判断候选答案:{}", JSONUtil.toJsonStr(messageList)); + String answer = AiUtil.chatByMessage(messageList); + log.info("conditionJudge判断结果是:{}", answer); + return answer; } } diff --git a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java index aa6eccb..d67fdbe 100644 --- a/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/kbqa-graph/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -145,20 +145,24 @@ public class AskServiceImpl implements AskService { } } String judgeResult = conditionJudgeHandler.conditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), possibleAnswerSet, roundTalkReqVO.getUserTalk()); + log.info("GPT判断结果:{}", judgeResult); Set judgeResultSet = new HashSet<>(Arrays.asList(judgeResult.split(";"))); // 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径 pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO); filterNotMatchNode(sessionParamDTO); - // 加到已匹配的项目 + // 加到已匹配的实体类型,下次不再匹配 sessionParamDTO.getAlreadyMatchEntitySet().add(sessionParamDTO.getCurrentEntity().getCurrentEntityType()); + redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); // 如果排除后只剩一个了,这时跳出多轮问答 if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) { sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next()); + redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); break match; } } // 首先获取出现次数最多的实体类型 - String mostFrequentType = sessionParamDTO.getEntityCountMap().entrySet().stream() + Map newCountMap = countCondition(sessionParamDTO); + String mostFrequentType = newCountMap.entrySet().stream() .filter(entry -> !sessionParamDTO.getAlreadyMatchEntitySet().contains(entry.getKey())) .max(Map.Entry.comparingByValue(Integer::compareTo)) .map(Map.Entry::getKey).orElseThrow(() -> new BusinessException("未找到条件判断路径")); @@ -167,7 +171,7 @@ public class AskServiceImpl implements AskService { Optional.ofNullable(question).orElseThrow(() -> new BusinessException("未找到条件判断路径")); sessionParamDTO.setCurrentEntity(EntityQuestionDTO.builder().currentEntityType(mostFrequentType).currentQuestion(question).build()); redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO); - // 返回这个问题给前端 + // 返回这个问题给前端(触发了多轮问法) return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion(question).build(); } else { // 获取到唯一节点 @@ -178,7 +182,27 @@ public class AskServiceImpl implements AskService { // 走到这里,说明就只有一个节点了,那么就可以进行下一步了 log.info("走到这里,说明找到了匹配的节点"); - return null; + + return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion("找到了匹配的节点").build(); + } + + private Map countCondition(SessionParamDTO sessionParamDTO) { + Map>> conditionPathMap = sessionParamDTO.getConditionPathMap(); + // 所有的实体类型以及出现次数计数 + Map entityCountMap = new HashMap<>(); + for (Map.Entry>> entry : conditionPathMap.entrySet()) { + // 然后根据条件进行计数 + for (List conditions : entry.getValue()) { + for (Condition condition : conditions) { + // 如果不存在,就添加进计数.如果存在,就+1 + entityCountMap.compute(condition.getEntityType(), (k, v) -> v == null ? 1 : v + 1); + } + } + } + sessionParamDTO.setEntityCountMap(entityCountMap); + redisTemplate.opsForValue().set(SESSION_PARAM + sessionParamDTO.getSessionId(), sessionParamDTO); + return entityCountMap; + } /** @@ -188,18 +212,28 @@ public class AskServiceImpl implements AskService { Map>> conditionPathMap = sessionParamDTO.getConditionPathMap(); for (Map.Entry>> entry : conditionPathMap.entrySet()) { List> conditionPath = entry.getValue(); - // 遍历所有的路径,找到回答 - for (int i = 0; i < conditionPath.size(); i++) { - List conditions = conditionPath.get(i); - // 遍历所有的条件,找到回答 + List> toRemove = new ArrayList<>(); + + // 遍历所有的路径,找到需要移除的路径 + for (List conditions : conditionPath) { + // 遍历所有的条件,判断是否需要移除该路径 + boolean shouldRemove = false; for (Condition condition : conditions) { - // 如果当前对话实体和条件中的实体类型相同,且不在比较结果中,说明这个结果不对,排除这个路径 if (entityType.equals(condition.getEntityType()) && !judgeResultSet.contains(condition.getCondition())) { - conditionPath.remove(i); + shouldRemove = true; + break; } } + + if (shouldRemove) { + toRemove.add(conditions); + } } + + // 移除所有需要移除的路径 + conditionPath.removeAll(toRemove); } + System.out.println(1); } /** @@ -208,11 +242,16 @@ public class AskServiceImpl implements AskService { private void filterNotMatchNode(SessionParamDTO sessionParamDTO) { Set emptyPathNodeIdSet = new HashSet<>(); Map>> conditionPathMap = sessionParamDTO.getConditionPathMap(); - for (Map.Entry>> entry : conditionPathMap.entrySet()) { + Iterator>>> iterator = conditionPathMap.entrySet().iterator(); + + while (iterator.hasNext()) { + Map.Entry>> entry = iterator.next(); if (CollUtil.isEmpty(entry.getValue())) { emptyPathNodeIdSet.add(entry.getKey()); + iterator.remove(); } } + Map waitMatchItemLeafMap = sessionParamDTO.getWaitMatchItemLeafMap(); for (String nodeId : emptyPathNodeIdSet) { if (emptyPathNodeIdSet.contains(nodeId)) { @@ -220,11 +259,4 @@ public class AskServiceImpl implements AskService { } } } - - /** - * 进行比较操作 - */ - private void conditionJudge(String question, Collection candidateAnswerList, String userAnswer) { - - } }