You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
5.4 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.ChineseEnglishWords;
import com.supervision.pdfqaserver.domain.DomainMetadata;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.ChatService;
import com.supervision.pdfqaserver.service.ChineseEnglishWordsService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final ChineseEnglishWordsService chineseEnglishWordsService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
//拼装领域元数据
Map<String, String> chineseEnglishWordsMap = chineseEnglishWordsService.list().stream()
.collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord));
//分别得到sourceTyperelationtargetType的group by后的集合
List<String> sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList();
List<String> relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList();
List<String> targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList();
//将三个集合分别结合chineseEnglishWordsMap的key转化为value集合
List<String> sourceTypeListEnList = sourceTypeList.stream().map(chineseEnglishWordsMap::get).toList();
List<String> relationListEnList = relationList.stream().map(chineseEnglishWordsMap::get).toList();
List<String> targetTypeListEnList = targetTypeList.stream().map(chineseEnglishWordsMap::get).toList();
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = String.join(",", sourceTypeListEnList);
String relationListEn = String.join(",", relationListEnList);
String targetTypeListEn = String.join(",", targetTypeListEnList);
log.info("sourceTypeListEn: {}, relationListEn: {}, targetTypeListEn: {}", sourceTypeListEn, relationListEn, targetTypeListEn);
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
String cypherJsonStr = ollamaChatModel.call(textToCypherMessage.getText());
log.info(cypherJsonStr);
List<String> cypherQueries;
try {
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
.toList(String.class);
} catch (Exception e) {
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
return Flux.just("查无结果");
}
log.info("转换后的Cypher语句{}", cypherQueries.toString());
//执行CYPHER查询并汇总结果
List<RelationObject> relationObjects = new ArrayList<>();
if (!cypherQueries.isEmpty()) {
for (String cypher : cypherQueries) {
relationObjects.addAll(neo4jRepository.execute(cypher, null));
}
}
if (relationObjects.isEmpty()) {
return Flux.just("查无结果");
}
log.info("三元组数据: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
}
}