You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
8.7 KiB
Java

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
1 month ago
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.DocumentTruncation;
import com.supervision.pdfqaserver.domain.Intention;
import com.supervision.pdfqaserver.dto.AnswerDetailDTO;
import com.supervision.pdfqaserver.dto.DomainMetadataDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
1 month ago
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
1 month ago
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
1 month ago
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final AiCallService aiCallService;
private final DocumentTruncationService documentTruncationService;
private final IntentionService intentionService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
1 month ago
List<Intention> intentions = intentionService.listAllPassed();
List<Intention> relations = classifyIntents(userQuery, intentions);
if (CollUtil.isEmpty(relations)){
log.info("没有匹配到意图,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
if (CollUtil.isEmpty(domainMetadataDTOS)){
log.info("没有匹配到领域元数据,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
1 month ago
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(","));
String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(","));
String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(","));
1 month ago
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn,
PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST,
targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
1 month ago
log.info("生成CYPHER语句的消息{}", textToCypherMessage);
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText();
1 month ago
log.info(cypherJsonStr);
1 month ago
log.info(cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim());
cypherJsonStr = cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
1 month ago
List<String> cypherQueries;
try {
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
.toList(String.class);
} catch (Exception e) {
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
1 month ago
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
1 month ago
}
log.info("转换后的Cypher语句{}", cypherQueries.toString());
//执行CYPHER查询并汇总结果
List<RelationObject> relationObjects = new ArrayList<>();
if (!cypherQueries.isEmpty()) {
for (String cypher : cypherQueries) {
relationObjects.addAll(neo4jRepository.execute(cypher, null));
}
}
if (relationObjects.isEmpty()) {
log.info("cypher没有查询到结果返回查无结果");
1 month ago
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
1 month ago
log.info("三元组数据: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
1 month ago
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
1 month ago
log.info("生成回答的提示词:{}", generateAnswerMessage);
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText())
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString()))
.concatWith(Flux.just("[END]"));
}
/**
*
*
* @param query
* @param intentions
* @return
*/
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>();
}
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
List<Intention> result = new ArrayList<>();
log.info("开始分类意图query: {}, intentions size: {}", query, intentions.size());
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 150);
for (List<Intention> intentionList : intentionSplit) {
log.info("分类意图query: {}, intentions size: {}", query, intentionList.size());
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
Map<String, Object> params = Map.of("query", query,
"intents", intents);
String format = StrUtil.format(prompt, params);
String call = aiCallService.call(format);
if (StrUtil.isEmpty(call)) {
return new ArrayList<>();
}
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
if (CollUtil.isEmpty(digests)) {
continue;
}
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
if (CollUtil.isNotEmpty(collect)) {
result.addAll(collect);
}
}
return result;
}
private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) {
if (CollUtil.isEmpty(relationObjects)) {
return new ArrayList<>();
}
List<AnswerDetailDTO> answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList());
if (CollUtil.isNotEmpty(answerDetailDTOList)){
List<String> truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList();
if (CollUtil.isEmpty(truncateIds)){
return answerDetailDTOList;
}
List<DocumentTruncation> documentTruncations = documentTruncationService.listByIds(truncateIds);
Map<String, String> contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent));
for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) {
answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId()));
}
}
return answerDetailDTOList;
}
}