问答功能优化-初始化表

v_0.0.2
xueqingkun 1 day ago
parent 2ea04d7325
commit fe1a6f1b1b

@ -104,6 +104,11 @@
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-openai</artifactId> <artifactId>spring-ai-starter-model-openai</artifactId>
</dependency> </dependency>
<dependency>
<groupId>com.hankcs</groupId>
<artifactId>hanlp</artifactId>
<version>portable-1.8.6</version>
</dependency>
</dependencies> </dependencies>
<build> <build>
<plugins> <plugins>

@ -87,6 +87,11 @@ public class PromptCache {
public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT"; public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT";
public static final String TEXT_TO_CYPHER_2 = "TEXT_TO_CYPHER_2"; public static final String TEXT_TO_CYPHER_2 = "TEXT_TO_CYPHER_2";
/**
* Cypher3
*/
public static final String TEXT_TO_CYPHER_3 = "TEXT_TO_CYPHER_3";
public static final String TEXT_TO_CYPHER_4 = "TEXT_TO_CYPHER_4";
public static final Map<String, String> promptMap = new HashMap<>(); public static final Map<String, String> promptMap = new HashMap<>();
@ -110,6 +115,8 @@ public class PromptCache {
promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT); promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT);
promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT); promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT);
promptMap.put(TEXT_TO_CYPHER_2, TEXT_TO_CYPHER_2_PROMPT); promptMap.put(TEXT_TO_CYPHER_2, TEXT_TO_CYPHER_2_PROMPT);
promptMap.put(TEXT_TO_CYPHER_3, TEXT_TO_CYPHER_3_PROMPT);
promptMap.put(TEXT_TO_CYPHER_4, TEXT_TO_CYPHER_4_PROMPT);
} }
@ -827,4 +834,98 @@ public class PromptCache {
- ****MATCH **** - ****MATCH ****
- cyphercypher/no_think - cyphercypher/no_think
"""; """;
private static final String TEXT_TO_CYPHER_3_PROMPT = """
CypherCypher`neo4j_schema`
```text
{query}
```
neo4j_schemaJSON
```schema
{schema}
```
#
{env}
1.
- 线schema
-
-
-
-
2.
-
- ********
3. MATCH
- 使,****
- ****`-[r:REL_TYPE]->`
- 使
- WHERE
4. RETURN
- ********
- ****
5. Cypher
- Cypher```cypher ```
- ****MATCH ****
- neo4j_schemacypher便
- cypher:['cypher1','...']
- cyphercypher/no_think
""";
private static final String TEXT_TO_CYPHER_4_PROMPT = """
Cyphercyphercypher便`neo4j_schema`cypher
```text
{query}
```
neo4j_schemaJSON
```shema
{shema}
```
#
${env}
# cypher
```json
{cypher}
```
1.
- 线schema
-
-
-
-
- cypher
2.
-
- ********
3. MATCH
- 使,****
- ****`-[r:REL_TYPE]->`
- 使
- WHERE
- cypher
4. RETURN
- ********
- ****
5. Cypher
- Cypher```cypher ```
- ****MATCH ****
- neo4j_schemacypher便
- cypher:['cypher1','...']
- cyphercypher/no_think
""";
} }

@ -165,12 +165,19 @@ public class Neo4jRepository {
// 检查是否已存在该节点类型 // 检查是否已存在该节点类型
final String nodeType_f = nodeType; final String nodeType_f = nodeType;
EntityExtractionDTO existingEntity = extractionDTOS.stream() EntityExtractionDTO existingEntity = extractionDTOS.stream()
.filter(e -> StrUtil.equals(e.getEntityEn(), nodeType_f)) .filter(e -> StrUtil.equals(e.getEntity(), nodeType_f))
.findFirst().orElse(null); .findFirst().orElse(null);
if (existingEntity != null) { if (existingEntity != null) {
// 如果已存在,添加属性 // 如果已存在,添加属性
existingEntity.getAttributes().add(attributeDTO); List<TruncationERAttributeDTO> attributes = existingEntity.getAttributes();
boolean noneMatch = attributes.stream().noneMatch(
attr -> StrUtil.equals(attr.getAttribute(), attributeDTO.getAttribute())
);
if (noneMatch) {
// 如果属性不存在,添加属性
attributes.add(attributeDTO);
}
} else { } else {
// 如果不存在创建新的实体DTO // 如果不存在创建新的实体DTO
List<TruncationERAttributeDTO> truncationERAttributeDTOS = new ArrayList<>(); List<TruncationERAttributeDTO> truncationERAttributeDTOS = new ArrayList<>();
@ -187,7 +194,7 @@ public class Neo4jRepository {
* schema * schema
* @return * @return
*/ */
public List<RelationExtractionDTO> getRelationSchema(){ public List<RelationExtractionDTO> getRelationSchema() {
String queryProper = """ String queryProper = """
CALL db.schema.relTypeProperties() CALL db.schema.relTypeProperties()
YIELD relType, propertyName, propertyTypes YIELD relType, propertyName, propertyTypes
@ -198,10 +205,10 @@ public class Neo4jRepository {
Result result = session.run(queryProper); Result result = session.run(queryProper);
for (Record record : result.list()) { for (Record record : result.list()) {
String relType = record.get("relType").asString(); String relType = record.get("relType").asString();
if (StrUtil.isEmpty(relType)){ if (StrUtil.isEmpty(relType)) {
continue; continue;
} }
relType = relType.substring(1, relType.length()-1).replace("`", ""); relType = relType.substring(1, relType.length() - 1).replace("`", "");
String propertyName = record.get("propertyName").asString(); String propertyName = record.get("propertyName").asString();
List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString); List<String> propertyTypes = record.get("propertyTypes").asList(Value::asString);
@ -209,7 +216,7 @@ public class Neo4jRepository {
boolean noneMatch = properties.stream().noneMatch( boolean noneMatch = properties.stream().noneMatch(
prop -> StrUtil.equals(prop.get("propertyName"), propertyName) prop -> StrUtil.equals(prop.get("propertyName"), propertyName)
); );
if (noneMatch){ if (noneMatch) {
Map<String, String> propMap = new HashMap<>(); Map<String, String> propMap = new HashMap<>();
propMap.put("propertyName", propertyName); propMap.put("propertyName", propertyName);
propMap.put("propertyTypes", CollUtil.getFirst(propertyTypes)); propMap.put("propertyTypes", CollUtil.getFirst(propertyTypes));
@ -219,14 +226,14 @@ public class Neo4jRepository {
List<RelationExtractionDTO> relationExtractionDTOS = new ArrayList<>(); List<RelationExtractionDTO> relationExtractionDTOS = new ArrayList<>();
String queryEndpoints = """ String queryEndpoints = """
MATCH (s)-[r: `{rtype}` ]->(t) MATCH (s)-[r: `{rtype}` ]->(t)
WITH labels(s)[0] AS src, labels(t)[0] AS tgt WITH labels(s)[0] AS src, labels(t)[0] AS tgt
RETURN src, tgt RETURN src, tgt
"""; """;
for (Map.Entry<String, List<Map<String, String>>> entry : relationProperties.entrySet()) { for (Map.Entry<String, List<Map<String, String>>> entry : relationProperties.entrySet()) {
String relType = entry.getKey(); String relType = entry.getKey();
List<Map<String, String>> properties = entry.getValue(); List<Map<String, String>> properties = entry.getValue();
String formatted = StrUtil.format(queryEndpoints,Map.of("rtype",relType)); String formatted = StrUtil.format(queryEndpoints, Map.of("rtype", relType));
Result run = session.run(formatted); Result run = session.run(formatted);
for (Record record : run.list()) { for (Record record : run.list()) {
String sourceType = record.get("src").asString(); String sourceType = record.get("src").asString();
@ -234,12 +241,32 @@ public class Neo4jRepository {
List<TruncationERAttributeDTO> attributeDTOS = properties.stream().map( List<TruncationERAttributeDTO> attributeDTOS = properties.stream().map(
prop -> new TruncationERAttributeDTO(prop.get("propertyName"), null, prop.get("propertyTypes")) prop -> new TruncationERAttributeDTO(prop.get("propertyName"), null, prop.get("propertyTypes"))
).collect(Collectors.toList()); ).collect(Collectors.toList());
RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null,null,sourceType, RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(null, null, sourceType,
relType, relType,
null, null,
targetType, targetType,
attributeDTOS); attributeDTOS);
relationExtractionDTOS.add(relationExtractionDTO); // 合并关系数据
Optional<RelationExtractionDTO> optional = relationExtractionDTOS.stream().filter(rel ->
StrUtil.equals(rel.getSourceType(), sourceType) &&
StrUtil.equals(rel.getRelation(), relType) &&
StrUtil.equals(rel.getTargetType(), targetType)).findFirst();
if (optional.isPresent()) {
List<TruncationERAttributeDTO> attributes = optional.get().getAttributes();
for (TruncationERAttributeDTO attribute : attributeDTOS) {
boolean noneMatch = attributes.stream().noneMatch(
attr -> StrUtil.equals(attr.getAttribute(), attribute.getAttribute())
);
if (noneMatch) {
attributes.add(attribute);
}
}
} else {
// 如果不存在,直接添加
relationExtractionDTO.setAttributes(attributeDTOS);
relationExtractionDTOS.add(relationExtractionDTO);
}
} }
} }
return relationExtractionDTOS; return relationExtractionDTOS;

@ -0,0 +1,54 @@
package com.supervision.pdfqaserver.domain;
import com.baomidou.mybatisplus.annotation.*;
import java.io.Serializable;
import java.time.LocalDateTime;
import com.supervision.pdfqaserver.config.VectorTypeHandler;
import lombok.Data;
/**
*
* @TableName node_relation_vector
*/
@TableName(value ="node_relation_vector")
@Data
public class NodeRelationVector implements Serializable {
/**
*
*/
@TableId
private String id;
/**
*
*/
private String content;
/**
*
*/
@TableField(typeHandler = VectorTypeHandler.class)
private float[] embedding;
/**
* N: R: ER:
*/
private String contentType;
/**
*
*/
@TableField(fill = FieldFill.INSERT)
private LocalDateTime createTime;
/**
*
*/
@TableField(fill = FieldFill.INSERT_UPDATE)
private LocalDateTime updateTime;
@TableField(exist = false)
private static final long serialVersionUID = 1L;
}

@ -41,15 +41,24 @@ public class CypherSchemaDTO {
/** /**
* DTO * DTO
* @param sourceOrTargetType * @param str
* @return * @return
*/ */
public List<RelationExtractionDTO> getRelations(String sourceOrTargetType) { public List<RelationExtractionDTO> getRelations(String str) {
List<RelationExtractionDTO> result = new ArrayList<>(); List<RelationExtractionDTO> result = new ArrayList<>();
for (RelationExtractionDTO relationDTO : relations) { for (RelationExtractionDTO relationDTO : relations) {
if (StrUtil.equals(relationDTO.getSourceType(), sourceOrTargetType) || if (StrUtil.equals(relationDTO.getSourceType(), str) ||
StrUtil.equals(relationDTO.getTargetType(), sourceOrTargetType)) { StrUtil.equals(relationDTO.getTargetType(), str) ||
result.add(relationDTO); StrUtil.equals(relationDTO.getRelation(), str)) {
boolean noneMatch = result.stream().noneMatch(
r -> StrUtil.equals(r.getSourceType(), relationDTO.getSourceType()) &&
StrUtil.equals(r.getRelation(), relationDTO.getRelation()) &&
StrUtil.equals(r.getTargetType(), relationDTO.getTargetType())
);
if (noneMatch){
result.add(relationDTO);
}
} }
} }
return result; return result;
@ -90,6 +99,9 @@ public class CypherSchemaDTO {
for (TruncationERAttributeDTO attribute : attributes) { for (TruncationERAttributeDTO attribute : attributes) {
boolean none = nodeAttr.entrySet().stream().noneMatch( boolean none = nodeAttr.entrySet().stream().noneMatch(
entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())); entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()));
if ("truncationId".equals(attribute.getAttribute())){
continue;
}
if (none){ if (none){
nodeAttr.set(attribute.getAttribute(), attribute.getDataType()); nodeAttr.set(attribute.getAttribute(), attribute.getDataType());
} }
@ -108,10 +120,14 @@ public class CypherSchemaDTO {
} }
json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType})); json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType}));
for (TruncationERAttributeDTO attribute : relation.getAttributes()) { for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
if ("truncationId".equals(attribute.getAttribute())){
continue;
}
boolean none = json.entrySet().stream().noneMatch( boolean none = json.entrySet().stream().noneMatch(
entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute()) entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())
); );
if (none) { if (none) {
json.set(attribute.getAttribute(), attribute.getDataType()); json.set(attribute.getAttribute(), attribute.getDataType());
} }
} }

@ -0,0 +1,43 @@
package com.supervision.pdfqaserver.dto;
import cn.hutool.core.util.StrUtil;
import lombok.Data;
@Data
public class TextTerm {
/**
*
*/
public String word;
/**
*
*/
public String label;
private float[] embedding;
public String getLabelValue() {
if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){
return word;
}
if (StrUtil.equals(label,"nt")){
return "机构";
}
if (StrUtil.equalsAny(label,"ntc","公司")){
return "公司";
}
if (StrUtil.equals(label,"ntcf")){
return "工厂";
}
if (StrUtil.equals(label,"nto")){
return "政府机构";
}
if (StrUtil.equals(label,"企业")){
return "企业";
}
return null;
}
}

@ -9,7 +9,7 @@ import java.util.Map;
@Data @Data
public class NodeDTO { public class NodeDTO {
private long id; private Long id;
private String elementId; private String elementId;
@ -27,4 +27,9 @@ public class NodeDTO {
this.properties = internalNode.asMap(); this.properties = internalNode.asMap();
this.labels = internalNode.labels(); this.labels = internalNode.labels();
} }
public void clearGraphElement(){
this.id = null;
this.elementId = null;
}
} }

@ -9,17 +9,17 @@ import java.util.Map;
public class RelationshipValueDTO { public class RelationshipValueDTO {
private long start; private Long start;
private String startElementId; private String startElementId;
private long end; private Long end;
private String endElementId; private String endElementId;
private String type; private String type;
private long id; private Long id;
private String elementId; private String elementId;
@ -30,7 +30,7 @@ public class RelationshipValueDTO {
} }
public RelationshipValueDTO(InternalRelationship relationship) { public RelationshipValueDTO(InternalRelationship relationship) {
this.start = (int) relationship.startNodeId(); this.start = relationship.startNodeId();
this.startElementId = relationship.startNodeElementId(); this.startElementId = relationship.startNodeElementId();
this.end = relationship.endNodeId(); this.end = relationship.endNodeId();
this.endElementId = relationship.endNodeElementId(); this.endElementId = relationship.endNodeElementId();
@ -40,4 +40,14 @@ public class RelationshipValueDTO {
this.properties = relationship.asMap(); this.properties = relationship.asMap();
} }
public void clearGraphElement() {
this.id = null;
this.elementId = null;
this.start = null;
this.startElementId = null;
this.end = null;
this.endElementId = null;
}
} }

@ -0,0 +1,22 @@
package com.supervision.pdfqaserver.mapper;
import com.supervision.pdfqaserver.domain.NodeRelationVector;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import java.util.List;
/**
* @author Administrator
* @description node_relation_vector()Mapper
* @createDate 2025-06-18 13:38:02
* @Entity com.supervision.pdfqaserver.domain.NodeRelationVector
*/
public interface NodeRelationVectorMapper extends BaseMapper<NodeRelationVector> {
List<NodeRelationVector> findSimilarByCosine(float[] embedding, double threshold, List<String> contentType, int limit);
Double matchContentScore(float[] embedding, String content);
}

@ -15,7 +15,10 @@ import java.util.List;
*/ */
public interface TextVectorMapper extends BaseMapper<TextVector> { public interface TextVectorMapper extends BaseMapper<TextVector> {
List<TextVectorDTO> findSimilarByCosine(@Param("embedding")float[] embedding, @Param("threshold") double threshold, @Param("limit")int limit); List<TextVectorDTO> findSimilarByCosine(@Param("embedding")float[] embedding,
@Param("threshold") double threshold,
@Param("categoryId") String categoryId,
@Param("limit")int limit);
} }

@ -4,6 +4,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.Embedding; import org.springframework.ai.embedding.Embedding;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import java.util.List;
/** /**
* @description: AI * @description: AI
@ -16,4 +17,6 @@ public interface AiCallService {
Flux<ChatResponse> stream(Prompt prompt); Flux<ChatResponse> stream(Prompt prompt);
Embedding embedding(String text); Embedding embedding(String text);
List<Embedding> embedding(List<String> texts);
} }

@ -7,7 +7,8 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.embedding.Embedding; import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatModel;
import reactor.core.publisher.Flux; import reactor.core.publisher.Flux;
import org.springframework.stereotype.Service; import java.util.List;
@Slf4j @Slf4j
//@Service //@Service
@RequiredArgsConstructor @RequiredArgsConstructor
@ -33,4 +34,9 @@ public class DeepSeekApiImpl implements AiCallService {
return null; return null;
} }
@Override
public List<Embedding> embedding(List<String> texts) {
return null;
}
} }

@ -0,0 +1,27 @@
package com.supervision.pdfqaserver.service;
import com.supervision.pdfqaserver.domain.NodeRelationVector;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import java.util.List;
/**
* @author Administrator
* @description node_relation_vector()Service
* @createDate 2025-06-18 13:38:02
*/
public interface NodeRelationVectorService extends IService<NodeRelationVector> {
void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO);
List<NodeRelationVector> matchSimilarByCosine(float[] embedding, double threshold , List<String> contentType, int limit);
/**
*
* @param embedding
* @param content
* @return
*/
Double matchContentScore(float[] embedding, String content);
}

@ -10,4 +10,11 @@ import com.baomidou.mybatisplus.extension.service.IService;
*/ */
public interface QuestionCategoryService extends IService<QuestionCategory> { public interface QuestionCategoryService extends IService<QuestionCategory> {
/**
* ID
* @param categoryId ID
* @return
*/
QuestionCategory findCategoryById(String categoryId);
} }

@ -10,4 +10,12 @@ import com.baomidou.mybatisplus.extension.service.IService;
*/ */
public interface QuestionHandlerMappingService extends IService<QuestionHandlerMapping> { public interface QuestionHandlerMappingService extends IService<QuestionHandlerMapping> {
/**
* ID
* @param categoryId ID
* @return
*/
QuestionHandlerMapping findHandlerByCategoryId(String categoryId);
} }

@ -0,0 +1,17 @@
package com.supervision.pdfqaserver.service;
import java.util.List;
import java.util.Map;
/**
*
*/
public interface Retriever {
/**
*
* @param query
* @return
*/
List<Map<String, Object>> retrieval(String query);
}

@ -0,0 +1,80 @@
package com.supervision.pdfqaserver.service;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.domain.QuestionHandlerMapping;
import com.supervision.pdfqaserver.dto.TextVectorDTO;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.embedding.Embedding;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
*
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class RetrieverDispatcher {
private final ApplicationContext applicationContext;
private final AiCallService aiCallService;
private final TextVectorService textVectorService;
private final QuestionHandlerMappingService questionHandlerMappingService;
@Value("${retriever.threshold:0.8}")
private double threshold; // 相似度阈值
private final Map<String, Retriever> retrieverMap = new HashMap<>();
/**
*
*
* @param query
* @return
*/
public Retriever mapping(String query) {
if (StrUtil.isEmpty(query)) {
log.warn("查询内容为空,无法获取检索器");
return null;
}
Embedding embedding = aiCallService.embedding(query);
List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(embedding.getOutput(), threshold, 1);
if (CollUtil.isEmpty(similarByCosine)) {
log.info("问题:{},未找到相似文本向量,匹配阈值:{}", query, threshold);
return null;
}
TextVectorDTO textVectorDTO = CollUtil.getFirst(similarByCosine);
Assert.notEmpty(textVectorDTO.getCategoryId(), "相似文本向量的分类ID不能为空");
QuestionHandlerMapping handler = questionHandlerMappingService.findHandlerByCategoryId(textVectorDTO.getCategoryId());
if (handler == null){
return null;
}
return retrieverMap.get(handler.getHandler());
}
@PostConstruct
public void init() {
applicationContext.getBeansOfType(Retriever.class)
.forEach((name, retriever) -> {
if (retrieverMap.containsKey(name)) {
throw new IllegalArgumentException("Retriever with name " + name + " already exists.");
}
retrieverMap.put(name, retriever);
});
}
}

@ -0,0 +1,23 @@
package com.supervision.pdfqaserver.service;
import com.supervision.pdfqaserver.dto.TextTerm;
import java.util.List;
public interface TextToSegmentService {
/**
*
* @param text
* @return
*/
List<TextTerm> segmentText(String text);
/**
*
* @param word
* @param label
* @param frequency
*/
void addDict(String word, String label,int frequency);
}

@ -35,6 +35,8 @@ public interface TripleToCypherExecutor {
*/ */
List<Map<String, Object>> executeCypher(String cypher); List<Map<String, Object>> executeCypher(String cypher);
Map<String, List<Map<String, Object>>> executeCypher(List<String> cypher);
void saveERE(EREDTO eredto); void saveERE(EREDTO eredto);
@ -43,13 +45,25 @@ public interface TripleToCypherExecutor {
*/ */
CypherSchemaDTO loadGraphSchema(); CypherSchemaDTO loadGraphSchema();
/**
* schema
*/
void refreshSchemaSegmentVector();
/** /**
* schema * schema
* @param metadataDTOS * @param metadataDTOS
* @return * @return
*/ */
CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS); CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS);
/**
* schema
* @param query
* @return schema
*/
CypherSchemaDTO queryRelationSchema(String query);
} }

@ -42,13 +42,15 @@ public class ChatServiceImpl implements ChatService {
private final TripleToCypherExecutor tripleToCypherExecutor; private final TripleToCypherExecutor tripleToCypherExecutor;
private final DataCompareRetriever compareRetriever;
@Override @Override
public Flux<String> knowledgeQA(String userQuery) { public Flux<String> knowledgeQA(String userQuery) {
log.info("用户查询: {}", userQuery); log.info("用户查询: {}", userQuery);
// 生成cypher语句 // 生成cypher语句
String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null); /*String cypher = tripleToCypherExecutor.generateQueryCypher(userQuery,null);
log.info("生成CYPHER语句的消息{}", cypher); log.info("生成CYPHER语句的消息{}", cypher);
if (StrUtil.isEmpty(cypher)){ if (StrUtil.isEmpty(cypher)){
return Flux.just("查无结果").concatWith(Flux.just("[END]")); return Flux.just("查无结果").concatWith(Flux.just("[END]"));
@ -58,7 +60,9 @@ public class ChatServiceImpl implements ChatService {
List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher); List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher);
if (CollUtil.isEmpty(graphResult)){ if (CollUtil.isEmpty(graphResult)){
return Flux.just("查无结果").concatWith(Flux.just("[END]")); return Flux.just("查无结果").concatWith(Flux.just("[END]"));
} }*/
List<Map<String, Object>> graphResult = compareRetriever.retrieval(userQuery);
//生成回答 //生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER)); SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery)); Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery));

@ -0,0 +1,113 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO;
import com.supervision.pdfqaserver.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
/**
*
*/
@Slf4j
@Service("dataCompareRetriever")
@RequiredArgsConstructor
public class DataCompareRetriever implements Retriever {
private final TripleToCypherExecutor tripleToCypherExecutor;
private final AiCallService aiCallService;
@Override
public List<Map<String, Object>> retrieval(String query) {
log.info("retrieval: 执行数据对比检索器,查询内容:{}", query);
if (StrUtil.isEmpty(query)) {
log.warn("查询内容为空,无法执行数据对比检索");
return new ArrayList<>();
}
// 对问题进行分词
CypherSchemaDTO schemaDTO = tripleToCypherExecutor.queryRelationSchema(query);
log.info("retrieval: 查询到的关系图谱schema 节点个数:{} ,关系结束{} ", schemaDTO.getNodes().size(), schemaDTO.getRelations().size());
log.info("retrieval: 查询到的关系图谱schema {} ", schemaDTO.format());
if (CollUtil.isEmpty(schemaDTO.getRelations()) || CollUtil.isEmpty(schemaDTO.getNodes())) {
log.info("没有找到匹配的关系或实体query: {}", query);
return new ArrayList<>();
}
// 利用大模型生成可执行的cypher语句
String prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_3);
String format = StrUtil.format(prompt, Map.of("query", query, "schema", schemaDTO.format(), "env", "- 当前时间是:" + DateUtil.now()));
log.info("retrieval: 生成的cypher语句{}", format);
String call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
if (StrUtil.isEmpty(call)) {
log.warn("retrieval: AI调用返回结果为空无法执行Cypher查询");
return new ArrayList<>();
}
List<Map<String, Object>> result = new ArrayList<>();
JSONArray js = JSONUtil.parseArray(call);
Map<String, List<Map<String, Object>>> cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return clearGraphElements(result);
}
}
if (CollUtil.isEmpty(result)){
log.info("retrieval: 执行Cypher语句无结果重新调整cypher语句{}", query);
prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4);
format = StrUtil.format(prompt,
Map.of("query", query, "schema", schemaDTO.format(),
"env", "- 当前时间是:" + DateUtil.now()),"cypher",js.toString());
log.info("retrieval: 生成的cypher语句{}", format);
call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
js = JSONUtil.parseArray(call);
cypherData = tripleToCypherExecutor.executeCypher(js.toList(String.class));
if (CollUtil.isNotEmpty(cypherData)) {
boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty2){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return clearGraphElements(result);
}
}
}
return clearGraphElements(result);
}
/**
*
* @param graphElements
* @return
*/
private List<Map<String, Object>> clearGraphElements(List<Map<String, Object>> graphElements) {
if (CollUtil.isEmpty(graphElements)){
return graphElements;
}
for (Map<String, Object> element : graphElements) {
for (Object value : element.values()) {
if (value instanceof RelationshipValueDTO relationshipValueDTO) {
relationshipValueDTO.clearGraphElement();
}
if (value instanceof com.supervision.pdfqaserver.dto.neo4j.NodeDTO nodeDTO) {
nodeDTO.clearGraphElement();
}
}
}
return graphElements;
}
}

@ -0,0 +1,108 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.pdfqaserver.domain.NodeRelationVector;
import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.service.AiCallService;
import com.supervision.pdfqaserver.service.NodeRelationVectorService;
import com.supervision.pdfqaserver.mapper.NodeRelationVectorMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.embedding.Embedding;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.List;
/**
* @author Administrator
* @description node_relation_vector()Service
* @createDate 2025-06-18 13:38:02
*/
@Slf4j
@Service
@RequiredArgsConstructor
public class NodeRelationVectorServiceImpl extends ServiceImpl<NodeRelationVectorMapper, NodeRelationVector>
implements NodeRelationVectorService{
private final AiCallService aiCallService;
@Override
@Transactional(rollbackFor = Exception.class)
public void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO) {
// 删除旧的向量数据
super.lambdaUpdate().remove();
// 重新插入新的向量数据
List<EntityExtractionDTO> nodes = cypherSchemaDTO.getNodes();
List<RelationExtractionDTO> relations = cypherSchemaDTO.getRelations();
List<NodeRelationVector> allRelationVectors = new ArrayList<>();
List<String> texts = new ArrayList<>();
for (List<RelationExtractionDTO> relationSplit : CollUtil.split(relations, 200)) {
List<String> rs = relationSplit.stream().map(RelationExtractionDTO::getRelation).toList();
List<Embedding> embedding = aiCallService.embedding(rs);
for (Embedding embed : embedding) {
if (texts.contains(rs.get(embed.getIndex()))){
continue;
}
texts.add(rs.get(embed.getIndex()));
NodeRelationVector vector = new NodeRelationVector();
vector.setContent(rs.get(embed.getIndex()));
vector.setEmbedding(embed.getOutput());
vector.setContentType("R");// 关系
allRelationVectors.add(vector);
}
List<String> ers = relationSplit.stream()
.map(r -> StrUtil.join(" ", r.getSourceType(), r.getRelation(),r.getTargetType())).toList();
List<Embedding> erEmbeddings = aiCallService.embedding(ers);
for (Embedding embed : erEmbeddings) {
if (texts.contains(ers.get(embed.getIndex()))) {
continue;
}
texts.add(ers.get(embed.getIndex()));
NodeRelationVector vector = new NodeRelationVector();
vector.setContent(ers.get(embed.getIndex()));
vector.setEmbedding(embed.getOutput());
vector.setContentType("ER");
allRelationVectors.add(vector);
}
}
super.saveBatch(allRelationVectors);
List<NodeRelationVector> allNodeVectors = new ArrayList<>();
texts = new ArrayList<>();
for (List<EntityExtractionDTO> entitySplit : CollUtil.split(nodes, 200)) {
List<String> es = entitySplit.stream().map(EntityExtractionDTO::getEntity).toList();
List<Embedding> embedding = aiCallService.embedding(es);
for (Embedding embed : embedding) {
if (texts.contains(es.get(embed.getIndex()))) {
continue;
}
texts.add(es.get(embed.getIndex()));
NodeRelationVector vector = new NodeRelationVector();
vector.setContent(es.get(embed.getIndex()));
vector.setEmbedding(embed.getOutput());
vector.setContentType("N");
allNodeVectors.add(vector);
}
}
super.saveBatch(allNodeVectors);
}
@Override
public List<NodeRelationVector> matchSimilarByCosine(float[] embedding, double threshold, List<String> contentType, int limit) {
return super.getBaseMapper().findSimilarByCosine(embedding, threshold, contentType, limit);
}
@Override
public Double matchContentScore(float[] embedding, String content) {
if (StrUtil.isEmpty(content) || embedding == null || embedding.length == 0) {
return 0.0;
}
return super.getBaseMapper().matchContentScore(embedding, content);
}
}

@ -37,4 +37,10 @@ public class OllamaCallServiceImpl implements AiCallService {
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of(text),null)); EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of(text),null));
return embeddingResponse.getResult(); return embeddingResponse.getResult();
} }
@Override
public List<Embedding> embedding(List<String> texts) {
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(texts,null));
return embeddingResponse.getResults();
}
} }

@ -15,6 +15,10 @@ import org.springframework.stereotype.Service;
public class QuestionCategoryServiceImpl extends ServiceImpl<QuestionCategoryMapper, QuestionCategory> public class QuestionCategoryServiceImpl extends ServiceImpl<QuestionCategoryMapper, QuestionCategory>
implements QuestionCategoryService{ implements QuestionCategoryService{
@Override
public QuestionCategory findCategoryById(String categoryId) {
return super.getById(categoryId);
}
} }

@ -1,9 +1,14 @@
package com.supervision.pdfqaserver.service.impl; package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.pdfqaserver.domain.QuestionCategory;
import com.supervision.pdfqaserver.domain.QuestionHandlerMapping; import com.supervision.pdfqaserver.domain.QuestionHandlerMapping;
import com.supervision.pdfqaserver.service.QuestionCategoryService;
import com.supervision.pdfqaserver.service.QuestionHandlerMappingService; import com.supervision.pdfqaserver.service.QuestionHandlerMappingService;
import com.supervision.pdfqaserver.mapper.QuestionHandlerMappingMapper; import com.supervision.pdfqaserver.mapper.QuestionHandlerMappingMapper;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
/** /**
@ -11,10 +16,37 @@ import org.springframework.stereotype.Service;
* @description question_handler_mapping()Service * @description question_handler_mapping()Service
* @createDate 2025-06-13 11:29:01 * @createDate 2025-06-13 11:29:01
*/ */
@Slf4j
@Service @Service
@RequiredArgsConstructor
public class QuestionHandlerMappingServiceImpl extends ServiceImpl<QuestionHandlerMappingMapper, QuestionHandlerMapping> public class QuestionHandlerMappingServiceImpl extends ServiceImpl<QuestionHandlerMappingMapper, QuestionHandlerMapping>
implements QuestionHandlerMappingService{ implements QuestionHandlerMappingService{
private final QuestionCategoryService categoryService;
@Override
public QuestionHandlerMapping findHandlerByCategoryId(String categoryId) {
while (true){
if (StrUtil.isEmpty(categoryId)) {
return null;
}
QuestionHandlerMapping one = super.lambdaQuery().eq(QuestionHandlerMapping::getQuestionCategoryId, categoryId).one();
if (null == one){
log.info("根据分类id:{}未找到处理器映射,尝试查询分类器上级关联数据", categoryId);
QuestionCategory category = categoryService.findCategoryById(categoryId);
if (StrUtil.isEmpty(category.getParentId())) {
log.info("分类id:{} 没有父级id不进行查询", categoryId);
return null;
}else {
log.info("分类id:{} 的父级id为:{}", categoryId, category.getParentId());
categoryId = category.getParentId();
continue;
}
}
return one;
}
}
} }

@ -0,0 +1,49 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.dictionary.CustomDictionary;
import com.hankcs.hanlp.seg.Segment;
import com.hankcs.hanlp.seg.common.Term;
import com.supervision.pdfqaserver.dto.TextTerm;
import com.supervision.pdfqaserver.service.TextToSegmentService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class TextToSegmentServiceImpl implements TextToSegmentService {
@Override
public List<TextTerm> segmentText(String text) {
if (StrUtil.isEmpty(text)){
return new ArrayList<>();
}
Segment segment = HanLP.newSegment()
.enableOrganizationRecognize(true)
.enablePlaceRecognize(true)
.enableNumberQuantifierRecognize(true);
List<Term> seg = segment.seg(text);
if (CollUtil.isEmpty(seg)){
return new ArrayList<>();
}
List<TextTerm> terms = new ArrayList<>();
for (Term term : seg) {
TextTerm textTerm = new TextTerm();
textTerm.setWord(term.word);
textTerm.setLabel(term.nature.toString());
terms.add(textTerm);
}
return terms;
}
@Override
public void addDict(String word, String label,int frequency) {
CustomDictionary.insert(word, label + " " + frequency);
}
}

@ -2,11 +2,11 @@ package com.supervision.pdfqaserver.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.pdfqaserver.domain.TextVector; import com.supervision.pdfqaserver.domain.TextVector;
import com.supervision.pdfqaserver.dto.TextVectorDTO; import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.service.TextVectorService; import com.supervision.pdfqaserver.service.TextVectorService;
import com.supervision.pdfqaserver.mapper.TextVectorMapper; import com.supervision.pdfqaserver.mapper.TextVectorMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.List;
/** /**
@ -15,12 +15,12 @@ import java.util.List;
* @createDate 2025-06-11 16:40:57 * @createDate 2025-06-11 16:40:57
*/ */
@Service @Service
@RequiredArgsConstructor
public class TextVectorServiceImpl extends ServiceImpl<TextVectorMapper, TextVector> public class TextVectorServiceImpl extends ServiceImpl<TextVectorMapper, TextVector>
implements TextVectorService{ implements TextVectorService{
@Override @Override
public List<TextVectorDTO> findSimilarByCosine(float[] embedding, double threshold , int limit) { public List<TextVectorDTO> findSimilarByCosine(float[] embedding, double threshold , int limit) {
return super.getBaseMapper().findSimilarByCosine(embedding, threshold,limit); return super.getBaseMapper().findSimilarByCosine(embedding, threshold,null,limit);
} }
} }

@ -1,24 +1,25 @@
package com.supervision.pdfqaserver.service.impl; package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.TimeInterval;
import cn.hutool.core.lang.Pair;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache; import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository; import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.Intention; import com.supervision.pdfqaserver.domain.Intention;
import com.supervision.pdfqaserver.domain.NodeRelationVector;
import com.supervision.pdfqaserver.dto.*; import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.dto.neo4j.NodeDTO; import com.supervision.pdfqaserver.dto.neo4j.NodeDTO;
import com.supervision.pdfqaserver.dto.neo4j.PathDTO; import com.supervision.pdfqaserver.dto.neo4j.PathDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO; import com.supervision.pdfqaserver.dto.neo4j.RelationshipValueDTO;
import com.supervision.pdfqaserver.service.AiCallService; import com.supervision.pdfqaserver.service.*;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import com.supervision.pdfqaserver.service.IntentionService;
import com.supervision.pdfqaserver.service.TripleToCypherExecutor;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.neo4j.driver.Record; import org.neo4j.driver.Record;
import org.neo4j.driver.internal.InternalNode; import org.neo4j.driver.internal.InternalNode;
import org.neo4j.driver.internal.InternalRelationship; import org.neo4j.driver.internal.InternalRelationship;
import org.springframework.ai.embedding.Embedding;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
@ -41,6 +42,12 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
private final DomainMetadataService domainMetadataService; private final DomainMetadataService domainMetadataService;
private final TextVectorService textVectorService;
private final NodeRelationVectorService nodeRelationVectorService;
private final TextToSegmentService textToSegmentService;
@Override @Override
public String generateInsertCypher(EREDTO eredto) { public String generateInsertCypher(EREDTO eredto) {
@ -80,6 +87,21 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
return mapRecords(records); return mapRecords(records);
} }
@Override
public Map<String, List<Map<String, Object>>> executeCypher(List<String> cypher) {
Map<String, List<Map<String, Object>>> result = new HashMap<>();
for (String c : cypher){
List<Map<String, Object>> maps = null;
try {
maps = executeCypher(c);
} catch (Exception e) {
log.info("执行Cypher语句失败语句{},错误信息:{}", c, e.getMessage());
}
result.put(c, maps);
}
return result;
}
private List<Map<String, Object>> mapRecords(List<Record> records) { private List<Map<String, Object>> mapRecords(List<Record> records) {
List<Map<String, Object>> recordList = new ArrayList<>(); List<Map<String, Object>> recordList = new ArrayList<>();
@ -185,18 +207,24 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
return new CypherSchemaDTO(entitySchema, relationSchema); return new CypherSchemaDTO(entitySchema, relationSchema);
} }
@Override
public void refreshSchemaSegmentVector() {
loadCypherSchemaIfAbsent();
if (cypherSchemaDTO == null) {
log.warn("图谱schema数据为空,不用刷新分词向量...");
return;
}
log.info("开始刷新图谱schema分词向量...");
nodeRelationVectorService.refreshSchemaSegmentVector(cypherSchemaDTO);
log.info("图谱schema分词向量刷新完成...");
}
@Override @Override
public CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS) { public CypherSchemaDTO queryRelationSchema(List<DomainMetadataDTO> metadataDTOS) {
if (CollUtil.isEmpty(metadataDTOS)){ if (CollUtil.isEmpty(metadataDTOS)){
return null; return null;
} }
if (cypherSchemaDTO == null) { loadCypherSchemaIfAbsent();
synchronized (TripleToCypherExecutorImpl.class) {
if (cypherSchemaDTO == null) {
cypherSchemaDTO = this.loadGraphSchema();
}
}
}
List<RelationExtractionDTO> merged = new ArrayList<>(); List<RelationExtractionDTO> merged = new ArrayList<>();
for (DomainMetadataDTO metadataDTO : metadataDTOS) { for (DomainMetadataDTO metadataDTO : metadataDTOS) {
String relation = metadataDTO.getRelation(); String relation = metadataDTO.getRelation();
@ -231,6 +259,95 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
); );
} }
@Override
public CypherSchemaDTO queryRelationSchema(String query) {
if (StrUtil.isEmpty(query)){
return new CypherSchemaDTO(List.of(), List.of());
}
// 对问题进行分词
List<TextTerm> terms = textToSegmentService.segmentText(query);
if (CollUtil.isEmpty(terms)){
return new CypherSchemaDTO(List.of(), List.of());
}
log.info("queryRelationSchema: 分词结果:{}", terms);
log.info("queryRelationSchema: 开始进行文本标签向量匹配...");
List<NodeRelationVector> matchedText = new ArrayList<>();
for (TextTerm term : terms) {
if (StrUtil.isEmpty(term.getLabelValue())){
log.info("queryRelationSchema: 分词结果`{}`不是关键标签,跳过...", term.getWord());
continue;
}
Embedding embedding = aiCallService.embedding(term.getLabelValue());
term.setEmbedding(embedding.getOutput());
List<NodeRelationVector> textVectorDTOS = nodeRelationVectorService.matchSimilarByCosine(embedding.getOutput(), 0.9, List.of("N","R"),3); // 继续过滤
log.info("retrieval: 文本:{}匹配到的文本向量:{}", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" ")));
matchedText.addAll(textVectorDTOS);
}
if (CollUtil.isEmpty(matchedText)){
log.info("retrieval: 未找到匹配的文本向量");
return new CypherSchemaDTO(List.of(), List.of());
}
loadCypherSchemaIfAbsent();
List<RelationExtractionDTO> merged = new ArrayList<>();
for (NodeRelationVector textVectorDTO : matchedText) {
String content = textVectorDTO.getContent();
List<RelationExtractionDTO> relations = cypherSchemaDTO.getRelations(content);
for (RelationExtractionDTO relation : relations) {
boolean noneMatch = merged.stream().noneMatch(i ->
StrUtil.equals(i.getSourceType(), relation.getSourceType()) &&
StrUtil.equals(i.getRelation(), relation.getRelation()) &&
StrUtil.equals(i.getTargetType(), relation.getTargetType())
);
if (noneMatch){
merged.add(relation);
}
}
}
// 对查询到的关系进行重排序
List<Pair<Double, RelationExtractionDTO>> pairs = new ArrayList<>();
TimeInterval timeInterval = new TimeInterval();
String join = terms.stream().map(TextTerm::getLabelValue).filter(StrUtil::isNotEmpty).collect(Collectors.joining());
Embedding embedding = aiCallService.embedding(join);
for (RelationExtractionDTO relation : merged) {
String content = relation.getSourceType() + " " + relation.getRelation() + " " + relation.getTargetType();
Double score = nodeRelationVectorService.matchContentScore(embedding.getOutput(),content); // 暂时调用数据库查询进行数据匹配。目前总体耗时1秒内
if (null == score){
continue;
}
log.info("queryRelationSchema: 关系`{}`的匹配分数:{}", content, score);
pairs.add(Pair.of(score, relation));
}
log.info("queryRelationSchema: 关系排序耗时:{}ms", timeInterval.intervalMs());
merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(5).map(Pair::getValue).toList();
List<EntityExtractionDTO> entityExtractionDTOS = new ArrayList<>();
for (RelationExtractionDTO relationExtractionDTO : merged) {
EntityExtractionDTO sourceNode = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType());
EntityExtractionDTO targetNode = cypherSchemaDTO.getNode(relationExtractionDTO.getTargetType());
if (null != sourceNode){
boolean none = entityExtractionDTOS.stream().noneMatch(
entity -> StrUtil.equals(entity.getEntity(), sourceNode.getEntity())
);
if (none) {
entityExtractionDTOS.add(sourceNode);
}
}
if (null != targetNode){
boolean none = entityExtractionDTOS.stream().noneMatch(
entity -> StrUtil.equals(entity.getEntity(), targetNode.getEntity())
);
if (none) {
entityExtractionDTOS.add(targetNode);
}
}
}
return new CypherSchemaDTO(
entityExtractionDTOS,
merged
);
}
private List<Intention> classifyIntents(String query, List<Intention> intentions) { private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) { if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>(); return new ArrayList<>();
@ -261,4 +378,18 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
} }
return result; return result;
} }
/**
* schema
* @return
*/
private void loadCypherSchemaIfAbsent() {
if (cypherSchemaDTO == null) {
synchronized (TripleToCypherExecutorImpl.class) {
if (cypherSchemaDTO == null) {
cypherSchemaDTO = this.loadGraphSchema();
}
}
}
}
} }

@ -0,0 +1,57 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.pdfqaserver.mapper.NodeRelationVectorMapper">
<resultMap id="BaseResultMap" type="com.supervision.pdfqaserver.domain.NodeRelationVector">
<id property="id" column="id" jdbcType="VARCHAR"/>
<result property="content" column="content" jdbcType="VARCHAR"/>
<result property="embedding" column="embedding" jdbcType="OTHER" typeHandler="com.supervision.pdfqaserver.config.VectorTypeHandler"/>
<result property="contentType" column="content_type" jdbcType="VARCHAR"/>
<result property="createTime" column="create_time" jdbcType="TIMESTAMP"/>
<result property="updateTime" column="update_time" jdbcType="TIMESTAMP"/>
</resultMap>
<sql id="Base_Column_List">
id,content,embedding,
content_type,create_time,update_time
</sql>
<select id="findSimilarByCosine" resultType="com.supervision.pdfqaserver.domain.NodeRelationVector">
SELECT * FROM (
SELECT
id,
content,
embedding,
content_type,
1 - (embedding <![CDATA[<=>]]> #{embedding, typeHandler=com.supervision.pdfqaserver.config.VectorTypeHandler}) AS similarityScore
FROM node_relation_vector
) t
WHERE t.similarityScore > #{threshold}
<if test="contentType != null and contentType.size() > 0">
AND content_type IN
<foreach item="item" collection="contentType" open="(" separator="," close=")">
#{item}
</foreach>
</if>
ORDER BY t.similarityScore DESC
LIMIT #{limit}
</select>
<select id="matchContentScore" resultType="java.lang.Double">
SELECT
CASE
WHEN #{embedding} IS NULL THEN 0
WHEN #{content} IS NULL THEN 0
ELSE COALESCE(
1 - (embedding <![CDATA[<=>]]>
#{embedding, typeHandler=com.supervision.pdfqaserver.config.VectorTypeHandler}),
0
)
END AS similarityScore
FROM node_relation_vector
WHERE content = #{content}
LIMIT 1
</select>
</mapper>

@ -28,6 +28,9 @@
FROM text_vector FROM text_vector
) t ) t
WHERE t.similarityScore > #{threshold} WHERE t.similarityScore > #{threshold}
<if test="categoryId != null and categoryId != ''">
AND t.category_id = #{categoryId}
</if>
ORDER BY t.similarityScore DESC ORDER BY t.similarityScore DESC
LIMIT #{limit} LIMIT #{limit}
</select> </select>

@ -1,9 +1,11 @@
package com.supervision.pdfqaserver; package com.supervision.pdfqaserver;
import cn.hutool.core.date.TimeInterval;
import cn.hutool.core.util.NumberUtil; import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray; import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject; import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum; import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum;
import com.supervision.pdfqaserver.domain.PdfAnalysisOutput; import com.supervision.pdfqaserver.domain.PdfAnalysisOutput;
import com.supervision.pdfqaserver.domain.TextVector; import com.supervision.pdfqaserver.domain.TextVector;
@ -16,7 +18,6 @@ import org.neo4j.driver.Record;
import org.springframework.ai.embedding.Embedding; import org.springframework.ai.embedding.Embedding;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static org.neo4j.driver.Values.parameters; import static org.neo4j.driver.Values.parameters;
@ -304,19 +305,18 @@ class PdfQaServerApplicationTests {
@Test @Test
public void textVectorTest() { public void textVectorTest() {
String texts = """ String texts = """
/
"""; """;
String[] split = texts.split("\n"); String[] split = texts.split("\n");
List<String> list = Arrays.stream(split).toList(); List<String> list = Arrays.stream(split).toList();
for (String text : list) { for (String text : list) {
TextVector textVector = new TextVector(); TextVector textVector = new TextVector();
textVector.setContent(text.trim()); textVector.setContent(text.trim());
textVector.setCategoryId("查询办公地点"); textVector.setCategoryId("分词");
float[] output = aiCallService.embedding(textVector.getContent()).getOutput(); float[] output = aiCallService.embedding(textVector.getContent()).getOutput();
textVector.setEmbedding(output); textVector.setEmbedding(output);
textVectorService.save(textVector); textVectorService.save(textVector);
@ -325,9 +325,10 @@ class PdfQaServerApplicationTests {
} }
@Test @Test
public void textVectorTest2() { public void textVectorTest2() {
String queryText = "告诉我龙源电力的办公地点?"; // 龙源电力集团近三年营收情况是多少
String queryText = "龙源电力集团近三年营收情况是多少";
float[] output = aiCallService.embedding(queryText).getOutput(); float[] output = aiCallService.embedding(queryText).getOutput();
List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(output, 0.3f, 10); List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(output, 0.1f, 5);
similarByCosine = similarByCosine.stream().sorted(Comparator.comparingDouble(TextVectorDTO::getSimilarityScore).reversed()).collect(Collectors.toList()); similarByCosine = similarByCosine.stream().sorted(Comparator.comparingDouble(TextVectorDTO::getSimilarityScore).reversed()).collect(Collectors.toList());
log.info("<<<===========================>>>" ); log.info("<<<===========================>>>" );
for (TextVectorDTO vectorDTO : similarByCosine) { for (TextVectorDTO vectorDTO : similarByCosine) {
@ -337,4 +338,22 @@ class PdfQaServerApplicationTests {
System.out.printf("%s\t%s\t%s\t%s%n",queryText, categoryId , NumberUtil.decimalFormat("0.0000",similarityScore),content); System.out.printf("%s\t%s\t%s\t%s%n",queryText, categoryId , NumberUtil.decimalFormat("0.0000",similarityScore),content);
} }
} }
@Autowired
private Retriever retriever;
@Autowired
private TextToSegmentService textToSegmentService;
@Test
public void textVectorTest3() {
// tripleToCypherExecutor.refreshSchemaSegmentVector();
TimeInterval timer = new TimeInterval();
textToSegmentService.addDict("龙源电力集团","企业",1000);
List<Map<String, Object>> retrieval = retriever.retrieval("龙源电力集团近三年营收情况是多少");
System.out.println(JSONUtil.toJsonStr(retrieval));
log.info("<<<===========================>>> 耗时: {} 毫秒", timer.intervalMs());
}
} }

Loading…
Cancel
Save