You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
5.8 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import com.supervision.pdfqaserver.dto.KeywordSynonymDTO;
import com.supervision.pdfqaserver.dto.TextTerm;
import com.supervision.pdfqaserver.service.*;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
/**
* 数据对比检索器
*/
@Slf4j
@Service("dataCompareRetriever")
@RequiredArgsConstructor
public class DataCompareRetriever implements Retriever {
private final TripleToCypherExecutor tripleToCypherExecutor;
private final AiCallService aiCallService;
private final KeywordSynonymService keywordSynonymService;
private final TextToSegmentService textToSegmentService;
private List<KeywordSynonymDTO> synonyms;
@Override
public List<Map<String, Object>> retrieval(String query) {
log.info("retrieval: 执行数据对比检索器,查询内容:{}", query);
if (StrUtil.isEmpty(query)) {
log.warn("查询内容为空,无法执行数据对比检索");
return new ArrayList<>();
}
// 对问题进行分词
CypherSchemaDTO schemaDTO = tripleToCypherExecutor.queryRelationSchema(query);
log.info("retrieval: 查询到的关系图谱schema 节点个数:{} ,关系结束{} ", schemaDTO.getNodes().size(), schemaDTO.getRelations().size());
log.info("retrieval: 查询到的关系图谱schema {} ", schemaDTO.format());
if (CollUtil.isEmpty(schemaDTO.getRelations()) || CollUtil.isEmpty(schemaDTO.getNodes())) {
log.info("没有找到匹配的关系或实体query: {}", query);
return new ArrayList<>();
}
// 利用大模型生成可执行的cypher语句
String prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_3);
String format = StrUtil.format(prompt, Map.of("query", query, "schema", schemaDTO.format(), "env", "- 当前时间是:" + DateUtil.now()));
log.info("retrieval: 生成的cypher语句{}", format);
String call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
if (StrUtil.isEmpty(call)) {
log.warn("retrieval: AI调用返回结果为空无法执行Cypher查询");
return new ArrayList<>();
}
List<Map<String, Object>> result = new ArrayList<>();
JSONArray js = JSONUtil.parseArray(call);
Map<String, List<Map<String, Object>>> cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return result;
}
}
if (CollUtil.isEmpty(result)){
log.info("retrieval: 执行Cypher语句无结果重新调整cypher语句{}", query);
prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4);
format = StrUtil.format(prompt,
Map.of("query", query, "schema", schemaDTO.format(),
"env", "- 当前时间是:" + DateUtil.now(),"cypher",js.toString()));
log.info("retrieval: 生成cypher的语句{}", format);
call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
js = JSONUtil.parseArray(call);
cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty2){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return result;
}
}
}
return result;
}
@Override
public String rewriteQuery(String query) {
Assert.notEmpty(query, "查询内容不能为空");
List<TextTerm> terms = textToSegmentService.segmentText(query);
return terms.stream().map(i -> {
String standardTerm = keywordSynonymService.getStandardTerm(i.getWord(), synonyms);
return standardTerm != null ? standardTerm : i.getWord();
})
.collect(Collectors.joining());
}
@PostConstruct
public void init() {
log.info("DataCompareRetriever initialized");
// 初始化同义词数据
synonyms = keywordSynonymService.listAllSynonyms();
if (CollUtil.isNotEmpty(synonyms)) {
for (KeywordSynonymDTO synonym : synonyms) {
textToSegmentService.addDict(synonym.getTerm(), synonym.getNature(), synonym.getFrequency());
if (CollUtil.isNotEmpty(synonym.getSynonyms())) {
for (KeywordSynonymDTO subSynonym : synonym.getSynonyms()) {
textToSegmentService.addDict(subSynonym.getTerm(), subSynonym.getNature(), subSynonym.getFrequency());
}
}
}
} else {
log.warn("DataCompareRetriever: 未找到任何同义词,不添加字典数据...");
}
}
}