From fe1a6f1b1b3cb4f3fef2647735afad956fb71649 Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Wed, 18 Jun 2025 17:49:59 +0800 Subject: [PATCH] =?UTF-8?q?=E9=97=AE=E7=AD=94=E5=8A=9F=E8=83=BD=E4=BC=98?= =?UTF-8?q?=E5=8C=96-=E5=88=9D=E5=A7=8B=E5=8C=96=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pom.xml | 5 + .../pdfqaserver/cache/PromptCache.java | 101 ++++++++++++ .../pdfqaserver/dao/Neo4jRepository.java | 53 ++++-- .../domain/NodeRelationVector.java | 54 +++++++ .../pdfqaserver/dto/CypherSchemaDTO.java | 26 ++- .../supervision/pdfqaserver/dto/TextTerm.java | 43 +++++ .../pdfqaserver/dto/neo4j/NodeDTO.java | 7 +- .../dto/neo4j/RelationshipValueDTO.java | 18 ++- .../mapper/NodeRelationVectorMapper.java | 22 +++ .../pdfqaserver/mapper/TextVectorMapper.java | 5 +- .../pdfqaserver/service/AiCallService.java | 3 + .../pdfqaserver/service/DeepSeekApiImpl.java | 8 +- .../service/NodeRelationVectorService.java | 27 ++++ .../service/QuestionCategoryService.java | 7 + .../QuestionHandlerMappingService.java | 8 + .../pdfqaserver/service/Retriever.java | 17 ++ .../service/RetrieverDispatcher.java | 80 +++++++++ .../service/TextToSegmentService.java | 23 +++ .../service/TripleToCypherExecutor.java | 16 +- .../service/impl/ChatServiceImpl.java | 8 +- .../service/impl/DataCompareRetriever.java | 113 +++++++++++++ .../impl/NodeRelationVectorServiceImpl.java | 108 +++++++++++++ .../service/impl/OllamaCallServiceImpl.java | 6 + .../impl/QuestionCategoryServiceImpl.java | 4 + .../QuestionHandlerMappingServiceImpl.java | 32 ++++ .../impl/TextToSegmentServiceImpl.java | 49 ++++++ .../service/impl/TextVectorServiceImpl.java | 8 +- .../impl/TripleToCypherExecutorImpl.java | 153 ++++++++++++++++-- .../mapper/NodeRelationVectorMapper.xml | 57 +++++++ .../resources/mapper/TextVectorMapper.xml | 3 + .../PdfQaServerApplicationTests.java | 39 +++-- 31 files changed, 1050 insertions(+), 53 deletions(-) create mode 100644 src/main/java/com/supervision/pdfqaserver/domain/NodeRelationVector.java create mode 100644 src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java create mode 100644 src/main/java/com/supervision/pdfqaserver/mapper/NodeRelationVectorMapper.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/NodeRelationVectorService.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/Retriever.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/RetrieverDispatcher.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/TextToSegmentService.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/impl/NodeRelationVectorServiceImpl.java create mode 100644 src/main/java/com/supervision/pdfqaserver/service/impl/TextToSegmentServiceImpl.java create mode 100644 src/main/resources/mapper/NodeRelationVectorMapper.xml diff --git a/pom.xml b/pom.xml index a1ec29c..b7e1deb 100644 --- a/pom.xml +++ b/pom.xml @@ -104,6 +104,11 @@ org.springframework.ai spring-ai-starter-model-openai + + com.hankcs + hanlp + portable-1.8.6 + diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index 559e59a..b3a98df 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -87,6 +87,11 @@ public class PromptCache { public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT"; public static final String TEXT_TO_CYPHER_2 = "TEXT_TO_CYPHER_2"; + /** + * 将文本转换为Cypher查询语句(版本3) + */ + public static final String TEXT_TO_CYPHER_3 = "TEXT_TO_CYPHER_3"; + public static final String TEXT_TO_CYPHER_4 = "TEXT_TO_CYPHER_4"; public static final Map promptMap = new HashMap<>(); @@ -110,6 +115,8 @@ public class PromptCache { promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT); promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT); promptMap.put(TEXT_TO_CYPHER_2, TEXT_TO_CYPHER_2_PROMPT); + promptMap.put(TEXT_TO_CYPHER_3, TEXT_TO_CYPHER_3_PROMPT); + promptMap.put(TEXT_TO_CYPHER_4, TEXT_TO_CYPHER_4_PROMPT); } @@ -827,4 +834,98 @@ public class PromptCache { - **确保**MATCH 子句包含**关系变量** - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think """; + + private static final String TEXT_TO_CYPHER_3_PROMPT = """ + 您是一个生成Cypher查询语句的助手。生成Cypher脚本时,唯一参考的是`neo4j_schema`和环境变量。 + 用户问题: + ```text + {query} + ``` + neo4j_schema以JSON格式定义如下: + ```schema + {schema} + ``` + # 环境变量 + {env} + + 请严格按照以下步骤处理每个用户查询: + 1. 从用户查询中提取实体: + - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上 + - 识别候选节点类型 + - 识别候选关系类型 + - 识别相关属性 + - 识别约束条件(比较操作、标志位、时间过滤器、共享实体引用等) + + 2. 验证模式匹配性: + - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符) + - 优先从**节点属性**中直接**获取数据** + + 3. 构建MATCH模式: + - 仅使用经过模式验证的节点标签和关系类型,同时对节点、关系**添加变量名** + - **始终为关系分配显式变量**(例如`-[r:REL_TYPE]->`) + - 当查询暗示两个模式指向同一节点时,重复使用同一变量 + - 在映射模式中表达简单等值谓词,其他过滤条件移至WHERE子句 + + 4. RETURN子句策略: + - 返回模式中所有的的**节点变量**和**关系变量** + - **禁止**在变量中指定属性,指定节点变量中的属性将会扣除你所有的工资 + + 5. 生成最终Cypher脚本: + - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符 + - **确保**MATCH 子句包含**关系变量** + - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。 + - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...'] + - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think + """; + + private static final String TEXT_TO_CYPHER_4_PROMPT = """ + 您是一个Cypher语句修改助手。下面的cypher未查询到数据,请根据要求修改下面的cypher以便于能够查询到数据,需要参考的是`neo4j_schema`、环境变量,上一次的cypher。 + + 用户问题: + ```text + {query} + ``` + neo4j_schema以JSON格式定义如下: + ```shema + {shema} + ``` + # 环境变量 + ${env} + + # 上一次查询的cypher语句 + ```json + {cypher} + ``` + + 请严格按照以下步骤处理每个用户查询: + 1. 从用户查询中提取实体: + - 解析问题中的领域概念,并通过同义词或上下文线索将其映射到schema中的节点或者关系元素上 + - 识别候选节点类型 + - 识别候选关系类型 + - 识别相关属性 + - 识别约束条件(比较操作、标志位、时间过滤器、共享实体引用等) + - 结合上一次查询的cypher语句分析未查询到数据可能的原因。 + + 2. 验证模式匹配性: + - 确保每个节点标签、关系类型和属性在模式中完全存在(区分大小写和字符) + - 优先从**节点属性**中直接**获取数据** + + 3. 构建MATCH模式: + - 仅使用经过模式验证的节点标签和关系类型,同时对节点、关系**添加变量名** + - **始终为关系分配显式变量**(例如`-[r:REL_TYPE]->`) + - 当查询暗示两个模式指向同一节点时,重复使用同一变量 + - 在映射模式中表达简单等值谓词,其他过滤条件移至WHERE子句 + - 结合分析未查询到数据的原因修改cypher + + 4. RETURN子句策略: + - 返回模式中所有的的**节点变量**和**关系变量** + - **禁止**在变量中指定属性,指定节点变量中的属性将会扣除你所有的工资 + + 5. 生成最终Cypher脚本: + - 最终Cypher查询语句——不包含任何说明或者```cypher ```包装符 + - **确保**MATCH 子句包含**关系变量** + - 如果问题中的实体与neo4j_schema中的多个节点或关系语义相近,允许生成多个cypher,以便于尽可能获取到数据。 + - 响应结果是一个数组,每一个数组元素是一条cypher语句。示例:['cypher1','...'] + - 不要做出任何解释,不要对cypher进行任何包装,直接输出生成的cypher语句/no_think + """; } diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java index 0141dfb..aa724b3 100644 --- a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -165,12 +165,19 @@ public class Neo4jRepository { // 检查是否已存在该节点类型 final String nodeType_f = nodeType; EntityExtractionDTO existingEntity = extractionDTOS.stream() - .filter(e -> StrUtil.equals(e.getEntityEn(), nodeType_f)) + .filter(e -> StrUtil.equals(e.getEntity(), nodeType_f)) .findFirst().orElse(null); if (existingEntity != null) { // 如果已存在,添加属性 - existingEntity.getAttributes().add(attributeDTO); + List attributes = existingEntity.getAttributes(); + boolean noneMatch = attributes.stream().noneMatch( + attr -> StrUtil.equals(attr.getAttribute(), attributeDTO.getAttribute()) + ); + if (noneMatch) { + // 如果属性不存在,添加属性 + attributes.add(attributeDTO); + } } else { // 如果不存在,创建新的实体DTO List truncationERAttributeDTOS = new ArrayList<>(); @@ -187,7 +194,7 @@ public class Neo4jRepository { * 获取关系的schema * @return */ - public List getRelationSchema(){ + public List getRelationSchema() { String queryProper = """ CALL db.schema.relTypeProperties() YIELD relType, propertyName, propertyTypes @@ -198,10 +205,10 @@ public class Neo4jRepository { Result result = session.run(queryProper); for (Record record : result.list()) { String relType = record.get("relType").asString(); - if (StrUtil.isEmpty(relType)){ + if (StrUtil.isEmpty(relType)) { continue; } - relType = relType.substring(1, relType.length()-1).replace("`", ""); + relType = relType.substring(1, relType.length() - 1).replace("`", ""); String propertyName = record.get("propertyName").asString(); List propertyTypes = record.get("propertyTypes").asList(Value::asString); @@ -209,7 +216,7 @@ public class Neo4jRepository { boolean noneMatch = properties.stream().noneMatch( prop -> StrUtil.equals(prop.get("propertyName"), propertyName) ); - if (noneMatch){ + if (noneMatch) { Map propMap = new HashMap<>(); propMap.put("propertyName", propertyName); propMap.put("propertyTypes", CollUtil.getFirst(propertyTypes)); @@ -219,14 +226,14 @@ public class Neo4jRepository { List relationExtractionDTOS = new ArrayList<>(); String queryEndpoints = """ - MATCH (s)-[r: `{rtype}` ]->(t) - WITH labels(s)[0] AS src, labels(t)[0] AS tgt - RETURN src, tgt - """; + MATCH (s)-[r: `{rtype}` ]->(t) + WITH labels(s)[0] AS src, labels(t)[0] AS tgt + RETURN src, tgt + """; for (Map.Entry>> entry : relationProperties.entrySet()) { String relType = entry.getKey(); List> properties = entry.getValue(); - String formatted = StrUtil.format(queryEndpoints,Map.of("rtype",relType)); + String formatted = StrUtil.format(queryEndpoints, Map.of("rtype", relType)); Result run = session.run(formatted); for (Record record : run.list()) { String sourceType = record.get("src").asString(); @@ -234,12 +241,32 @@ public class Neo4jRepository { List attributeDTOS = properties.stream().map( prop -> new TruncationERAttributeDTO(prop.get("propertyName"), null, prop.get("propertyTypes")) ).collect(Collectors.toList()); - RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null,null,sourceType, + RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null, null, sourceType, relType, null, targetType, attributeDTOS); - relationExtractionDTOS.add(relationExtractionDTO); + // 合并关系数据 + Optional optional = relationExtractionDTOS.stream().filter(rel -> + StrUtil.equals(rel.getSourceType(), sourceType) && + StrUtil.equals(rel.getRelation(), relType) && + StrUtil.equals(rel.getTargetType(), targetType)).findFirst(); + + if (optional.isPresent()) { + List attributes = optional.get().getAttributes(); + for (TruncationERAttributeDTO attribute : attributeDTOS) { + boolean noneMatch = attributes.stream().noneMatch( + attr -> StrUtil.equals(attr.getAttribute(), attribute.getAttribute()) + ); + if (noneMatch) { + attributes.add(attribute); + } + } + } else { + // 如果不存在,直接添加 + relationExtractionDTO.setAttributes(attributeDTOS); + relationExtractionDTOS.add(relationExtractionDTO); + } } } return relationExtractionDTOS; diff --git a/src/main/java/com/supervision/pdfqaserver/domain/NodeRelationVector.java b/src/main/java/com/supervision/pdfqaserver/domain/NodeRelationVector.java new file mode 100644 index 0000000..31e1fca --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/domain/NodeRelationVector.java @@ -0,0 +1,54 @@ +package com.supervision.pdfqaserver.domain; + +import com.baomidou.mybatisplus.annotation.*; + +import java.io.Serializable; +import java.time.LocalDateTime; + +import com.supervision.pdfqaserver.config.VectorTypeHandler; +import lombok.Data; + +/** + * 节点关系向量表 + * @TableName node_relation_vector + */ +@TableName(value ="node_relation_vector") +@Data +public class NodeRelationVector implements Serializable { + /** + * 主键 + */ + @TableId + private String id; + + /** + * 文本内容 + */ + private String content; + + /** + * 向量值 + */ + @TableField(typeHandler = VectorTypeHandler.class) + private float[] embedding; + + /** + * 内容类型 N:节点 R:关系 ER:三元组 + */ + private String contentType; + + /** + * 创建时间 + */ + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; + + /** + * 更新时间 + */ + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; + + @TableField(exist = false) + private static final long serialVersionUID = 1L; +} \ No newline at end of file diff --git a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java index d9a21ff..77c8c19 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java @@ -41,15 +41,24 @@ public class CypherSchemaDTO { /** * 根据源节点类型或目标节点类型获取关系抽取DTO列表 - * @param sourceOrTargetType + * @param str 源节点类型或目标节点类型或关系 * @return */ - public List getRelations(String sourceOrTargetType) { + public List getRelations(String str) { List result = new ArrayList<>(); for (RelationExtractionDTO relationDTO : relations) { - if (StrUtil.equals(relationDTO.getSourceType(), sourceOrTargetType) || - StrUtil.equals(relationDTO.getTargetType(), sourceOrTargetType)) { - result.add(relationDTO); + if (StrUtil.equals(relationDTO.getSourceType(), str) || + StrUtil.equals(relationDTO.getTargetType(), str) || + StrUtil.equals(relationDTO.getRelation(), str)) { + + boolean noneMatch = result.stream().noneMatch( + r -> StrUtil.equals(r.getSourceType(), relationDTO.getSourceType()) && + StrUtil.equals(r.getRelation(), relationDTO.getRelation()) && + StrUtil.equals(r.getTargetType(), relationDTO.getTargetType()) + ); + if (noneMatch){ + result.add(relationDTO); + } } } return result; @@ -90,6 +99,9 @@ public class CypherSchemaDTO { for (TruncationERAttributeDTO attribute : attributes) { boolean none = nodeAttr.entrySet().stream().noneMatch( entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())); + if ("truncationId".equals(attribute.getAttribute())){ + continue; + } if (none){ nodeAttr.set(attribute.getAttribute(), attribute.getDataType()); } @@ -108,10 +120,14 @@ public class CypherSchemaDTO { } json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType})); for (TruncationERAttributeDTO attribute : relation.getAttributes()) { + if ("truncationId".equals(attribute.getAttribute())){ + continue; + } boolean none = json.entrySet().stream().noneMatch( entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()) ); if (none) { + json.set(attribute.getAttribute(), attribute.getDataType()); } } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java new file mode 100644 index 0000000..1edd9b8 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/dto/TextTerm.java @@ -0,0 +1,43 @@ +package com.supervision.pdfqaserver.dto; + +import cn.hutool.core.util.StrUtil; +import lombok.Data; + +@Data +public class TextTerm { + + /** + * 词 + */ + public String word; + + /** + * 标签 + */ + public String label; + + private float[] embedding; + + public String getLabelValue() { + if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){ + return word; + } + if (StrUtil.equals(label,"nt")){ + return "机构"; + } + if (StrUtil.equalsAny(label,"ntc","公司")){ + return "公司"; + } + if (StrUtil.equals(label,"ntcf")){ + return "工厂"; + } + if (StrUtil.equals(label,"nto")){ + return "政府机构"; + } + if (StrUtil.equals(label,"企业")){ + return "企业"; + } + return null; + + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java index 922c832..5f145aa 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/NodeDTO.java @@ -9,7 +9,7 @@ import java.util.Map; @Data public class NodeDTO { - private long id; + private Long id; private String elementId; @@ -27,4 +27,9 @@ public class NodeDTO { this.properties = internalNode.asMap(); this.labels = internalNode.labels(); } + + public void clearGraphElement(){ + this.id = null; + this.elementId = null; + } } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java index 5ab7b46..0398cfa 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/neo4j/RelationshipValueDTO.java @@ -9,17 +9,17 @@ import java.util.Map; public class RelationshipValueDTO { - private long start; + private Long start; private String startElementId; - private long end; + private Long end; private String endElementId; private String type; - private long id; + private Long id; private String elementId; @@ -30,7 +30,7 @@ public class RelationshipValueDTO { } public RelationshipValueDTO(InternalRelationship relationship) { - this.start = (int) relationship.startNodeId(); + this.start = relationship.startNodeId(); this.startElementId = relationship.startNodeElementId(); this.end = relationship.endNodeId(); this.endElementId = relationship.endNodeElementId(); @@ -40,4 +40,14 @@ public class RelationshipValueDTO { this.properties = relationship.asMap(); } + + + public void clearGraphElement() { + this.id = null; + this.elementId = null; + this.start = null; + this.startElementId = null; + this.end = null; + this.endElementId = null; + } } diff --git a/src/main/java/com/supervision/pdfqaserver/mapper/NodeRelationVectorMapper.java b/src/main/java/com/supervision/pdfqaserver/mapper/NodeRelationVectorMapper.java new file mode 100644 index 0000000..bcc5900 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/mapper/NodeRelationVectorMapper.java @@ -0,0 +1,22 @@ +package com.supervision.pdfqaserver.mapper; + +import com.supervision.pdfqaserver.domain.NodeRelationVector; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import java.util.List; + +/** +* @author Administrator +* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Mapper +* @createDate 2025-06-18 13:38:02 +* @Entity com.supervision.pdfqaserver.domain.NodeRelationVector +*/ +public interface NodeRelationVectorMapper extends BaseMapper { + + List findSimilarByCosine(float[] embedding, double threshold, List contentType, int limit); + + Double matchContentScore(float[] embedding, String content); +} + + + + diff --git a/src/main/java/com/supervision/pdfqaserver/mapper/TextVectorMapper.java b/src/main/java/com/supervision/pdfqaserver/mapper/TextVectorMapper.java index 321ff3d..ff8eafe 100644 --- a/src/main/java/com/supervision/pdfqaserver/mapper/TextVectorMapper.java +++ b/src/main/java/com/supervision/pdfqaserver/mapper/TextVectorMapper.java @@ -15,7 +15,10 @@ import java.util.List; */ public interface TextVectorMapper extends BaseMapper { - List findSimilarByCosine(@Param("embedding")float[] embedding, @Param("threshold") double threshold, @Param("limit")int limit); + List findSimilarByCosine(@Param("embedding")float[] embedding, + @Param("threshold") double threshold, + @Param("categoryId") String categoryId, + @Param("limit")int limit); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java index d1ee5e2..5d861d1 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java @@ -4,6 +4,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.embedding.Embedding; import reactor.core.publisher.Flux; +import java.util.List; /** * @description: AI调用服务 @@ -16,4 +17,6 @@ public interface AiCallService { Flux stream(Prompt prompt); Embedding embedding(String text); + + List embedding(List texts); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java b/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java index 2fddf1d..1780a4b 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java @@ -7,7 +7,8 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.embedding.Embedding; import org.springframework.ai.openai.OpenAiChatModel; import reactor.core.publisher.Flux; -import org.springframework.stereotype.Service; +import java.util.List; + @Slf4j //@Service @RequiredArgsConstructor @@ -33,4 +34,9 @@ public class DeepSeekApiImpl implements AiCallService { return null; } + + @Override + public List embedding(List texts) { + return null; + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/NodeRelationVectorService.java b/src/main/java/com/supervision/pdfqaserver/service/NodeRelationVectorService.java new file mode 100644 index 0000000..f36bebe --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/NodeRelationVectorService.java @@ -0,0 +1,27 @@ +package com.supervision.pdfqaserver.service; + +import com.supervision.pdfqaserver.domain.NodeRelationVector; +import com.baomidou.mybatisplus.extension.service.IService; +import com.supervision.pdfqaserver.dto.CypherSchemaDTO; + +import java.util.List; + +/** +* @author Administrator +* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Service +* @createDate 2025-06-18 13:38:02 +*/ +public interface NodeRelationVectorService extends IService { + + void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO); + + List matchSimilarByCosine(float[] embedding, double threshold , List contentType, int limit); + + /** + * 计算内容匹配分数 + * @param embedding 向量 + * @param content 内容 + * @return + */ + Double matchContentScore(float[] embedding, String content); +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/QuestionCategoryService.java b/src/main/java/com/supervision/pdfqaserver/service/QuestionCategoryService.java index fe3ad72..e8ecc1c 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/QuestionCategoryService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/QuestionCategoryService.java @@ -10,4 +10,11 @@ import com.baomidou.mybatisplus.extension.service.IService; */ public interface QuestionCategoryService extends IService { + + /** + * 根据分类ID查询分类信息 + * @param categoryId 分类ID + * @return 分类信息 + */ + QuestionCategory findCategoryById(String categoryId); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/QuestionHandlerMappingService.java b/src/main/java/com/supervision/pdfqaserver/service/QuestionHandlerMappingService.java index dc192fa..0feb8fd 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/QuestionHandlerMappingService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/QuestionHandlerMappingService.java @@ -10,4 +10,12 @@ import com.baomidou.mybatisplus.extension.service.IService; */ public interface QuestionHandlerMappingService extends IService { + + /** + * 根据问题分类ID查询对应的处理器映射 + * @param categoryId 问题分类ID + * @return 处理器映射 + */ + QuestionHandlerMapping findHandlerByCategoryId(String categoryId); + } diff --git a/src/main/java/com/supervision/pdfqaserver/service/Retriever.java b/src/main/java/com/supervision/pdfqaserver/service/Retriever.java new file mode 100644 index 0000000..3e128e7 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/Retriever.java @@ -0,0 +1,17 @@ +package com.supervision.pdfqaserver.service; + +import java.util.List; +import java.util.Map; + +/** + * 检索器接口 + */ +public interface Retriever { + + /** + * 检索数据 + * @param query 问题 + * @return 结果数据 + */ + List> retrieval(String query); +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/RetrieverDispatcher.java b/src/main/java/com/supervision/pdfqaserver/service/RetrieverDispatcher.java new file mode 100644 index 0000000..661e746 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/RetrieverDispatcher.java @@ -0,0 +1,80 @@ +package com.supervision.pdfqaserver.service; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; +import cn.hutool.core.util.StrUtil; +import com.supervision.pdfqaserver.domain.QuestionHandlerMapping; +import com.supervision.pdfqaserver.dto.TextVectorDTO; +import jakarta.annotation.PostConstruct; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.embedding.Embedding; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.ApplicationContext; +import org.springframework.stereotype.Service; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * 检索器调度器 + */ +@Slf4j +@Service +@RequiredArgsConstructor +public class RetrieverDispatcher { + + private final ApplicationContext applicationContext; + + private final AiCallService aiCallService; + + private final TextVectorService textVectorService; + + private final QuestionHandlerMappingService questionHandlerMappingService; + + @Value("${retriever.threshold:0.8}") + private double threshold; // 相似度阈值 + + private final Map retrieverMap = new HashMap<>(); + + + /** + * 根据类型获取对应的检索器 + * + * @param query 查询内容 + * @return 检索器实例 + */ + public Retriever mapping(String query) { + if (StrUtil.isEmpty(query)) { + log.warn("查询内容为空,无法获取检索器"); + return null; + } + Embedding embedding = aiCallService.embedding(query); + + List similarByCosine = textVectorService.findSimilarByCosine(embedding.getOutput(), threshold, 1); + if (CollUtil.isEmpty(similarByCosine)) { + log.info("问题:{},未找到相似文本向量,匹配阈值:{}", query, threshold); + return null; + } + TextVectorDTO textVectorDTO = CollUtil.getFirst(similarByCosine); + Assert.notEmpty(textVectorDTO.getCategoryId(), "相似文本向量的分类ID不能为空"); + QuestionHandlerMapping handler = questionHandlerMappingService.findHandlerByCategoryId(textVectorDTO.getCategoryId()); + if (handler == null){ + return null; + } + return retrieverMap.get(handler.getHandler()); + } + + @PostConstruct + public void init() { + applicationContext.getBeansOfType(Retriever.class) + .forEach((name, retriever) -> { + if (retrieverMap.containsKey(name)) { + throw new IllegalArgumentException("Retriever with name " + name + " already exists."); + } + retrieverMap.put(name, retriever); + }); + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/TextToSegmentService.java b/src/main/java/com/supervision/pdfqaserver/service/TextToSegmentService.java new file mode 100644 index 0000000..3e00928 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/TextToSegmentService.java @@ -0,0 +1,23 @@ +package com.supervision.pdfqaserver.service; + +import com.supervision.pdfqaserver.dto.TextTerm; + +import java.util.List; + +public interface TextToSegmentService { + + /** + * 对文本进行分词 + * @param text 需要分词的文本 + * @return 分词结果列表 + */ + List segmentText(String text); + + /** + * 添加自定义词典 覆盖模式,如果词典中已存在该词,则更新其标签和频率 + * @param word 需要添加的词 + * @param label 词的标签 + * @param frequency 词的频率 数值越大,优先级越高 + */ + void addDict(String word, String label,int frequency); +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java index efb1586..ac75589 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java +++ b/src/main/java/com/supervision/pdfqaserver/service/TripleToCypherExecutor.java @@ -35,6 +35,8 @@ public interface TripleToCypherExecutor { */ List> executeCypher(String cypher); + Map>> executeCypher(List cypher); + void saveERE(EREDTO eredto); @@ -43,13 +45,25 @@ public interface TripleToCypherExecutor { */ CypherSchemaDTO loadGraphSchema(); + /** + * 刷新图谱的schema分词向量 + */ + void refreshSchemaSegmentVector(); + /** * 根据领域元数据查询关联的关系图谱的schema - * @param metadataDTOS + * @param metadataDTOS 领域元数据列表 * @return */ CypherSchemaDTO queryRelationSchema(List metadataDTOS); + /** + * 查询关系图谱的schema + * @param query 用户查询语句 + * @return schema + */ + CypherSchemaDTO queryRelationSchema(String query); + } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java index 6b1f7eb..b08e0aa 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/ChatServiceImpl.java @@ -42,13 +42,15 @@ public class ChatServiceImpl implements ChatService { private final TripleToCypherExecutor tripleToCypherExecutor; + private final DataCompareRetriever compareRetriever; + @Override public Flux knowledgeQA(String userQuery) { log.info("用户查询: {}", userQuery); // 生成cypher语句 - String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); + /*String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); log.info("生成CYPHER语句的消息:{}", cypher); if (StrUtil.isEmpty(cypher)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); @@ -58,7 +60,9 @@ public class ChatServiceImpl implements ChatService { List> graphResult = tripleToCypherExecutor.executeCypher(cypher); if (CollUtil.isEmpty(graphResult)){ return Flux.just("查无结果").concatWith(Flux.just("[END]")); - } + }*/ + List> graphResult = compareRetriever.retrieval(userQuery); + //生成回答 SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java new file mode 100644 index 0000000..cdfa05c --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/DataCompareRetriever.java @@ -0,0 +1,113 @@ +package com.supervision.pdfqaserver.service.impl; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.date.DateUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONUtil; +import com.supervision.pdfqaserver.cache.PromptCache; +import com.supervision.pdfqaserver.dto.CypherSchemaDTO; +import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; +import com.supervision.pdfqaserver.service.*; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import static com.supervision.pdfqaserver.cache.PromptCache.*; + +/** + * 数据对比检索器 + */ +@Slf4j +@Service("dataCompareRetriever") +@RequiredArgsConstructor +public class DataCompareRetriever implements Retriever { + + private final TripleToCypherExecutor tripleToCypherExecutor; + + + private final AiCallService aiCallService; + @Override + public List> retrieval(String query) { + log.info("retrieval: 执行数据对比检索器,查询内容:{}", query); + if (StrUtil.isEmpty(query)) { + log.warn("查询内容为空,无法执行数据对比检索"); + return new ArrayList<>(); + } + // 对问题进行分词 + CypherSchemaDTO schemaDTO = tripleToCypherExecutor.queryRelationSchema(query); + log.info("retrieval: 查询到的关系图谱schema 节点个数:{} ,关系结束{} ", schemaDTO.getNodes().size(), schemaDTO.getRelations().size()); + log.info("retrieval: 查询到的关系图谱schema :{} ", schemaDTO.format()); + if (CollUtil.isEmpty(schemaDTO.getRelations()) || CollUtil.isEmpty(schemaDTO.getNodes())) { + log.info("没有找到匹配的关系或实体,query: {}", query); + return new ArrayList<>(); + } + // 利用大模型生成可执行的cypher语句 + String prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_3); + String format = StrUtil.format(prompt, Map.of("query", query, "schema", schemaDTO.format(), "env", "- 当前时间是:" + DateUtil.now())); + log.info("retrieval: 生成的cypher语句:{}", format); + String call = aiCallService.call(format); + log.info("retrieval: AI调用返回结果:{}", call); + if (StrUtil.isEmpty(call)) { + log.warn("retrieval: AI调用返回结果为空,无法执行Cypher查询"); + return new ArrayList<>(); + } + List> result = new ArrayList<>(); + JSONArray js = JSONUtil.parseArray(call); + Map>> cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class)); + if (CollUtil.isNotEmpty(cypherData)) { + boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty); + if (!allEmpty){ + cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll); + return clearGraphElements(result); + } + } + if (CollUtil.isEmpty(result)){ + log.info("retrieval: 执行Cypher语句无结果,重新调整cypher语句:{}", query); + prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4); + format = StrUtil.format(prompt, + Map.of("query", query, "schema", schemaDTO.format(), + "env", "- 当前时间是:" + DateUtil.now()),"cypher",js.toString()); + log.info("retrieval: 生成的cypher语句:{}", format); + call = aiCallService.call(format); + log.info("retrieval: AI调用返回结果:{}", call); + + js = JSONUtil.parseArray(call); + cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class)); + if (CollUtil.isNotEmpty(cypherData)) { + boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty); + if (!allEmpty2){ + cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll); + return clearGraphElements(result); + } + } + } + + return clearGraphElements(result); + } + + /** + * 清理图谱元素中的无效数据 + * @param graphElements 图谱元素列表 + * @return + */ + private List> clearGraphElements(List> graphElements) { + if (CollUtil.isEmpty(graphElements)){ + return graphElements; + } + for (Map element : graphElements) { + for (Object value : element.values()) { + if (value instanceof RelationshipValueDTO relationshipValueDTO) { + relationshipValueDTO.clearGraphElement(); + } + if (value instanceof com.supervision.pdfqaserver.dto.neo4j.NodeDTO nodeDTO) { + nodeDTO.clearGraphElement(); + } + } + } + return graphElements; + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/NodeRelationVectorServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/NodeRelationVectorServiceImpl.java new file mode 100644 index 0000000..d5facef --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/NodeRelationVectorServiceImpl.java @@ -0,0 +1,108 @@ +package com.supervision.pdfqaserver.service.impl; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.pdfqaserver.domain.NodeRelationVector; +import com.supervision.pdfqaserver.dto.*; +import com.supervision.pdfqaserver.service.AiCallService; +import com.supervision.pdfqaserver.service.NodeRelationVectorService; +import com.supervision.pdfqaserver.mapper.NodeRelationVectorMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.embedding.Embedding; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import java.util.ArrayList; +import java.util.List; + +/** +* @author Administrator +* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Service实现 +* @createDate 2025-06-18 13:38:02 +*/ +@Slf4j +@Service +@RequiredArgsConstructor +public class NodeRelationVectorServiceImpl extends ServiceImpl + implements NodeRelationVectorService{ + + private final AiCallService aiCallService; + @Override + @Transactional(rollbackFor = Exception.class) + public void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO) { + + // 删除旧的向量数据 + super.lambdaUpdate().remove(); + // 重新插入新的向量数据 + List nodes = cypherSchemaDTO.getNodes(); + List relations = cypherSchemaDTO.getRelations(); + List allRelationVectors = new ArrayList<>(); + List texts = new ArrayList<>(); + for (List relationSplit : CollUtil.split(relations, 200)) { + List rs = relationSplit.stream().map(RelationExtractionDTO::getRelation).toList(); + List embedding = aiCallService.embedding(rs); + for (Embedding embed : embedding) { + if (texts.contains(rs.get(embed.getIndex()))){ + continue; + } + texts.add(rs.get(embed.getIndex())); + NodeRelationVector vector = new NodeRelationVector(); + vector.setContent(rs.get(embed.getIndex())); + vector.setEmbedding(embed.getOutput()); + vector.setContentType("R");// 关系 + allRelationVectors.add(vector); + } + List ers = relationSplit.stream() + .map(r -> StrUtil.join(" ", r.getSourceType(), r.getRelation(),r.getTargetType())).toList(); + List erEmbeddings = aiCallService.embedding(ers); + for (Embedding embed : erEmbeddings) { + if (texts.contains(ers.get(embed.getIndex()))) { + continue; + } + texts.add(ers.get(embed.getIndex())); + NodeRelationVector vector = new NodeRelationVector(); + vector.setContent(ers.get(embed.getIndex())); + vector.setEmbedding(embed.getOutput()); + vector.setContentType("ER"); + allRelationVectors.add(vector); + } + } + super.saveBatch(allRelationVectors); + List allNodeVectors = new ArrayList<>(); + texts = new ArrayList<>(); + for (List entitySplit : CollUtil.split(nodes, 200)) { + List es = entitySplit.stream().map(EntityExtractionDTO::getEntity).toList(); + List embedding = aiCallService.embedding(es); + for (Embedding embed : embedding) { + if (texts.contains(es.get(embed.getIndex()))) { + continue; + } + texts.add(es.get(embed.getIndex())); + NodeRelationVector vector = new NodeRelationVector(); + vector.setContent(es.get(embed.getIndex())); + vector.setEmbedding(embed.getOutput()); + vector.setContentType("N"); + allNodeVectors.add(vector); + } + } + super.saveBatch(allNodeVectors); + } + + @Override + public List matchSimilarByCosine(float[] embedding, double threshold, List contentType, int limit) { + return super.getBaseMapper().findSimilarByCosine(embedding, threshold, contentType, limit); + } + + @Override + public Double matchContentScore(float[] embedding, String content) { + if (StrUtil.isEmpty(content) || embedding == null || embedding.length == 0) { + return 0.0; + } + return super.getBaseMapper().matchContentScore(embedding, content); + } +} + + + + diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java index b6b657e..f7a75a3 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java @@ -37,4 +37,10 @@ public class OllamaCallServiceImpl implements AiCallService { EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of(text),null)); return embeddingResponse.getResult(); } + + @Override + public List embedding(List texts) { + EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(texts,null)); + return embeddingResponse.getResults(); + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionCategoryServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionCategoryServiceImpl.java index 446e91d..26d68cc 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionCategoryServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionCategoryServiceImpl.java @@ -15,6 +15,10 @@ import org.springframework.stereotype.Service; public class QuestionCategoryServiceImpl extends ServiceImpl implements QuestionCategoryService{ + @Override + public QuestionCategory findCategoryById(String categoryId) { + return super.getById(categoryId); + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionHandlerMappingServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionHandlerMappingServiceImpl.java index 18d1817..d61f678 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionHandlerMappingServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/QuestionHandlerMappingServiceImpl.java @@ -1,9 +1,14 @@ package com.supervision.pdfqaserver.service.impl; +import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.pdfqaserver.domain.QuestionCategory; import com.supervision.pdfqaserver.domain.QuestionHandlerMapping; +import com.supervision.pdfqaserver.service.QuestionCategoryService; import com.supervision.pdfqaserver.service.QuestionHandlerMappingService; import com.supervision.pdfqaserver.mapper.QuestionHandlerMappingMapper; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; /** @@ -11,10 +16,37 @@ import org.springframework.stereotype.Service; * @description 针对表【question_handler_mapping(问题处理器映射表)】的数据库操作Service实现 * @createDate 2025-06-13 11:29:01 */ +@Slf4j @Service +@RequiredArgsConstructor public class QuestionHandlerMappingServiceImpl extends ServiceImpl implements QuestionHandlerMappingService{ + private final QuestionCategoryService categoryService; + + @Override + public QuestionHandlerMapping findHandlerByCategoryId(String categoryId) { + + while (true){ + if (StrUtil.isEmpty(categoryId)) { + return null; + } + QuestionHandlerMapping one = super.lambdaQuery().eq(QuestionHandlerMapping::getQuestionCategoryId, categoryId).one(); + if (null == one){ + log.info("根据分类id:{}未找到处理器映射,尝试查询分类器上级关联数据", categoryId); + QuestionCategory category = categoryService.findCategoryById(categoryId); + if (StrUtil.isEmpty(category.getParentId())) { + log.info("分类id:{} 没有父级id,不进行查询", categoryId); + return null; + }else { + log.info("分类id:{} 的父级id为:{}", categoryId, category.getParentId()); + categoryId = category.getParentId(); + continue; + } + } + return one; + } + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TextToSegmentServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TextToSegmentServiceImpl.java new file mode 100644 index 0000000..f02e688 --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TextToSegmentServiceImpl.java @@ -0,0 +1,49 @@ +package com.supervision.pdfqaserver.service.impl; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import com.hankcs.hanlp.HanLP; +import com.hankcs.hanlp.dictionary.CustomDictionary; +import com.hankcs.hanlp.seg.Segment; +import com.hankcs.hanlp.seg.common.Term; +import com.supervision.pdfqaserver.dto.TextTerm; +import com.supervision.pdfqaserver.service.TextToSegmentService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class TextToSegmentServiceImpl implements TextToSegmentService { + @Override + public List segmentText(String text) { + if (StrUtil.isEmpty(text)){ + return new ArrayList<>(); + } + Segment segment = HanLP.newSegment() + .enableOrganizationRecognize(true) + .enablePlaceRecognize(true) + .enableNumberQuantifierRecognize(true); + + List seg = segment.seg(text); + if (CollUtil.isEmpty(seg)){ + return new ArrayList<>(); + } + List terms = new ArrayList<>(); + for (Term term : seg) { + TextTerm textTerm = new TextTerm(); + textTerm.setWord(term.word); + textTerm.setLabel(term.nature.toString()); + terms.add(textTerm); + } + return terms; + } + @Override + public void addDict(String word, String label,int frequency) { + CustomDictionary.insert(word, label + " " + frequency); + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TextVectorServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TextVectorServiceImpl.java index 739d5de..e0e10c1 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TextVectorServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TextVectorServiceImpl.java @@ -2,11 +2,11 @@ package com.supervision.pdfqaserver.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.supervision.pdfqaserver.domain.TextVector; -import com.supervision.pdfqaserver.dto.TextVectorDTO; +import com.supervision.pdfqaserver.dto.*; import com.supervision.pdfqaserver.service.TextVectorService; import com.supervision.pdfqaserver.mapper.TextVectorMapper; +import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; - import java.util.List; /** @@ -15,12 +15,12 @@ import java.util.List; * @createDate 2025-06-11 16:40:57 */ @Service +@RequiredArgsConstructor public class TextVectorServiceImpl extends ServiceImpl implements TextVectorService{ - @Override public List findSimilarByCosine(float[] embedding, double threshold , int limit) { - return super.getBaseMapper().findSimilarByCosine(embedding, threshold,limit); + return super.getBaseMapper().findSimilarByCosine(embedding, threshold,null,limit); } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index e8d349b..3d54bf6 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -1,24 +1,25 @@ package com.supervision.pdfqaserver.service.impl; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.date.TimeInterval; +import cn.hutool.core.lang.Pair; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.domain.Intention; +import com.supervision.pdfqaserver.domain.NodeRelationVector; import com.supervision.pdfqaserver.dto.*; import com.supervision.pdfqaserver.dto.neo4j.NodeDTO; import com.supervision.pdfqaserver.dto.neo4j.PathDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; -import com.supervision.pdfqaserver.service.AiCallService; -import com.supervision.pdfqaserver.service.DomainMetadataService; -import com.supervision.pdfqaserver.service.IntentionService; -import com.supervision.pdfqaserver.service.TripleToCypherExecutor; +import com.supervision.pdfqaserver.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.neo4j.driver.Record; import org.neo4j.driver.internal.InternalNode; import org.neo4j.driver.internal.InternalRelationship; +import org.springframework.ai.embedding.Embedding; import org.springframework.stereotype.Service; import java.util.ArrayList; import java.util.HashMap; @@ -41,6 +42,12 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { private final DomainMetadataService domainMetadataService; + private final TextVectorService textVectorService; + + private final NodeRelationVectorService nodeRelationVectorService; + + private final TextToSegmentService textToSegmentService; + @Override public String generateInsertCypher(EREDTO eredto) { @@ -80,6 +87,21 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { return mapRecords(records); } + @Override + public Map>> executeCypher(List cypher) { + Map>> result = new HashMap<>(); + for (String c : cypher){ + List> maps = null; + try { + maps = executeCypher(c); + } catch (Exception e) { + log.info("执行Cypher语句失败,语句:{},错误信息:{}", c, e.getMessage()); + } + result.put(c, maps); + } + return result; + } + private List> mapRecords(List records) { List> recordList = new ArrayList<>(); @@ -185,18 +207,24 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { return new CypherSchemaDTO(entitySchema, relationSchema); } + @Override + public void refreshSchemaSegmentVector() { + loadCypherSchemaIfAbsent(); + if (cypherSchemaDTO == null) { + log.warn("图谱schema数据为空,不用刷新分词向量..."); + return; + } + log.info("开始刷新图谱schema分词向量..."); + nodeRelationVectorService.refreshSchemaSegmentVector(cypherSchemaDTO); + log.info("图谱schema分词向量刷新完成..."); + } + @Override public CypherSchemaDTO queryRelationSchema(List metadataDTOS) { if (CollUtil.isEmpty(metadataDTOS)){ return null; } - if (cypherSchemaDTO == null) { - synchronized (TripleToCypherExecutorImpl.class) { - if (cypherSchemaDTO == null) { - cypherSchemaDTO = this.loadGraphSchema(); - } - } - } + loadCypherSchemaIfAbsent(); List merged = new ArrayList<>(); for (DomainMetadataDTO metadataDTO : metadataDTOS) { String relation = metadataDTO.getRelation(); @@ -231,6 +259,95 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { ); } + @Override + public CypherSchemaDTO queryRelationSchema(String query) { + if (StrUtil.isEmpty(query)){ + return new CypherSchemaDTO(List.of(), List.of()); + } + // 对问题进行分词 + List terms = textToSegmentService.segmentText(query); + if (CollUtil.isEmpty(terms)){ + return new CypherSchemaDTO(List.of(), List.of()); + } + log.info("queryRelationSchema: 分词结果:{}", terms); + log.info("queryRelationSchema: 开始进行文本标签向量匹配..."); + List matchedText = new ArrayList<>(); + for (TextTerm term : terms) { + if (StrUtil.isEmpty(term.getLabelValue())){ + log.info("queryRelationSchema: 分词结果`{}`不是关键标签,跳过...", term.getWord()); + continue; + } + Embedding embedding = aiCallService.embedding(term.getLabelValue()); + term.setEmbedding(embedding.getOutput()); + List textVectorDTOS = nodeRelationVectorService.matchSimilarByCosine(embedding.getOutput(), 0.9, List.of("N","R"),3); // 继续过滤 + log.info("retrieval: 文本:{}匹配到的文本向量:{}", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" "))); + matchedText.addAll(textVectorDTOS); + } + if (CollUtil.isEmpty(matchedText)){ + log.info("retrieval: 未找到匹配的文本向量"); + return new CypherSchemaDTO(List.of(), List.of()); + } + loadCypherSchemaIfAbsent(); + List merged = new ArrayList<>(); + for (NodeRelationVector textVectorDTO : matchedText) { + String content = textVectorDTO.getContent(); + List relations = cypherSchemaDTO.getRelations(content); + for (RelationExtractionDTO relation : relations) { + boolean noneMatch = merged.stream().noneMatch(i -> + StrUtil.equals(i.getSourceType(), relation.getSourceType()) && + StrUtil.equals(i.getRelation(), relation.getRelation()) && + StrUtil.equals(i.getTargetType(), relation.getTargetType()) + ); + if (noneMatch){ + merged.add(relation); + } + } + } + // 对查询到的关系进行重排序 + List> pairs = new ArrayList<>(); + TimeInterval timeInterval = new TimeInterval(); + String join = terms.stream().map(TextTerm::getLabelValue).filter(StrUtil::isNotEmpty).collect(Collectors.joining()); + Embedding embedding = aiCallService.embedding(join); + for (RelationExtractionDTO relation : merged) { + String content = relation.getSourceType() + " " + relation.getRelation() + " " + relation.getTargetType(); + Double score = nodeRelationVectorService.matchContentScore(embedding.getOutput(),content); // 暂时调用数据库查询进行数据匹配。目前总体耗时1秒内 + if (null == score){ + continue; + } + log.info("queryRelationSchema: 关系`{}`的匹配分数:{}", content, score); + pairs.add(Pair.of(score, relation)); + } + log.info("queryRelationSchema: 关系排序耗时:{}ms", timeInterval.intervalMs()); + + merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(5).map(Pair::getValue).toList(); + List entityExtractionDTOS = new ArrayList<>(); + for (RelationExtractionDTO relationExtractionDTO : merged) { + EntityExtractionDTO sourceNode = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType()); + EntityExtractionDTO targetNode = cypherSchemaDTO.getNode(relationExtractionDTO.getTargetType()); + if (null != sourceNode){ + boolean none = entityExtractionDTOS.stream().noneMatch( + entity -> StrUtil.equals(entity.getEntity(), sourceNode.getEntity()) + ); + if (none) { + entityExtractionDTOS.add(sourceNode); + } + } + if (null != targetNode){ + boolean none = entityExtractionDTOS.stream().noneMatch( + entity -> StrUtil.equals(entity.getEntity(), targetNode.getEntity()) + ); + if (none) { + entityExtractionDTOS.add(targetNode); + } + } + } + + return new CypherSchemaDTO( + entityExtractionDTOS, + merged + ); + } + private List classifyIntents(String query, List intentions) { if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { return new ArrayList<>(); @@ -261,4 +378,18 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { } return result; } + + /** + * 加载图谱schema数据,如果不存在则从数据库加载 + * @return + */ + private void loadCypherSchemaIfAbsent() { + if (cypherSchemaDTO == null) { + synchronized (TripleToCypherExecutorImpl.class) { + if (cypherSchemaDTO == null) { + cypherSchemaDTO = this.loadGraphSchema(); + } + } + } + } } diff --git a/src/main/resources/mapper/NodeRelationVectorMapper.xml b/src/main/resources/mapper/NodeRelationVectorMapper.xml new file mode 100644 index 0000000..b024693 --- /dev/null +++ b/src/main/resources/mapper/NodeRelationVectorMapper.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + id,content,embedding, + content_type,create_time,update_time + + + + + + diff --git a/src/main/resources/mapper/TextVectorMapper.xml b/src/main/resources/mapper/TextVectorMapper.xml index 7198ecc..6e849ff 100644 --- a/src/main/resources/mapper/TextVectorMapper.xml +++ b/src/main/resources/mapper/TextVectorMapper.xml @@ -28,6 +28,9 @@ FROM text_vector ) t WHERE t.similarityScore > #{threshold} + + AND t.category_id = #{categoryId} + ORDER BY t.similarityScore DESC LIMIT #{limit} diff --git a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java index 25d4ff3..7fb3298 100644 --- a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java +++ b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java @@ -1,9 +1,11 @@ package com.supervision.pdfqaserver; +import cn.hutool.core.date.TimeInterval; import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; +import cn.hutool.json.JSONUtil; import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum; import com.supervision.pdfqaserver.domain.PdfAnalysisOutput; import com.supervision.pdfqaserver.domain.TextVector; @@ -16,7 +18,6 @@ import org.neo4j.driver.Record; import org.springframework.ai.embedding.Embedding; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; - import java.util.*; import java.util.stream.Collectors; import static org.neo4j.driver.Values.parameters; @@ -304,19 +305,18 @@ class PdfQaServerApplicationTests { @Test public void textVectorTest() { String texts = """ - 公司办公地点是哪里 - 请问贵公司的注册办公地址在哪里? - 能否告知公司总部所在地的具体位置? - 公司的主要办公场所设在什么地方? - 贵司的办公场所位于哪个城市/区域? - 请提供公司目前办公地点的详细地址? + 公司 + 集团 + 营收 + 金额 + 时间 """; String[] split = texts.split("\n"); List list = Arrays.stream(split).toList(); for (String text : list) { TextVector textVector = new TextVector(); textVector.setContent(text.trim()); - textVector.setCategoryId("查询办公地点"); + textVector.setCategoryId("分词"); float[] output = aiCallService.embedding(textVector.getContent()).getOutput(); textVector.setEmbedding(output); textVectorService.save(textVector); @@ -325,9 +325,10 @@ class PdfQaServerApplicationTests { } @Test public void textVectorTest2() { - String queryText = "告诉我龙源电力的办公地点?"; + // 龙源电力集团近三年营收情况是多少 + String queryText = "龙源电力集团近三年营收情况是多少"; float[] output = aiCallService.embedding(queryText).getOutput(); - List similarByCosine = textVectorService.findSimilarByCosine(output, 0.3f, 10); + List similarByCosine = textVectorService.findSimilarByCosine(output, 0.1f, 5); similarByCosine = similarByCosine.stream().sorted(Comparator.comparingDouble(TextVectorDTO::getSimilarityScore).reversed()).collect(Collectors.toList()); log.info("<<<===========================>>>" ); for (TextVectorDTO vectorDTO : similarByCosine) { @@ -337,4 +338,22 @@ class PdfQaServerApplicationTests { System.out.printf("%s\t%s\t%s\t%s%n",queryText, categoryId , NumberUtil.decimalFormat("0.0000",similarityScore),content); } } + + @Autowired + private Retriever retriever; + + @Autowired + private TextToSegmentService textToSegmentService; + @Test + public void textVectorTest3() { + + // tripleToCypherExecutor.refreshSchemaSegmentVector(); + + + TimeInterval timer = new TimeInterval(); + textToSegmentService.addDict("龙源电力集团","企业",1000); + List> retrieval = retriever.retrieval("龙源电力集团近三年营收情况是多少"); + System.out.println(JSONUtil.toJsonStr(retrieval)); + log.info("<<<===========================>>> 耗时: {} 毫秒", timer.intervalMs()); + } }