KBQA代码提交

main
liu 11 months ago
parent 31750e3f0a
commit 746b38e04c

@ -11,6 +11,8 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.*;
import java.util.Set;
@Slf4j
@RestController
@RequestMapping("ask")
@ -36,7 +38,7 @@ public class AskController {
@ApiOperation("查询多轮对话中用户需要填写的参数")
@GetMapping("queryUserNeedParam")
public void queryUserNeedParam(String sessionId) {
public Set<String> queryUserNeedParam(String sessionId) {
return askService.queryUserNeedParam(sessionId);
}
}

@ -21,7 +21,7 @@ public class ConditionJudgeHandler {
public String conditionJudge(String question, Collection<String> candidateAnswerList, String userAnswer) {
List<MessageDTO> messageList = new ArrayList<>();
messageList.add(new MessageDTO("system", "你是一个条件判断模型且精通政务事项,我现在给你一个问题,给你候选答案列表,请你根据用户的实际回答,从候选答案列表中给我选择对应的候选答案.除了候选答案,什么其他的都不要说."));
messageList.add(new MessageDTO("system", "现在要做社会保障业务分类,我现在给你一个问题,给你候选答案列表,请你根据用户的实际回答,从候选答案列表中给我选择对应的候选答案.除了候选答案,什么其他的都不要说."));
messageList.add(new MessageDTO("assistant", "好的"));
messageList.add(new MessageDTO("user", StrUtil.format("问题:[{}]", question)));
messageList.add(new MessageDTO("assistant", "继续"));

@ -63,5 +63,36 @@ public class ItemExtractHandler {
return Collections.singletonList("退休");
}
public String itemExtractBusiness(String question) {
List<MessageDTO> messageList = new ArrayList<>();
messageList.add(new MessageDTO("user", "我现在是要进行实体抽取任务,并且精通社保业务,现在需要抽取[业务]这个实体类型的内容。\n" +
"\n" +
"下面是一些示例:\n" +
"\n" +
"输入:女性工人一般多少岁可以退休?\n" +
"输出:{\"business\":\"退休\"}\n" +
"输入:办理退休一般需要什么条件?\n" +
"输出:{\"business\":\"退休\"}\n" +
"输入:身份证挂失怎么办?\n" +
"输出:{\"business\":\"身份证挂失\"}\n" +
"输入:退休金没发,是什么原因?\n" +
"输出:{\"business\":\"退休金发放\"}\n" +
"\n" +
"现在有一句话,请进行抽取。\n" +
"输入:" + question + "\n" +
"输出:"));
log.info("itemExtractBusiness查询语句为:{}", JSONUtil.toJsonStr(messageList));
String item = AiUtil.chatByMessage(messageList);
log.info("itemExtractBusiness查询结果为:{}", item);
boolean typeJSON = JSONUtil.isTypeJSON(item);
if (typeJSON) {
String business = JSONUtil.parseObj(item).getStr("business");
if (StrUtil.isNotBlank(business)) {
return business;
}
}
throw new ItemExtractException("未从问题中找到业务事项");
}
}

@ -6,10 +6,13 @@ import com.supervision.vo.UserParamReqVO;
import org.springframework.web.bind.annotation.RequestBody;
import java.util.List;
import java.util.Set;
public interface AskService {
RoundTalkResVO roundTalk(RoundTalkReqVO roundTalkReqVO);
void saveUserParam(UserParamReqVO paramReqVO);
Set<String> queryUserNeedParam(String sessionId);
}

@ -33,6 +33,9 @@ import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
*
*/
@Slf4j
@Service
@RequiredArgsConstructor
@ -59,6 +62,12 @@ public class AskServiceImpl implements AskService {
private static final String USER_PARAM = "KBQA:ASK:USER_PARAM:";
/**
*
*
* @param roundTalkReqVO
* @return
*/
@Override
public RoundTalkResVO roundTalk(RoundTalkReqVO roundTalkReqVO) {
String sessionId = roundTalkReqVO.getSessionId();
@ -84,11 +93,12 @@ public class AskServiceImpl implements AskService {
if (CollUtil.isEmpty(sessionParamDTO.getEntityValueByExtract())) {
// 识别实体(先从图中获取所有的节点名称,然后识别)
List<String> allItemNode = findItemNodeHandler.findAllItemNode();
List<String> extractValue = itemExtractHandler.itemExtractByPossibleItem(sessionParamDTO.getOriginalQuestion(), allItemNode);
sessionParamDTO.setEntityValueByExtract(extractValue);
//List<String> extractValue = itemExtractHandler.itemExtractByPossibleItem(sessionParamDTO.getOriginalQuestion(), allItemNode);
// 换了另外一种匹配方式
String extractValue = itemExtractHandler.itemExtractBusiness(sessionParamDTO.getOriginalQuestion());
sessionParamDTO.setEntityValueByExtract(Collections.singletonList(extractValue));
// 根据提取的内容,开始在知识图谱中寻找节点(首先找叶子节点,如果叶子节点有数据,直接返回,如果叶子节点没数据,再去找分支节点)
List<ItemLeaf> allMatchLeafNode = findItemNodeHandler.findAllMatchLeafNode(extractValue);
List<ItemLeaf> allMatchLeafNode = findItemNodeHandler.findAllMatchLeafNode(Collections.singletonList(extractValue));
// 如果找到的节点只有1个,那么说明问的就是这个节点,那么直接缓存起来进行下一步
if (allMatchLeafNode.size() == 1) {
ItemLeaf itemLeaf = allMatchLeafNode.get(0);
@ -258,7 +268,6 @@ public class AskServiceImpl implements AskService {
// 移除所有需要移除的路径
conditionPath.removeAll(toRemove);
}
System.out.println(1);
}
/**
@ -293,4 +302,15 @@ public class AskServiceImpl implements AskService {
}
}
@Override
public Set<String> queryUserNeedParam(String sessionId) {
// 首先根据session获取需要判断的条件
Object sessionCache = redisTemplate.opsForValue().get(SESSION_PARAM + sessionId);
SessionParamDTO sessionParamDTO = BeanUtil.toBean(sessionCache, SessionParamDTO.class);
// 然后获取里面的数据
return sessionParamDTO.getEntityCountMap().keySet();
}
}

@ -8,4 +8,5 @@ public class RoundTalkReqVO {
private String sessionId;
private String userTalk;
}

Loading…
Cancel
Save