|
|
package com.supervision.pdfqaserver.service.impl;
|
|
|
|
|
|
import cn.hutool.core.collection.CollUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.json.JSONObject;
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
import com.supervision.pdfqaserver.cache.PromptCache;
|
|
|
import com.supervision.pdfqaserver.dao.Neo4jRepository;
|
|
|
import com.supervision.pdfqaserver.domain.DocumentTruncation;
|
|
|
import com.supervision.pdfqaserver.domain.Intention;
|
|
|
import com.supervision.pdfqaserver.dto.AnswerDetailDTO;
|
|
|
import com.supervision.pdfqaserver.dto.DomainMetadataDTO;
|
|
|
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
|
|
|
import com.supervision.pdfqaserver.service.*;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.messages.Message;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
|
|
import org.springframework.ai.ollama.OllamaChatModel;
|
|
|
import org.springframework.ai.ollama.api.OllamaOptions;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static com.supervision.pdfqaserver.cache.PromptCache.*;
|
|
|
|
|
|
@Slf4j
|
|
|
@Service
|
|
|
@RequiredArgsConstructor
|
|
|
public class ChatServiceImpl implements ChatService {
|
|
|
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
|
|
|
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
|
|
|
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
|
|
|
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
|
|
|
private static final String PROMPT_PARAM_QUERY = "query";
|
|
|
private static final String CYPHER_QUERIES = "cypherQueries";
|
|
|
|
|
|
private final Neo4jRepository neo4jRepository;
|
|
|
|
|
|
private final OllamaChatModel ollamaChatModel;
|
|
|
|
|
|
private final DomainMetadataService domainMetadataService;
|
|
|
|
|
|
private final AiCallService aiCallService;
|
|
|
|
|
|
private final DocumentTruncationService documentTruncationService;
|
|
|
|
|
|
|
|
|
private final IntentionService intentionService;
|
|
|
|
|
|
@Override
|
|
|
public Flux<String> knowledgeQA(String userQuery) {
|
|
|
|
|
|
List<Intention> intentions = intentionService.listAllPassed();
|
|
|
List<Intention> relations = classifyIntents(userQuery, intentions);
|
|
|
if (CollUtil.isEmpty(relations)){
|
|
|
log.info("没有匹配到意图,返回查无结果");
|
|
|
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
|
|
|
}
|
|
|
|
|
|
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
|
|
|
if (CollUtil.isEmpty(domainMetadataDTOS)){
|
|
|
log.info("没有匹配到领域元数据,返回查无结果");
|
|
|
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
|
|
|
}
|
|
|
//将三个集合分别转换为英文逗号分隔的字符串
|
|
|
String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(","));
|
|
|
String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(","));
|
|
|
String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(","));
|
|
|
|
|
|
//LLM生成CYPHER
|
|
|
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
|
|
|
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn,
|
|
|
PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST,
|
|
|
targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
|
|
|
log.info("生成CYPHER语句的消息:{}", textToCypherMessage);
|
|
|
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText();
|
|
|
|
|
|
log.info(cypherJsonStr);
|
|
|
log.info(cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim());
|
|
|
cypherJsonStr = cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
|
|
|
List<String> cypherQueries;
|
|
|
try {
|
|
|
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
|
|
|
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
|
|
|
.toList(String.class);
|
|
|
} catch (Exception e) {
|
|
|
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
|
|
|
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
|
|
|
}
|
|
|
log.info("转换后的Cypher语句:{}", cypherQueries.toString());
|
|
|
|
|
|
//执行CYPHER查询并汇总结果
|
|
|
List<RelationObject> relationObjects = new ArrayList<>();
|
|
|
if (!cypherQueries.isEmpty()) {
|
|
|
for (String cypher : cypherQueries) {
|
|
|
relationObjects.addAll(neo4jRepository.execute(cypher, null));
|
|
|
}
|
|
|
}
|
|
|
if (relationObjects.isEmpty()) {
|
|
|
log.info("cypher没有查询到结果,返回查无结果");
|
|
|
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
|
|
|
}
|
|
|
log.info("三元组数据: {}", relationObjects);
|
|
|
|
|
|
//生成回答
|
|
|
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
|
|
|
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
|
|
|
log.info("生成回答的提示词:{}", generateAnswerMessage);
|
|
|
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText())
|
|
|
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString()))
|
|
|
.concatWith(Flux.just("[END]"));
|
|
|
}
|
|
|
|
|
|
/**
|
|
|
*
|
|
|
* 分类查询意图
|
|
|
* @param query 问题
|
|
|
* @param intentions 意图列表
|
|
|
* @return
|
|
|
*/
|
|
|
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
|
|
|
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
|
|
|
return new ArrayList<>();
|
|
|
}
|
|
|
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
|
|
|
List<Intention> result = new ArrayList<>();
|
|
|
log.info("开始分类意图,query: {}, intentions size: {}", query, intentions.size());
|
|
|
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 150);
|
|
|
for (List<Intention> intentionList : intentionSplit) {
|
|
|
log.info("分类意图,query: {}, intentions size: {}", query, intentionList.size());
|
|
|
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
|
|
|
Map<String, Object> params = Map.of("query", query,
|
|
|
"intents", intents);
|
|
|
String format = StrUtil.format(prompt, params);
|
|
|
String call = aiCallService.call(format);
|
|
|
|
|
|
if (StrUtil.isEmpty(call)) {
|
|
|
return new ArrayList<>();
|
|
|
}
|
|
|
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
|
|
|
if (CollUtil.isEmpty(digests)) {
|
|
|
continue;
|
|
|
}
|
|
|
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
|
|
|
if (CollUtil.isNotEmpty(collect)) {
|
|
|
result.addAll(collect);
|
|
|
}
|
|
|
}
|
|
|
return result;
|
|
|
}
|
|
|
|
|
|
private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) {
|
|
|
if (CollUtil.isEmpty(relationObjects)) {
|
|
|
return new ArrayList<>();
|
|
|
}
|
|
|
List<AnswerDetailDTO> answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList());
|
|
|
if (CollUtil.isNotEmpty(answerDetailDTOList)){
|
|
|
List<String> truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList();
|
|
|
if (CollUtil.isEmpty(truncateIds)){
|
|
|
return answerDetailDTOList;
|
|
|
}
|
|
|
List<DocumentTruncation> documentTruncations = documentTruncationService.listByIds(truncateIds);
|
|
|
Map<String, String> contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent));
|
|
|
for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) {
|
|
|
answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId()));
|
|
|
}
|
|
|
}
|
|
|
return answerDetailDTOList;
|
|
|
}
|
|
|
}
|