You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
7.3 KiB
Java

package com.supervision.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.model.GeBytearray;
import com.supervision.model.KgInfo;
import com.supervision.qanything.QanythingService;
import com.supervision.qanything.dto.ChatResult;
import com.supervision.qanything.dto.ResultWrapper;
import com.supervision.qanything.dto.SourceDTO;
import com.supervision.qanything.dto.UploadResult;
import com.supervision.service.GeBytearrayService;
import com.supervision.service.KGService;
import com.supervision.service.KgInfoService;
import com.supervision.vo.kg.ChatReqVo;
import com.supervision.vo.kg.ChatResVo;
import com.supervision.vo.kg.SourceKgInfo;
import com.supervision.vo.kg.UploadDocResVo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class KGServiceImpl implements KGService {
@Value("${youdao.qanthing.kbId}")
private String kbId;
private final QanythingService qanythingService;
private final KgInfoService kgInfoService;
private final GeBytearrayService geBytearrayService;
@Override
public ChatResVo chat(String question) throws NoSuchAlgorithmException {
Assert.notEmpty(question, "问题不能为空");
ArrayList<String> kbIds = CollUtil.newArrayList("KB6ada48988e694ad197af53cc0d1e4b78_240328", "KB403a6543629648a3a74b60be5707398f_240328");
ResultWrapper<ChatResult> chat = qanythingService.chat(question, kbIds);
if (!chat.isSuccess()) {
log.info("Qanything 聊天失败,原因:{}",chat.getMsg());
return ChatResVo.makeError();
}
ChatResult result = chat.getResult();
if (Objects.isNull(result)){
log.info("Qanything 聊天结果为空");
return ChatResVo.makeError();
}
ChatResVo chatResVo = ChatResVo.makeSuccess(result.getResponse());
if (CollUtil.isEmpty(result.getSource())){
return chatResVo;
}
chatResVo.setSourceKgInfoList(qaSource2SourceKgInfo(result.getSource()));
return chatResVo;
}
@Override
public ChatResVo chat(ChatReqVo chatReqVo) throws NoSuchAlgorithmException {
ChatResVo chat = this.chat(chatReqVo.getQuestion());
if (!chat.isSuccess()
||CollUtil.isEmpty(chat.getSourceKgInfoList())
|| StrUtil.isBlank(chatReqVo.getLabel())){
return chat;
}
List<SourceKgInfo> sourceKgInfoList = chat.getSourceKgInfoList().stream()
.filter(sourceKgInfo -> chatReqVo.getLabel().equals(sourceKgInfo.getLabel())).collect(Collectors.toList());
chat.setSourceKgInfoList(sourceKgInfoList);
return chat;
}
private List<SourceKgInfo> qaSource2SourceKgInfo(List<SourceDTO> sourceDTOList){
List<SourceKgInfo> sourceKgInfoList = CollUtil.newArrayList();
if (CollUtil.isEmpty(sourceDTOList)){
return sourceKgInfoList;
}
// 根据fileId进行去重
sourceDTOList = sourceDTOList.stream().filter(distinctPredicate(SourceDTO::getFileId)).collect(Collectors.toList());
// qanything 中文件id与文件内容的映射
Map<String, String> fileIdMapContent = sourceDTOList.stream().collect(Collectors.toMap(SourceDTO::getFileId, SourceDTO::getContent, (o, n) -> o));
Set<String> qaFileId = sourceDTOList.stream().map(SourceDTO::getFileId).filter(Objects::nonNull).collect(Collectors.toSet());
// 优先选择的文档id
Set<String> preferFileIds = new HashSet<>();
if (CollUtil.isNotEmpty(qaFileId)) {
// 优先使用数据库中配置的信息
List<KgInfo> preferKgInfoList = kgInfoService.lambdaQuery().in(KgInfo::getFileQaDocId, qaFileId).list();
List<SourceKgInfo> sourceKgInfos = preferKgInfoList.stream().peek(kgInfo->preferFileIds.add(kgInfo.getFileQaDocId()))
.peek(kgInfo->{
if (StrUtil.isEmpty(kgInfo.getContent())){
kgInfo.setSummary(fileIdMapContent.get(kgInfo.getFileQaDocId()));
}
}).map(SourceKgInfo::kgInfo2SourceKgInfo)
.filter(source-> StrUtil.isNotEmpty(source.getSummary())).collect(Collectors.toList());
sourceKgInfoList.addAll(sourceKgInfos);
}
// 数据库中未配置信息则是使用Qanything返回的信息
List<SourceKgInfo> candidateKgInfoList = sourceDTOList.stream().filter(source -> !preferFileIds.contains(source.getFileId()))
.map(SourceKgInfo::sourceDTO2SourceKgInfo).collect(Collectors.toList());
sourceKgInfoList.addAll(candidateKgInfoList);
return sourceKgInfoList;
}
@Override
public IPage<KgInfo> hotKG(Page<KgInfo> page) {
return kgInfoService.lambdaQuery().eq(KgInfo::getHotFlag,1).page(page);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void uploadDoc(KgInfo kgInfo,String kbId, File file) throws IOException {
// 保存到qanthing中
try {
ResultWrapper<List<UploadResult>> resultWrapper = qanythingService.uploadDoc(kbId, file);
if (!resultWrapper.isSuccess() || CollUtil.isEmpty(resultWrapper.getResult())){
log.error("Qanything 上传失败,原因:{}",resultWrapper.getMsg());
return;
}
List<UploadResult> result = resultWrapper.getResult();
if (CollUtil.size(result)>1){
log.warn("Qanything 上传成功,但是返回结果不止一个,返回结果:{}",result);
}
kgInfo.setFileQaDocId(CollUtil.getFirst(result).getFileId());
} catch (NoSuchAlgorithmException e) {
log.error("Qanything 签名失败",e);
throw new RuntimeException(e);
}
//保存到数据库中
GeBytearray geBytearray = new GeBytearray();
geBytearray.setFileName(file.getName());
geBytearray.setFileSize(file.length());
geBytearray.setContent(IoUtil.readBytes(Files.newInputStream(file.toPath())));
geBytearrayService.save(geBytearray);
// 保存文件信息
kgInfo.setFileByteId(geBytearray.getId());
kgInfoService.save(kgInfo);
}
@Override
public UploadDocResVo uploadDoc(String kbId, MultipartFile multipartFile) {
return null;
}
public static <K> Predicate<K> distinctPredicate(Function<K, Object> function) {
ConcurrentHashMap<Object, Boolean> map = new ConcurrentHashMap<>();
// 根据key进行去重,并排除为null的数据
return t -> null == map.putIfAbsent(function.apply(t), true);
}
}