You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

179 lines
7.3 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.supervision.model.GeBytearray;
import com.supervision.model.KgInfo;
import com.supervision.qanything.QanythingService;
import com.supervision.qanything.dto.ChatResult;
import com.supervision.qanything.dto.ResultWrapper;
import com.supervision.qanything.dto.SourceDTO;
import com.supervision.qanything.dto.UploadResult;
import com.supervision.service.GeBytearrayService;
import com.supervision.service.KGService;
import com.supervision.service.KgInfoService;
import com.supervision.vo.kg.ChatReqVo;
import com.supervision.vo.kg.ChatResVo;
import com.supervision.vo.kg.SourceKgInfo;
import com.supervision.vo.kg.UploadDocResVo;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.security.NoSuchAlgorithmException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class KGServiceImpl implements KGService {
@Value("${youdao.qanthing.kbId}")
private String kbId;
private final QanythingService qanythingService;
private final KgInfoService kgInfoService;
private final GeBytearrayService geBytearrayService;
@Override
public ChatResVo chat(String question) throws NoSuchAlgorithmException {
Assert.notEmpty(question, "问题不能为空");
ArrayList<String> kbIds = CollUtil.newArrayList("KB6ada48988e694ad197af53cc0d1e4b78_240328", "KB403a6543629648a3a74b60be5707398f_240328");
ResultWrapper<ChatResult> chat = qanythingService.chat(question, kbIds);
if (!chat.isSuccess()) {
log.info("Qanything 聊天失败,原因:{}",chat.getMsg());
return ChatResVo.makeError();
}
ChatResult result = chat.getResult();
if (Objects.isNull(result)){
log.info("Qanything 聊天结果为空");
return ChatResVo.makeError();
}
ChatResVo chatResVo = ChatResVo.makeSuccess(result.getResponse());
if (CollUtil.isEmpty(result.getSource())){
return chatResVo;
}
chatResVo.setSourceKgInfoList(qaSource2SourceKgInfo(result.getSource()));
return chatResVo;
}
@Override
public ChatResVo chat(ChatReqVo chatReqVo) throws NoSuchAlgorithmException {
ChatResVo chat = this.chat(chatReqVo.getQuestion());
if (!chat.isSuccess()
||CollUtil.isEmpty(chat.getSourceKgInfoList())
|| StrUtil.isBlank(chatReqVo.getLabel())){
return chat;
}
List<SourceKgInfo> sourceKgInfoList = chat.getSourceKgInfoList().stream()
.filter(sourceKgInfo -> chatReqVo.getLabel().equals(sourceKgInfo.getLabel())).collect(Collectors.toList());
chat.setSourceKgInfoList(sourceKgInfoList);
return chat;
}
private List<SourceKgInfo> qaSource2SourceKgInfo(List<SourceDTO> sourceDTOList){
List<SourceKgInfo> sourceKgInfoList = CollUtil.newArrayList();
if (CollUtil.isEmpty(sourceDTOList)){
return sourceKgInfoList;
}
// 根据fileId进行去重
sourceDTOList = sourceDTOList.stream().filter(distinctPredicate(SourceDTO::getFileId)).collect(Collectors.toList());
// qanything 中文件id与文件内容的映射
Map<String, String> fileIdMapContent = sourceDTOList.stream().collect(Collectors.toMap(SourceDTO::getFileId, SourceDTO::getContent, (o, n) -> o));
Set<String> qaFileId = sourceDTOList.stream().map(SourceDTO::getFileId).filter(Objects::nonNull).collect(Collectors.toSet());
// 优先选择的文档id
Set<String> preferFileIds = new HashSet<>();
if (CollUtil.isNotEmpty(qaFileId)) {
// 优先使用数据库中配置的信息
List<KgInfo> preferKgInfoList = kgInfoService.lambdaQuery().in(KgInfo::getFileQaDocId, qaFileId).list();
List<SourceKgInfo> sourceKgInfos = preferKgInfoList.stream().peek(kgInfo->preferFileIds.add(kgInfo.getFileQaDocId()))
.peek(kgInfo->{
if (StrUtil.isEmpty(kgInfo.getContent())){
kgInfo.setSummary(fileIdMapContent.get(kgInfo.getFileQaDocId()));
}
}).map(SourceKgInfo::kgInfo2SourceKgInfo)
.filter(source-> StrUtil.isNotEmpty(source.getSummary())).collect(Collectors.toList());
sourceKgInfoList.addAll(sourceKgInfos);
}
// 数据库中未配置信息则是使用Qanything返回的信息
List<SourceKgInfo> candidateKgInfoList = sourceDTOList.stream().filter(source -> !preferFileIds.contains(source.getFileId()))
.map(SourceKgInfo::sourceDTO2SourceKgInfo).collect(Collectors.toList());
sourceKgInfoList.addAll(candidateKgInfoList);
return sourceKgInfoList;
}
@Override
public IPage<KgInfo> hotKG(Page<KgInfo> page) {
return kgInfoService.lambdaQuery().eq(KgInfo::getHotFlag,1).page(page);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void uploadDoc(KgInfo kgInfo,String kbId, File file) throws IOException {
// 保存到qanthing中
try {
ResultWrapper<List<UploadResult>> resultWrapper = qanythingService.uploadDoc(kbId, file);
if (!resultWrapper.isSuccess() || CollUtil.isEmpty(resultWrapper.getResult())){
log.error("Qanything 上传失败,原因:{}",resultWrapper.getMsg());
return;
}
List<UploadResult> result = resultWrapper.getResult();
if (CollUtil.size(result)>1){
log.warn("Qanything 上传成功,但是返回结果不止一个,返回结果:{}",result);
}
kgInfo.setFileQaDocId(CollUtil.getFirst(result).getFileId());
} catch (NoSuchAlgorithmException e) {
log.error("Qanything 签名失败",e);
throw new RuntimeException(e);
}
//保存到数据库中
GeBytearray geBytearray = new GeBytearray();
geBytearray.setFileName(file.getName());
geBytearray.setFileSize(file.length());
geBytearray.setContent(IoUtil.readBytes(Files.newInputStream(file.toPath())));
geBytearrayService.save(geBytearray);
// 保存文件信息
kgInfo.setFileByteId(geBytearray.getId());
kgInfoService.save(kgInfo);
}
@Override
public UploadDocResVo uploadDoc(String kbId, MultipartFile multipartFile) {
return null;
}
public static <K> Predicate<K> distinctPredicate(Function<K, Object> function) {
ConcurrentHashMap<Object, Boolean> map = new ConcurrentHashMap<>();
// 根据key进行去重,并排除为null的数据
return t -> null == map.putIfAbsent(function.apply(t), true);
}
}