提示词适配

master
daixiaoyi 1 month ago
parent 5a75f9b810
commit 5677140172

@ -191,15 +191,76 @@ public class PromptCache {
""";
private static final String TEXT_TO_CYPHER_PROMPT = """
CYPHER
{domainMetadata}
{userQuery}
Neo4j Cypher
---
****
- **relationType**
{relationTypeList}
- **sourceType**
{sourceTypeList}
- **targetType**
{targetTypeList}
---
****
1. `Cypher `
2. 使 `WHERE`
3. ****
4. Cypher
5. relationTypesourceType targetType
```
```
---
****
1. - ****
- ** Cypher **
\\{
"cypherQueries": [
"MATCH (c:Company)-[r:HAS_LEGAL_REP]->(t) RETURN c, r, t",
"MATCH (c:Company)-[r:HAS_PHONE]->(t) RETURN c, r, t",
.....
]
\\}
2. - ****
- ** Cypher **
\\{
"cypherQueries": [
"MATCH (c:Company)-[r:IssueDocument]->(t:FinancialBill) RETURN c, r, t",
.....
]
\\}
{query}
Cypher
""";
private static final String GENERATE_ANSWER_PROMPT = """
{tripleMetaData}
{userQuery}
{query}
{example_text}
1. 使
2.
3.
"您好!当前系统功能聚焦于审计报告相关内容分析,您的问题暂不在支持范围内。如需查询财务数据、票据详情或其他审计相关信息,请提供具体问题,我们将全力协助。"
""";
private static final String CHINESE_TO_ENGLISH_PROMPT = """

@ -26,20 +26,13 @@ public class OllamaChatModelAspect {
@Around("execution(* org.springframework.ai.chat.model.ChatModel.call(..))")
public Object aroundMethodExecution(ProceedingJoinPoint joinPoint) throws Throwable {
String signature = joinPoint.getSignature().toString();
if (StrUtil.equals(model,"qwen3:30b-a3b") && StrUtil.equals(signature, callStringMessage)) {
Object[] args = joinPoint.getArgs();
if (args.length > 0) {
String arg = (String) args[0];
args[0] = arg + "\n /no_think";
}
}
// 执行原方法
Object result = joinPoint.proceed();
if (StrUtil.equals(model,"qwen3:30b-a3b") && StrUtil.equals(signature, callStringMessage)) {
if (StrUtil.equals(model,"qwen3:30b-a3b") ) {
if(StrUtil.equals(signature, callStringMessage)){
result = ((String) result).replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
}
}
return result;
}
}

@ -39,9 +39,9 @@ public class Neo4jRepository {
while (result.hasNext()) {
org.neo4j.driver.Record record = result.next();
// 从 Record 中取出三部分
Node a = record.get("startNode").asNode();
Node a = record.get("c").asNode();
Relationship r = record.get("r").asRelationship();
Node b = record.get("endNode").asNode();
Node b = record.get("t").asNode();
// 转成我们的 DTO
NodeData start = mapNode(a);

@ -1,5 +1,7 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.ChineseEnglishWords;
@ -11,14 +13,13 @@ import com.supervision.pdfqaserver.service.DomainMetadataService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -30,14 +31,16 @@ import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private static final String PROMPT_PARAM_DOMAIN_METADATA = "domainMetadata";
private static final String PROMPT_PARAM_TRIPLE_METADATA = "tripleMetaData";
private static final String PROMPT_PARAM_USER_QUERY = "userQuery";
private static final String PROMPT_PARAM_SOURCE_TYPE_LIST = "sourceTypeList";
private static final String PROMPT_PARAM_RELATION_TYPE_LIST = "relationTypeList";
private static final String PROMPT_PARAM_TARGET_TYPE_LIST = "targetTypeList";
private static final String PROMPT_PARAM_EXAMPLE_TEXT = "example_text";
private static final String PROMPT_PARAM_QUERY = "query";
private static final String CYPHER_QUERIES = "cypherQueries";
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final ChineseEnglishWordsService chineseEnglishWordsService;
@ -46,33 +49,54 @@ public class ChatServiceImpl implements ChatService {
//拼装领域元数据
Map<String, String> chineseEnglishWordsMap = chineseEnglishWordsService.list().stream()
.collect(Collectors.toMap(ChineseEnglishWords::getChineseWord, ChineseEnglishWords::getEnglishWord));
List<Map<String, String>> domainMappings = domainMetadataService.list().stream().map(domainMetadata -> {
Map<String, String> mapping = new HashMap<>();
mapping.put("source", domainMetadata.getSourceType());
mapping.put("sourceType", chineseEnglishWordsMap.get(domainMetadata.getSourceType()));
mapping.put("relation", domainMetadata.getRelation());
mapping.put("relationType", chineseEnglishWordsMap.get(domainMetadata.getRelation()));
mapping.put("target", domainMetadata.getTargetType());
mapping.put("targetType", chineseEnglishWordsMap.get(domainMetadata.getTargetType()));
return mapping;
}).toList();
log.info("domainMappings: {}", domainMappings);
//生成CYPHER
//分别得到sourceTyperelationtargetType的group by后的集合
List<String> sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList();
List<String> relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList();
List<String> targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList();
//将三个集合分别结合chineseEnglishWordsMap的key转化为value集合
List<String> sourceTypeListEnList = sourceTypeList.stream().map(chineseEnglishWordsMap::get).toList();
List<String> relationListEnList = relationList.stream().map(chineseEnglishWordsMap::get).toList();
List<String> targetTypeListEnList = targetTypeList.stream().map(chineseEnglishWordsMap::get).toList();
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = String.join(",", sourceTypeListEnList);
String relationListEn = String.join(",", relationListEnList);
String targetTypeListEn = String.join(",", targetTypeListEnList);
log.info("sourceTypeListEn: {}, relationListEn: {}, targetTypeListEn: {}", sourceTypeListEn, relationListEn, targetTypeListEn);
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_DOMAIN_METADATA, domainMappings, PROMPT_PARAM_USER_QUERY, userQuery));
ChatResponse textToCypherResponse = ollamaChatModel.call(new Prompt(textToCypherMessage));
String queryCypher = "MATCH (startNode:公司)-[r]->(endNode) RETURN startNode,r,endNode";
log.info(textToCypherResponse.getResult().getOutput().getText());
// String queryCypher = textToCypherResponse.getResult().getOutput().getText();
List<RelationObject> relationObjects = neo4jRepository.execute(queryCypher, null);
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
String cypherJsonStr = ollamaChatModel.call(textToCypherMessage.getText());
log.info(cypherJsonStr);
List<String> cypherQueries;
try {
JSONObject jsonObj = JSONUtil.parseObj(cypherJsonStr);
cypherQueries = jsonObj.getJSONArray(CYPHER_QUERIES)
.toList(String.class);
} catch (Exception e) {
log.error("解析CYPHER JSON字符串失败: {}", e.getMessage());
return Flux.just("查无结果");
}
log.info("转换后的Cypher语句{}", cypherQueries.toString());
//执行CYPHER查询并汇总结果
List<RelationObject> relationObjects = new ArrayList<>();
if (!cypherQueries.isEmpty()) {
for (String cypher : cypherQueries) {
relationObjects.addAll(neo4jRepository.execute(cypher, null));
}
}
if (relationObjects.isEmpty()) {
return Flux.just("没有找到相关数据");
return Flux.just("查无结果");
}
log.info("relationObjects: {}", relationObjects);
log.info("三元组数据: {}", relationObjects);
//生成回答
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_TRIPLE_METADATA, relationObjects, PROMPT_PARAM_USER_QUERY, userQuery));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText());
}
}

Loading…
Cancel
Save