You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.7 KiB
Java

package com.supervision.pdfqaserver.service;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.domain.QuestionHandlerMapping;
import com.supervision.pdfqaserver.dto.TextVectorDTO;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.embedding.Embedding;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* 检索器调度器
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class RetrieverDispatcher {
private final ApplicationContext applicationContext;
private final AiCallService aiCallService;
private final TextVectorService textVectorService;
private final QuestionHandlerMappingService questionHandlerMappingService;
@Value("${retriever.threshold:0.8}")
private double threshold; // 相似度阈值
private final Map<String, Retriever> retrieverMap = new HashMap<>();
/**
* 根据类型获取对应的检索器
*
* @param query 查询内容
* @return 检索器实例
*/
public Retriever mapping(String query) {
if (StrUtil.isEmpty(query)) {
log.warn("查询内容为空,无法获取检索器");
return null;
}
Embedding embedding = aiCallService.embedding(query);
List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(embedding.getOutput(), threshold, 1);
if (CollUtil.isEmpty(similarByCosine)) {
log.info("问题:{},未找到相似文本向量,匹配阈值:{}", query, threshold);
return null;
}
TextVectorDTO textVectorDTO = CollUtil.getFirst(similarByCosine);
Assert.notEmpty(textVectorDTO.getCategoryId(), "相似文本向量的分类ID不能为空");
QuestionHandlerMapping handler = questionHandlerMappingService.findHandlerByCategoryId(textVectorDTO.getCategoryId());
if (handler == null){
return null;
}
return retrieverMap.get(handler.getHandler());
}
@PostConstruct
public void init() {
applicationContext.getBeansOfType(Retriever.class)
.forEach((name, retriever) -> {
if (retrieverMap.containsKey(name)) {
throw new IllegalArgumentException("Retriever with name " + name + " already exists.");
}
retrieverMap.put(name, retriever);
});
}
}