问答功能优化-初始化表
parent
2ea04d7325
commit
fe1a6f1b1b
@ -0,0 +1,54 @@
|
||||
package com.supervision.pdfqaserver.domain;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.*;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
import com.supervision.pdfqaserver.config.VectorTypeHandler;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 节点关系向量表
|
||||
* @TableName node_relation_vector
|
||||
*/
|
||||
@TableName(value ="node_relation_vector")
|
||||
@Data
|
||||
public class NodeRelationVector implements Serializable {
|
||||
/**
|
||||
* 主键
|
||||
*/
|
||||
@TableId
|
||||
private String id;
|
||||
|
||||
/**
|
||||
* 文本内容
|
||||
*/
|
||||
private String content;
|
||||
|
||||
/**
|
||||
* 向量值
|
||||
*/
|
||||
@TableField(typeHandler = VectorTypeHandler.class)
|
||||
private float[] embedding;
|
||||
|
||||
/**
|
||||
* 内容类型 N:节点 R:关系 ER:三元组
|
||||
*/
|
||||
private String contentType;
|
||||
|
||||
/**
|
||||
* 创建时间
|
||||
*/
|
||||
@TableField(fill = FieldFill.INSERT)
|
||||
private LocalDateTime createTime;
|
||||
|
||||
/**
|
||||
* 更新时间
|
||||
*/
|
||||
@TableField(fill = FieldFill.INSERT_UPDATE)
|
||||
private LocalDateTime updateTime;
|
||||
|
||||
@TableField(exist = false)
|
||||
private static final long serialVersionUID = 1L;
|
||||
}
|
@ -0,0 +1,43 @@
|
||||
package com.supervision.pdfqaserver.dto;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class TextTerm {
|
||||
|
||||
/**
|
||||
* 词
|
||||
*/
|
||||
public String word;
|
||||
|
||||
/**
|
||||
* 标签
|
||||
*/
|
||||
public String label;
|
||||
|
||||
private float[] embedding;
|
||||
|
||||
public String getLabelValue() {
|
||||
if (StrUtil.equalsAny(label,"n","nl","nr","ns","nsf","nz")){
|
||||
return word;
|
||||
}
|
||||
if (StrUtil.equals(label,"nt")){
|
||||
return "机构";
|
||||
}
|
||||
if (StrUtil.equalsAny(label,"ntc","公司")){
|
||||
return "公司";
|
||||
}
|
||||
if (StrUtil.equals(label,"ntcf")){
|
||||
return "工厂";
|
||||
}
|
||||
if (StrUtil.equals(label,"nto")){
|
||||
return "政府机构";
|
||||
}
|
||||
if (StrUtil.equals(label,"企业")){
|
||||
return "企业";
|
||||
}
|
||||
return null;
|
||||
|
||||
}
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
package com.supervision.pdfqaserver.mapper;
|
||||
|
||||
import com.supervision.pdfqaserver.domain.NodeRelationVector;
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author Administrator
|
||||
* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Mapper
|
||||
* @createDate 2025-06-18 13:38:02
|
||||
* @Entity com.supervision.pdfqaserver.domain.NodeRelationVector
|
||||
*/
|
||||
public interface NodeRelationVectorMapper extends BaseMapper<NodeRelationVector> {
|
||||
|
||||
List<NodeRelationVector> findSimilarByCosine(float[] embedding, double threshold, List<String> contentType, int limit);
|
||||
|
||||
Double matchContentScore(float[] embedding, String content);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,27 @@
|
||||
package com.supervision.pdfqaserver.service;
|
||||
|
||||
import com.supervision.pdfqaserver.domain.NodeRelationVector;
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.supervision.pdfqaserver.dto.CypherSchemaDTO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author Administrator
|
||||
* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Service
|
||||
* @createDate 2025-06-18 13:38:02
|
||||
*/
|
||||
public interface NodeRelationVectorService extends IService<NodeRelationVector> {
|
||||
|
||||
void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO);
|
||||
|
||||
List<NodeRelationVector> matchSimilarByCosine(float[] embedding, double threshold , List<String> contentType, int limit);
|
||||
|
||||
/**
|
||||
* 计算内容匹配分数
|
||||
* @param embedding 向量
|
||||
* @param content 内容
|
||||
* @return
|
||||
*/
|
||||
Double matchContentScore(float[] embedding, String content);
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package com.supervision.pdfqaserver.service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 检索器接口
|
||||
*/
|
||||
public interface Retriever {
|
||||
|
||||
/**
|
||||
* 检索数据
|
||||
* @param query 问题
|
||||
* @return 结果数据
|
||||
*/
|
||||
List<Map<String, Object>> retrieval(String query);
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
package com.supervision.pdfqaserver.service;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.supervision.pdfqaserver.domain.QuestionHandlerMapping;
|
||||
import com.supervision.pdfqaserver.dto.TextVectorDTO;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 检索器调度器
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class RetrieverDispatcher {
|
||||
|
||||
private final ApplicationContext applicationContext;
|
||||
|
||||
private final AiCallService aiCallService;
|
||||
|
||||
private final TextVectorService textVectorService;
|
||||
|
||||
private final QuestionHandlerMappingService questionHandlerMappingService;
|
||||
|
||||
@Value("${retriever.threshold:0.8}")
|
||||
private double threshold; // 相似度阈值
|
||||
|
||||
private final Map<String, Retriever> retrieverMap = new HashMap<>();
|
||||
|
||||
|
||||
/**
|
||||
* 根据类型获取对应的检索器
|
||||
*
|
||||
* @param query 查询内容
|
||||
* @return 检索器实例
|
||||
*/
|
||||
public Retriever mapping(String query) {
|
||||
if (StrUtil.isEmpty(query)) {
|
||||
log.warn("查询内容为空,无法获取检索器");
|
||||
return null;
|
||||
}
|
||||
Embedding embedding = aiCallService.embedding(query);
|
||||
|
||||
List<TextVectorDTO> similarByCosine = textVectorService.findSimilarByCosine(embedding.getOutput(), threshold, 1);
|
||||
if (CollUtil.isEmpty(similarByCosine)) {
|
||||
log.info("问题:{},未找到相似文本向量,匹配阈值:{}", query, threshold);
|
||||
return null;
|
||||
}
|
||||
TextVectorDTO textVectorDTO = CollUtil.getFirst(similarByCosine);
|
||||
Assert.notEmpty(textVectorDTO.getCategoryId(), "相似文本向量的分类ID不能为空");
|
||||
QuestionHandlerMapping handler = questionHandlerMappingService.findHandlerByCategoryId(textVectorDTO.getCategoryId());
|
||||
if (handler == null){
|
||||
return null;
|
||||
}
|
||||
return retrieverMap.get(handler.getHandler());
|
||||
}
|
||||
|
||||
@PostConstruct
|
||||
public void init() {
|
||||
applicationContext.getBeansOfType(Retriever.class)
|
||||
.forEach((name, retriever) -> {
|
||||
if (retrieverMap.containsKey(name)) {
|
||||
throw new IllegalArgumentException("Retriever with name " + name + " already exists.");
|
||||
}
|
||||
retrieverMap.put(name, retriever);
|
||||
});
|
||||
}
|
||||
}
|
@ -0,0 +1,108 @@
|
||||
package com.supervision.pdfqaserver.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.supervision.pdfqaserver.domain.NodeRelationVector;
|
||||
import com.supervision.pdfqaserver.dto.*;
|
||||
import com.supervision.pdfqaserver.service.AiCallService;
|
||||
import com.supervision.pdfqaserver.service.NodeRelationVectorService;
|
||||
import com.supervision.pdfqaserver.mapper.NodeRelationVectorMapper;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @author Administrator
|
||||
* @description 针对表【node_relation_vector(节点关系向量表)】的数据库操作Service实现
|
||||
* @createDate 2025-06-18 13:38:02
|
||||
*/
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class NodeRelationVectorServiceImpl extends ServiceImpl<NodeRelationVectorMapper, NodeRelationVector>
|
||||
implements NodeRelationVectorService{
|
||||
|
||||
private final AiCallService aiCallService;
|
||||
@Override
|
||||
@Transactional(rollbackFor = Exception.class)
|
||||
public void refreshSchemaSegmentVector(CypherSchemaDTO cypherSchemaDTO) {
|
||||
|
||||
// 删除旧的向量数据
|
||||
super.lambdaUpdate().remove();
|
||||
// 重新插入新的向量数据
|
||||
List<EntityExtractionDTO> nodes = cypherSchemaDTO.getNodes();
|
||||
List<RelationExtractionDTO> relations = cypherSchemaDTO.getRelations();
|
||||
List<NodeRelationVector> allRelationVectors = new ArrayList<>();
|
||||
List<String> texts = new ArrayList<>();
|
||||
for (List<RelationExtractionDTO> relationSplit : CollUtil.split(relations, 200)) {
|
||||
List<String> rs = relationSplit.stream().map(RelationExtractionDTO::getRelation).toList();
|
||||
List<Embedding> embedding = aiCallService.embedding(rs);
|
||||
for (Embedding embed : embedding) {
|
||||
if (texts.contains(rs.get(embed.getIndex()))){
|
||||
continue;
|
||||
}
|
||||
texts.add(rs.get(embed.getIndex()));
|
||||
NodeRelationVector vector = new NodeRelationVector();
|
||||
vector.setContent(rs.get(embed.getIndex()));
|
||||
vector.setEmbedding(embed.getOutput());
|
||||
vector.setContentType("R");// 关系
|
||||
allRelationVectors.add(vector);
|
||||
}
|
||||
List<String> ers = relationSplit.stream()
|
||||
.map(r -> StrUtil.join(" ", r.getSourceType(), r.getRelation(),r.getTargetType())).toList();
|
||||
List<Embedding> erEmbeddings = aiCallService.embedding(ers);
|
||||
for (Embedding embed : erEmbeddings) {
|
||||
if (texts.contains(ers.get(embed.getIndex()))) {
|
||||
continue;
|
||||
}
|
||||
texts.add(ers.get(embed.getIndex()));
|
||||
NodeRelationVector vector = new NodeRelationVector();
|
||||
vector.setContent(ers.get(embed.getIndex()));
|
||||
vector.setEmbedding(embed.getOutput());
|
||||
vector.setContentType("ER");
|
||||
allRelationVectors.add(vector);
|
||||
}
|
||||
}
|
||||
super.saveBatch(allRelationVectors);
|
||||
List<NodeRelationVector> allNodeVectors = new ArrayList<>();
|
||||
texts = new ArrayList<>();
|
||||
for (List<EntityExtractionDTO> entitySplit : CollUtil.split(nodes, 200)) {
|
||||
List<String> es = entitySplit.stream().map(EntityExtractionDTO::getEntity).toList();
|
||||
List<Embedding> embedding = aiCallService.embedding(es);
|
||||
for (Embedding embed : embedding) {
|
||||
if (texts.contains(es.get(embed.getIndex()))) {
|
||||
continue;
|
||||
}
|
||||
texts.add(es.get(embed.getIndex()));
|
||||
NodeRelationVector vector = new NodeRelationVector();
|
||||
vector.setContent(es.get(embed.getIndex()));
|
||||
vector.setEmbedding(embed.getOutput());
|
||||
vector.setContentType("N");
|
||||
allNodeVectors.add(vector);
|
||||
}
|
||||
}
|
||||
super.saveBatch(allNodeVectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NodeRelationVector> matchSimilarByCosine(float[] embedding, double threshold, List<String> contentType, int limit) {
|
||||
return super.getBaseMapper().findSimilarByCosine(embedding, threshold, contentType, limit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Double matchContentScore(float[] embedding, String content) {
|
||||
if (StrUtil.isEmpty(content) || embedding == null || embedding.length == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
return super.getBaseMapper().matchContentScore(embedding, content);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,49 @@
|
||||
package com.supervision.pdfqaserver.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.dictionary.CustomDictionary;
|
||||
import com.hankcs.hanlp.seg.Segment;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.supervision.pdfqaserver.dto.TextTerm;
|
||||
import com.supervision.pdfqaserver.service.TextToSegmentService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class TextToSegmentServiceImpl implements TextToSegmentService {
|
||||
@Override
|
||||
public List<TextTerm> segmentText(String text) {
|
||||
if (StrUtil.isEmpty(text)){
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Segment segment = HanLP.newSegment()
|
||||
.enableOrganizationRecognize(true)
|
||||
.enablePlaceRecognize(true)
|
||||
.enableNumberQuantifierRecognize(true);
|
||||
|
||||
List<Term> seg = segment.seg(text);
|
||||
if (CollUtil.isEmpty(seg)){
|
||||
return new ArrayList<>();
|
||||
}
|
||||
List<TextTerm> terms = new ArrayList<>();
|
||||
for (Term term : seg) {
|
||||
TextTerm textTerm = new TextTerm();
|
||||
textTerm.setWord(term.word);
|
||||
textTerm.setLabel(term.nature.toString());
|
||||
terms.add(textTerm);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
@Override
|
||||
public void addDict(String word, String label,int frequency) {
|
||||
CustomDictionary.insert(word, label + " " + frequency);
|
||||
}
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE mapper
|
||||
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.supervision.pdfqaserver.mapper.NodeRelationVectorMapper">
|
||||
|
||||
<resultMap id="BaseResultMap" type="com.supervision.pdfqaserver.domain.NodeRelationVector">
|
||||
<id property="id" column="id" jdbcType="VARCHAR"/>
|
||||
<result property="content" column="content" jdbcType="VARCHAR"/>
|
||||
<result property="embedding" column="embedding" jdbcType="OTHER" typeHandler="com.supervision.pdfqaserver.config.VectorTypeHandler"/>
|
||||
<result property="contentType" column="content_type" jdbcType="VARCHAR"/>
|
||||
<result property="createTime" column="create_time" jdbcType="TIMESTAMP"/>
|
||||
<result property="updateTime" column="update_time" jdbcType="TIMESTAMP"/>
|
||||
</resultMap>
|
||||
|
||||
<sql id="Base_Column_List">
|
||||
id,content,embedding,
|
||||
content_type,create_time,update_time
|
||||
</sql>
|
||||
|
||||
<select id="findSimilarByCosine" resultType="com.supervision.pdfqaserver.domain.NodeRelationVector">
|
||||
SELECT * FROM (
|
||||
SELECT
|
||||
id,
|
||||
content,
|
||||
embedding,
|
||||
content_type,
|
||||
1 - (embedding <![CDATA[<=>]]> #{embedding, typeHandler=com.supervision.pdfqaserver.config.VectorTypeHandler}) AS similarityScore
|
||||
FROM node_relation_vector
|
||||
) t
|
||||
WHERE t.similarityScore > #{threshold}
|
||||
<if test="contentType != null and contentType.size() > 0">
|
||||
AND content_type IN
|
||||
<foreach item="item" collection="contentType" open="(" separator="," close=")">
|
||||
#{item}
|
||||
</foreach>
|
||||
</if>
|
||||
ORDER BY t.similarityScore DESC
|
||||
LIMIT #{limit}
|
||||
</select>
|
||||
<select id="matchContentScore" resultType="java.lang.Double">
|
||||
SELECT
|
||||
CASE
|
||||
WHEN #{embedding} IS NULL THEN 0
|
||||
WHEN #{content} IS NULL THEN 0
|
||||
ELSE COALESCE(
|
||||
1 - (embedding <![CDATA[<=>]]>
|
||||
#{embedding, typeHandler=com.supervision.pdfqaserver.config.VectorTypeHandler}),
|
||||
0
|
||||
)
|
||||
END AS similarityScore
|
||||
FROM node_relation_vector
|
||||
WHERE content = #{content}
|
||||
LIMIT 1
|
||||
</select>
|
||||
|
||||
</mapper>
|
Loading…
Reference in New Issue