Compare commits

...

2 Commits

Author SHA1 Message Date
xueqingkun ee4b4adb37 代码功能bug修复
优化问答功能
2 weeks ago
xueqingkun 2003b2ab5f domain_metadata 表删除domainType字段,该字段与domainCategoryId字段同义 2 weeks ago

@ -8,17 +8,44 @@ import java.util.Map;
*/
public class PromptCache {
/**
*
*/
public static final String DOERE_TEXT = "DOERE_TEXT";
/**
*
*/
public static final String DOERE_TABLE = "DOERE_TABLE";
/**
* Cypher
*/
public static final String TEXT_TO_CYPHER = "TEXT_TO_CYPHER";
/**
*
*/
public static final String GENERATE_ANSWER = "GENERATE_ANSWER";
/**
*
*/
public static final String CHINESE_TO_ENGLISH = "CHINESE_TO_ENGLISH";
/**
* Cypher
*/
public static final String ERE_TO_INSERT_CYPHER = "ERE_TO_INSERT_CYPHER";
/**
*
*/
public static final String CLASSIFY_TABLE = "CLASSIFY_TABLE";
/**
*
*/
public static final String EXTRACT_TABLE_TITLE = "EXTRACT_TABLE_TITLE";
/**
@ -53,6 +80,12 @@ public class PromptCache {
*/
public static final String EXTRACT_ERE_BASE_INTENT = "EXTRACT_ERE_BASE_INTENT";
/**
*
*/
public static final String CLASSIFY_QUERY_INTENT = "CLASSIFY_QUERY_INTENT";
public static final Map<String, String> promptMap = new HashMap<>();
static {
@ -73,6 +106,7 @@ public class PromptCache {
promptMap.put(CLASSIFY_INTENT_TRAIN, CLASSIFY_INTENT_TRAIN_PROMPT);
promptMap.put(EXTRACT_INTENT_METADATA, EXTRACT_INTENT_METADATA_PROMPT);
promptMap.put(EXTRACT_ERE_BASE_INTENT, EXTRACT_ERE_BASE_INTENT_PROMPT);
promptMap.put(CLASSIFY_QUERY_INTENT, CLASSIFY_QUERY_INTENT_PROMPT);
}
@ -759,4 +793,25 @@ public class PromptCache {
-
- JSON使```json ```Markdown./no_think
""";
private static final String CLASSIFY_QUERY_INTENT_PROMPT = """
JSON使
#
{intents}
#
1.
2. 使
3.
4. JSON使```json ```Markdown
#
"我昨天买的鞋子怎么还没发货?"
["订单查询"]
#
{query}
""";
}

@ -18,12 +18,6 @@ public class DomainMetadata implements Serializable {
@TableId
private String id;
/**
*
*/
@Deprecated
private String domainType;
/**
*
*/

@ -0,0 +1,94 @@
package com.supervision.pdfqaserver.dto;
import cn.hutool.core.collection.CollUtil;
import com.supervision.pdfqaserver.dto.neo4j.NodeData;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.dto.neo4j.RelationshipData;
import lombok.Data;
@Data
public class AnswerDetailDTO {
/**
* id
*/
private String truncateId;
/**
*
*/
private String sourceType;
/**
*
*/
private String sourceName;
/**
*
*/
private String targetType;
/**
*
*/
private String targetName;
/**
*
*/
private String relation;
/**
*
*/
private String truncateContent;
/**
* PDF ID
*/
private String pdfId;
/**
* PDF
*/
private String pdfName;
public AnswerDetailDTO() {
}
public AnswerDetailDTO(RelationObject relationObject) {
NodeData endNode = relationObject.endNode();
NodeData startNode = relationObject.startNode();
RelationshipData relationship = relationObject.relationship();
if (null == startNode || null == endNode || null == relationship){
return;
}
if (CollUtil.isNotEmpty(startNode.properties())){
Object truncationId = startNode.properties().get("truncationId");
if (null != truncationId){
this.truncateId = truncationId.toString();
}
}
if (CollUtil.isNotEmpty(endNode.labels())){
this.sourceType = String.join(",", endNode.labels());
}
if (CollUtil.isNotEmpty(startNode.properties())){
if (null != startNode.properties().get("name")){
this.sourceName = startNode.properties().get("name").toString();
}
}
if (CollUtil.isNotEmpty(endNode.labels())){
this.targetType = String.join(",", startNode.labels());
}
if (CollUtil.isNotEmpty(endNode.properties())){
if (null != startNode.properties().get("name")){
this.targetName = endNode.properties().get("name").toString();
}
}
if (CollUtil.isNotEmpty(relationship.properties())) {
this.relation = relationship.type();
}
}
}

@ -146,6 +146,7 @@ public class EREDTO {
entityExtractionDTO.setAttributes(truncationErAttributeDTOS);
entities.add(entityExtractionDTO);
}
eredto.setEntities(entities);
return eredto;

@ -1,9 +1,7 @@
package com.supervision.pdfqaserver.service;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.domain.Intention;
import com.baomidou.mybatisplus.extension.service.IService;
import java.util.List;
/**
@ -33,4 +31,12 @@ public interface IntentionService extends IService<Intention> {
Intention queryByDigestAndDomainCategoryId(String digest, String domainCategoryId);
List<Intention> queryByDomainCategoryId(String domainCategoryId);
/**
*
* @return
*/
List<Intention> listAllPassed();
}

@ -40,6 +40,13 @@ public interface KnowledgeGraphService {
void generateGraph(List<EREDTO> eredtoList);
/**
*
* @param eredtoList
*/
void generateGraphSimple(List<EREDTO> eredtoList);
List<EREDTO> truncateERE(List<TruncateDTO> truncateDTOS);

@ -1,14 +1,17 @@
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.dao.Neo4jRepository;
import com.supervision.pdfqaserver.domain.DomainMetadata;
import com.supervision.pdfqaserver.domain.DocumentTruncation;
import com.supervision.pdfqaserver.domain.Intention;
import com.supervision.pdfqaserver.dto.AnswerDetailDTO;
import com.supervision.pdfqaserver.dto.DomainMetadataDTO;
import com.supervision.pdfqaserver.dto.neo4j.RelationObject;
import com.supervision.pdfqaserver.service.ChatService;
import com.supervision.pdfqaserver.service.ChineseEnglishWordsService;
import com.supervision.pdfqaserver.service.DomainMetadataService;
import com.supervision.pdfqaserver.service.*;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@ -18,13 +21,12 @@ import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.supervision.pdfqaserver.cache.PromptCache.GENERATE_ANSWER;
import static com.supervision.pdfqaserver.cache.PromptCache.TEXT_TO_CYPHER;
import static com.supervision.pdfqaserver.cache.PromptCache.*;
@Slf4j
@Service
@ -40,26 +42,44 @@ public class ChatServiceImpl implements ChatService {
private final Neo4jRepository neo4jRepository;
private final OllamaChatModel ollamaChatModel;
private final DomainMetadataService domainMetadataService;
private final ChineseEnglishWordsService chineseEnglishWordsService;
private final AiCallService aiCallService;
private final DocumentTruncationService documentTruncationService;
private final IntentionService intentionService;
@Override
public Flux<String> knowledgeQA(String userQuery) {
//分别得到sourceTyperelationtargetType的group by后的集合
List<String> sourceTypeList = domainMetadataService.list().stream().map(DomainMetadata::getSourceType).distinct().toList();
List<String> relationList = domainMetadataService.list().stream().map(DomainMetadata::getRelation).distinct().toList();
List<String> targetTypeList = domainMetadataService.list().stream().map(DomainMetadata::getTargetType).distinct().toList();
List<Intention> intentions = intentionService.listAllPassed();
List<Intention> relations = classifyIntents(userQuery, intentions);
if (CollUtil.isEmpty(relations)){
log.info("没有匹配到意图,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
List<DomainMetadataDTO> domainMetadataDTOS = domainMetadataService.listByIntentionIds(relations.stream().map(Intention::getId).toList());
if (CollUtil.isEmpty(domainMetadataDTOS)){
log.info("没有匹配到领域元数据,返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
//将三个集合分别转换为英文逗号分隔的字符串
String sourceTypeListEn = String.join(",", sourceTypeList);
String relationListEn = String.join(",", relationList);
String targetTypeListEn = String.join(",", targetTypeList);
String sourceTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getSourceType).distinct().collect(Collectors.joining(","));
String relationListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getRelation).distinct().collect(Collectors.joining(","));
String targetTypeListEn = domainMetadataDTOS.stream().map(DomainMetadataDTO::getTargetType).distinct().collect(Collectors.joining(","));
//LLM生成CYPHER
SystemPromptTemplate textToCypherTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(TEXT_TO_CYPHER));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn, PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST, targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
Message textToCypherMessage = textToCypherTemplate.createMessage(Map.of(PROMPT_PARAM_SOURCE_TYPE_LIST, sourceTypeListEn,
PROMPT_PARAM_RELATION_TYPE_LIST, relationListEn, PROMPT_PARAM_TARGET_TYPE_LIST,
targetTypeListEn, PROMPT_PARAM_QUERY, userQuery));
log.info("生成CYPHER语句的消息{}", textToCypherMessage);
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(3.0).build())).getResult().getOutput().getText();
String cypherJsonStr = ollamaChatModel.call(new Prompt(textToCypherMessage, OllamaOptions.builder().temperature(0.3).build())).getResult().getOutput().getText();
log.info(cypherJsonStr);
log.info(cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim());
cypherJsonStr = cypherJsonStr.replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
@ -82,6 +102,7 @@ public class ChatServiceImpl implements ChatService {
}
}
if (relationObjects.isEmpty()) {
log.info("cypher没有查询到结果返回查无结果");
return Flux.just("查无结果").concatWith(Flux.just("[END]"));
}
log.info("三元组数据: {}", relationObjects);
@ -90,6 +111,65 @@ public class ChatServiceImpl implements ChatService {
SystemPromptTemplate generateAnswerTemplate = new SystemPromptTemplate(PromptCache.promptMap.get(GENERATE_ANSWER));
Message generateAnswerMessage = generateAnswerTemplate.createMessage(Map.of(PROMPT_PARAM_EXAMPLE_TEXT, relationObjects, PROMPT_PARAM_QUERY, userQuery));
log.info("生成回答的提示词:{}", generateAnswerMessage);
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText()).concatWith(Flux.just("[END]"));
return ollamaChatModel.stream(new Prompt(generateAnswerMessage)).map(response -> response.getResult().getOutput().getText())
.concatWith(Flux.just(new JSONObject().set("answerDetails", convertToAnswerDetails(relationObjects)).toString()))
.concatWith(Flux.just("[END]"));
}
/**
*
*
* @param query
* @param intentions
* @return
*/
private List<Intention> classifyIntents(String query, List<Intention> intentions) {
if (StrUtil.isEmpty(query) || CollUtil.isEmpty(intentions)) {
return new ArrayList<>();
}
String prompt = promptMap.get(CLASSIFY_QUERY_INTENT);
List<Intention> result = new ArrayList<>();
log.info("开始分类意图query: {}, intentions size: {}", query, intentions.size());
List<List<Intention>> intentionSplit = CollUtil.split(intentions, 150);
for (List<Intention> intentionList : intentionSplit) {
log.info("分类意图query: {}, intentions size: {}", query, intentionList.size());
String intents = intentionList.stream().map(i -> " - " + i.getDigest() + "\n").collect(Collectors.joining());
Map<String, Object> params = Map.of("query", query,
"intents", intents);
String format = StrUtil.format(prompt, params);
String call = aiCallService.call(format);
if (StrUtil.isEmpty(call)) {
return new ArrayList<>();
}
List<String> digests = JSONUtil.parseArray(call).stream().map(Object::toString).toList();
if (CollUtil.isEmpty(digests)) {
continue;
}
List<Intention> collect = intentionList.stream().filter(i -> digests.contains(i.getDigest())).collect(Collectors.toList());
if (CollUtil.isNotEmpty(collect)) {
result.addAll(collect);
}
}
return result;
}
private List<AnswerDetailDTO> convertToAnswerDetails(List<RelationObject> relationObjects) {
if (CollUtil.isEmpty(relationObjects)) {
return new ArrayList<>();
}
List<AnswerDetailDTO> answerDetailDTOList = relationObjects.stream().map(AnswerDetailDTO::new).collect(Collectors.toList());
if (CollUtil.isNotEmpty(answerDetailDTOList)){
List<String> truncateIds = answerDetailDTOList.stream().map(AnswerDetailDTO::getTruncateId).distinct().toList();
if (CollUtil.isEmpty(truncateIds)){
return answerDetailDTOList;
}
List<DocumentTruncation> documentTruncations = documentTruncationService.listByIds(truncateIds);
Map<String, String> contentMap = documentTruncations.stream().collect(Collectors.toMap(DocumentTruncation::getId, DocumentTruncation::getContent));
for (AnswerDetailDTO answerDetailDTO : answerDetailDTOList) {
answerDetailDTO.setTruncateContent(contentMap.get(answerDetailDTO.getTruncateId()));
}
}
return answerDetailDTOList;
}
}

@ -65,6 +65,11 @@ public class IntentionServiceImpl extends ServiceImpl<IntentionMapper, Intention
public List<Intention> queryByDomainCategoryId(String domainCategoryId) {
return super.lambdaQuery().eq(Intention::getDomainCategoryId, domainCategoryId).list();
}
@Override
public List<Intention> listAllPassed() {
return super.lambdaQuery().eq(Intention::getGenerationType, "0").list();
}
}

@ -321,7 +321,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
log.info("开始生成知识图谱...");
timer.start("generateGraph");
generateGraph(eredtos);
generateGraphSimple(eredtos);
log.info("生成知识图谱完成,耗时:{}秒", timer.intervalSecond("generateGraph"));
}
@ -347,7 +347,6 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
}
for (RelationExtractionDTO relation : relations) {
DomainMetadata domainMetadata = relation.toDomainMetadata();
domainMetadata.setDomainType("1");
domainMetadata.setGenerationType(DomainMetaGenerationEnum.SYSTEM_AUTO_GENERATION.getCode());
domainMetadataService.saveIfNotExists(domainMetadata);
}
@ -400,6 +399,24 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
}
}
@Override
public void generateGraphSimple(List<EREDTO> eredtoList) {
log.info("开始合并实体关系抽取结果...");
List<EREDTO> mergedList = tripleConversionPipeline.mergeEreResults(eredtoList);
log.info("合并实体关系抽取结果完成,合并后个数:{}", mergedList.size());
for (EREDTO eredto : mergedList) {
if (CollUtil.isEmpty(eredto.getEntities()) && CollUtil.isEmpty(eredto.getRelations())){
continue;
}
try {
tripleToCypherExecutor.saveERE(eredto);
} catch (Exception e) {
log.info("生成cypher语句失败,切分文档id:{}", JSONUtil.toJsonStr(eredto), e);
}
}
}
private static List<ChineseEnglishWords> getChineseEnglishWords(EREDTO eredto) {
List<ChineseEnglishWords> allWords;
allWords = eredto.getEntities().stream().flatMap(entity -> {

@ -87,7 +87,6 @@ public class TripleToCypherExecutorImpl implements TripleToCypherExecutor {
Map<String, Object> attributes = relation.getAttributes().stream().collect(Collectors.toMap(
TruncationERAttributeDTO::getAttributeEn, TruncationERAttributeDTO::getValue
));
attributes.put("sourceType", relation.getSourceType());
attributes.put("truncationId", relation.getTruncationId());
for (Long sourceNodeId : sourceNodeIds) {
for (Long targetNodeId : targetNodeIds) {

@ -46,6 +46,7 @@ public class TruncationEntityExtractionServiceImpl extends ServiceImpl<Truncatio
for (TruncationERAttributeDTO attribute : attributes) {
attribute.setTerId(tee.getId());
TruncationErAttribute era = attribute.toTruncationErAttribute();
era.setAssociationType("0"); // 0: 实体
truncationErAttributeService.save(era);
}
}

@ -42,6 +42,7 @@ public class TruncationRelationExtractionServiceImpl extends ServiceImpl<Truncat
}
for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
TruncationErAttribute era = attribute.toTruncationErAttribute();
era.setAssociationType("1"); // 1: 关系
era.setTerId(re.getId());
truncationErAttributeService.save(era);
}

@ -6,7 +6,6 @@
<resultMap id="BaseResultMap" type="com.supervision.pdfqaserver.domain.DomainMetadata">
<id property="id" column="id" jdbcType="VARCHAR"/>
<result property="domainType" column="domain_type" jdbcType="VARCHAR"/>
<result property="sourceType" column="source_type" jdbcType="VARCHAR"/>
<result property="relation" column="relation" jdbcType="VARCHAR"/>
<result property="targetType" column="target_type" jdbcType="VARCHAR"/>
@ -17,7 +16,7 @@
</resultMap>
<sql id="Base_Column_List">
id,domain_type,source_type,
id,source_type,
relation,target_type,generation_type,domain_category_id,
create_time,update_time
</sql>

Loading…
Cancel
Save