|
|
package com.supervision.pdfqaserver.service.impl;
|
|
|
|
|
|
import cn.hutool.json.JSONObject;
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
import com.supervision.pdfqaserver.cache.PromptCache;
|
|
|
import com.supervision.pdfqaserver.dao.Neo4jRepository;
|
|
|
import com.supervision.pdfqaserver.domain.ChineseEnglishWords;
|
|
|
import com.supervision.pdfqaserver.domain.DomainMetadata;
|
|
|
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
|
|
|
import com.supervision.pdfqaserver.service.ChatService;
|
|
|
import com.supervision.pdfqaserver.service.ChineseEnglishWordsService;
|
|
|
import com.supervision.pdfqaserver.service.DomainMetadataService;
|
|
|
import lombok.RequiredArgsConstructor;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.messages.Message;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
|
|
import org.springframework.ai.ollama.OllamaChatModel;
|
|
|
import org.springframework.stereotype.Service;
|
|
|
import reactor.core.publisher.Flux;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
|
|
|
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
|
|
|
|
|
|
@Slf4j
|
|
|
@Service
|
|
|
@RequiredArgsConstructor
|
|
|
public class ChatServiceImpl implements ChatService {
|
|
|
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
|
|
|
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
|
|
|
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
|
|
|
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
|
|
|
private static final String PROMPT_PARAM_QUERY = "query";
|
|
|
private static final String CYPHER_QUERIES = "cypherQueries";
|
|
|
|
|
|
private final Neo4jRepository neo4jRepository;
|
|
|
|
|
|
private final OllamaChatModel ollamaChatModel;
|
|
|
private final DomainMetadataService domainMetadataService;
|
|
|
private final ChineseEnglishWordsService chineseEnglishWordsService;
|
|
|
|
|
|
@Override
|
|
|
public Flux<String> knowledgeQA(String userQuery) {
|
|
|
//拼装领域元数据
|
|
|
Map<String, String> chineseEnglishWordsMap = chineseEnglishWordsService.list().stream()
|
|
|
.collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord));
|
|
|
|
|
|
//分别得到sourceType,relation,targetType的group by后的集合
|
|
|
List<String> sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList();
|
|
|
List<String> relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList();
|
|
|
List<String> targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList();
|
|
|
|
|
|
//将三个集合分别结合chineseEnglishWordsMap的key转化为value集合
|
|
|
List<String> sourceTypeListEnList = sourceTypeList.stream().map(chineseEnglishWordsMap::get).toList();
|
|
|
List<String> relationListEnList = relationList.stream().map(chineseEnglishWordsMap::get).toList();
|
|
|
List<String> targetTypeListEnList = targetTypeList.stream().map(chineseEnglishWordsMap::get).toList();
|
|
|
|
|
|
//将三个集合分别转换为英文逗号分隔的字符串
|
|
|
String sourceTypeListEn = String.join(",", sourceTypeListEnList);
|
|
|
String relationListEn = String.join(",", relationListEnList);
|
|
|
String targetTypeListEn = String.join(",", targetTypeListEnList);
|
|
|
log.info("sourceTypeListEn: {}, relationListEn: {}, targetTypeListEn: {}", sourceTypeListEn, relationListEn, targetTypeListEn);
|
|
|
|
|
|
//LLM生成CYPHER
|
|
|
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
|
|
|
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
|
|
|
String cypherJsonStr = ollamaChatModel.call(textToCypherMessage.getText());
|
|
|
log.info(cypherJsonStr);
|
|
|
List<String> cypherQueries;
|
|
|
try {
|
|
|
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
|
|
|
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
|
|
|
.toList(String.class);
|
|
|
} catch (Exception e) {
|
|
|
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
|
|
|
return Flux.just("查无结果");
|
|
|
}
|
|
|
log.info("转换后的Cypher语句:{}", cypherQueries.toString());
|
|
|
|
|
|
//执行CYPHER查询并汇总结果
|
|
|
List<RelationObject> relationObjects = new ArrayList<>();
|
|
|
if (!cypherQueries.isEmpty()) {
|
|
|
for (String cypher : cypherQueries) {
|
|
|
relationObjects.addAll(neo4jRepository.execute(cypher, null));
|
|
|
}
|
|
|
}
|
|
|
if (relationObjects.isEmpty()) {
|
|
|
return Flux.just("查无结果");
|
|
|
}
|
|
|
log.info("三元组数据: {}", relationObjects);
|
|
|
|
|
|
//生成回答
|
|
|
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
|
|
|
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
|
|
|
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
|
|
|
}
|
|
|
}
|