You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
8.7 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.DocumentTruncation;
import com.supervision.pdfqaserver.domain.Intention;
import com.supervision.pdfqaserver.dto.AnswerDetailDTO;
import com.supervision.pdfqaserver.dto.DomainMetadataDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final AiCallService aiCallService;
private final DocumentTruncationService documentTruncationService;
private final IntentionService intentionService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
List<Intention> intentions = intentionService.listAllPassed();
List<Intention> relations = classifyIntents(userQuery, intentions);
if (CollUtil.isEmpty(relations)){
log.info("没有匹配到意图,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
if (CollUtil.isEmpty(domainMetadataDTOS)){
log.info("没有匹配到领域元数据,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(","));
String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(","));
String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(","));
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn,
PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST,
targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
log.info("生成CYPHER语句的消息{}", textToCypherMessage);
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText();
log.info(cypherJsonStr);
log.info(cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim());
cypherJsonStr = cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
List<String> cypherQueries;
try {
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
.toList(String.class);
} catch (Exception e) {
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
log.info("转换后的Cypher语句{}", cypherQueries.toString());
//执行CYPHER查询并汇总结果
List<RelationObject> relationObjects = new ArrayList<>();
if (!cypherQueries.isEmpty()) {
for (String cypher : cypherQueries) {
relationObjects.addAll(neo4jRepository.execute(cypher, null));
}
}
if (relationObjects.isEmpty()) {
log.info("cypher没有查询到结果返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
log.info("三元组数据: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
log.info("生成回答的提示词:{}", generateAnswerMessage);
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText())
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString()))
.concatWith(Flux.just("[END]"));
}
/**
*
* 分类查询意图
* @param query 问题
* @param intentions 意图列表
* @return
*/
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>();
}
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
List<Intention> result = new ArrayList<>();
log.info("开始分类意图query: {}, intentions size: {}", query, intentions.size());
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 150);
for (List<Intention> intentionList : intentionSplit) {
log.info("分类意图query: {}, intentions size: {}", query, intentionList.size());
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
Map<String, Object> params = Map.of("query", query,
"intents", intents);
String format = StrUtil.format(prompt, params);
String call = aiCallService.call(format);
if (StrUtil.isEmpty(call)) {
return new ArrayList<>();
}
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
if (CollUtil.isEmpty(digests)) {
continue;
}
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
if (CollUtil.isNotEmpty(collect)) {
result.addAll(collect);
}
}
return result;
}
private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) {
if (CollUtil.isEmpty(relationObjects)) {
return new ArrayList<>();
}
List<AnswerDetailDTO> answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList());
if (CollUtil.isNotEmpty(answerDetailDTOList)){
List<String> truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList();
if (CollUtil.isEmpty(truncateIds)){
return answerDetailDTOList;
}
List<DocumentTruncation> documentTruncations = documentTruncationService.listByIds(truncateIds);
Map<String, String> contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent));
for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) {
answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId()));
}
}
return answerDetailDTOList;
}
}