代码功能bug修复

优化问答功能
v_0.0.2
xueqingkun 4 months ago
parent ee4b4adb37
commit eaf043aa07

@ -86,6 +86,8 @@ public class PromptCache {
*/ */
public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT"; public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT";
public static final String TEXT_TO_CYPHER_2 = "TEXT_TO_CYPHER_2";
public static final Map<String, String> promptMap = new HashMap<>(); public static final Map<String, String> promptMap = new HashMap<>();
static { static {
@ -107,6 +109,7 @@ public class PromptCache {
promptMap.put(EXTRACT_INTENT_METADATA, EXTRACT_INTENT_METADATA_PROMPT); promptMap.put(EXTRACT_INTENT_METADATA, EXTRACT_INTENT_METADATA_PROMPT);
promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT); promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT);
promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT); promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT);
promptMap.put(TEXT_TO_CYPHER_2, TEXT_TO_CYPHER_2_PROMPT);
} }
@ -657,13 +660,9 @@ public class PromptCache {
3. 3.
```json
{"IntentTypeList": ["...", "..."]} {"IntentTypeList": ["...", "..."]}
```
- -
```json x {}
{}
```
3.使...... 3.使......
./no_think ./no_think
@ -685,8 +684,8 @@ public class PromptCache {
1. 1.
2. 2.
3. / 3. /
- type - type
- attributes - attributes
4. JSON使```json ```Markdown 4. JSON使```json ```Markdown
5. 使 5. 使
@ -694,15 +693,15 @@ public class PromptCache {
{ {
"source": { "source": {
"type": "实体类型1", "type": "实体类型1",
"attributes": ["属性1", "属性2"] "attributes": ["属性类型1", "属性类型2",....]
}, },
"relation": { "relation": {
"type": "关系类型", "type": "关系类型",
"attributes": [] "attributes": ["属性类型3"...]
}, },
"target": { "target": {
"type": "实体类型2", "type": "实体类型2",
"attributes": ["属性3"] "attributes": ["属性类型4"...]
}, },
"intent": "匹配的意图标签" "intent": "匹配的意图标签"
}, },
@ -814,4 +813,34 @@ public class PromptCache {
# #
{query} {query}
"""; """;
private static final String TEXT_TO_CYPHER_2_PROMPT = """
"You are a Cyphergenerating assistant. "
"Your sole reference for generating Cypher scripts is the `neo4j_schema` variable.\\n\\n"
"User question:\\n{question}\\n\\n"
"The schema is defined below in JSON format:\\n"
"{schema}\\n\\n"
"Follow these exact steps for every user query:\\n\\n"
"1. Extract Entities from User Query:\\n"
"- Parse the question for domain concepts and use synonyms or contextual cues to map them to schema elements.\\n"
"- Identify candidate **node types**.\\n"
"- Identify candidate **relationship types**.\\n"
"- Identify relevant **properties**.\\n"
"- Identify **constraints or conditions** (comparisons, flags, temporal filters, sharedentity references, etc.).\\n\\n"
"2. Validate Against the Schema:\\n"
"- Ensure every node label, relationship type, and property exists in the schema **exactly** (case and charactersensitive).\\n"
"- If any required element is missing, respond exactly:\\n"
' \\"I could not generate a Cypher script; the required information is not part of the Neo4j schema.\\"\\n\\n'
"3. Construct the MATCH Pattern:\\n"
"- Use only schemavalidated node labels and relationship types.\\n"
"- Reuse a single variable whenever the query implies that two patterns refer to the same node.\\n"
"- Express simple equality predicates in map patterns and move all other filters to a **WHERE** clause.\\n\\n"
"4. Return Clause Strategy:\\n"
"- RETURN every node and relationship mentioned, unless the user explicitly requests specific properties.\\n\\n"
"- The truncationId、name attribute of a node is very important, and each node needs to return truncationId、name .\\n\\n"
"5. Final Cypher Script Generation:\\n"
"- Respond with **only** the final Cypher query—no commentary or extra text.\\n"
"- Use OPTIONAL MATCH only if explicitly required by the user and supported by the schema.\\n"
""";
} }

@ -1,12 +1,16 @@
package com.supervision.pdfqaserver.dao; package com.supervision.pdfqaserver.dao;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.dto.EntityExtractionDTO;
import com.supervision.pdfqaserver.dto.RelationExtractionDTO;
import com.supervision.pdfqaserver.dto.TruncationERAttributeDTO;
import com.supervision.pdfqaserver.dto.neo4j.NodeData; import com.supervision.pdfqaserver.dto.neo4j.NodeData;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipData; import com.supervision.pdfqaserver.dto.neo4j.RelationshipData;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.neo4j.driver.Driver; import org.neo4j.driver.*;
import org.neo4j.driver.Result; import org.neo4j.driver.Record;
import org.neo4j.driver.Session;
import org.neo4j.driver.types.Node; import org.neo4j.driver.types.Node;
import org.neo4j.driver.types.Relationship; import org.neo4j.driver.types.Relationship;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@ -59,6 +63,20 @@ public class Neo4jRepository {
} }
/**
* Cypher
* @param cypher Cypher
* @param params
* @return List<Record>
*/
public List<Record> executeCypherNative(String cypher, Map<String, Object> params) {
try (Session session = driver.session()) {
Result run = session.run(cypher, params == null ? Collections.emptyMap() : params);
return run.list();
}
}
/** /**
* *
* uniqueKey * uniqueKey
@ -115,6 +133,113 @@ public class Neo4jRepository {
} }
} }
/**
* schema
* @return
*/
public List<EntityExtractionDTO> getNodeSchema(){
String query = """
CALL db.schema.nodeTypeProperties()
YIELD nodeType, propertyName, propertyTypes
RETURN nodeType, propertyName, propertyTypes
""";
try (Session session = driver.session()) {
List<EntityExtractionDTO> extractionDTOS = new ArrayList<>();
Result result = session.run(query);
for (Record record : result.list()) {
String nodeType = record.get("nodeType").asString();
if (StrUtil.isEmpty(nodeType)){
continue;
}
String propertyName = record.get("propertyName").asString();
List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString);
// 创建属性DTO
TruncationERAttributeDTO attributeDTO = new TruncationERAttributeDTO(propertyName, null, CollUtil.getFirst(propertyTypes));
// 检查是否已存在该节点类型
EntityExtractionDTO existingEntity = extractionDTOS.stream()
.filter(e -> StrUtil.equals(e.getEntityEn(), nodeType))
.findFirst().orElse(null);
if (existingEntity != null) {
// 如果已存在,添加属性
existingEntity.getAttributes().add(attributeDTO);
} else {
// 如果不存在创建新的实体DTO
List<TruncationERAttributeDTO> truncationERAttributeDTOS = new ArrayList<>();
truncationERAttributeDTOS.add(attributeDTO);
EntityExtractionDTO entityExtractionDTO = new EntityExtractionDTO(null,nodeType, null,truncationERAttributeDTOS);
extractionDTOS.add(entityExtractionDTO);
}
}
return extractionDTOS;
}
}
/**
* schema
* @return
*/
public List<RelationExtractionDTO> getRelationSchema(){
String queryProper = """
CALL db.schema.relTypeProperties()
YIELD relType, propertyName, propertyTypes
RETURN relType, propertyName, propertyTypes
""";
Map<String, List<Map<String, String>>> relationProperties = new HashMap<>();
try (Session session = driver.session()) {
Result result = session.run(queryProper);
for (Record record : result.list()) {
String relType = record.get("relType").asString();
if (StrUtil.isEmpty(relType)){
continue;
}
String propertyName = record.get("propertyName").asString();
List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString);
List<Map<String, String>> properties = relationProperties.computeIfAbsent(relType, k -> new ArrayList<>());
boolean noneMatch = properties.stream().noneMatch(
prop -> StrUtil.equals(prop.get("propertyName"), propertyName)
);
if (noneMatch){
Map<String, String> propMap = new HashMap<>();
propMap.put("propertyName", propertyName);
propMap.put("propertyTypes", CollUtil.getFirst(propertyTypes));
properties.add(propMap);
}
}
List<RelationExtractionDTO> relationExtractionDTOS = new ArrayList<>();
String queryEndpoints = """
MATCH (s)-[r:`{rtype}`]->(t)
WITH labels(s)[0] AS src, labels(t)[0] AS tgt
RETURN src, tgt
""";
for (Map.Entry<String, List<Map<String, String>>> entry : relationProperties.entrySet()) {
String relType = entry.getKey();
List<Map<String, String>> properties = entry.getValue();
Result run = session.run(queryEndpoints, parameters("rtype", relType));
for (Record record : run.list()) {
String sourceType = record.get("src").asString();
String targetType = record.get("tgt").asString();
List<TruncationERAttributeDTO> attributeDTOS = properties.stream().map(
prop -> new TruncationERAttributeDTO(prop.get("propertyName"), null, prop.get("propertyTypes"))
).collect(Collectors.toList());
RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null,null,sourceType,
relType,
null,
targetType,
attributeDTOS);
relationExtractionDTOS.add(relationExtractionDTO);
}
}
return relationExtractionDTOS;
}
}
private NodeData mapNode(Node node) { private NodeData mapNode(Node node) {
return new NodeData( return new NodeData(

@ -0,0 +1,117 @@
package com.supervision.pdfqaserver.dto;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import java.util.ArrayList;
import java.util.List;
/**
* CypherSchemaDTO
*/
public class CypherSchemaDTO {
private List<EntityExtractionDTO> nodes = new ArrayList<>();
private List<RelationExtractionDTO> relations = new ArrayList<>();
public CypherSchemaDTO(List<EntityExtractionDTO> nodes, List<RelationExtractionDTO> relations) {
this.nodes = nodes;
this.relations = relations;
}
/**
* DTO
* @param sourceType
* @param relation
* @param targetType
* @return
*/
public RelationExtractionDTO getRelation(String sourceType, String relation,String targetType) {
for (RelationExtractionDTO relationDTO : relations) {
if (StrUtil.equals(relationDTO.getSourceType(), sourceType) &&
StrUtil.equals(relationDTO.getRelation(), relation) &&
StrUtil.equals(relationDTO.getTargetType(), targetType)) {
return relationDTO;
}
}
return null;
}
/**
* DTO
* @param sourceOrTargetType
* @return
*/
public List<RelationExtractionDTO> getRelations(String sourceOrTargetType) {
List<RelationExtractionDTO> result = new ArrayList<>();
for (RelationExtractionDTO relationDTO : relations) {
if (StrUtil.equals(relationDTO.getSourceType(), sourceOrTargetType) ||
StrUtil.equals(relationDTO.getTargetType(), sourceOrTargetType)) {
result.add(relationDTO);
}
}
return result;
}
/**
* DTO
* @param entity
* @return
*/
public EntityExtractionDTO getNode(String entity) {
for (EntityExtractionDTO node : nodes) {
if (StrUtil.equals(node.getEntity(), entity)) {
return node;
}
}
return null;
}
public String format(){
JSONObject nodeJson = new JSONObject();
for (EntityExtractionDTO node : nodes) {
String entity = node.getEntity();
List<TruncationERAttributeDTO> attributes = node.getAttributes();
JSONObject nodeAttr = nodeJson.getJSONObject(entity);
if (nodeAttr == null) {
nodeAttr = new JSONObject();
nodeJson.set(entity, nodeAttr);
}
for (TruncationERAttributeDTO attribute : attributes) {
boolean none = nodeAttr.entrySet().stream().noneMatch(
entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()));
if (none){
nodeAttr.set(attribute.getAttribute(), attribute.getDataType());
}
}
}
JSONObject relJson = new JSONObject();
for (RelationExtractionDTO relation : relations) {
String sourceType = relation.getSourceType();
String targetType = relation.getTargetType();
String rela = relation.getRelation();
JSONObject json = relJson.getJSONObject(rela);
if (null == json) {
json = new JSONObject();
relJson.set(rela, json);
}
json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType}));
for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
boolean none = json.entrySet().stream().noneMatch(
entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())
);
if (none) {
json.set(attribute.getAttribute(), attribute.getDataType());
}
}
}
JSONObject object = new JSONObject()
.set("nodetypes", nodeJson)
.set("relationshiptypes", relJson);
return object.toString();
}
}

@ -0,0 +1,30 @@
package com.supervision.pdfqaserver.dto.neo4j;
import lombok.Data;
import org.neo4j.driver.internal.InternalNode;
import java.util.Collection;
import java.util.Map;
@Data
public class NodeDTO {
private long id;
private String elementId;
private Map<String, Object> properties;
private Collection<String> labels;
public NodeDTO() {
}
public NodeDTO(InternalNode internalNode) {
this.id = internalNode.id();
this.elementId = internalNode.elementId();
this.properties = internalNode.asMap();
this.labels = internalNode.labels();
}
}

@ -0,0 +1,41 @@
package com.supervision.pdfqaserver.dto.neo4j;
import lombok.Data;
import org.neo4j.driver.internal.InternalNode;
import org.neo4j.driver.internal.InternalRelationship;
import org.neo4j.driver.types.Node;
import org.neo4j.driver.types.Path;
import org.neo4j.driver.types.Relationship;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@Data
public class PathDTO {
private List<NodeDTO> nodes;
private List<RelationshipValueDTO> relationships;
public PathDTO() {
}
public PathDTO(Path path) {
Iterator<Node> nodeIterator = path.nodes().iterator();
List<NodeDTO> nodes = new ArrayList<>();
while (nodeIterator.hasNext()){
Node next = nodeIterator.next();
nodes.add(new NodeDTO((InternalNode) next));
}
this.nodes = nodes;
Iterator<Relationship> iterator = path.relationships().iterator();
List<RelationshipValueDTO> relationships = new ArrayList<>();
while (iterator.hasNext()){
relationships.add(new RelationshipValueDTO((InternalRelationship) iterator.next()));
}
this.relationships = relationships;
}
}

@ -0,0 +1,43 @@
package com.supervision.pdfqaserver.dto.neo4j;
import lombok.Data;
import org.neo4j.driver.internal.InternalRelationship;
import java.util.Map;
@Data
public class RelationshipValueDTO {
private long start;
private String startElementId;
private long end;
private String endElementId;
private String type;
private long id;
private String elementId;
private Map<String,Object> properties;
public RelationshipValueDTO() {
}
public RelationshipValueDTO(InternalRelationship relationship) {
this.start = (int) relationship.startNodeId();
this.startElementId = relationship.startNodeElementId();
this.end = relationship.endNodeId();
this.endElementId = relationship.endNodeElementId();
this.type = relationship.type();
this.id = relationship.id();
this.elementId = relationship.elementId();
this.properties = relationship.asMap();
}
}

@ -1,5 +1,9 @@
package com.supervision.pdfqaserver.service; package com.supervision.pdfqaserver.service;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
/** /**
* @description: AI * @description: AI
*/ */
@ -7,4 +11,6 @@ public interface AiCallService {
String call(String prompt); String call(String prompt);
Flux<ChatResponse> stream(Prompt prompt);
} }

@ -1,6 +1,10 @@
package com.supervision.pdfqaserver.service; package com.supervision.pdfqaserver.service;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import com.supervision.pdfqaserver.dto.DomainMetadataDTO;
import com.supervision.pdfqaserver.dto.EREDTO; import com.supervision.pdfqaserver.dto.EREDTO;
import java.util.List;
import java.util.Map;
/** /**
* Cypher * Cypher
@ -15,19 +19,37 @@ public interface TripleToCypherExecutor {
String generateInsertCypher(EREDTO eredto); String generateInsertCypher(EREDTO eredto);
/** /**
* Cypher * Cypher
* @param query * @param query
* @param domainCategoryId ID
* @return * @return
*/ */
String generateQueryCypher(String query); String generateQueryCypher(String query,String domainCategoryId);
/** /**
* Cypher * Cypher
* @param cypher * @param cypher
* @return * @return
*/ */
void executeCypher(String cypher); List<Map<String, Object>> executeCypher(String cypher);
void saveERE(EREDTO eredto); void saveERE(EREDTO eredto);
/**
* schema
*/
CypherSchemaDTO loadGraphSchema();
/**
* schema
* @param metadataDTOS
* @return
*/
CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS);
} }

@ -39,120 +39,43 @@ public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_QUERY = "query"; private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries"; private static final String CYPHER_QUERIES = "cypherQueries";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel; private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final AiCallService aiCallService; private final AiCallService aiCallService;
private final DocumentTruncationService documentTruncationService; private final DocumentTruncationService documentTruncationService;
private final TripleToCypherExecutor tripleToCypherExecutor;
private final IntentionService intentionService;
@Override @Override
public Flux<String> knowledgeQA(String userQuery) { public Flux<String> knowledgeQA(String userQuery) {
List<Intention> intentions = intentionService.listAllPassed(); log.info("用户查询: {}", userQuery);
List<Intention> relations = classifyIntents(userQuery, intentions); // 生成cypher语句
if (CollUtil.isEmpty(relations)){ String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null);
log.info("没有匹配到意图,返回查无结果"); log.info("生成CYPHER语句的消息{}", cypher);
if (StrUtil.isEmpty(cypher)){
return Flux.just("查无结果").concatWith(Flux.just("[END]")); return Flux.just("查无结果").concatWith(Flux.just("[END]"));
} }
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); // 执行cypher语句
if (CollUtil.isEmpty(domainMetadataDTOS)){ List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher);
log.info("没有匹配到领域元数据,返回查无结果"); if (CollUtil.isEmpty(graphResult)){
return Flux.just("查无结果").concatWith(Flux.just("[END]")); return Flux.just("查无结果").concatWith(Flux.just("[END]"));
} }
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(","));
String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(","));
String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(","));
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn,
PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST,
targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
log.info("生成CYPHER语句的消息{}", textToCypherMessage);
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText();
log.info(cypherJsonStr);
log.info(cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim());
cypherJsonStr = cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
List<String> cypherQueries;
try {
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
.toList(String.class);
} catch (Exception e) {
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
log.info("转换后的Cypher语句{}", cypherQueries.toString());
//执行CYPHER查询并汇总结果
List<RelationObject> relationObjects = new ArrayList<>();
if (!cypherQueries.isEmpty()) {
for (String cypher : cypherQueries) {
relationObjects.addAll(neo4jRepository.execute(cypher, null));
}
}
if (relationObjects.isEmpty()) {
log.info("cypher没有查询到结果返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
log.info("三元组数据: {}", relationObjects);
//生成回答 //生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery));
log.info("生成回答的提示词:{}", generateAnswerMessage); log.info("生成回答的提示词:{}", generateAnswerMessage);
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()) return aiCallService.stream(new Prompt(generateAnswerMessage))
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString())) .map(response -> response.getResult().getOutput().getText())
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(null)).toString()))
.concatWith(Flux.just("[END]")); .concatWith(Flux.just("[END]"));
} }
/**
*
*
* @param query
* @param intentions
* @return
*/
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>();
}
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
List<Intention> result = new ArrayList<>();
log.info("开始分类意图query: {}, intentions size: {}", query, intentions.size());
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 150);
for (List<Intention> intentionList : intentionSplit) {
log.info("分类意图query: {}, intentions size: {}", query, intentionList.size());
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
Map<String, Object> params = Map.of("query", query,
"intents", intents);
String format = StrUtil.format(prompt, params);
String call = aiCallService.call(format);
if (StrUtil.isEmpty(call)) {
return new ArrayList<>();
}
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
if (CollUtil.isEmpty(digests)) {
continue;
}
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
if (CollUtil.isNotEmpty(collect)) {
result.addAll(collect);
}
}
return result;
}
private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) { private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) {
if (CollUtil.isEmpty(relationObjects)) { if (CollUtil.isEmpty(relationObjects)) {

@ -209,6 +209,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
// 保存意图数据 // 保存意图数据
intentSize ++; intentSize ++;
index ++; index ++;
List<Intention> intentions = intentionService.batchSaveIfAbsent(intents, pdfInfo.getDomainCategoryId(), pdfId.toString()); List<Intention> intentions = intentionService.batchSaveIfAbsent(intents, pdfInfo.getDomainCategoryId(), pdfId.toString());
for (Intention intention : intentions) { for (Intention intention : intentions) {
List<DomainMetadataDTO> metadataDTOS = domainMetadataDTOS.stream() List<DomainMetadataDTO> metadataDTOS = domainMetadataDTOS.stream()

@ -5,46 +5,110 @@ import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.dto.TruncationERAttributeDTO; import com.supervision.pdfqaserver.domain.Intention;
import com.supervision.pdfqaserver.dto.EREDTO; import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.dto.EntityExtractionDTO; import com.supervision.pdfqaserver.dto.neo4j.NodeDTO;
import com.supervision.pdfqaserver.dto.RelationExtractionDTO; import com.supervision.pdfqaserver.dto.neo4j.PathDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO;
import com.supervision.pdfqaserver.service.AiCallService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import com.supervision.pdfqaserver.service.IntentionService;
import com.supervision.pdfqaserver.service.TripleToCypherExecutor; import com.supervision.pdfqaserver.service.TripleToCypherExecutor;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatModel; import org.neo4j.driver.Record;
import org.neo4j.driver.internal.InternalNode;
import org.neo4j.driver.internal.InternalRelationship;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
import static com.supervision.pdfqaserver.cache.PromptCache.ERE_TO_INSERT_CYPHER;
@Slf4j @Slf4j
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
private final OllamaChatModel ollamaChatModel;
private final Neo4jRepository neo4jRepository; private final Neo4jRepository neo4jRepository;
private final IntentionService intentionService;
private static volatile CypherSchemaDTO cypherSchemaDTO;
private final AiCallService aiCallService;
private final DomainMetadataService domainMetadataService;
@Override @Override
public String generateInsertCypher(EREDTO eredto) { public String generateInsertCypher(EREDTO eredto) {
String prompt = PromptCache.promptMap.get(ERE_TO_INSERT_CYPHER); String prompt = PromptCache.promptMap.get(ERE_TO_INSERT_CYPHER);
String call = ollamaChatModel.call(prompt); return aiCallService.call(prompt);
return call;
} }
@Override @Override
public String generateQueryCypher(String query) { public String generateQueryCypher(String query,String domainCategoryId) {
return null; List<Intention> intentions = intentionService.listAllPassed();
List<Intention> relations = classifyIntents(query, intentions);
if (CollUtil.isEmpty(relations)) {
log.info("没有找到匹配的意图query: {}", query);
return null;
}
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
CypherSchemaDTO schemaDTO = this.queryRelationSchema(domainMetadataDTOS);
String prompt = promptMap.get(TEXT_TO_CYPHER_2);
String format = StrUtil.format(prompt, Map.of("question", query, "schema", schemaDTO.format()));
return aiCallService.call(format);
} }
@Override @Override
public void executeCypher(String cypher) { public List<Map<String, Object>> executeCypher(String cypher) {
List<Record> records = neo4jRepository.executeCypherNative(cypher, null);
return mapRecords(records);
}
private List<Map<String, Object>> mapRecords(List<Record> records) {
List<Map<String, Object>> recordList = new ArrayList<>();
for (Record record : records) {
HashMap<String, Object> map = new HashMap<>();
for (String key : record.keys()) {
org.neo4j.driver.Value value = record.get(key);
String typeName = value.type().name();
if (typeName.equals("NULL")){
map.put(key,null);
}
if (StrUtil.equalsAny(typeName, "BOOLEAN","STRING", "NUMBER", "INTEGER", "FLOAT")){
// MATCH (n)-[r]-() where n.caseId= '1' RETURN n.recordId limit 10
map.put(key,value.asObject());
}
if (typeName.equals("PATH")){
// MATCH p=(n)-[*2]-() where n.caseId= '1' RETURN p limit 10
map.put(key,new PathDTO(value.asPath()));
}
if (typeName.equals("RELATIONSHIP")){
// MATCH (n)-[r]-() where n.caseId= '1' RETURN r limit 10
map.put(key,new RelationshipValueDTO((InternalRelationship) value.asRelationship()));
}
if (typeName.equals("LIST OF ANY?")){
List<RelationshipValueDTO> list = value.asList().stream()
.map(i -> new RelationshipValueDTO((InternalRelationship) i)).toList();
map.put(key,list);
}
if (typeName.equals("NODE")){
// MATCH (n)-[r]-() where n.caseId= '1' RETURN r limit 10
map.put(key,new NodeDTO((InternalNode) value.asNode()));
}
recordList.add(map);
}
}
return recordList;
} }
@Override @Override
@ -101,4 +165,101 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
} }
} }
} }
@Override
public CypherSchemaDTO loadGraphSchema() {
List<RelationExtractionDTO> relationSchema = neo4jRepository.getRelationSchema();
List<EntityExtractionDTO> entitySchema = neo4jRepository.getNodeSchema();
return new CypherSchemaDTO(entitySchema, relationSchema);
}
@Override
public CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS) {
if (CollUtil.isEmpty(metadataDTOS)){
return null;
}
if (cypherSchemaDTO == null) {
synchronized (TripleToCypherExecutorImpl.class) {
if (cypherSchemaDTO == null) {
cypherSchemaDTO = this.loadGraphSchema();
}
}
}
List<RelationExtractionDTO> merged = new ArrayList<>();
for (DomainMetadataDTO metadataDTO : metadataDTOS) {
String relation = metadataDTO.getRelation();
String sourceType = metadataDTO.getSourceType();
String targetType = metadataDTO.getTargetType();
if (StrUtil.isEmpty(relation) || StrUtil.isEmpty(sourceType) || StrUtil.isEmpty(targetType)){
log.warn("元数据中关系、源类型或目标类型为空无法查询关系schema: {}", metadataDTO);
continue;
}
RelationExtractionDTO rel = cypherSchemaDTO.getRelation(sourceType, relation, targetType);
if (null == rel){
continue;
}
List<RelationExtractionDTO> relSourceType = cypherSchemaDTO.getRelations(sourceType);
List<RelationExtractionDTO> relTargetType = cypherSchemaDTO.getRelations(targetType);
relSourceType.add(rel);
relSourceType.addAll(relTargetType);
for (RelationExtractionDTO relationExtractionDTO : relSourceType) {
boolean none = merged.stream().noneMatch(i -> StrUtil.equals(i.getRelation(), relationExtractionDTO.getRelation()) &&
StrUtil.equals(i.getSourceType(), relationExtractionDTO.getSourceType()) &&
StrUtil.equals(i.getTargetType(), relationExtractionDTO.getTargetType()));
if (none){
merged.add(relationExtractionDTO);
}
}
}
List<EntityExtractionDTO> entityExtractionDTOS = new ArrayList<>();
for (RelationExtractionDTO relationExtractionDTO : merged) {
EntityExtractionDTO node = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType());
if (null != node){
boolean none = entityExtractionDTOS.stream().noneMatch(
entity -> StrUtil.equals(entity.getEntity(), node.getEntity())
);
if (none) {
entityExtractionDTOS.add(node);
}
}
}
return new CypherSchemaDTO(
entityExtractionDTOS,
merged
);
}
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>();
}
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
List<Intention> result = new ArrayList<>();
log.info("开始分类意图query: {}, intentions size: {}", query, intentions.size());
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 200);
for (List<Intention> intentionList : intentionSplit) {
log.info("分类意图query: {}, intentions size: {}", query, intentionList.size());
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
Map<String, Object> params = Map.of("query", query,
"intents", intents);
String format = StrUtil.format(prompt, params);
String call = aiCallService.call(format);
if (StrUtil.isEmpty(call)) {
return new ArrayList<>();
}
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
if (CollUtil.isEmpty(digests)) {
continue;
}
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
if (CollUtil.isNotEmpty(collect)) {
result.addAll(collect);
}
}
return result;
}
} }

Loading…
Cancel
Save