代码功能bug修复

优化问答功能
v_0.0.2
xueqingkun 2 weeks ago
parent eaf043aa07
commit 930fcff1f8

@ -100,6 +100,10 @@
<artifactId>commonmark-ext-gfm-tables</artifactId>
<version>0.21.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId>
</dependency>
</dependencies>
<build>
<plugins>

@ -838,7 +838,6 @@ x {}
"- Express simple equality predicates in map patterns and move all other filters to a **WHERE** clause.\\n\\n"
"4. Return Clause Strategy:\\n"
"- RETURN every node and relationship mentioned, unless the user explicitly requests specific properties.\\n\\n"
"- The truncationId、name attribute of a node is very important, and each node needs to return truncationId、name .\\n\\n"
"5. Final Cypher Script Generation:\\n"
"- Respond with **only** the final Cypher query—no commentary or extra text.\\n"
"- Use OPTIONAL MATCH only if explicitly required by the user and supported by the schema.\\n"

@ -9,6 +9,7 @@ import com.supervision.pdfqaserver.dto.neo4j.NodeData;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipData;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.*;
import org.neo4j.driver.Record;
import org.neo4j.driver.types.Node;
@ -20,6 +21,7 @@ import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import static org.neo4j.driver.Values.parameters;
@Slf4j
@Repository
@RequiredArgsConstructor
public class Neo4jRepository {
@ -153,6 +155,7 @@ public class Neo4jRepository {
if (StrUtil.isEmpty(nodeType)){
continue;
}
nodeType = nodeType.substring(1, nodeType.length()-1).replace("`", "");
String propertyName = record.get("propertyName").asString();
List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString);
@ -160,8 +163,9 @@ public class Neo4jRepository {
TruncationERAttributeDTO attributeDTO = new TruncationERAttributeDTO(propertyName, null, CollUtil.getFirst(propertyTypes));
// 检查是否已存在该节点类型
final String nodeType_f = nodeType;
EntityExtractionDTO existingEntity = extractionDTOS.stream()
.filter(e -> StrUtil.equals(e.getEntityEn(), nodeType))
.filter(e -> StrUtil.equals(e.getEntityEn(), nodeType_f))
.findFirst().orElse(null);
if (existingEntity != null) {
@ -197,6 +201,7 @@ public class Neo4jRepository {
if (StrUtil.isEmpty(relType)){
continue;
}
relType = relType.substring(1, relType.length()-1).replace("`", "");
String propertyName = record.get("propertyName").asString();
List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString);
@ -214,14 +219,15 @@ public class Neo4jRepository {
List<RelationExtractionDTO> relationExtractionDTOS = new ArrayList<>();
String queryEndpoints = """
MATCH (s)-[r:`{rtype}`]->(t)
MATCH (s)-[r: `{rtype}` ]->(t)
WITH labels(s)[0] AS src, labels(t)[0] AS tgt
RETURN src, tgt
""";
for (Map.Entry<String, List<Map<String, String>>> entry : relationProperties.entrySet()) {
String relType = entry.getKey();
List<Map<String, String>> properties = entry.getValue();
Result run = session.run(queryEndpoints, parameters("rtype", relType));
String formatted = StrUtil.format(queryEndpoints,Map.of("rtype",relType));
Result run = session.run(formatted);
for (Record record : run.list()) {
String sourceType = record.get("src").asString();
String targetType = record.get("tgt").asString();

@ -38,6 +38,12 @@ public class ErAttribute implements Serializable {
*/
private String erType;
/**
*
*/
private String erLabel;
/**
*

@ -55,6 +55,14 @@ public class CypherSchemaDTO {
return result;
}
public List<EntityExtractionDTO> getNodes() {
return nodes;
}
public List<RelationExtractionDTO> getRelations() {
return relations;
}
/**
* DTO
* @param entity

@ -80,14 +80,14 @@ public class DomainMetadataDTO {
if (StrUtil.equals(erAttribute.getDomainMetadataId(),this.id)){
if(StrUtil.equals(erAttribute.getErType(),"1")){
// 节点数据
if (StrUtil.equals(erAttribute.getAttrName(),this.sourceType)) {
if (StrUtil.equals(erAttribute.getErLabel(),this.sourceType)) {
this.sourceAttributes.add(new ERAttributeDTO(erAttribute));
}
if (StrUtil.equals(erAttribute.getAttrName(),this.targetType)) {
if (StrUtil.equals(erAttribute.getErLabel(),this.targetType)) {
this.targetAttributes.add(new ERAttributeDTO(erAttribute));
}
}else {
if (StrUtil.equals(erAttribute.getAttrName(),this.relation)) {
if (StrUtil.equals(erAttribute.getErLabel(),this.relation)) {
this.relationAttributes.add(new ERAttributeDTO(erAttribute));
}
}

@ -23,6 +23,11 @@ public class ERAttributeDTO {
*/
private String attrName;
/**
*
*/
private String erLabel;
/**
*
*/
@ -37,14 +42,6 @@ public class ERAttributeDTO {
public ERAttributeDTO() {
}
public ERAttributeDTO(String id, String domainMetadataId, String erName, String attrName, String attrValueType, String erType) {
this.id = id;
this.domainMetadataId = domainMetadataId;
this.erName = erName;
this.attrName = attrName;
this.attrValueType = attrValueType;
this.erType = erType;
}
public ERAttributeDTO(String attrName) {
this.attrName = attrName;
@ -56,6 +53,7 @@ public class ERAttributeDTO {
this.attrName = erAttribute.getAttrName();
this.attrValueType = erAttribute.getAttrValueType();
this.erType = erAttribute.getErType();
this.erLabel = erAttribute.getErLabel();
}
public ErAttribute toErAttribute() {
@ -65,6 +63,7 @@ public class ERAttributeDTO {
erAttribute.setAttrName(this.attrName);
erAttribute.setAttrValueType(this.attrValueType);
erAttribute.setErType(this.erType);
erAttribute.setErLabel(this.erLabel);
return erAttribute;
}
}

@ -232,6 +232,23 @@ public class EREDTO {
}
public void setEn() {
for (EntityExtractionDTO entity : entities) {
entity.setEntityEn(entity.getEntity());
for (TruncationERAttributeDTO attribute : entity.getAttributes()) {
attribute.setAttributeEn(attribute.getAttribute());
}
}
for (RelationExtractionDTO relation : relations) {
relation.setRelationEn(relation.getRelation());
relation.setSourceTypeEn(relation.getSourceType());
relation.setTargetTypeEn(relation.getTargetType());
for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
attribute.setAttributeEn(attribute.getAttribute());
}
}
}
private void setAttributeEn(TruncationERAttributeDTO attribute, List<ChineseEnglishWords> wordsList) {
if (null == attribute || CollUtil.isEmpty(wordsList)){
return;

@ -13,4 +13,6 @@ public interface AiCallService {
String call(String prompt);
Flux<ChatResponse> stream(Prompt prompt);
abstract void embedding(String text);
}

@ -0,0 +1,34 @@
package com.supervision.pdfqaserver.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class DeepSeekApiImpl implements AiCallService {
private final OpenAiChatModel ollamaChatModel;
@Override
public String call(String prompt) {
if (prompt.endsWith("./no_think")){
prompt = prompt.replace("./no_think", "");
}
prompt = prompt.replace("./no_think", "");
return ollamaChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return ollamaChatModel.stream(prompt);
}
@Override
public void embedding(String text) {
}
}

@ -86,6 +86,7 @@ public class DomainMetadataServiceImpl extends ServiceImpl<DomainMetadataMapper,
for (ERAttributeDTO relationAttribute : relationAttributes) {
relationAttribute.setDomainMetadataId(metadata.getId());
relationAttribute.setErType("2");
relationAttribute.setErLabel(metadata.getRelation());
erAttributeService.saveIfAbsents(relationAttribute.toErAttribute(), metadata.getId());
}
}
@ -96,6 +97,7 @@ public class DomainMetadataServiceImpl extends ServiceImpl<DomainMetadataMapper,
for (ERAttributeDTO nodeAttribute : nodeAttributes) {
nodeAttribute.setDomainMetadataId(metadata.getId());
nodeAttribute.setErType("1");
nodeAttribute.setErLabel(nodeAttribute.getAttrName());
erAttributeService.saveIfAbsents(nodeAttribute.toErAttribute(), metadata.getId());
}
}

@ -26,7 +26,7 @@ public class ErAttributeServiceImpl extends ServiceImpl<ErAttributeMapper, ErAtt
Assert.notEmpty(domainMetadataId, "领域分类id不能为空");
List<ErAttribute> erAttributes = this.listByDomainMetadataId(domainMetadataId);
boolean exists = erAttributes.stream().anyMatch(item -> StrUtil.equals(item.getAttrName(), erAttribute.getAttrName())
&& StrUtil.equals(item.getAttrValueType(), erAttribute.getAttrValueType()));
&& StrUtil.equals(item.getErLabel(), erAttribute.getErLabel()));
if (exists){
log.info("属性已存在,{},不进行保存...", erAttribute.getAttrName());
return;

@ -410,6 +410,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
if (CollUtil.isEmpty(eredto.getEntities()) && CollUtil.isEmpty(eredto.getRelations())){
continue;
}
eredto.setEn();
try {
tripleToCypherExecutor.saveERE(eredto);
} catch (Exception e) {

@ -3,8 +3,15 @@ package com.supervision.pdfqaserver.service.impl;
import com.supervision.pdfqaserver.service.AiCallService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.*;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.OllamaEmbeddingModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.List;
@Slf4j
@Service
@ -12,9 +19,25 @@ import org.springframework.stereotype.Service;
public class OllamaCallServiceImpl implements AiCallService {
private final OllamaChatModel ollamaChatModel;
private final OllamaEmbeddingModel embeddingModel;
@Override
public String call(String prompt) {
return ollamaChatModel.call(prompt);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return ollamaChatModel.stream(prompt);
}
public void embedding(String text) {
EmbeddingResponse embeddingResponse = embeddingModel.call(
new EmbeddingRequest(List.of("Hello World", "World is big and salvation is near"),
OllamaOptions.builder().model("quentinz/bge-large-zh-v1.5:latest").build()));
Embedding result = embeddingResponse.getResult();
System.out.println(result);
}
}

@ -58,9 +58,18 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
}
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
CypherSchemaDTO schemaDTO = this.queryRelationSchema(domainMetadataDTOS);
if (CollUtil.isEmpty(schemaDTO.getRelations()) && CollUtil.isEmpty(schemaDTO.getNodes())) {
log.info("没有找到匹配的关系或实体query: {}", query);
return null;
}
String prompt = promptMap.get(TEXT_TO_CYPHER_2);
String format = StrUtil.format(prompt, Map.of("question", query, "schema", schemaDTO.format()));
return aiCallService.call(format);
String call = aiCallService.call(format);
if (StrUtil.equals(call,"I could not generate a Cypher script; the required information is not part of the Neo4j schema.")){
log.info("大模型没能生成cypherquery: {}", query);
return null;
}
return call;
}
@Override

@ -16,6 +16,12 @@ spring:
max-file-size: 10MB
max-request-size: 100MB
ai:
openai:
baseUrl: https://api.deepseek.com
apiKey: sk-0b2c506c47e74594b5361c0f6844fd25
chat:
options:
model: deepseek-chat
ollama:
baseUrl: http://192.168.10.70:11434
chat:

@ -10,12 +10,13 @@
<result property="attrName" column="attr_name" jdbcType="VARCHAR"/>
<result property="attrValueType" column="attr_value_type" jdbcType="VARCHAR"/>
<result property="erType" column="er_type" jdbcType="VARCHAR"/>
<result property="erLabel" column="er_label" jdbcType="VARCHAR"/>
<result property="createTime" column="create_time" jdbcType="TIMESTAMP"/>
<result property="updateTime" column="update_time" jdbcType="TIMESTAMP"/>
</resultMap>
<sql id="Base_Column_List">
id,domain_metadata_id,
id,domain_metadata_id,er_label,
attr_name,attr_value_type,er_type,
create_time,update_time
</sql>

@ -1,23 +1,23 @@
package com.supervision.pdfqaserver;
import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum;
import com.supervision.pdfqaserver.domain.PdfAnalysisOutput;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import com.supervision.pdfqaserver.dto.EREDTO;
import com.supervision.pdfqaserver.dto.IntentDTO;
import com.supervision.pdfqaserver.dto.TruncateDTO;
import com.supervision.pdfqaserver.service.ChinesEsToEnglishGenerator;
import com.supervision.pdfqaserver.service.KnowledgeGraphService;
import com.supervision.pdfqaserver.service.TripleConversionPipeline;
import com.supervision.pdfqaserver.service.*;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.neo4j.driver.*;
import org.neo4j.driver.Record;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.neo4j.driver.Values.parameters;
import org.commonmark.node.*;
@Slf4j
@SpringBootTest
@ -27,15 +27,15 @@ class PdfQaServerApplicationTests {
private KnowledgeGraphService knowledgeGraphService;
@Test
void generateGraphTest() {
knowledgeGraphService.generateGraph("40");
knowledgeGraphService.generateGraph("15");
log.info("finish...");
}
@Test
void testGenerateGraph2() {
List<EREDTO> eredtos = knowledgeGraphService.listPdfEREDTO("17");
List<EREDTO> eredtos = knowledgeGraphService.listPdfEREDTO("16");
knowledgeGraphService.generateGraph(eredtos);
knowledgeGraphService.generateGraphSimple(eredtos);
log.info("finish...");
}
@ -160,8 +160,54 @@ class PdfQaServerApplicationTests {
@Test
void generateGraphBaseTrainTest() {
knowledgeGraphService.generateGraphBaseTrain(14);
knowledgeGraphService.generateGraphBaseTrain(15);
}
@Autowired
private AiCallService aiCallService;
@Test
void aiCallServiceCallTest() {
String call = aiCallService.call("你好");
System.out.println(call);
}
@Test
void resetGraphDataTest() {
knowledgeGraphService.resetGraphData("15");
}
@Autowired
private PdfAnalysisOutputService pdfAnalysisOutputService;
@Test
void queryGraphTest() {
List<PdfAnalysisOutput> pdfAnalysisOutputs = pdfAnalysisOutputService.queryByPdfId(15);
List<PdfAnalysisOutput> newPdfAnalysisOutputs = new ArrayList<>();
for (PdfAnalysisOutput pdfAnalysisOutput : pdfAnalysisOutputs) {
PdfAnalysisOutput pdf = new PdfAnalysisOutput();
pdf.setContent(pdfAnalysisOutput.getContent());
pdf.setPageNo(pdfAnalysisOutput.getPageNo());
pdf.setDisplayOrder(pdfAnalysisOutput.getDisplayOrder());
pdf.setTableTitle(pdfAnalysisOutput.getTableTitle());
pdf.setLayoutType(pdfAnalysisOutput.getLayoutType());
pdf.setPdfId(16);
newPdfAnalysisOutputs.add(pdf);
}
pdfAnalysisOutputService.saveBatch(newPdfAnalysisOutputs);
}
@Autowired
private TripleToCypherExecutor tripleToCypherExecutor;
@Test
void testQueryGraph() {
CypherSchemaDTO schemaDTO = tripleToCypherExecutor.loadGraphSchema();
System.out.println(schemaDTO);
}
@Test
void testQueryGraph2() {
aiCallService.embedding("");
System.out.println("done");
}
}

Loading…
Cancel
Save