You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
5.8 KiB
Java

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import com.supervision.pdfqaserver.dto.KeywordSynonymDTO;
import com.supervision.pdfqaserver.dto.TextTerm;
import com.supervision.pdfqaserver.service.*;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
/**
*
*/
@Slf4j
@Service("dataCompareRetriever")
@RequiredArgsConstructor
public class DataCompareRetriever implements Retriever {
private final TripleToCypherExecutor tripleToCypherExecutor;
private final AiCallService aiCallService;
private final KeywordSynonymService keywordSynonymService;
private final TextToSegmentService textToSegmentService;
private List<KeywordSynonymDTO> synonyms;
@Override
public List<Map<String, Object>> retrieval(String query) {
log.info("retrieval: 执行数据对比检索器,查询内容:{}", query);
if (StrUtil.isEmpty(query)) {
log.warn("查询内容为空,无法执行数据对比检索");
return new ArrayList<>();
}
// 对问题进行分词
CypherSchemaDTO schemaDTO = tripleToCypherExecutor.queryRelationSchema(query);
log.info("retrieval: 查询到的关系图谱schema 节点个数:{} ,关系结束{} ", schemaDTO.getNodes().size(), schemaDTO.getRelations().size());
log.info("retrieval: 查询到的关系图谱schema {} ", schemaDTO.format());
if (CollUtil.isEmpty(schemaDTO.getRelations()) || CollUtil.isEmpty(schemaDTO.getNodes())) {
log.info("没有找到匹配的关系或实体query: {}", query);
return new ArrayList<>();
}
// 利用大模型生成可执行的cypher语句
String prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_3);
String format = StrUtil.format(prompt, Map.of("query", query, "schema", schemaDTO.format(), "env", "- 当前时间是:" + DateUtil.now()));
log.info("retrieval: 生成的cypher语句{}", format);
String call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
if (StrUtil.isEmpty(call)) {
log.warn("retrieval: AI调用返回结果为空无法执行Cypher查询");
return new ArrayList<>();
}
List<Map<String, Object>> result = new ArrayList<>();
JSONArray js = JSONUtil.parseArray(call);
Map<String, List<Map<String, Object>>> cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return result;
}
}
if (CollUtil.isEmpty(result)){
log.info("retrieval: 执行Cypher语句无结果重新调整cypher语句{}", query);
prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4);
format = StrUtil.format(prompt,
Map.of("query", query, "schema", schemaDTO.format(),
"env", "- 当前时间是:" + DateUtil.now(),"cypher",js.toString()));
log.info("retrieval: 生成cypher的语句{}", format);
call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
js = JSONUtil.parseArray(call);
cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty2){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return result;
}
}
}
return result;
}
@Override
public String rewriteQuery(String query) {
Assert.notEmpty(query, "查询内容不能为空");
List<TextTerm> terms = textToSegmentService.segmentText(query);
return terms.stream().map(i -> {
String standardTerm = keywordSynonymService.getStandardTerm(i.getWord(), synonyms);
return standardTerm != null ? standardTerm : i.getWord();
})
.collect(Collectors.joining());
}
@PostConstruct
public void init() {
log.info("DataCompareRetriever initialized");
// 初始化同义词数据
synonyms = keywordSynonymService.listAllSynonyms();
if (CollUtil.isNotEmpty(synonyms)) {
for (KeywordSynonymDTO synonym : synonyms) {
textToSegmentService.addDict(synonym.getTerm(), synonym.getNature(), synonym.getFrequency());
if (CollUtil.isNotEmpty(synonym.getSynonyms())) {
for (KeywordSynonymDTO subSynonym : synonym.getSynonyms()) {
textToSegmentService.addDict(subSynonym.getTerm(), subSynonym.getNature(), subSynonym.getFrequency());
}
}
}
} else {
log.warn("DataCompareRetriever: 未找到任何同义词,不添加字典数据...");
}
}
}