diff --git a/pom.xml b/pom.xml index dc143d7..a1ec29c 100644 --- a/pom.xml +++ b/pom.xml @@ -100,6 +100,10 @@ commonmark-ext-gfm-tables 0.21.0 + + org.springframework.ai + spring-ai-starter-model-openai + diff --git a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java index d355d8b..9ff1499 100644 --- a/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java +++ b/src/main/java/com/supervision/pdfqaserver/cache/PromptCache.java @@ -838,7 +838,6 @@ x {} "- Express simple equality predicates in map patterns and move all other filters to a **WHERE** clause.\\n\\n" "4. Return Clause Strategy:\\n" "- RETURN every node and relationship mentioned, unless the user explicitly requests specific properties.\\n\\n" - "- The truncationId、name attribute of a node is very important, and each node needs to return truncationId、name .\\n\\n" "5. Final Cypher Script Generation:\\n" "- Respond with **only** the final Cypher query—no commentary or extra text.\\n" "- Use OPTIONAL MATCH only if explicitly required by the user and supported by the schema.\\n" diff --git a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java index 1ccd977..0141dfb 100644 --- a/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java +++ b/src/main/java/com/supervision/pdfqaserver/dao/Neo4jRepository.java @@ -9,6 +9,7 @@ import com.supervision.pdfqaserver.dto.neo4j.NodeData; import com.supervision.pdfqaserver.dto.neo4j.RelationObject; import com.supervision.pdfqaserver.dto.neo4j.RelationshipData; import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.neo4j.driver.*; import org.neo4j.driver.Record; import org.neo4j.driver.types.Node; @@ -20,6 +21,7 @@ import java.util.stream.Collectors; import java.util.stream.StreamSupport; import static org.neo4j.driver.Values.parameters; +@Slf4j @Repository @RequiredArgsConstructor public class Neo4jRepository { @@ -153,6 +155,7 @@ public class Neo4jRepository { if (StrUtil.isEmpty(nodeType)){ continue; } + nodeType = nodeType.substring(1, nodeType.length()-1).replace("`", ""); String propertyName = record.get("propertyName").asString(); List propertyTypes = record.get("propertyTypes").asList(Value::asString); @@ -160,8 +163,9 @@ public class Neo4jRepository { TruncationERAttributeDTO attributeDTO = new TruncationERAttributeDTO(propertyName, null, CollUtil.getFirst(propertyTypes)); // 检查是否已存在该节点类型 + final String nodeType_f = nodeType; EntityExtractionDTO existingEntity = extractionDTOS.stream() - .filter(e -> StrUtil.equals(e.getEntityEn(), nodeType)) + .filter(e -> StrUtil.equals(e.getEntityEn(), nodeType_f)) .findFirst().orElse(null); if (existingEntity != null) { @@ -197,6 +201,7 @@ public class Neo4jRepository { if (StrUtil.isEmpty(relType)){ continue; } + relType = relType.substring(1, relType.length()-1).replace("`", ""); String propertyName = record.get("propertyName").asString(); List propertyTypes = record.get("propertyTypes").asList(Value::asString); @@ -214,14 +219,15 @@ public class Neo4jRepository { List relationExtractionDTOS = new ArrayList<>(); String queryEndpoints = """ - MATCH (s)-[r:`{rtype}`]->(t) + MATCH (s)-[r: `{rtype}` ]->(t) WITH labels(s)[0] AS src, labels(t)[0] AS tgt RETURN src, tgt """; for (Map.Entry>> entry : relationProperties.entrySet()) { String relType = entry.getKey(); List> properties = entry.getValue(); - Result run = session.run(queryEndpoints, parameters("rtype", relType)); + String formatted = StrUtil.format(queryEndpoints,Map.of("rtype",relType)); + Result run = session.run(formatted); for (Record record : run.list()) { String sourceType = record.get("src").asString(); String targetType = record.get("tgt").asString(); diff --git a/src/main/java/com/supervision/pdfqaserver/domain/ErAttribute.java b/src/main/java/com/supervision/pdfqaserver/domain/ErAttribute.java index 26380b6..fc1e063 100644 --- a/src/main/java/com/supervision/pdfqaserver/domain/ErAttribute.java +++ b/src/main/java/com/supervision/pdfqaserver/domain/ErAttribute.java @@ -38,6 +38,12 @@ public class ErAttribute implements Serializable { */ private String erType; + /** + * 节点关系标签 + */ + private String erLabel; + + /** * 创建时间 diff --git a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java index f9d603d..d9a21ff 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/CypherSchemaDTO.java @@ -55,6 +55,14 @@ public class CypherSchemaDTO { return result; } + public List getNodes() { + return nodes; + } + + public List getRelations() { + return relations; + } + /** * 根据实体名获取关系抽取DTO列表 * @param entity diff --git a/src/main/java/com/supervision/pdfqaserver/dto/DomainMetadataDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/DomainMetadataDTO.java index b332993..a11ec43 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/DomainMetadataDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/DomainMetadataDTO.java @@ -80,14 +80,14 @@ public class DomainMetadataDTO { if (StrUtil.equals(erAttribute.getDomainMetadataId(),this.id)){ if(StrUtil.equals(erAttribute.getErType(),"1")){ // 节点数据 - if (StrUtil.equals(erAttribute.getAttrName(),this.sourceType)) { + if (StrUtil.equals(erAttribute.getErLabel(),this.sourceType)) { this.sourceAttributes.add(new ERAttributeDTO(erAttribute)); } - if (StrUtil.equals(erAttribute.getAttrName(),this.targetType)) { + if (StrUtil.equals(erAttribute.getErLabel(),this.targetType)) { this.targetAttributes.add(new ERAttributeDTO(erAttribute)); } }else { - if (StrUtil.equals(erAttribute.getAttrName(),this.relation)) { + if (StrUtil.equals(erAttribute.getErLabel(),this.relation)) { this.relationAttributes.add(new ERAttributeDTO(erAttribute)); } } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/ERAttributeDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/ERAttributeDTO.java index 6a2c84c..93bd007 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/ERAttributeDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/ERAttributeDTO.java @@ -23,6 +23,11 @@ public class ERAttributeDTO { */ private String attrName; + /** + * 节点、关系标签 + */ + private String erLabel; + /** * 属性值类型 */ @@ -37,14 +42,6 @@ public class ERAttributeDTO { public ERAttributeDTO() { } - public ERAttributeDTO(String id, String domainMetadataId, String erName, String attrName, String attrValueType, String erType) { - this.id = id; - this.domainMetadataId = domainMetadataId; - this.erName = erName; - this.attrName = attrName; - this.attrValueType = attrValueType; - this.erType = erType; - } public ERAttributeDTO(String attrName) { this.attrName = attrName; @@ -56,6 +53,7 @@ public class ERAttributeDTO { this.attrName = erAttribute.getAttrName(); this.attrValueType = erAttribute.getAttrValueType(); this.erType = erAttribute.getErType(); + this.erLabel = erAttribute.getErLabel(); } public ErAttribute toErAttribute() { @@ -65,6 +63,7 @@ public class ERAttributeDTO { erAttribute.setAttrName(this.attrName); erAttribute.setAttrValueType(this.attrValueType); erAttribute.setErType(this.erType); + erAttribute.setErLabel(this.erLabel); return erAttribute; } } diff --git a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java index bf97460..1dbbc70 100644 --- a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java +++ b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java @@ -232,6 +232,23 @@ public class EREDTO { } + public void setEn() { + for (EntityExtractionDTO entity : entities) { + entity.setEntityEn(entity.getEntity()); + for (TruncationERAttributeDTO attribute : entity.getAttributes()) { + attribute.setAttributeEn(attribute.getAttribute()); + } + } + for (RelationExtractionDTO relation : relations) { + relation.setRelationEn(relation.getRelation()); + relation.setSourceTypeEn(relation.getSourceType()); + relation.setTargetTypeEn(relation.getTargetType()); + for (TruncationERAttributeDTO attribute : relation.getAttributes()) { + attribute.setAttributeEn(attribute.getAttribute()); + } + } + } + private void setAttributeEn(TruncationERAttributeDTO attribute, List wordsList) { if (null == attribute || CollUtil.isEmpty(wordsList)){ return; diff --git a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java index b82861f..523474a 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/AiCallService.java @@ -13,4 +13,6 @@ public interface AiCallService { String call(String prompt); Flux stream(Prompt prompt); + + abstract void embedding(String text); } diff --git a/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java b/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java new file mode 100644 index 0000000..a2cf87b --- /dev/null +++ b/src/main/java/com/supervision/pdfqaserver/service/DeepSeekApiImpl.java @@ -0,0 +1,34 @@ +package com.supervision.pdfqaserver.service; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.openai.OpenAiChatModel; +import reactor.core.publisher.Flux; +import org.springframework.stereotype.Service; +@Slf4j +@Service +@RequiredArgsConstructor +public class DeepSeekApiImpl implements AiCallService { + private final OpenAiChatModel ollamaChatModel; + @Override + public String call(String prompt) { + + if (prompt.endsWith("./no_think")){ + prompt = prompt.replace("./no_think", ""); + } + prompt = prompt.replace("./no_think", ""); + return ollamaChatModel.call(prompt); + } + + @Override + public Flux stream(Prompt prompt) { + return ollamaChatModel.stream(prompt); + } + + @Override + public void embedding(String text) { + + } +} diff --git a/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java b/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java index 0f95dd7..a5b579b 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java +++ b/src/main/java/com/supervision/pdfqaserver/service/KnowledgeGraphService.java @@ -45,7 +45,7 @@ public interface KnowledgeGraphService { * @param eredtoList */ - void generateGraphSimple(List eredtoList); + void generateGraphSimple(List eredtoList); List truncateERE(List truncateDTOS); diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/DomainMetadataServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/DomainMetadataServiceImpl.java index f05d911..65fefff 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/DomainMetadataServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/DomainMetadataServiceImpl.java @@ -86,6 +86,7 @@ public class DomainMetadataServiceImpl extends ServiceImpl erAttributes = this.listByDomainMetadataId(domainMetadataId); boolean exists = erAttributes.stream().anyMatch(item -> StrUtil.equals(item.getAttrName(), erAttribute.getAttrName()) - && StrUtil.equals(item.getAttrValueType(), erAttribute.getAttrValueType())); + && StrUtil.equals(item.getErLabel(), erAttribute.getErLabel())); if (exists){ log.info("属性已存在,{},不进行保存...", erAttribute.getAttrName()); return; diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java index 370f35e..9f18a31 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java @@ -410,6 +410,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService { if (CollUtil.isEmpty(eredto.getEntities()) && CollUtil.isEmpty(eredto.getRelations())){ continue; } + eredto.setEn(); try { tripleToCypherExecutor.saveERE(eredto); } catch (Exception e) { diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java index 40a46d9..a3de512 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/OllamaCallServiceImpl.java @@ -3,8 +3,15 @@ package com.supervision.pdfqaserver.service.impl; import com.supervision.pdfqaserver.service.AiCallService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.*; import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.OllamaEmbeddingModel; +import org.springframework.ai.ollama.api.OllamaOptions; import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; +import java.util.List; @Slf4j @Service @@ -12,9 +19,25 @@ import org.springframework.stereotype.Service; public class OllamaCallServiceImpl implements AiCallService { private final OllamaChatModel ollamaChatModel; + + private final OllamaEmbeddingModel embeddingModel; @Override public String call(String prompt) { return ollamaChatModel.call(prompt); } + + @Override + public Flux stream(Prompt prompt) { + return ollamaChatModel.stream(prompt); + } + + public void embedding(String text) { + + EmbeddingResponse embeddingResponse = embeddingModel.call( + new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"), + OllamaOptions.builder().model("quentinz/bge-large-zh-v1.5:latest").build())); + Embedding result = embeddingResponse.getResult(); + System.out.println(result); + } } diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java index 56985b0..5c10579 100644 --- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java +++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleToCypherExecutorImpl.java @@ -58,9 +58,18 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor { } List domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList()); CypherSchemaDTO schemaDTO = this.queryRelationSchema(domainMetadataDTOS); + if (CollUtil.isEmpty(schemaDTO.getRelations()) && CollUtil.isEmpty(schemaDTO.getNodes())) { + log.info("没有找到匹配的关系或实体,query: {}", query); + return null; + } String prompt = promptMap.get(TEXT_TO_CYPHER_2); String format = StrUtil.format(prompt, Map.of("question", query, "schema", schemaDTO.format())); - return aiCallService.call(format); + String call = aiCallService.call(format); + if (StrUtil.equals(call,"I could not generate a Cypher script; the required information is not part of the Neo4j schema.")){ + log.info("大模型没能生成cypher,query: {}", query); + return null; + } + return call; } @Override diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 381ee8e..ca796c8 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -16,6 +16,12 @@ spring: max-file-size: 10MB max-request-size: 100MB ai: + openai: + baseUrl: https://api.deepseek.com + apiKey: sk-0b2c506c47e74594b5361c0f6844fd25 + chat: + options: + model: deepseek-chat ollama: baseUrl: http://192.168.10.70:11434 chat: diff --git a/src/main/resources/mapper/ErAttributeMapper.xml b/src/main/resources/mapper/ErAttributeMapper.xml index ea156cd..13fd2f1 100644 --- a/src/main/resources/mapper/ErAttributeMapper.xml +++ b/src/main/resources/mapper/ErAttributeMapper.xml @@ -10,12 +10,13 @@ + - id,domain_metadata_id, + id,domain_metadata_id,er_label, attr_name,attr_value_type,er_type, create_time,update_time diff --git a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java index f5550b6..d782961 100644 --- a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java +++ b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java @@ -1,23 +1,23 @@ package com.supervision.pdfqaserver; import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum; +import com.supervision.pdfqaserver.domain.PdfAnalysisOutput; +import com.supervision.pdfqaserver.dto.CypherSchemaDTO; import com.supervision.pdfqaserver.dto.EREDTO; import com.supervision.pdfqaserver.dto.IntentDTO; import com.supervision.pdfqaserver.dto.TruncateDTO; -import com.supervision.pdfqaserver.service.ChinesEsToEnglishGenerator; -import com.supervision.pdfqaserver.service.KnowledgeGraphService; -import com.supervision.pdfqaserver.service.TripleConversionPipeline; +import com.supervision.pdfqaserver.service.*; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.neo4j.driver.*; import org.neo4j.driver.Record; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import static org.neo4j.driver.Values.parameters; -import org.commonmark.node.*; @Slf4j @SpringBootTest @@ -27,15 +27,15 @@ class PdfQaServerApplicationTests { private KnowledgeGraphService knowledgeGraphService; @Test void generateGraphTest() { - knowledgeGraphService.generateGraph("40"); + knowledgeGraphService.generateGraph("15"); log.info("finish..."); } @Test void testGenerateGraph2() { - List eredtos = knowledgeGraphService.listPdfEREDTO("17"); + List eredtos = knowledgeGraphService.listPdfEREDTO("16"); - knowledgeGraphService.generateGraph(eredtos); + knowledgeGraphService.generateGraphSimple(eredtos); log.info("finish..."); } @@ -160,8 +160,54 @@ class PdfQaServerApplicationTests { @Test void generateGraphBaseTrainTest() { - knowledgeGraphService.generateGraphBaseTrain(14); + knowledgeGraphService.generateGraphBaseTrain(15); } + @Autowired + private AiCallService aiCallService; + @Test + void aiCallServiceCallTest() { + + String call = aiCallService.call("你好"); + System.out.println(call); + } + + @Test + void resetGraphDataTest() { + knowledgeGraphService.resetGraphData("15"); + } + + @Autowired + private PdfAnalysisOutputService pdfAnalysisOutputService; + @Test + void queryGraphTest() { + List pdfAnalysisOutputs = pdfAnalysisOutputService.queryByPdfId(15); + List newPdfAnalysisOutputs = new ArrayList<>(); + for (PdfAnalysisOutput pdfAnalysisOutput : pdfAnalysisOutputs) { + PdfAnalysisOutput pdf = new PdfAnalysisOutput(); + pdf.setContent(pdfAnalysisOutput.getContent()); + pdf.setPageNo(pdfAnalysisOutput.getPageNo()); + pdf.setDisplayOrder(pdfAnalysisOutput.getDisplayOrder()); + pdf.setTableTitle(pdfAnalysisOutput.getTableTitle()); + pdf.setLayoutType(pdfAnalysisOutput.getLayoutType()); + pdf.setPdfId(16); + newPdfAnalysisOutputs.add(pdf); + } + pdfAnalysisOutputService.saveBatch(newPdfAnalysisOutputs); + } + + @Autowired + private TripleToCypherExecutor tripleToCypherExecutor; + @Test + void testQueryGraph() { + CypherSchemaDTO schemaDTO = tripleToCypherExecutor.loadGraphSchema(); + System.out.println(schemaDTO); + } + + @Test + void testQueryGraph2() { + aiCallService.embedding(""); + System.out.println("done"); + } }