KBQA代码提交

main
liu 11 months ago
parent 746b38e04c
commit f662e3b1db

@ -1,6 +1,7 @@
package com.supervision.handler.gpt;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.BooleanUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.ai.AiUtil;
@ -8,9 +9,7 @@ import com.supervision.ai.dto.MessageDTO;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.*;
/**
* handler
@ -19,7 +18,15 @@ import java.util.List;
@Component
public class ConditionJudgeHandler {
public String conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
/**
* ,
*
* @param question
* @param candidateAnswerList
* @param userAnswer
* @return
*/
public Set<String> conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
List<MessageDTO> messageList = new ArrayList<>();
messageList.add(new MessageDTO("system", "现在要做社会保障业务分类,我现在给你一个问题,给你候选答案列表,请你根据用户的实际回答,从候选答案列表中给我选择对应的候选答案.除了候选答案,什么其他的都不要说."));
messageList.add(new MessageDTO("assistant", "好的"));
@ -29,8 +36,36 @@ public class ConditionJudgeHandler {
messageList.add(new MessageDTO("assistant", "继续"));
messageList.add(new MessageDTO("user", StrUtil.format("用户答案:[{}],现在请给我匹配的候选答案,其他什么都不要说.如果有多个候选答案,用;号分割", userAnswer)));
log.info("conditionJudge判断候选答案:{}", JSONUtil.toJsonStr(messageList));
String answer = AiUtil.chatByMessage(messageList);
log.info("conditionJudge判断结果是:{}", answer);
return answer;
String judgeResult = AiUtil.chatByMessage(messageList);
log.info("conditionJudge判断结果是:{}", judgeResult);
return new HashSet<>(Arrays.asList(judgeResult.split(";")));
}
/**
*
*
* @param question
* @param candidateAnswerList
* @param userAnswer
*/
public Set<String> newConditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer, String conditionType) {
Set<String> judgeResultSet = new HashSet<>();
String template = "当我问用户:{},用户给我的回答是:[{}]\n" +
"基于用户的回答,判断一下用户{}是否满足[{}]满足就只回复true反之只回复false";
for (String candidateAnswer : candidateAnswerList) {
String judgeResult = StrUtil.format(template, question, userAnswer, conditionType, candidateAnswer);
String answer = AiUtil.chat(judgeResult);
log.info("conditionJudge判断条件:\n{},\n结果是:{}", judgeResult, answer);
try {
if (BooleanUtil.toBoolean(answer)) {
judgeResultSet.add(candidateAnswer);
}
} catch (Exception e) {
log.info("{}非布尔类型,不统计在内", answer);
}
}
return judgeResultSet;
}
}

@ -44,9 +44,14 @@ public class IdentifyIntentHandler {
String intent = AiUtil.chatByMessage(messageList);
log.info("identifyIntent意图识别结果为:{}", intent);
// 尝试转为JSON的形式
if (StrUtil.isBlank(intent) || StrUtil.equals("未识别", intent)) {
if (StrUtil.isBlank(intent) || StrUtil.equals("未识别", intent) || intent.contains("未识别")) {
throw new IdentifyIntentException("意图未识别");
}
return intent;
for (IdentifyIntentEnum value : IdentifyIntentEnum.values()) {
if (intent.contains(value.getIntent())){
return value.getIntent();
}
}
throw new IdentifyIntentException("意图未识别");
}
}

@ -164,9 +164,10 @@ public class AskServiceImpl implements AskService {
}
}
}
String judgeResult = conditionJudgeHandler.conditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(), possibleAnswerSet, roundTalkReqVO.getUserTalk());
log.info("GPT判断结果:{}", judgeResult);
Set<String> judgeResultSet = new HashSet<>(Arrays.asList(judgeResult.split(";")));
Set<String> judgeResultSet = conditionJudgeHandler.newConditionJudge(sessionParamDTO.getCurrentEntity().getCurrentQuestion(),
possibleAnswerSet,
roundTalkReqVO.getUserTalk(),
sessionParamDTO.getCurrentEntity().getCurrentEntityType());
// 筛选路径,如果某个路径的结果不在比较结果中,说明这个结果不对,排除这个路径
pathFilterByJudgeResult(sessionParamDTO.getCurrentEntity().getCurrentEntityType(), judgeResultSet, sessionParamDTO);
filterNotMatchNode(sessionParamDTO);

@ -0,0 +1,23 @@
package com.supervision;
import cn.hutool.core.collection.CollUtil;
import com.supervision.handler.gpt.ConditionJudgeHandler;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import java.util.ArrayList;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@RunWith(SpringJUnit4ClassRunner.class)
public class GlmTest {
@Autowired
private ConditionJudgeHandler conditionJudgeHandler;
}
Loading…
Cancel
Save