|
|
|
@ -4,8 +4,12 @@ import cn.hutool.core.bean.BeanUtil;
|
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
|
|
import cn.hutool.core.util.ObjectUtil;
|
|
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
|
|
import com.supervision.dto.roundAsk.EntityQuestionDTO;
|
|
|
|
|
import com.supervision.dto.roundAsk.ItemNodeDTO;
|
|
|
|
|
import com.supervision.dto.roundAsk.SessionParamDTO;
|
|
|
|
|
import com.supervision.enums.EntityQuestionEnum;
|
|
|
|
|
import com.supervision.exception.BusinessException;
|
|
|
|
|
import com.supervision.handler.gpt.ConditionJudgeHandler;
|
|
|
|
|
import com.supervision.handler.gpt.IdentifyIntentHandler;
|
|
|
|
|
import com.supervision.handler.gpt.ItemExtractHandler;
|
|
|
|
|
import com.supervision.handler.graph.FindConditionPathHandler;
|
|
|
|
@ -13,10 +17,13 @@ import com.supervision.handler.graph.FindItemNodeHandler;
|
|
|
|
|
import com.supervision.ngbatis.domain.tag.Condition;
|
|
|
|
|
import com.supervision.ngbatis.domain.tag.ItemLeaf;
|
|
|
|
|
import com.supervision.service.AskService;
|
|
|
|
|
import com.supervision.vo.RoundTalkReqVO;
|
|
|
|
|
import com.supervision.vo.RoundTalkResVO;
|
|
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
import org.springframework.data.redis.core.RedisTemplate;
|
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
import org.springframework.web.bind.annotation.RequestBody;
|
|
|
|
|
|
|
|
|
|
import java.util.*;
|
|
|
|
|
import java.util.function.Function;
|
|
|
|
@ -37,32 +44,38 @@ public class AskServiceImpl implements AskService {
|
|
|
|
|
|
|
|
|
|
private final FindConditionPathHandler findConditionPathHandler;
|
|
|
|
|
|
|
|
|
|
private final ConditionJudgeHandler conditionJudgeHandler;
|
|
|
|
|
|
|
|
|
|
private static final String SESSION_PARAM = "KBQA:ASK:SESSION_PARAM:";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public void ask(String sessionId, String question) {
|
|
|
|
|
public RoundTalkResVO roundTalk(RoundTalkReqVO roundTalkReqVO) {
|
|
|
|
|
String sessionId = roundTalkReqVO.getSessionId();
|
|
|
|
|
// 去Redis中,首先判断session处在哪个阶段,是否有识别的意图
|
|
|
|
|
Object cache = redisTemplate.opsForValue().get(SESSION_PARAM + sessionId);
|
|
|
|
|
SessionParamDTO sessionParamDTO;
|
|
|
|
|
if (ObjectUtil.isEmpty(cache)) {
|
|
|
|
|
sessionParamDTO = new SessionParamDTO();
|
|
|
|
|
sessionParamDTO.setOriginalQuestion(question);
|
|
|
|
|
sessionParamDTO.setOriginalQuestion(roundTalkReqVO.getUserTalk());
|
|
|
|
|
sessionParamDTO.setSessionId(sessionId);
|
|
|
|
|
sessionParamDTO.setAlreadyMatchEntitySet(new HashSet<>());
|
|
|
|
|
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
|
|
|
|
|
} else {
|
|
|
|
|
sessionParamDTO = BeanUtil.toBean(cache, SessionParamDTO.class);
|
|
|
|
|
}
|
|
|
|
|
// 判断意图是否为空,如果意图为空,进行识别意图
|
|
|
|
|
if (StrUtil.isBlank(sessionParamDTO.getIntent())) {
|
|
|
|
|
String intent = identifyIntentHandler.identifyIntent(question);
|
|
|
|
|
String intent = identifyIntentHandler.identifyIntent(roundTalkReqVO.getUserTalk());
|
|
|
|
|
sessionParamDTO.setIntent(intent);
|
|
|
|
|
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
|
|
|
|
|
}
|
|
|
|
|
// 识别出来意图之后,再去判断是否识别过实体
|
|
|
|
|
if (StrUtil.isBlank(sessionParamDTO.getEntityValueByExtract())) {
|
|
|
|
|
// 识别实体
|
|
|
|
|
String extractValue = itemExtractHandler.itemExtract(sessionParamDTO.getOriginalQuestion());
|
|
|
|
|
if (CollUtil.isEmpty(sessionParamDTO.getEntityValueByExtract())) {
|
|
|
|
|
// 识别实体(先从图中获取所有的节点名称,然后识别)
|
|
|
|
|
List<String> allItemNode = findItemNodeHandler.findAllItemNode();
|
|
|
|
|
List<String> extractValue = itemExtractHandler.itemExtractByPossibleItem(sessionParamDTO.getOriginalQuestion(), allItemNode);
|
|
|
|
|
|
|
|
|
|
sessionParamDTO.setEntityValueByExtract(extractValue);
|
|
|
|
|
// 根据提取的内容,开始在知识图谱中寻找节点(首先找叶子节点,如果叶子节点有数据,直接返回,如果叶子节点没数据,再去找分支节点)
|
|
|
|
|
List<ItemLeaf> allMatchLeafNode = findItemNodeHandler.findAllMatchLeafNode(extractValue);
|
|
|
|
@ -74,18 +87,21 @@ public class AskServiceImpl implements AskService {
|
|
|
|
|
} else {
|
|
|
|
|
// 如果不等于1,说明可能有不确定的节点,这时就要开始找节点
|
|
|
|
|
Map<String, ItemLeaf> waitMatchItemLeafMap = allMatchLeafNode.stream().collect(Collectors.toMap(ItemLeaf::getVid, Function.identity(), (k1, k2) -> k1));
|
|
|
|
|
sessionParamDTO.setWaitMatchItemLeafMap(waitMatchItemLeafMap);
|
|
|
|
|
// 开始寻找条件路径
|
|
|
|
|
Set<String> itemLeafIdSet = waitMatchItemLeafMap.keySet();
|
|
|
|
|
|
|
|
|
|
// 所有的实体类型以及出现次数计数
|
|
|
|
|
Map<String, Integer> entityCountMap = new HashMap<>();
|
|
|
|
|
// 用来存放所有的节点以及节点路径(key是节点ID,value是节点路径)
|
|
|
|
|
Map<String, List<List<Condition>>> conditionPathMap = new HashMap<>();
|
|
|
|
|
|
|
|
|
|
for (String leafId : itemLeafIdSet) {
|
|
|
|
|
// 进行遍历.寻找所有的路径
|
|
|
|
|
for (String leafId : waitMatchItemLeafMap.keySet()) {
|
|
|
|
|
List<List<Condition>> conditionPath = findConditionPathHandler.findConditionPath(leafId);
|
|
|
|
|
if (CollUtil.isEmpty(conditionPath)) {
|
|
|
|
|
waitMatchItemLeafMap.remove(leafId);
|
|
|
|
|
} else {
|
|
|
|
|
// 如果路径不为空,则放到缓存中去
|
|
|
|
|
conditionPathMap.put(leafId, conditionPath);
|
|
|
|
|
// 然后根据条件进行计数
|
|
|
|
|
for (List<Condition> conditions : conditionPath) {
|
|
|
|
|
for (Condition condition : conditions) {
|
|
|
|
|
// 如果不存在,就添加进计数.如果存在,就+1
|
|
|
|
@ -94,7 +110,121 @@ public class AskServiceImpl implements AskService {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// 缓存到Redis中
|
|
|
|
|
sessionParamDTO.setWaitMatchItemLeafMap(waitMatchItemLeafMap);
|
|
|
|
|
sessionParamDTO.setConditionPathMap(conditionPathMap);
|
|
|
|
|
sessionParamDTO.setEntityCountMap(entityCountMap);
|
|
|
|
|
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// 如果判断过实体,这时就要判断是否已经确认了节点,如果没有确认,在这里进行确认
|
|
|
|
|
match:
|
|
|
|
|
if (ObjectUtil.isEmpty(sessionParamDTO.getMatchItemLeaf())) {
|
|
|
|
|
// 如果没有确定节点,且没有路径可供匹配,抛出异常
|
|
|
|
|
if (CollUtil.isEmpty(sessionParamDTO.getConditionPathMap())) {
|
|
|
|
|
throw new BusinessException("未找到条件判断路径");
|
|
|
|
|
}
|
|
|
|
|
// 判断待匹配的节点是不是只有一个了,如果有多个,就从路径中选择一个问题问前端
|
|
|
|
|
if (sessionParamDTO.getWaitMatchItemLeafMap().size() != 1) {
|
|
|
|
|
if (ObjectUtil.isNotEmpty(sessionParamDTO.getCurrentEntity())) {
|
|
|
|
|
// 如果当前对话实体不为空,说明当前问答就是上一个问题的回复,这个时候,就去GPT中进行匹配
|
|
|
|
|
// 遍历所有的path,找到回答
|
|
|
|
|
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
|
|
|
|
|
Set<String> possibleAnswerSet = new HashSet<>();
|
|
|
|
|
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
|
|
|
|
|
List<List<Condition>> conditionPath = entry.getValue();
|
|
|
|
|
// 遍历所有的路径,找到回答
|
|
|
|
|
for (List<Condition> conditions : conditionPath) {
|
|
|
|
|
// 遍历所有的条件,找到回答
|
|
|
|
|
for (Condition condition : conditions) {
|
|
|
|
|
// 如果当前对话实体和条件中的实体类型相同,就添加到possibleAnswerSet中
|
|
|
|
|
if (sessionParamDTO.getCurrentEntity().getCurrentEntityType().equals(condition.getEntityType())) {
|
|
|
|
|
possibleAnswerSet.add(condition.getCondition());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
String judgeResult = conditionJudgeHandler.conditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), possibleAnswerSet, roundTalkReqVO.getUserTalk());
|
|
|
|
|
Set<String> judgeResultSet = new HashSet<>(Arrays.asList(judgeResult.split(";")));
|
|
|
|
|
// 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径
|
|
|
|
|
pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO);
|
|
|
|
|
filterNotMatchNode(sessionParamDTO);
|
|
|
|
|
// 加到已匹配的项目
|
|
|
|
|
sessionParamDTO.getAlreadyMatchEntitySet().add(sessionParamDTO.getCurrentEntity().getCurrentEntityType());
|
|
|
|
|
// 如果排除后只剩一个了,这时跳出多轮问答
|
|
|
|
|
if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) {
|
|
|
|
|
sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next());
|
|
|
|
|
break match;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// 首先获取出现次数最多的实体类型
|
|
|
|
|
String mostFrequentType = sessionParamDTO.getEntityCountMap().entrySet().stream()
|
|
|
|
|
.filter(entry -> !sessionParamDTO.getAlreadyMatchEntitySet().contains(entry.getKey()))
|
|
|
|
|
.max(Map.Entry.comparingByValue(Integer::compareTo))
|
|
|
|
|
.map(Map.Entry::getKey).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
|
|
|
|
|
// 获取这个类型对应的问题
|
|
|
|
|
String question = EntityQuestionEnum.getQuestionByEntityType(mostFrequentType);
|
|
|
|
|
Optional.ofNullable(question).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
|
|
|
|
|
sessionParamDTO.setCurrentEntity(EntityQuestionDTO.builder().currentEntityType(mostFrequentType).currentQuestion(question).build());
|
|
|
|
|
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
|
|
|
|
|
// 返回这个问题给前端
|
|
|
|
|
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion(question).build();
|
|
|
|
|
} else {
|
|
|
|
|
// 获取到唯一节点
|
|
|
|
|
ItemLeaf itemLeaf = sessionParamDTO.getWaitMatchItemLeafMap().values().stream().findFirst().orElseThrow(() -> new BusinessException("未找到条件判断路径"));
|
|
|
|
|
sessionParamDTO.setMatchItemLeaf(itemLeaf);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// 走到这里,说明就只有一个节点了,那么就可以进行下一步了
|
|
|
|
|
log.info("走到这里,说明找到了匹配的节点");
|
|
|
|
|
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 根据比较结果进行路径筛选,把不符合的路径进行移除
|
|
|
|
|
*/
|
|
|
|
|
private void pathFilterByJudgeResult(String entityType, Set<String> judgeResultSet, SessionParamDTO sessionParamDTO) {
|
|
|
|
|
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
|
|
|
|
|
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
|
|
|
|
|
List<List<Condition>> conditionPath = entry.getValue();
|
|
|
|
|
// 遍历所有的路径,找到回答
|
|
|
|
|
for (int i = 0; i < conditionPath.size(); i++) {
|
|
|
|
|
List<Condition> conditions = conditionPath.get(i);
|
|
|
|
|
// 遍历所有的条件,找到回答
|
|
|
|
|
for (Condition condition : conditions) {
|
|
|
|
|
// 如果当前对话实体和条件中的实体类型相同,且不在比较结果中,说明这个结果不对,排除这个路径
|
|
|
|
|
if (entityType.equals(condition.getEntityType()) && !judgeResultSet.contains(condition.getCondition())) {
|
|
|
|
|
conditionPath.remove(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 过滤掉不匹配的路径
|
|
|
|
|
*/
|
|
|
|
|
private void filterNotMatchNode(SessionParamDTO sessionParamDTO) {
|
|
|
|
|
Set<String> emptyPathNodeIdSet = new HashSet<>();
|
|
|
|
|
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
|
|
|
|
|
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
|
|
|
|
|
if (CollUtil.isEmpty(entry.getValue())) {
|
|
|
|
|
emptyPathNodeIdSet.add(entry.getKey());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Map<String, ItemLeaf> waitMatchItemLeafMap = sessionParamDTO.getWaitMatchItemLeafMap();
|
|
|
|
|
for (String nodeId : emptyPathNodeIdSet) {
|
|
|
|
|
if (emptyPathNodeIdSet.contains(nodeId)) {
|
|
|
|
|
waitMatchItemLeafMap.remove(nodeId);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 进行比较操作
|
|
|
|
|
*/
|
|
|
|
|
private void conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|