You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

402 lines
18 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.pdfqaserver.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.BooleanUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.pdfqaserver.cache.PromptCache;
import com.supervision.pdfqaserver.constant.DocumentContentTypeEnum;
import com.supervision.pdfqaserver.constant.LayoutTypeEnum;
import com.supervision.pdfqaserver.dto.*;
import com.supervision.pdfqaserver.service.TripleConversionPipeline;
import edu.stanford.nlp.pipeline.CoreDocument;
import edu.stanford.nlp.pipeline.CoreSentence;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class TripleConversionPipelineImpl implements TripleConversionPipeline {
private final OllamaChatModel ollamaChatModel;
@Override
public DocumentContentTypeEnum makeOutPdfContentType(Integer pdfId) {
return null;
}
@Override
public String makeOutPdfIndustry(Integer pdfId) {
return null;
}
@Override
public List<String> makeOutTruncationIntent(TruncateDTO truncate) {
return null;
}
@Override
public List<IntentDTO> makeOutTruncationIntent(TruncateDTO truncate, List<IntentDTO> intents) {
return null;
}
@Override
public List<DomainMetadataDTO> makeOutDomainMetadata(TruncateDTO truncate,List<String> intents) {
return null;
}
@Override
public EREDTO doEre(TruncateDTO truncateDTO, List<IntentDTO> intents) {
return null;
}
/**
* 切分文档
* 切分规则:
* 文本类型: 以单句为最小单元最大字数现在这1000字以内。单句超过1000字取完成的单句。
* 表格类型: 以4行数据为最小单元。
* @param documents 文档列表
* @return
*/
@Override
public List<TruncateDTO> sliceDocuments(List<DocumentDTO> documents) {
int maxTextLength = 1000;
int minTextLength = 800;
int INITIAL_BUFFER_SIZE = 1500;
// 对pdfAnalysisOutputs进行排序
List<DocumentDTO> documentDTOList = documents.stream().sorted(
// 先对pageNo进行排序再对layoutOrder进行排序
(o1, o2) -> {
if (o1.getPageNo().equals(o2.getPageNo())) {
return Integer.compare(o1.getDisplayOrder(), o2.getDisplayOrder());
}
return Integer.compare(o1.getPageNo(), o2.getPageNo());
}
).toList();
Properties props = new Properties();
props.setProperty("annotators", "tokenize, ssplit");
// 创建管道
StanfordCoreNLP pipeline = new StanfordCoreNLP(props);
List<TruncateDTO> truncateDTOS = new ArrayList<>();
StringBuilder truncateTextBuild = new StringBuilder(1500);
for (DocumentDTO documentDTO : documentDTOList) {
String content = documentDTO.getContent();
if (StrUtil.isEmpty(content)){
continue;
}
Integer layoutType = documentDTO.getLayoutType();
if (LayoutTypeEnum.TEXT.getCode() == layoutType){
// 如果是文本类型的布局,进行合并
CoreDocument document = new CoreDocument(content);
// 分析文本
pipeline.annotate(document);
// 获取句子
for (CoreSentence sentence : document.sentences()) {
if (StrUtil.isEmpty(sentence.text())) {
continue;
}
if (sentence.text().length() >= maxTextLength) {
if (truncateTextBuild.length() >= minTextLength) {
// 提交缓存内容
truncateDTOS.add(new TruncateDTO(documentDTO, truncateTextBuild.toString()));
truncateTextBuild = new StringBuilder(INITIAL_BUFFER_SIZE);
}
// 提交超长句子
truncateDTOS.add(new TruncateDTO(documentDTO, sentence.text()));
} else {
if (truncateTextBuild.length() + sentence.text().length() >= minTextLength) {
truncateTextBuild.append(sentence.text());
truncateDTOS.add(new TruncateDTO(documentDTO, truncateTextBuild.toString()));
truncateTextBuild = new StringBuilder(INITIAL_BUFFER_SIZE);
} else {
truncateTextBuild.append(sentence.text());
}
}
}
// 处理剩余内容
if (!truncateTextBuild.isEmpty()) {
truncateDTOS.add(new TruncateDTO(documentDTO, truncateTextBuild.toString()));
}
} else if (LayoutTypeEnum.TABLE.getCode() == layoutType) {
// 如果是表格类型的布局,进行切分
// 提前抽取表名
TableTitleDTO tableTitleDTO = this.extractTableTitle(documentDTO.getTitle());
if (null != tableTitleDTO && StrUtil.isNotEmpty(tableTitleDTO.getTitle())){
documentDTO.setTitle(tableTitleDTO.getTitle());
}else {
// 生成一个默认的表
documentDTO.setTitle("tableName-"+ RandomUtil.randomString(10));
}
List<String> tableRows = StrUtil.split(documentDTO.getContent(), "\n").stream().filter(StrUtil::isNotEmpty).collect(Collectors.toList());
if (tableRows.size()<5){
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTOS.add(truncateDTO);
continue;
}
String tableTitle = tableRows.get(0);
// 标题分割符
String tableTitleSplit = tableRows.get(1);
List<String> noTitleRows = tableRows.subList(2,tableRows.size()-1);
List<List<String>> rows = CollUtil.split(noTitleRows, 4);
for (List<String> row : rows) {
StringBuilder sb = new StringBuilder();
sb.append(tableTitle).append("\n");
sb.append(tableTitleSplit).append("\n");
for (String s : row) {
sb.append(s).append("\n");
}
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTO.setContent(sb.toString());
truncateDTOS.add(truncateDTO);
}
} else {
log.info("sliceDocuments:错误的布局类型: {}", layoutType);
}
}
return truncateDTOS;
}
@Override
public EREDTO doEre(TruncateDTO truncateDTO) {
if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TEXT.getCode()))){
return doTextEre(truncateDTO);
}
if (StrUtil.equals(truncateDTO.getLayoutType(),String.valueOf(LayoutTypeEnum.TABLE.getCode()))){
// 先分析表格是否是描述类型
Boolean classify = this.classify(truncateDTO.getContent());
if (null == classify){
log.info("doEre:表格分类结果为空,切分文档id:{}", truncateDTO.getId());
return null;
}
if (classify){
return doTextEre(truncateDTO);
}
return doTableEre(truncateDTO);
}
log.warn("doEre:错误的布局类型: {}", truncateDTO.getLayoutType());
return null;
}
@Override
public Boolean classify(String content) {
Assert.notEmpty(content, "内容不能为空");
// 对表格内容进行精简,只获取与前四行相关的内容
String[] lines = content.split("\n");
if (lines.length > 5){
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 5; i++) {
sb.append(lines[i]).append("\n");
}
content = sb.toString();
}
log.info("classify:开始进行实体关系分类,内容:{}", content);
String prompt = PromptCache.promptMap.get(PromptCache.CLASSIFY_TABLE);
String format = StrUtil.format(prompt, content);
String response = ollamaChatModel.call(format);
log.info("classify响应结果:{}", response);
return BooleanUtil.toBooleanObject(response);
}
@Override
public TableTitleDTO extractTableTitle(String content) {
TableTitleDTO tableTitleDTO = new TableTitleDTO();
if (StrUtil.isEmpty(content)){
log.warn("extractTableTitle:内容为空");
return tableTitleDTO;
}
String table = PromptCache.promptMap.get(PromptCache.EXTRACT_TABLE_TITLE);
String format = StrUtil.format(table, content);
String response = ollamaChatModel.call(format);
tableTitleDTO.setTitle(response);
return tableTitleDTO;
}
private EREDTO doTextEre(TruncateDTO truncateDTO) {
log.info("doTextEre:开始进行文本实体关系抽取,内容:{}", truncateDTO.getContent());
String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TEXT);
String formatted = StrUtil.format(prompt, truncateDTO.getContent());
String response = ollamaChatModel.call(formatted);
log.info("doTextEre响应结果:{}", response);
return EREDTO.fromTextJson(response, truncateDTO.getId());
}
private EREDTO doTableEre(TruncateDTO truncateDTO) {
log.info("doTableEre:开始进行表格实体关系抽取,内容:{}", truncateDTO.getContent());
String prompt = PromptCache.promptMap.get(PromptCache.DOERE_TABLE);
String formatted = StrUtil.format(prompt, truncateDTO.getContent());
String response = ollamaChatModel.call(formatted);
log.info("doTableEre响应结果:{}", response);
EREDTO eredto = EREDTO.fromTableJson(response, truncateDTO.getId());
// 手动设置表格标题
EntityExtractionDTO titleEntity = new EntityExtractionDTO();
titleEntity.setEntity("表");
titleEntity.setTruncationId(truncateDTO.getId());
titleEntity.setName(truncateDTO.getTitle());
// 添加关系
List<RelationExtractionDTO> relations = new ArrayList<>();
for (EntityExtractionDTO entity : eredto.getEntities()) {
RelationExtractionDTO relationExtractionDTO = new RelationExtractionDTO(truncateDTO.getId(),
titleEntity.getName(), titleEntity.getEntity(), "包含", entity.getName(), entity.getEntity(), entity.getAttributes());
relations.add(relationExtractionDTO);
}
eredto.getEntities().add(titleEntity);
eredto.setRelations(relations);
return eredto;
}
/**
* 合并实体关系抽取结果 主要是对实体和关系中的属性进行合并
* 表不参与合并
* @param eredtoList 实体关系抽取结果列表
* @return
*/
@Override
public List<EREDTO> mergeEreResults(List<EREDTO> eredtoList) {
List<EREDTO> merged = new ArrayList<>();
if (CollUtil.isEmpty(eredtoList)){
return merged;
}
// 将表单独拿出来
merged = eredtoList.stream().filter(ere->
ere.getEntities().stream().anyMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());
// 把剩下的数据进行合并计算
eredtoList = eredtoList.stream().filter(ere->
ere.getEntities().stream().noneMatch(e->StrUtil.equals(e.getEntity(),"表"))).collect(Collectors.toList());
Map<String, EntityExtractionDTO> entityMap = new HashMap<>();
Map<String, RelationExtractionDTO> relationMap = new HashMap<>();
for (EREDTO eredto : eredtoList) {
List<EntityExtractionDTO> entities = eredto.getEntities();
if (CollUtil.isNotEmpty(entities)){
for (EntityExtractionDTO entity : entities) {
String key = generateEntityMapKey(entity);
mergeAttribute(entityMap,entity, key);
}
}
List<RelationExtractionDTO> relations = eredto.getRelations();
if (CollUtil.isNotEmpty(relations)){
for (RelationExtractionDTO relation : relations) {
// source和target,re完全相等看作是同一个数据
String relationMapKey = generateRelationMapKey(relation);
mergeAttribute(relationMap,relation, relationMapKey);
}
}
}
// 利用合并后的map生成新的EREDTO
// 优先先把有关系的节点与关系组合在一次
Set<String> relationEntityKey = new HashSet<>();
for (Map.Entry<String, RelationExtractionDTO> relationEntry : relationMap.entrySet()) {
RelationExtractionDTO value = relationEntry.getValue();
EntityExtractionDTO sourceEntity = entityMap.get(StrUtil.join("_",value.getSourceType(), value.getSource()));
if (null == sourceEntity){
log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到头节点映射关系", value.getSourceType(), value.getSource());
continue;
}
EntityExtractionDTO targetEntity = entityMap.get(StrUtil.join("_", value.getTargetType(),value.getTarget()));
if (null == targetEntity){
log.warn("mergeEreResults:根据entity:{},name:{}未在entityMap中找到尾节点映射关系", value.getTargetType(), value.getTarget());
continue;
}
EREDTO eredto = new EREDTO();
eredto.setEntities(List.of(sourceEntity,targetEntity));
eredto.setRelations(List.of(value));
merged.add(eredto);
relationEntityKey.addAll(List.of(generateEntityMapKey(sourceEntity),generateEntityMapKey(targetEntity)));
}
// 将没有关系的节点单独放在一起
List<EntityExtractionDTO> leavedEntities = new ArrayList<>();
for (Map.Entry<String, EntityExtractionDTO> entry : entityMap.entrySet()) {
if (!relationEntityKey.contains(entry.getKey())){
leavedEntities.add(entry.getValue());
}
}
if (CollUtil.isNotEmpty(leavedEntities)){
EREDTO eredto = new EREDTO();
eredto.setEntities(leavedEntities);
merged.add(eredto);
}
return merged;
}
private void mergeAttribute(Map<String, RelationExtractionDTO> entityMap,RelationExtractionDTO relation, String key) {
RelationExtractionDTO cachedRelation = entityMap.get(key);
if (null == cachedRelation){
entityMap.put(key, relation);
}else {
if (CollUtil.isEmpty(relation.getAttributes())){
return;
}
// 合并属性
List<TruncationERAttributeDTO> cachedAttributes = cachedRelation.getAttributes();
if (null == cachedAttributes){
cachedAttributes = new ArrayList<>();
}
for (TruncationERAttributeDTO attribute : relation.getAttributes()) {
String attributeKey = attribute.getAttribute();
String attributeValue = attribute.getValue();
if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
continue;
}
// 如果属性已经存在,则不添加
if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
cachedAttributes.add(attribute);
}
}
}
}
private void mergeAttribute(Map<String, EntityExtractionDTO> entityMap,EntityExtractionDTO entity, String key) {
EntityExtractionDTO cachedEntity = entityMap.get(key);
if (null == cachedEntity){
entityMap.put(key, entity);
}else {
if (CollUtil.isEmpty(entity.getAttributes())){
return;
}
// 合并属性
List<TruncationERAttributeDTO> cachedAttributes = cachedEntity.getAttributes();
if (null == cachedAttributes){
cachedAttributes = new ArrayList<>();
cachedEntity.setAttributes(cachedAttributes);
}
for (TruncationERAttributeDTO attribute : entity.getAttributes()) {
String attributeKey = attribute.getAttribute();
String attributeValue = attribute.getValue();
if (StrUtil.isEmpty(attributeKey) || StrUtil.isEmpty(attributeValue)){
continue;
}
// 如果属性已经存在,则不添加
if (cachedAttributes.stream().noneMatch(a -> StrUtil.equals(a.getAttribute(), attributeKey))) {
cachedAttributes.add(attribute);
}
}
}
}
private String generateEntityMapKey(EntityExtractionDTO entityExtractionDTO) {
return entityExtractionDTO.getEntity() + "_" + entityExtractionDTO.getName();
}
private String generateRelationMapKey(RelationExtractionDTO relationExtractionDTO) {
return relationExtractionDTO.getSource()+ "_" + relationExtractionDTO.getRelation() + "_" + relationExtractionDTO.getTarget();
}
}