You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

297 lines
16 KiB
Java

package com.supervision.service.impl;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.dto.roundAsk.EntityQuestionDTO;
import com.supervision.dto.roundAsk.ItemNodeDTO;
import com.supervision.dto.roundAsk.SessionParamDTO;
import com.supervision.enums.EntityQuestionEnum;
import com.supervision.enums.IdentifyIntentEnum;
import com.supervision.exception.BusinessException;
import com.supervision.handler.gpt.AnswerQuestionHandler;
import com.supervision.handler.gpt.ConditionJudgeHandler;
import com.supervision.handler.gpt.IdentifyIntentHandler;
import com.supervision.handler.gpt.ItemExtractHandler;
import com.supervision.handler.graph.FindConditionPathHandler;
import com.supervision.handler.graph.FindItemDetailHandler;
import com.supervision.handler.graph.FindItemNodeHandler;
import com.supervision.ngbatis.domain.tag.Condition;
import com.supervision.ngbatis.domain.tag.ItemLeaf;
import com.supervision.service.AskService;
import com.supervision.vo.RoundTalkReqVO;
import com.supervision.vo.RoundTalkResVO;
import com.supervision.vo.UserParamReqVO;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.bind.annotation.RequestBody;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class AskServiceImpl implements AskService {
private final RedisTemplate<String, Object> redisTemplate;
private final IdentifyIntentHandler identifyIntentHandler;
private final ItemExtractHandler itemExtractHandler;
private final FindItemNodeHandler findItemNodeHandler;
private final FindConditionPathHandler findConditionPathHandler;
private final ConditionJudgeHandler conditionJudgeHandler;
private final FindItemDetailHandler findItemDetailHandler;
private final AnswerQuestionHandler answerQuestionHandler;
private static final String SESSION_PARAM = "KBQA:ASK:SESSION_PARAM:";
private static final String USER_PARAM = "KBQA:ASK:USER_PARAM:";
@Override
public RoundTalkResVO roundTalk(RoundTalkReqVO roundTalkReqVO) {
String sessionId = roundTalkReqVO.getSessionId();
// 去Redis中,首先判断session处在哪个阶段,是否有识别的意图
Object cache = redisTemplate.opsForValue().get(SESSION_PARAM + sessionId);
SessionParamDTO sessionParamDTO;
if (ObjectUtil.isEmpty(cache)) {
sessionParamDTO = new SessionParamDTO();
sessionParamDTO.setOriginalQuestion(roundTalkReqVO.getUserTalk());
sessionParamDTO.setSessionId(sessionId);
sessionParamDTO.setAlreadyMatchEntitySet(new HashSet<>());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
} else {
sessionParamDTO = BeanUtil.toBean(cache, SessionParamDTO.class);
}
// 判断意图是否为空,如果意图为空,进行识别意图
if (StrUtil.isBlank(sessionParamDTO.getIntent())) {
String intent = identifyIntentHandler.identifyIntent(roundTalkReqVO.getUserTalk());
sessionParamDTO.setIntent(intent);
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
}
// 识别出来意图之后,再去判断是否识别过实体
if (CollUtil.isEmpty(sessionParamDTO.getEntityValueByExtract())) {
// 识别实体(先从图中获取所有的节点名称,然后识别)
List<String> allItemNode = findItemNodeHandler.findAllItemNode();
List<String> extractValue = itemExtractHandler.itemExtractByPossibleItem(sessionParamDTO.getOriginalQuestion(), allItemNode);
sessionParamDTO.setEntityValueByExtract(extractValue);
// 根据提取的内容,开始在知识图谱中寻找节点(首先找叶子节点,如果叶子节点有数据,直接返回,如果叶子节点没数据,再去找分支节点)
List<ItemLeaf> allMatchLeafNode = findItemNodeHandler.findAllMatchLeafNode(extractValue);
// 如果找到的节点只有1个,那么说明问的就是这个节点,那么直接缓存起来进行下一步
if (allMatchLeafNode.size() == 1) {
ItemLeaf itemLeaf = allMatchLeafNode.get(0);
sessionParamDTO.setMatchItemLeaf(itemLeaf);
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
} else {
// 如果不等于1,说明可能有不确定的节点,这时就要开始找节点
Map<String, ItemLeaf> waitMatchItemLeafMap = allMatchLeafNode.stream().collect(Collectors.toMap(ItemLeaf::getVid, Function.identity(), (k1, k2) -> k1));
// 所有的实体类型以及出现次数计数
Map<String, Integer> entityCountMap = new HashMap<>();
// 用来存放所有的节点以及节点路径(key是节点ID,value是节点路径)
Map<String, List<List<Condition>>> conditionPathMap = new HashMap<>();
// 进行遍历.寻找所有的路径
for (String leafId : waitMatchItemLeafMap.keySet()) {
List<List<Condition>> conditionPath = findConditionPathHandler.findConditionPath(leafId);
if (CollUtil.isEmpty(conditionPath)) {
waitMatchItemLeafMap.remove(leafId);
} else {
// 如果路径不为空,则放到缓存中去
conditionPathMap.put(leafId, conditionPath);
// 然后根据条件进行计数
for (List<Condition> conditions : conditionPath) {
for (Condition condition : conditions) {
// 如果不存在,就添加进计数.如果存在,就+1
entityCountMap.compute(condition.getEntityType(), (k, v) -> v == null ? 1 : v + 1);
}
}
}
}
// 缓存到Redis中
sessionParamDTO.setWaitMatchItemLeafMap(waitMatchItemLeafMap);
sessionParamDTO.setConditionPathMap(conditionPathMap);
sessionParamDTO.setEntityCountMap(entityCountMap);
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
}
}
// 如果判断过实体,这时就要判断是否已经确认了节点,如果没有确认,在这里进行确认
match:
if (ObjectUtil.isEmpty(sessionParamDTO.getMatchItemLeaf())) {
// 如果没有确定节点,且没有路径可供匹配,抛出异常
if (CollUtil.isEmpty(sessionParamDTO.getConditionPathMap())) {
throw new BusinessException("未找到条件判断路径");
}
// 判断待匹配的节点是不是只有一个了,如果有多个,就从路径中选择一个问题问前端
if (sessionParamDTO.getWaitMatchItemLeafMap().size() != 1) {
if (ObjectUtil.isNotEmpty(sessionParamDTO.getCurrentEntity())) {
// 如果当前对话实体不为空,说明当前问答就是上一个问题的回复,这个时候,就去GPT中进行匹配
// 遍历所有的path,找到回答
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
Set<String> possibleAnswerSet = new HashSet<>();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
List<List<Condition>> conditionPath = entry.getValue();
// 遍历所有的路径,找到回答
for (List<Condition> conditions : conditionPath) {
// 遍历所有的条件,找到回答
for (Condition condition : conditions) {
// 如果当前对话实体和条件中的实体类型相同,就添加到possibleAnswerSet中
if (sessionParamDTO.getCurrentEntity().getCurrentEntityType().equals(condition.getEntityType())) {
possibleAnswerSet.add(condition.getCondition());
}
}
}
}
String judgeResult = conditionJudgeHandler.conditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), possibleAnswerSet, roundTalkReqVO.getUserTalk());
log.info("GPT判断结果:{}", judgeResult);
Set<String> judgeResultSet = new HashSet<>(Arrays.asList(judgeResult.split(";")));
// 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径
pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO);
filterNotMatchNode(sessionParamDTO);
// 加到已匹配的实体类型,下次不再匹配
sessionParamDTO.getAlreadyMatchEntitySet().add(sessionParamDTO.getCurrentEntity().getCurrentEntityType());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
// 如果排除后只剩一个了,这时跳出多轮问答
if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) {
sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
break match;
}
}
// 首先获取出现次数最多的实体类型
Map<String, Integer> newCountMap = countCondition(sessionParamDTO);
String mostFrequentType = newCountMap.entrySet().stream()
.filter(entry -> !sessionParamDTO.getAlreadyMatchEntitySet().contains(entry.getKey()))
.max(Map.Entry.comparingByValue(Integer::compareTo))
.map(Map.Entry::getKey).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
// 获取这个类型对应的问题
String question = EntityQuestionEnum.getQuestionByEntityType(mostFrequentType);
Optional.ofNullable(question).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
sessionParamDTO.setCurrentEntity(EntityQuestionDTO.builder().currentEntityType(mostFrequentType).currentQuestion(question).build());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
// 返回这个问题给前端(触发了多轮问法)
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion(question).build();
} else {
// 获取到唯一节点
ItemLeaf itemLeaf = sessionParamDTO.getWaitMatchItemLeafMap().values().stream().findFirst().orElseThrow(() -> new BusinessException("未找到条件判断路径"));
sessionParamDTO.setMatchItemLeaf(itemLeaf);
}
}
// 走到这里,说明就只有一个节点了,那么就可以进行下一步了
log.info("走到这里,说明找到了匹配的节点,开始根据用户的意图生成");
String intent = sessionParamDTO.getIntent();
ItemLeaf matchItemLeaf = sessionParamDTO.getMatchItemLeaf();
// 根据用户的意图找到对应的节点
IdentifyIntentEnum intentEnum = IdentifyIntentEnum.getEnumByIntent(intent);
if (ObjectUtil.isEmpty(intentEnum) || null == intentEnum) {
throw new BusinessException("暂不支持该意图的问答");
}
// 根据意图和节点,找到对应的结果
List<String> itemDetail = findItemDetailHandler.findItemDetail(matchItemLeaf.getVid(), intentEnum.getTagType(), intentEnum.getEdgeType());
if (CollUtil.isEmpty(itemDetail)) {
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion("暂不支持该意图的问答").build();
}
// 提交GPT,问问题的答案
String answer = answerQuestionHandler.answerQuestion(sessionParamDTO.getOriginalQuestion(), itemDetail);
if (StrUtil.isBlank(answer)) {
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion("暂时还不会回答这个问题哦").build();
}
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion(answer).build();
}
private Map<String, Integer> countCondition(SessionParamDTO sessionParamDTO) {
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
// 所有的实体类型以及出现次数计数
Map<String, Integer> entityCountMap = new HashMap<>();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
// 然后根据条件进行计数
for (List<Condition> conditions : entry.getValue()) {
for (Condition condition : conditions) {
// 如果不存在,就添加进计数.如果存在,就+1
entityCountMap.compute(condition.getEntityType(), (k, v) -> v == null ? 1 : v + 1);
}
}
}
sessionParamDTO.setEntityCountMap(entityCountMap);
redisTemplate.opsForValue().set(SESSION_PARAM + sessionParamDTO.getSessionId(), sessionParamDTO);
return entityCountMap;
}
/**
* 根据比较结果进行路径筛选,把不符合的路径进行移除
*/
private void pathFilterByJudgeResult(String entityType, Set<String> judgeResultSet, SessionParamDTO sessionParamDTO) {
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
List<List<Condition>> conditionPath = entry.getValue();
List<List<Condition>> toRemove = new ArrayList<>();
// 遍历所有的路径,找到需要移除的路径
for (List<Condition> conditions : conditionPath) {
// 遍历所有的条件,判断是否需要移除该路径
boolean shouldRemove = false;
for (Condition condition : conditions) {
if (entityType.equals(condition.getEntityType()) && !judgeResultSet.contains(condition.getCondition())) {
shouldRemove = true;
break;
}
}
if (shouldRemove) {
toRemove.add(conditions);
}
}
// 移除所有需要移除的路径
conditionPath.removeAll(toRemove);
}
System.out.println(1);
}
/**
* 过滤掉不匹配的路径
*/
private void filterNotMatchNode(SessionParamDTO sessionParamDTO) {
Set<String> emptyPathNodeIdSet = new HashSet<>();
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
Iterator<Map.Entry<String, List<List<Condition>>>> iterator = conditionPathMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, List<List<Condition>>> entry = iterator.next();
if (CollUtil.isEmpty(entry.getValue())) {
emptyPathNodeIdSet.add(entry.getKey());
iterator.remove();
}
}
Map<String, ItemLeaf> waitMatchItemLeafMap = sessionParamDTO.getWaitMatchItemLeafMap();
for (String nodeId : emptyPathNodeIdSet) {
if (emptyPathNodeIdSet.contains(nodeId)) {
waitMatchItemLeafMap.remove(nodeId);
}
}
}
@Override
public void saveUserParam(UserParamReqVO paramReqVO) {
// 缓存到Redis中
if (CollUtil.isNotEmpty(paramReqVO.getParamList()) && StrUtil.isNotBlank(paramReqVO.getSessionId())) {
redisTemplate.opsForValue().set(USER_PARAM + paramReqVO.getSessionId(), paramReqVO.getParamList());
}
}
}