Merge remote-tracking branch 'origin/main'

main
xueqingkun 11 months ago
commit df31d7fb2e

@ -2,6 +2,7 @@ package com.supervision.handler.gpt;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.ai.AiUtil;
import com.supervision.ai.dto.MessageDTO;
import lombok.extern.slf4j.Slf4j;
@ -20,13 +21,16 @@ public class ConditionJudgeHandler {
public String conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
List<MessageDTO> messageList = new ArrayList<>();
messageList.add(new MessageDTO("system", "你是一个条件判断模型且精通政务事项,我现在给你一个问题,给你候选答案,请你根据用户的实际回答,从候选答案中给我选择一个对应的候选答案.除了候选答案,什么其他的都不要说."));
messageList.add(new MessageDTO("system", "你是一个条件判断模型且精通政务事项,我现在给你一个问题,给你候选答案列表,请你根据用户的实际回答,从候选答案列表中给我选择对应的候选答案.除了候选答案,什么其他的都不要说."));
messageList.add(new MessageDTO("assistant", "好的"));
messageList.add(new MessageDTO("user", StrUtil.format("问题:[{}]", question)));
messageList.add(new MessageDTO("assistant", "继续"));
messageList.add(new MessageDTO("user", StrUtil.format("候选答案:[{};未找到匹配答案]", StrUtil.join(";", candidateAnswerList))));
messageList.add(new MessageDTO("user", StrUtil.format("候选答案列表:[{};未找到匹配答案]", StrUtil.join(";", candidateAnswerList))));
messageList.add(new MessageDTO("assistant", "继续"));
messageList.add(new MessageDTO("user", StrUtil.format("用户答案:[{}],现在请给我匹配的候选答案,其他什么都不要说.如果有多个候选答案,用;号分割", userAnswer)));
return AiUtil.chatByMessage(messageList);
log.info("conditionJudge判断候选答案:{}", JSONUtil.toJsonStr(messageList));
String answer = AiUtil.chatByMessage(messageList);
log.info("conditionJudge判断结果是:{}", answer);
return answer;
}
}

@ -145,20 +145,24 @@ public class AskServiceImpl implements AskService {
}
}
String judgeResult = conditionJudgeHandler.conditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), possibleAnswerSet, roundTalkReqVO.getUserTalk());
log.info("GPT判断结果:{}", judgeResult);
Set<String> judgeResultSet = new HashSet<>(Arrays.asList(judgeResult.split(";")));
// 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径
pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO);
filterNotMatchNode(sessionParamDTO);
// 加到已匹配的项目
// 加到已匹配的实体类型,下次不再匹配
sessionParamDTO.getAlreadyMatchEntitySet().add(sessionParamDTO.getCurrentEntity().getCurrentEntityType());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
// 如果排除后只剩一个了,这时跳出多轮问答
if (sessionParamDTO.getWaitMatchItemLeafMap().size() == 1) {
sessionParamDTO.setMatchItemLeaf(sessionParamDTO.getWaitMatchItemLeafMap().values().iterator().next());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
break match;
}
}
// 首先获取出现次数最多的实体类型
String mostFrequentType = sessionParamDTO.getEntityCountMap().entrySet().stream()
Map<String, Integer> newCountMap = countCondition(sessionParamDTO);
String mostFrequentType = newCountMap.entrySet().stream()
.filter(entry -> !sessionParamDTO.getAlreadyMatchEntitySet().contains(entry.getKey()))
.max(Map.Entry.comparingByValue(Integer::compareTo))
.map(Map.Entry::getKey).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
@ -167,7 +171,7 @@ public class AskServiceImpl implements AskService {
Optional.ofNullable(question).orElseThrow(() -> new BusinessException("未找到条件判断路径"));
sessionParamDTO.setCurrentEntity(EntityQuestionDTO.builder().currentEntityType(mostFrequentType).currentQuestion(question).build());
redisTemplate.opsForValue().set(SESSION_PARAM + sessionId, sessionParamDTO);
// 返回这个问题给前端
// 返回这个问题给前端(触发了多轮问法)
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion(question).build();
} else {
// 获取到唯一节点
@ -178,7 +182,27 @@ public class AskServiceImpl implements AskService {
// 走到这里,说明就只有一个节点了,那么就可以进行下一步了
log.info("走到这里,说明找到了匹配的节点");
return null;
return RoundTalkResVO.builder().sessionId(sessionId).replyQuestion("找到了匹配的节点").build();
}
private Map<String, Integer> countCondition(SessionParamDTO sessionParamDTO) {
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
// 所有的实体类型以及出现次数计数
Map<String, Integer> entityCountMap = new HashMap<>();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
// 然后根据条件进行计数
for (List<Condition> conditions : entry.getValue()) {
for (Condition condition : conditions) {
// 如果不存在,就添加进计数.如果存在,就+1
entityCountMap.compute(condition.getEntityType(), (k, v) -> v == null ? 1 : v + 1);
}
}
}
sessionParamDTO.setEntityCountMap(entityCountMap);
redisTemplate.opsForValue().set(SESSION_PARAM + sessionParamDTO.getSessionId(), sessionParamDTO);
return entityCountMap;
}
/**
@ -188,18 +212,28 @@ public class AskServiceImpl implements AskService {
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
List<List<Condition>> conditionPath = entry.getValue();
// 遍历所有的路径,找到回答
for (int i = 0; i < conditionPath.size(); i++) {
List<Condition> conditions = conditionPath.get(i);
// 遍历所有的条件,找到回答
List<List<Condition>> toRemove = new ArrayList<>();
// 遍历所有的路径,找到需要移除的路径
for (List<Condition> conditions : conditionPath) {
// 遍历所有的条件,判断是否需要移除该路径
boolean shouldRemove = false;
for (Condition condition : conditions) {
// 如果当前对话实体和条件中的实体类型相同,且不在比较结果中,说明这个结果不对,排除这个路径
if (entityType.equals(condition.getEntityType()) && !judgeResultSet.contains(condition.getCondition())) {
conditionPath.remove(i);
shouldRemove = true;
break;
}
}
if (shouldRemove) {
toRemove.add(conditions);
}
}
// 移除所有需要移除的路径
conditionPath.removeAll(toRemove);
}
System.out.println(1);
}
/**
@ -208,11 +242,16 @@ public class AskServiceImpl implements AskService {
private void filterNotMatchNode(SessionParamDTO sessionParamDTO) {
Set<String> emptyPathNodeIdSet = new HashSet<>();
Map<String, List<List<Condition>>> conditionPathMap = sessionParamDTO.getConditionPathMap();
for (Map.Entry<String, List<List<Condition>>> entry : conditionPathMap.entrySet()) {
Iterator<Map.Entry<String, List<List<Condition>>>> iterator = conditionPathMap.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, List<List<Condition>>> entry = iterator.next();
if (CollUtil.isEmpty(entry.getValue())) {
emptyPathNodeIdSet.add(entry.getKey());
iterator.remove();
}
}
Map<String, ItemLeaf> waitMatchItemLeafMap = sessionParamDTO.getWaitMatchItemLeafMap();
for (String nodeId : emptyPathNodeIdSet) {
if (emptyPathNodeIdSet.contains(nodeId)) {
@ -220,11 +259,4 @@ public class AskServiceImpl implements AskService {
}
}
}
/**
*
*/
private void conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
}
}

Loading…
Cancel
Save