You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.7 KiB
Java
81 lines
2.7 KiB
Java
package com.supervision.pdfqaserver.service;
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
import cn.hutool.core.lang.Assert;
|
|
import cn.hutool.core.util.StrUtil;
|
|
import com.supervision.pdfqaserver.domain.QuestionHandlerMapping;
|
|
import com.supervision.pdfqaserver.dto.TextVectorDTO;
|
|
import jakarta.annotation.PostConstruct;
|
|
import lombok.RequiredArgsConstructor;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.springframework.ai.embedding.Embedding;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
import org.springframework.context.ApplicationContext;
|
|
import org.springframework.stereotype.Service;
|
|
|
|
import java.util.HashMap;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
|
|
/**
|
|
* 检索器调度器
|
|
*/
|
|
@Slf4j
|
|
@Service
|
|
@RequiredArgsConstructor
|
|
public class RetrieverDispatcher {
|
|
|
|
private final ApplicationContext applicationContext;
|
|
|
|
private final AiCallService aiCallService;
|
|
|
|
private final TextVectorService textVectorService;
|
|
|
|
private final QuestionHandlerMappingService questionHandlerMappingService;
|
|
|
|
@Value("${retriever.threshold:0.8}")
|
|
private double threshold; // 相似度阈值
|
|
|
|
private final Map<String, Retriever> retrieverMap = new HashMap<>();
|
|
|
|
|
|
/**
|
|
* 根据类型获取对应的检索器
|
|
*
|
|
* @param query 查询内容
|
|
* @return 检索器实例
|
|
*/
|
|
public Retriever mapping(String query) {
|
|
if (StrUtil.isEmpty(query)) {
|
|
log.warn("查询内容为空,无法获取检索器");
|
|
return null;
|
|
}
|
|
Embedding embedding = aiCallService.embedding(query);
|
|
|
|
List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(embedding.getOutput(), threshold, 1);
|
|
if (CollUtil.isEmpty(similarByCosine)) {
|
|
log.info("问题:{},未找到相似文本向量,匹配阈值:{}", query, threshold);
|
|
return null;
|
|
}
|
|
TextVectorDTO textVectorDTO = CollUtil.getFirst(similarByCosine);
|
|
Assert.notEmpty(textVectorDTO.getCategoryId(), "相似文本向量的分类ID不能为空");
|
|
QuestionHandlerMapping handler = questionHandlerMappingService.findHandlerByCategoryId(textVectorDTO.getCategoryId());
|
|
if (handler == null){
|
|
return null;
|
|
}
|
|
return retrieverMap.get(handler.getHandler());
|
|
}
|
|
|
|
@PostConstruct
|
|
public void init() {
|
|
applicationContext.getBeansOfType(Retriever.class)
|
|
.forEach((name, retriever) -> {
|
|
if (retrieverMap.containsKey(name)) {
|
|
throw new IllegalArgumentException("Retriever with name " + name + " already exists.");
|
|
}
|
|
retrieverMap.put(name, retriever);
|
|
});
|
|
}
|
|
}
|