问答功能优化-初始化表

v_0.0.2
xueqingkun 1 day ago
parent fe1a6f1b1b
commit e181c00c40

@ -850,7 +850,7 @@ public class PromptCache {
1.
- 线schema
- 线schema********
-
-
-
@ -858,6 +858,7 @@ public class PromptCache {
2.
-
-
- ********
3. MATCH
@ -873,58 +874,49 @@ public class PromptCache {
5. Cypher
- Cypher```cypher ```
- ****MATCH ****
- neo4j_schemacypher便
- ********
- schemacypher
- cypher:['cypher1','...']
- cyphercypher/no_think
""";
private static final String TEXT_TO_CYPHER_4_PROMPT = """
Cyphercyphercypher便`neo4j_schema`cypher
Cyphercyphercypher便`neo4j_schema`使cypher
```text
{query}
```
neo4j_schemaJSON
```shema
{shema}
{schema}
```
#
${env}
{env}
# cypher
# 使cypher
```json
{cypher}
```
1.
- 线schema
1.cypher
- ****
2.
- 线schema
-
-
-
-
- cypher
2.
-
- ********
3. MATCH
- 使,****
- ****`-[r:REL_TYPE]->`
- 使
- WHERE
- cypher
2. MATCH
- MATCHWHERE
4. RETURN
3. RETURN
- ********
- ****
5. Cypher
4. Cypher
- Cypher```cypher ```
- ****MATCH ****
- neo4j_schemacypher便
- cypher:['cypher1','...']
- cyphercypher/no_think
""";

@ -3,9 +3,9 @@ package com.supervision.pdfqaserver.dto;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONArray;
import cn.hutool.json.JSONObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* CypherSchemaDTO
@ -118,7 +118,24 @@ public class CypherSchemaDTO {
json = new JSONObject();
relJson.set(rela, json);
}
json.set("_endpoints", new JSONArray(new String[]{sourceType, targetType}));
JSONArray endpoints = json.getJSONArray("_endpoints");
if (null == endpoints){
endpoints = new JSONArray();
endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType));
json.set("_endpoints", endpoints);
}else {
boolean absent = false;
for (Object endpoint : endpoints) {
Map<String,Object> nodes = (Map<String, Object>) endpoint;
if (sourceType.equals(nodes.get("sourceType"))|| sourceType.equals(nodes.get("targetType"))){
absent = true;
break;
}
}
if (absent){
endpoints.add(Map.of("sourceType", sourceType, "targetType", targetType));
}
}
for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
if ("truncationId".equals(attribute.getAttribute())){
continue;
@ -127,7 +144,6 @@ public class CypherSchemaDTO {
entry -> StrUtil.equals(entry.getKey(), attribute.getAttribute())
);
if (none) {
json.set(attribute.getAttribute(), attribute.getDataType());
}
}

@ -1,7 +1,9 @@
package com.supervision.pdfqaserver.dto;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import lombok.Data;
import java.util.List;
@Data
public class TextTerm {
@ -18,7 +20,10 @@ public class TextTerm {
private float[] embedding;
public String getLabelValue() {
public String getLabelValue(List<String> keyWords) {
if (CollUtil.isNotEmpty(keyWords) && keyWords.contains(word)) {
return word;
}
if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){
return word;
}

@ -1,5 +1,6 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
@ -18,6 +19,7 @@ import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -58,14 +60,14 @@ public class ChatServiceImpl implements ChatService {
// 执行cypher语句
List<Map<String, Object>> graphResult = tripleToCypherExecutor.executeCypher(cypher);
*/
List<Map<String, Object>> graphResult = compareRetriever.retrieval(userQuery);
if (CollUtil.isEmpty(graphResult)){
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}*/
List<Map<String, Object>> graphResult = compareRetriever.retrieval(userQuery);
}
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(graphResult), PROMPT_PARAM_QUERY, userQuery));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, JSONUtil.toJsonStr(clearGraphElements(graphResult)), PROMPT_PARAM_QUERY, userQuery));
log.info("生成回答的提示词:{}", generateAnswerMessage);
return aiCallService.stream(new Prompt(generateAnswerMessage))
.map(response -> response.getResult().getOutput().getText())
@ -167,4 +169,34 @@ public class ChatServiceImpl implements ChatService {
}
return distinct;
}
/**
*
* @param graphElements
* @return
*/
public List<Map<String, Object>> clearGraphElements(List<Map<String, Object>> graphElements) {
if (CollUtil.isEmpty(graphElements)) {
return graphElements;
}
List<Map<String, Object>> result = new ArrayList<>(graphElements.size());
for (Map<String, Object> originalMap : graphElements) {
Map<String, Object> newMap = new HashMap<>();
for (Map.Entry<String, Object> entry : originalMap.entrySet()) {
String key = entry.getKey();
Object value = entry.getValue();
if (value instanceof NodeDTO nodeDTO){
NodeDTO newNodeDTO = BeanUtil.copyProperties(nodeDTO, NodeDTO.class);
newNodeDTO.clearGraphElement(); // 清理图谱元素
newMap.put(key, newNodeDTO);
} else if (value instanceof RelationshipValueDTO relationshipValueDTO) {
RelationshipValueDTO newRelationshipValueDTO = BeanUtil.copyProperties(relationshipValueDTO, RelationshipValueDTO.class);
newRelationshipValueDTO.clearGraphElement(); // 清理图谱元素
newMap.put(key, newRelationshipValueDTO);
}
}
result.add(newMap);
}
return result;
}
}

@ -62,7 +62,7 @@ public class DataCompareRetriever implements Retriever {
boolean allEmpty = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return clearGraphElements(result);
return result;
}
}
if (CollUtil.isEmpty(result)){
@ -70,8 +70,8 @@ public class DataCompareRetriever implements Retriever {
prompt = PromptCache.promptMap.get(TEXT_TO_CYPHER_4);
format = StrUtil.format(prompt,
Map.of("query", query, "schema", schemaDTO.format(),
"env", "- 当前时间是:" + DateUtil.now()),"cypher",js.toString());
log.info("retrieval: 生成cypher语句{}", format);
"env", "- 当前时间是:" + DateUtil.now(),"cypher",js.toString()));
log.info("retrieval: 生成cypher语句:{}", format);
call = aiCallService.call(format);
log.info("retrieval: AI调用返回结果{}", call);
@ -81,33 +81,11 @@ public class DataCompareRetriever implements Retriever {
boolean allEmpty2 = cypherData.values().stream().noneMatch(CollUtil::isNotEmpty);
if (!allEmpty2){
cypherData.values().stream().filter(CollUtil::isNotEmpty).forEach(result::addAll);
return clearGraphElements(result);
return result;
}
}
}
return clearGraphElements(result);
}
/**
*
* @param graphElements
* @return
*/
private List<Map<String, Object>> clearGraphElements(List<Map<String, Object>> graphElements) {
if (CollUtil.isEmpty(graphElements)){
return graphElements;
}
for (Map<String, Object> element : graphElements) {
for (Object value : element.values()) {
if (value instanceof RelationshipValueDTO relationshipValueDTO) {
relationshipValueDTO.clearGraphElement();
}
if (value instanceof com.supervision.pdfqaserver.dto.neo4j.NodeDTO nodeDTO) {
nodeDTO.clearGraphElement();
}
}
}
return graphElements;
return result;
}
}

@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
@Slf4j
@ -272,15 +273,16 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
log.info("queryRelationSchema: 分词结果:{}", terms);
log.info("queryRelationSchema: 开始进行文本标签向量匹配...");
List<NodeRelationVector> matchedText = new ArrayList<>();
List<String> keywords = mergeNodeAndRelationLabel();
for (TextTerm term : terms) {
if (StrUtil.isEmpty(term.getLabelValue())){
if (StrUtil.isEmpty(term.getLabelValue(keywords))){
log.info("queryRelationSchema: 分词结果`{}`不是关键标签,跳过...", term.getWord());
continue;
}
Embedding embedding = aiCallService.embedding(term.getLabelValue());
Embedding embedding = aiCallService.embedding(term.getLabelValue(keywords));
term.setEmbedding(embedding.getOutput());
List<NodeRelationVector> textVectorDTOS = nodeRelationVectorService.matchSimilarByCosine(embedding.getOutput(), 0.9, List.of("N","R"),3); // 继续过滤
log.info("retrieval: 文本:{}匹配到的文本向量:{}", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" ")));
log.info("retrieval: 文本:`{}`匹配到的文本向量:`{}`", term.getWord() ,textVectorDTOS.stream().map(NodeRelationVector::getContent).collect(Collectors.joining(" ")));
matchedText.addAll(textVectorDTOS);
}
if (CollUtil.isEmpty(matchedText)){
@ -306,7 +308,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
// 对查询到的关系进行重排序
List<Pair<Double, RelationExtractionDTO>> pairs = new ArrayList<>();
TimeInterval timeInterval = new TimeInterval();
String join = terms.stream().map(TextTerm::getLabelValue).filter(StrUtil::isNotEmpty).collect(Collectors.joining());
String join = terms.stream().map(t->t.getLabelValue(keywords)).filter(StrUtil::isNotEmpty).collect(Collectors.joining());
Embedding embedding = aiCallService.embedding(join);
for (RelationExtractionDTO relation : merged) {
String content = relation.getSourceType() + " " + relation.getRelation() + " " + relation.getTargetType();
@ -319,7 +321,7 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
}
log.info("queryRelationSchema: 关系排序耗时:{}ms", timeInterval.intervalMs());
merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(5).map(Pair::getValue).toList();
merged = pairs.stream().sorted((p1, p2) -> Double.compare(p2.getKey(), p1.getKey())).limit(4).map(Pair::getValue).toList();
List<EntityExtractionDTO> entityExtractionDTOS = new ArrayList<>();
for (RelationExtractionDTO relationExtractionDTO : merged) {
EntityExtractionDTO sourceNode = cypherSchemaDTO.getNode(relationExtractionDTO.getSourceType());
@ -392,4 +394,14 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
}
}
}
private List<String> mergeNodeAndRelationLabel() {
loadCypherSchemaIfAbsent();
if (CollUtil.isEmpty(cypherSchemaDTO.getRelations())) {
log.warn("图谱schema数据为空无法合并节点和关系标签");
return new ArrayList<>();
}
return cypherSchemaDTO.getRelations().stream()
.flatMap(r -> Stream.of(r.getSourceType(), r.getRelation(), r.getTargetType())).distinct().collect(Collectors.toList());
}
}

Loading…
Cancel
Save