You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
8.3 KiB
Java

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
4 months ago
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.domain.DocumentTruncation;
import com.supervision.pdfqaserver.dto.AnswerDetailDTO;
import com.supervision.pdfqaserver.dto.neo4j.NodeDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO;
import com.supervision.pdfqaserver.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
4 months ago
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
4 months ago
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries";
private final AiCallService aiCallService;
private final DocumentTruncationService documentTruncationService;
private final TripleToCypherExecutor tripleToCypherExecutor;
@Override
public Flux<String> knowledgeQA(String userQuery) {
4 months ago
log.info("用户查询: {}", userQuery);
// 生成cypher语句
String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null);
log.info("生成CYPHER语句的消息{}", cypher);
if (StrUtil.isEmpty(cypher)){
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
// 执行cypher语句
List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher);
if (CollUtil.isEmpty(graphResult)){
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery));
4 months ago
log.info("生成回答的提示词:{}", generateAnswerMessage);
return aiCallService.stream(new Prompt(generateAnswerMessage))
.map(response -> response.getResult().getOutput().getText())
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(graphResult)).toString()))
.concatWith(Flux.just("[END]"));
}
private List<AnswerDetailDTO> convertToAnswerDetails(List<Map<String, Object>> graphResult) {
if (CollUtil.isEmpty(graphResult)){
return new ArrayList<>();
}
List<AnswerDetailDTO> answerDetailDTOS = new ArrayList<>();
for (Map<String, Object> map : graphResult) {
Long start = null;
Long end = null;
for (Map.Entry<String, Object> entry : map.entrySet()) {
// 先找到头节点和尾节点id
if (entry.getValue() instanceof RelationshipValueDTO value){
start = value.getStart();
end = value.getEnd();
break;
}
}
AnswerDetailDTO answerDetailDTO = new AnswerDetailDTO();
if (null == start) {
// 没有关系类型
for (Map.Entry<String, Object> entry : map.entrySet()) {
// 处理头节点
if(entry.getValue() instanceof NodeDTO nodeDTO){
Map<String, Object> properties = nodeDTO.getProperties();
if (StrUtil.isEmpty(answerDetailDTO.getSourceType())){
answerDetailDTO.setSourceName((String) properties.get("name"));
answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型
// 设置truncationId属性
answerDetailDTO.setTruncateId((String) properties.get("truncationId"));
}else {
answerDetailDTO.setTargetName((String) properties.get("name"));
answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型
}
}
}
answerDetailDTOS.add(answerDetailDTO);
}else {
// 有关系节点
for (Map.Entry<String, Object> entry : map.entrySet()) {
// 处理头节点
if(entry.getValue() instanceof NodeDTO nodeDTO){
if (start.equals(nodeDTO.getId())){
Map<String, Object> properties = nodeDTO.getProperties();
answerDetailDTO.setSourceName((String) properties.get("name"));
answerDetailDTO.setSourceType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是源类型
// 设置truncationId属性
answerDetailDTO.setTruncateId((String) properties.get("truncationId"));
}
if (end.equals(nodeDTO.getId())){
Map<String, Object> properties = nodeDTO.getProperties();
answerDetailDTO.setTargetName((String) properties.get("name"));
answerDetailDTO.setTargetType(CollUtil.getFirst(nodeDTO.getLabels())); // 假设第一个标签是目标类型
}
}
if (entry.getValue() instanceof RelationshipValueDTO value) {
// 处理关系
if (start.equals(value.getStart()) || end.equals(value.getEnd())) {
answerDetailDTO.setRelation(value.getType());
}
}
}
answerDetailDTOS.add(answerDetailDTO);
}
}
List<AnswerDetailDTO> distinct = new ArrayList<>();
if (CollUtil.isNotEmpty(answerDetailDTOS)){
//去重answerDetailDTOS
for (AnswerDetailDTO answerDetailDTO : answerDetailDTOS) {
boolean noned = distinct.stream().noneMatch(i ->
StrUtil.equals(i.getSourceName(), answerDetailDTO.getSourceName()) &&
StrUtil.equals(i.getTargetName(), answerDetailDTO.getTargetName()) &&
StrUtil.equals(i.getRelation(), answerDetailDTO.getRelation()) &&
StrUtil.equals(i.getSourceType(), answerDetailDTO.getSourceType()) &&
StrUtil.equals(i.getTargetType(), answerDetailDTO.getTargetType()) &&
StrUtil.equals(i.getTruncateId(), answerDetailDTO.getTruncateId())
);
if (noned){
distinct.add(answerDetailDTO);
}
}
List<String> truncateIds = distinct.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList();
if (CollUtil.isEmpty(truncateIds)){
return answerDetailDTOS;
}
List<DocumentTruncation> documentTruncations = documentTruncationService.listByIds(truncateIds);
Map<String, String> contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent));
for (AnswerDetailDTO answerDetailDTO : distinct) {
answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId()));
}
}
return distinct;
}
}