package com.supervision.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.io.IoUtil; import cn.hutool.core.lang.Assert; import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.supervision.model.GeBytearray; import com.supervision.model.KgInfo; import com.supervision.qanything.QanythingService; import com.supervision.qanything.dto.ChatResult; import com.supervision.qanything.dto.ResultWrapper; import com.supervision.qanything.dto.SourceDTO; import com.supervision.qanything.dto.UploadResult; import com.supervision.service.GeBytearrayService; import com.supervision.service.KGService; import com.supervision.service.KgInfoService; import com.supervision.vo.kg.ChatReqVo; import com.supervision.vo.kg.ChatResVo; import com.supervision.vo.kg.SourceKgInfo; import com.supervision.vo.kg.UploadDocResVo; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.multipart.MultipartFile; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.security.NoSuchAlgorithmException; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @Slf4j @Service @RequiredArgsConstructor public class KGServiceImpl implements KGService { @Value("${youdao.qanthing.kbId}") private String kbId; private final QanythingService qanythingService; private final KgInfoService kgInfoService; private final GeBytearrayService geBytearrayService; @Override public ChatResVo chat(String question) throws NoSuchAlgorithmException { Assert.notEmpty(question, "问题不能为空"); ArrayList kbIds = CollUtil.newArrayList("KB6ada48988e694ad197af53cc0d1e4b78_240328", "KB403a6543629648a3a74b60be5707398f_240328"); ResultWrapper chat = qanythingService.chat(question, kbIds); if (!chat.isSuccess()) { log.info("Qanything 聊天失败,原因:{}",chat.getMsg()); return ChatResVo.makeError(); } ChatResult result = chat.getResult(); if (Objects.isNull(result)){ log.info("Qanything 聊天结果为空"); return ChatResVo.makeError(); } ChatResVo chatResVo = ChatResVo.makeSuccess(result.getResponse()); if (CollUtil.isEmpty(result.getSource())){ return chatResVo; } chatResVo.setSourceKgInfoList(qaSource2SourceKgInfo(result.getSource())); return chatResVo; } @Override public ChatResVo chat(ChatReqVo chatReqVo) throws NoSuchAlgorithmException { ChatResVo chat = this.chat(chatReqVo.getQuestion()); if (!chat.isSuccess() ||CollUtil.isEmpty(chat.getSourceKgInfoList()) || StrUtil.isBlank(chatReqVo.getLabel())){ return chat; } List sourceKgInfoList = chat.getSourceKgInfoList().stream() .filter(sourceKgInfo -> chatReqVo.getLabel().equals(sourceKgInfo.getLabel())).collect(Collectors.toList()); chat.setSourceKgInfoList(sourceKgInfoList); return chat; } private List qaSource2SourceKgInfo(List sourceDTOList){ List sourceKgInfoList = CollUtil.newArrayList(); if (CollUtil.isEmpty(sourceDTOList)){ return sourceKgInfoList; } // 根据fileId进行去重 sourceDTOList = sourceDTOList.stream().filter(distinctPredicate(SourceDTO::getFileId)).collect(Collectors.toList()); // qanything 中文件id与文件内容的映射 Map fileIdMapContent = sourceDTOList.stream().collect(Collectors.toMap(SourceDTO::getFileId, SourceDTO::getContent, (o, n) -> o)); Set qaFileId = sourceDTOList.stream().map(SourceDTO::getFileId).filter(Objects::nonNull).collect(Collectors.toSet()); // 优先选择的文档id Set preferFileIds = new HashSet<>(); if (CollUtil.isNotEmpty(qaFileId)) { // 优先使用数据库中配置的信息 List preferKgInfoList = kgInfoService.lambdaQuery().in(KgInfo::getFileQaDocId, qaFileId).list(); List sourceKgInfos = preferKgInfoList.stream().peek(kgInfo->preferFileIds.add(kgInfo.getFileQaDocId())) .peek(kgInfo->{ if (StrUtil.isEmpty(kgInfo.getContent())){ kgInfo.setSummary(fileIdMapContent.get(kgInfo.getFileQaDocId())); } }).map(SourceKgInfo::kgInfo2SourceKgInfo) .filter(source-> StrUtil.isNotEmpty(source.getSummary())).collect(Collectors.toList()); sourceKgInfoList.addAll(sourceKgInfos); } // 数据库中未配置信息,则是使用Qanything返回的信息 List candidateKgInfoList = sourceDTOList.stream().filter(source -> !preferFileIds.contains(source.getFileId())) .map(SourceKgInfo::sourceDTO2SourceKgInfo).collect(Collectors.toList()); sourceKgInfoList.addAll(candidateKgInfoList); return sourceKgInfoList; } @Override public IPage hotKG(Page page) { return kgInfoService.lambdaQuery().eq(KgInfo::getHotFlag,1).page(page); } @Override @Transactional(rollbackFor = Exception.class) public void uploadDoc(KgInfo kgInfo,String kbId, File file) throws IOException { // 保存到qanthing中 try { ResultWrapper> resultWrapper = qanythingService.uploadDoc(kbId, file); if (!resultWrapper.isSuccess() || CollUtil.isEmpty(resultWrapper.getResult())){ log.error("Qanything 上传失败,原因:{}",resultWrapper.getMsg()); return; } List result = resultWrapper.getResult(); if (CollUtil.size(result)>1){ log.warn("Qanything 上传成功,但是返回结果不止一个,返回结果:{}",result); } kgInfo.setFileQaDocId(CollUtil.getFirst(result).getFileId()); } catch (NoSuchAlgorithmException e) { log.error("Qanything 签名失败",e); throw new RuntimeException(e); } //保存到数据库中 GeBytearray geBytearray = new GeBytearray(); geBytearray.setFileName(file.getName()); geBytearray.setFileSize(file.length()); geBytearray.setContent(IoUtil.readBytes(Files.newInputStream(file.toPath()))); geBytearrayService.save(geBytearray); // 保存文件信息 kgInfo.setFileByteId(geBytearray.getId()); kgInfoService.save(kgInfo); } @Override public UploadDocResVo uploadDoc(String kbId, MultipartFile multipartFile) { return null; } public static Predicate distinctPredicate(Function function) { ConcurrentHashMap map = new ConcurrentHashMap<>(); // 根据key进行去重,并排除为null的数据 return t -> null == map.putIfAbsent(function.apply(t), true); } }