You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

337 lines
15 KiB
Java

2 months ago
package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.BooleanUtil;
import cn.hutool.core.util.RandomUtil;
2 months ago
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
2 months ago
import com.supervision.pdfqaserver.constant.LayoutTypeEnum;
2 months ago
import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.service.TripleConversionPipeline;
2 months ago
import edu.stanford.nlp.pipeline.CoreDocument;
import edu.stanford.nlp.pipeline.CoreSentence;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
2 months ago
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
2 months ago
import java.util.*;
import java.util.stream.Collectors;
2 months ago
2 months ago
@Slf4j
@Service
@RequiredArgsConstructor
public class TripleConversionPipelineImpl implements TripleConversionPipeline {
private final OllamaChatModel ollamaChatModel;
@Override
public List<TruncateDTO> sliceDocuments(List<DocumentDTO> documents) {
// 对pdfAnalysisOutputs进行排序
List<DocumentDTO> documentDTOList = documents.stream().sorted(
// 先对pageNo进行排序再对layoutOrder进行排序
(o1, o2) -> {
if (o1.getPageNo().equals(o2.getPageNo())) {
return Integer.compare(o1.getDisplayOrder(), o2.getDisplayOrder());
2 months ago
}
return Integer.compare(o1.getPageNo(), o2.getPageNo());
}
).toList();
2 months ago
Properties props = new Properties();
props.setProperty("annotators", "tokenize, ssplit");
// 创建管道
StanfordCoreNLP pipeline = new StanfordCoreNLP(props);
List<TruncateDTO> truncateDTOS = new ArrayList<>();
for (DocumentDTO documentDTO : documentDTOList) {
String content = documentDTO.getContent();
if (StrUtil.isEmpty(content)){
continue;
}
Integer layoutType = documentDTO.getLayoutType();
if (LayoutTypeEnum.TEXT.getCode() == layoutType){
// 如果是文本类型的布局,进行合并
CoreDocument document = new CoreDocument(content);
// 分析文本
pipeline.annotate(document);
// 获取句子
for (CoreSentence sentence : document.sentences()) {
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTO.setContent(sentence.text());
truncateDTOS.add(truncateDTO);
}
} else if (LayoutTypeEnum.TABLE.getCode() == layoutType) {
// 如果是表格类型的布局,进行切分
// 提前抽取表名
TableTitleDTO tableTitleDTO = this.extractTableTitle(documentDTO.getTitle());
if (null != tableTitleDTO && StrUtil.isNotEmpty(tableTitleDTO.getTitle())){
documentDTO.setTitle(tableTitleDTO.getTitle());
}else {
// 生成一个默认的表
documentDTO.setTitle("tableName-"+ RandomUtil.randomString(10));
}
List<String> tableRows = StrUtil.split(documentDTO.getContent(), "\n").stream().filter(StrUtil::isNotEmpty).collect(Collectors.toList());
if (tableRows.size()<5){
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTOS.add(truncateDTO);
continue;
}
String tableTitle = tableRows.get(0);
// 标题分割符
String tableTitleSplit = tableRows.get(1);
List<String> noTitleRows = tableRows.subList(2,tableRows.size()-1);
List<List<String>> rows = CollUtil.split(noTitleRows, 4);
for (List<String> row : rows) {
StringBuilder sb = new StringBuilder();
sb.append(tableTitle).append("\n");
sb.append(tableTitleSplit).append("\n");
for (String s : row) {
sb.append(s).append("\n");
}
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTO.setContent(sb.toString());
truncateDTOS.add(truncateDTO);
}
2 months ago
} else {
log.info("sliceDocuments:错误的布局类型: {}", layoutType);
}
}
return truncateDTOS;
2 months ago
}
@Override
public EREDTO doEre(TruncateDTO truncateDTO) {
if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TEXT.getCode()))){
return doTextEre(truncateDTO);
}
2 months ago
if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TABLE.getCode()))){
// 先分析表格是否是描述类型
Boolean classify = this.classify(truncateDTO.getContent());
if (null == classify){
log.info("doEre:表格分类结果为空,切分文档id:{}", truncateDTO.getId());
return null;
}
if (classify){
return doTextEre(truncateDTO);
}
return doTableEre(truncateDTO);
2 months ago
}
log.warn("doEre:错误的布局类型: {}", truncateDTO.getLayoutType());
return null;
}
2 months ago
@Override
public Boolean classify(String content) {
Assert.notEmpty(content, "内容不能为空");
// 对表格内容进行精简,只获取与前四行相关的内容
String[] lines = content.split("\n");
if (lines.length > 5){
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 5; i++) {
sb.append(lines[i]).append("\n");
}
content = sb.toString();
2 months ago
}
log.info("classify:开始进行实体关系分类,内容:{}", content);
String prompt = PromptCache.promptMap.get(PromptCache.CLASSIFY_TABLE);
String format = StrUtil.format(prompt, content);
String response = ollamaChatModel.call(format);
log.info("classify响应结果:{}", response);
return BooleanUtil.toBooleanObject(response);
}
@Override
public TableTitleDTO extractTableTitle(String content) {
TableTitleDTO tableTitleDTO = new TableTitleDTO();
if (StrUtil.isEmpty(content)){
log.warn("extractTableTitle:内容为空");
return tableTitleDTO;
}
String table = PromptCache.promptMap.get(PromptCache.EXTRACT_TABLE_TITLE);
String format = StrUtil.format(table, content);
String response = ollamaChatModel.call(format);
tableTitleDTO.setTitle(response);
return tableTitleDTO;
2 months ago
}
private EREDTO doTextEre(TruncateDTO truncateDTO) {
log.info("doTextEre:开始进行文本实体关系抽取,内容:{}", truncateDTO.getContent());
2 months ago
String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TEXT);
String formatted = StrUtil.format(prompt, truncateDTO.getContent());
2 months ago
String response = ollamaChatModel.call(formatted);
log.info("doTextEre响应结果:{}", response);
2 months ago
return EREDTO.fromTextJson(response, truncateDTO.getId());
}
private EREDTO doTableEre(TruncateDTO truncateDTO) {
log.info("doTableEre:开始进行表格实体关系抽取,内容:{}", truncateDTO.getContent());
2 months ago
String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TABLE);
String formatted = StrUtil.format(prompt, truncateDTO.getContent());
2 months ago
String response = ollamaChatModel.call(formatted);
log.info("doTableEre响应结果:{}", response);
EREDTO eredto = EREDTO.fromTableJson(response, truncateDTO.getId());
// 手动设置表格标题
EntityExtractionDTO titleEntity = new EntityExtractionDTO();
titleEntity.setEntity("表");
titleEntity.setTruncationId(truncateDTO.getId());
titleEntity.setName(truncateDTO.getTitle());
// 添加关系
List<RelationExtractionDTO> relations = new ArrayList<>();
for (EntityExtractionDTO entity : eredto.getEntities()) {
RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(truncateDTO.getId(),
titleEntity.getName(), titleEntity.getEntity(), "包含", entity.getName(), entity.getEntity(), entity.getAttributes());
relations.add(relationExtractionDTO);
}
eredto.getEntities().add(titleEntity);
eredto.setRelations(relations);
return eredto;
2 months ago
}
/**
*
*
2 months ago
* @param eredtoList
* @return
*/
@Override
public List<EREDTO> mergeEreResults(List<EREDTO> eredtoList) {
List<EREDTO> merged = new ArrayList<>();
if (CollUtil.isEmpty(eredtoList)){
return merged;
}
// 将表单独拿出来
merged = eredtoList.stream().filter(ere->
ere.getEntities().stream().anyMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());
// 把剩下的数据进行合并计算
eredtoList = eredtoList.stream().filter(ere->
ere.getEntities().stream().noneMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());
2 months ago
Map<String, EntityExtractionDTO> entityMap = new HashMap<>();
Map<String, RelationExtractionDTO> relationMap = new HashMap<>();
2 months ago
for (EREDTO eredto : eredtoList) {
List<EntityExtractionDTO> entities = eredto.getEntities();
if (CollUtil.isNotEmpty(entities)){
for (EntityExtractionDTO entity : entities) {
2 months ago
String key = generateEntityMapKey(entity);
mergeAttribute(entityMap,entity, key);
2 months ago
}
}
List<RelationExtractionDTO> relations = eredto.getRelations();
if (CollUtil.isNotEmpty(relations)){
for (RelationExtractionDTO relation : relations) {
// source和target,re完全相等看作是同一个数据
2 months ago
String relationMapKey = generateRelationMapKey(relation);
mergeAttribute(relationMap,relation, relationMapKey);
2 months ago
}
}
}
2 months ago
// 利用合并后的map生成新的EREDTO
// 优先先把有关系的节点与关系组合在一次
Set<String> relationEntityKey = new HashSet<>();
for (Map.Entry<String, RelationExtractionDTO> relationEntry : relationMap.entrySet()) {
RelationExtractionDTO value = relationEntry.getValue();
EntityExtractionDTO sourceEntity = entityMap.get(StrUtil.join("_",value.getSourceType(), value.getSource()));
2 months ago
if (null == sourceEntity){
log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到头节点映射关系", value.getSourceType(), value.getSource());
continue;
}
EntityExtractionDTO targetEntity = entityMap.get(StrUtil.join("_", value.getTargetType(),value.getTarget()));
2 months ago
if (null == targetEntity){
log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到尾节点映射关系", value.getTargetType(), value.getTarget());
continue;
}
EREDTO eredto = new EREDTO();
eredto.setEntities(List.of(sourceEntity,targetEntity));
eredto.setRelations(List.of(value));
merged.add(eredto);
relationEntityKey.addAll(List.of(generateEntityMapKey(sourceEntity),generateEntityMapKey(targetEntity)));
}
// 将没有关系的节点单独放在一起
List<EntityExtractionDTO> leavedEntities = new ArrayList<>();
for (Map.Entry<String, EntityExtractionDTO> entry : entityMap.entrySet()) {
if (!relationEntityKey.contains(entry.getKey())){
leavedEntities.add(entry.getValue());
}
}
if (CollUtil.isNotEmpty(leavedEntities)){
EREDTO eredto = new EREDTO();
eredto.setEntities(leavedEntities);
merged.add(eredto);
}
2 months ago
return merged;
}
2 months ago
2 months ago
private void mergeAttribute(Map<String, RelationExtractionDTO> entityMap,RelationExtractionDTO relation, String key) {
RelationExtractionDTO cachedRelation = entityMap.get(key);
if (null == cachedRelation){
2 months ago
entityMap.put(key, relation);
}else {
if (CollUtil.isEmpty(relation.getAttributes())){
return;
}
// 合并属性
List<ERAttributeDTO> cachedAttributes = cachedRelation.getAttributes();
if (null == cachedAttributes){
cachedAttributes = new ArrayList<>();
2 months ago
}
for (ERAttributeDTO attribute : relation.getAttributes()) {
String attributeKey = attribute.getAttribute();
String attributeValue = attribute.getValue();
if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
continue;
}
// 如果属性已经存在,则不添加
if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
cachedAttributes.add(attribute);
2 months ago
}
}
}
}
private void mergeAttribute(Map<String, EntityExtractionDTO> entityMap,EntityExtractionDTO entity, String key) {
EntityExtractionDTO cachedEntity = entityMap.get(key);
if (null == cachedEntity){
entityMap.put(key, entity);
}else {
if (CollUtil.isEmpty(entity.getAttributes())){
return;
}
// 合并属性
List<ERAttributeDTO> cachedAttributes = cachedEntity.getAttributes();
if (null == cachedAttributes){
cachedAttributes = new ArrayList<>();
cachedEntity.setAttributes(cachedAttributes);
2 months ago
}
for (ERAttributeDTO attribute : entity.getAttributes()) {
String attributeKey = attribute.getAttribute();
String attributeValue = attribute.getValue();
if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
continue;
}
// 如果属性已经存在,则不添加
if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
cachedAttributes.add(attribute);
2 months ago
}
}
}
}
private String generateEntityMapKey(EntityExtractionDTO entityExtractionDTO) {
return entityExtractionDTO.getEntity() + "_" + entityExtractionDTO.getName();
}
private String generateRelationMapKey(RelationExtractionDTO relationExtractionDTO) {
return relationExtractionDTO.getSource()+ "_" + relationExtractionDTO.getRelation() + "_" + relationExtractionDTO.getTarget();
2 months ago
}
}