diff --git a/pom.xml b/pom.xml
index 8c3435a..dc143d7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -90,6 +90,16 @@
neo4j-java-driver
5.15.0
+
+ org.commonmark
+ commonmark
+ 0.21.0
+
+
+ org.commonmark
+ commonmark-ext-gfm-tables
+ 0.21.0
+
diff --git a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java
index f54f67b..1f37e60 100644
--- a/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java
+++ b/src/main/java/com/supervision/pdfqaserver/dto/EREDTO.java
@@ -2,6 +2,7 @@ package com.supervision.pdfqaserver.dto;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.UUID;
+import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.RandomUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONArray;
@@ -109,6 +110,46 @@ public class EREDTO {
return eredto;
}
+
+
+ public static EREDTO fromHeadAndRows(List heads,List> rows,String truncationId) {
+
+ EREDTO eredto = new EREDTO();
+ if (CollUtil.isEmpty(heads) || CollUtil.isEmpty(rows)){
+ return eredto;
+ }
+ List entities = new ArrayList<>();
+ for (List row : rows) {
+ if (CollUtil.isEmpty(row)){
+ continue;
+ }
+ EntityExtractionDTO entityExtractionDTO = new EntityExtractionDTO();
+ entityExtractionDTO.setEntity("行");
+ // 避免表格行名重复
+ entityExtractionDTO.setName("行-" + RandomUtil.randomString(UUID.randomUUID().toString(), 10));
+ entityExtractionDTO.setTruncationId(truncationId);
+ List truncationErAttributeDTOS = new ArrayList<>();
+ for (int i = 0; i < heads.size(); i++) {
+ String key = heads.get(i);
+ if (StrUtil.isBlank(key)){
+ continue;
+ }
+ key = StrUtil.trim(key);
+ String value = i < row.size() ? row.get(i) : "";
+ if (StrUtil.isBlank(value)){
+ continue;
+ }
+ value = StrUtil.trim(value);
+ TruncationERAttributeDTO truncationErAttributeDTO = new TruncationERAttributeDTO(key, value, NumberUtil.isNumber(value) ? "1" : "0");
+ truncationErAttributeDTOS.add(truncationErAttributeDTO);
+ }
+ entityExtractionDTO.setAttributes(truncationErAttributeDTOS);
+ entities.add(entityExtractionDTO);
+ }
+ return eredto;
+
+
+ }
public static EREDTO fromTableJson(String json,String truncationId) {
EREDTO eredto = new EREDTO();
diff --git a/src/main/java/com/supervision/pdfqaserver/service/TableVisitor.java b/src/main/java/com/supervision/pdfqaserver/service/TableVisitor.java
new file mode 100644
index 0000000..d87aba5
--- /dev/null
+++ b/src/main/java/com/supervision/pdfqaserver/service/TableVisitor.java
@@ -0,0 +1,98 @@
+package com.supervision.pdfqaserver.service;
+
+import org.commonmark.ext.gfm.tables.*;
+import org.commonmark.node.*;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class TableVisitor extends AbstractVisitor {
+ private boolean inHeader = false;
+ private boolean inBody = false;
+ private List currentRow = null;
+
+ private List headers = new ArrayList<>();
+
+ private final List> rows = new ArrayList<>();
+
+ @Override
+ public void visit(CustomBlock customBlock) {
+ if (customBlock instanceof TableBlock) {
+ handleTableBlock((TableBlock) customBlock);
+ } else {
+ super.visit(customBlock);
+ }
+ }
+
+ @Override
+ public void visit(CustomNode customNode) {
+ if (customNode instanceof TableHead) {
+ handleTableHead((TableHead) customNode);
+ } else if (customNode instanceof TableBody) {
+ handleTableBody((TableBody) customNode);
+ } else if (customNode instanceof TableRow) {
+ handleTableRow((TableRow) customNode);
+ } else if (customNode instanceof TableCell) {
+ handleTableCell((TableCell) customNode);
+ } else {
+ super.visit(customNode);
+ }
+ }
+
+ private void handleTableBlock(TableBlock tableBlock) {
+ // 重置状态
+ inHeader = false;
+ inBody = false;
+ visitChildren(tableBlock);
+ }
+
+ private void handleTableHead(TableHead tableHead) {
+ inHeader = true;
+ visitChildren(tableHead);
+ inHeader = false;
+ }
+
+ private void handleTableBody(TableBody tableBody) {
+ inBody = true;
+ visitChildren(tableBody);
+ inBody = false;
+ }
+
+ private void handleTableRow(TableRow tableRow) {
+ currentRow = new ArrayList<>();
+ visitChildren(tableRow);
+
+ if (inHeader) {
+ this.headers = currentRow;
+ } else if (inBody) {
+ this.rows.add(currentRow);
+ }
+ }
+
+ private void handleTableCell(TableCell tableCell) {
+ if (currentRow != null) {
+ currentRow.add(getTextContent(tableCell));
+ }
+ visitChildren(tableCell);
+ }
+
+ private String getTextContent(Node node) {
+ StringBuilder sb = new StringBuilder();
+ Node child = node.getFirstChild();
+ while (child != null) {
+ if (child instanceof Text) {
+ sb.append(((Text) child).getLiteral());
+ }
+ child = child.getNext();
+ }
+ return sb.toString().trim();
+ }
+
+ public List getTableHeaders() {
+ return headers;
+ }
+
+ public List> getTableRows() {
+ return rows;
+ }
+}
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java
index f452e52..fbfe65b 100644
--- a/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/KnowledgeGraphServiceImpl.java
@@ -208,6 +208,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
log.info("意图元数据识别完成,切分文档id:{},耗时:{}毫秒", truncateDTO.getId(),interval.intervalMs("makeOutDomainMetadata"));
// 保存意图数据
intentSize ++;
+ index ++;
List intentions = intentionService.batchSaveIfAbsent(intents, pdfInfo.getDomainCategoryId(), pdfId.toString());
for (Intention intention : intentions) {
List metadataDTOS = domainMetadataDTOS.stream()
@@ -215,7 +216,7 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
domainMetadataService.batchSaveOrUpdateMetadata(metadataDTOS,intention.getId(), pdfInfo.getDomainCategoryId());
}
}catch (Exception e){
- intentSize ++;
+ index ++;
log.error("切分文档id:{},意图识别失败", truncateDTO.getId(), e);
}
@@ -284,15 +285,13 @@ public class KnowledgeGraphServiceImpl implements KnowledgeGraphService {
try {
if (StrUtil.equals(truncateDTO.getLayoutType(), String.valueOf(LayoutTypeEnum.TABLE.getCode()))){
log.info("切分文档id:{},表格类型数据,不进行意图识别...", truncateDTO.getId());
- /*EREDTO eredto = conversionPipeline.doEre(truncateDTO, new ArrayList<>());
+ EREDTO eredto = conversionPipeline.doEre(truncateDTO, new ArrayList<>());
if (null == eredto){
log.info("切分文档id:{},命名实体识别结果为空...", truncateDTO.getId());
continue;
}
this.saveERE(eredto, truncateDTO.getId());
eredtos.add(eredto);
- */
- continue;
}
timer.start("makeOutTruncationIntent");
diff --git a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java
index 0861d2d..5e5971a 100644
--- a/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java
+++ b/src/main/java/com/supervision/pdfqaserver/service/impl/TripleConversionPipelineImpl.java
@@ -18,6 +18,10 @@ import edu.stanford.nlp.pipeline.CoreSentence;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
+import org.commonmark.Extension;
+import org.commonmark.ext.gfm.tables.TablesExtension;
+import org.commonmark.node.Node;
+import org.commonmark.parser.Parser;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@@ -216,7 +220,7 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
return doTextEre(truncateDTO);
}
- return doTableEre(truncateDTO);
+ return doTableEreFast(truncateDTO);
}
log.warn("doEre:错误的布局类型: {}", truncateDTO.getLayoutType());
return null;
@@ -225,7 +229,7 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
/**
* 切分文档
* 切分规则:
- * 文本类型: 以单句为最小单元,最大字数现在这1000字以内。单句超过1000字取完成的单句。
+ * 文本类型: 以单句为最小单元,最大字数现在这maxTextLength字以内。单句超过maxTextLength字取完成的单句。
* 表格类型: 以4行数据为最小单元。
* @param documents 文档列表
* @return
@@ -252,16 +256,14 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
// 创建管道
StanfordCoreNLP pipeline = new StanfordCoreNLP(props);
List truncateDTOS = new ArrayList<>();
- StringBuilder truncateTextBuild = new StringBuilder(1500);
- DocumentDTO documentDTOLast = null;
+ StringBuilder truncateTextBuild = new StringBuilder(minTextLength + maxTextLength);
for (DocumentDTO documentDTO : documentDTOList) {
- documentDTOLast = documentDTO;
String content = documentDTO.getContent();
- if (StrUtil.isEmpty(content)){
+ if (StrUtil.isEmpty(content)) {
continue;
}
Integer layoutType = documentDTO.getLayoutType();
- if (LayoutTypeEnum.TEXT.getCode() == layoutType){
+ if (LayoutTypeEnum.TEXT.getCode() == layoutType) {
// 如果是文本类型的布局,进行合并
CoreDocument document = new CoreDocument(content);
// 分析文本
@@ -293,18 +295,20 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
// 如果是表格类型的布局,进行切分
// 出现表格后,如果truncateTextBuild不为空,单独作为一个片段
if (!truncateTextBuild.isEmpty()) {
- truncateDTOS.add(new TruncateDTO(documentDTO, truncateTextBuild.toString()));
+ TruncateDTO truncateDTO = new TruncateDTO(documentDTO, truncateTextBuild.toString());
+ truncateDTO.setLayoutType(String.valueOf(LayoutTypeEnum.TEXT.getCode()));//强制设置为文本类型
+ truncateDTOS.add(truncateDTO);
}
// 提前抽取表名
TableTitleDTO tableTitleDTO = this.extractTableTitle(documentDTO.getTitle());
- if (null != tableTitleDTO && StrUtil.isNotEmpty(tableTitleDTO.getTitle())){
+ if (null != tableTitleDTO && StrUtil.isNotEmpty(tableTitleDTO.getTitle())) {
documentDTO.setTitle(tableTitleDTO.getTitle());
- }else {
+ } else {
// 生成一个默认的表
- documentDTO.setTitle("tableName-"+ RandomUtil.randomString(10));
+ documentDTO.setTitle("tableName-" + RandomUtil.randomString(10));
}
List tableRows = StrUtil.split(documentDTO.getContent(), "\n").stream().filter(StrUtil::isNotEmpty).collect(Collectors.toList());
- if (tableRows.size()<5){
+ if (tableRows.size() < 5) {
TruncateDTO truncateDTO = new TruncateDTO(documentDTO);
truncateDTOS.add(truncateDTO);
continue;
@@ -312,7 +316,7 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
String tableTitle = tableRows.get(0);
// 标题分割符
String tableTitleSplit = tableRows.get(1);
- List noTitleRows = tableRows.subList(2,tableRows.size()-1);
+ List noTitleRows = tableRows.subList(2, tableRows.size() - 1);
List> rows = CollUtil.split(noTitleRows, 4);
for (List row : rows) {
StringBuilder sb = new StringBuilder();
@@ -331,8 +335,10 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
log.info("sliceDocuments:错误的布局类型: {}", layoutType);
}
}
- if (!truncateTextBuild.isEmpty() && null != documentDTOLast) {
- truncateDTOS.add(new TruncateDTO(documentDTOLast, truncateTextBuild.toString()));
+ if (!truncateTextBuild.isEmpty() && null != CollUtil.getLast(documentDTOList)) {
+ TruncateDTO truncateDTO = new TruncateDTO(CollUtil.getLast(documentDTOList), truncateTextBuild.toString());
+ truncateDTO.setLayoutType(String.valueOf(LayoutTypeEnum.TEXT.getCode()));//强制设置为文本类型
+ truncateDTOS.add(truncateDTO);
}
return truncateDTOS;
}
@@ -465,6 +471,33 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
log.info("doTableEre响应结果:{}", response);
EREDTO eredto = EREDTO.fromTableJson(response, truncateDTO.getId());
// 手动设置表格标题
+ manualSetTableTitle(truncateDTO, eredto);
+ return eredto;
+ }
+
+ private EREDTO doTableEreFast(TruncateDTO truncateDTO){
+ log.info("doTableEreFast:开始进行表格实体关系抽取,内容:{}", truncateDTO.getContent());
+ if (StrUtil.isEmpty(truncateDTO.getContent())){
+ return null;
+ }
+ List extensions = Arrays.asList(TablesExtension.create());
+ Parser parser = Parser.builder().extensions(extensions).build();
+
+ Node document = parser.parse(truncateDTO.getContent());
+ TableVisitor visitor = new TableVisitor();
+ document.accept(visitor);
+
+ List tableHeaders = visitor.getTableHeaders();
+ List> tableRows = visitor.getTableRows();
+
+ EREDTO eredto = EREDTO.fromHeadAndRows(tableHeaders, tableRows, truncateDTO.getId());
+ // 手动设置表格标题
+ manualSetTableTitle(truncateDTO, eredto);
+ return eredto;
+
+ }
+
+ private void manualSetTableTitle(TruncateDTO truncateDTO, EREDTO eredto) {
EntityExtractionDTO titleEntity = new EntityExtractionDTO();
titleEntity.setEntity("表");
titleEntity.setTruncationId(truncateDTO.getId());
@@ -478,7 +511,6 @@ public class TripleConversionPipelineImpl implements TripleConversionPipeline {
}
eredto.getEntities().add(titleEntity);
eredto.setRelations(relations);
- return eredto;
}
/**
diff --git a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java
index c056497..f5550b6 100644
--- a/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java
+++ b/src/test/java/com/supervision/pdfqaserver/PdfQaServerApplicationTests.java
@@ -17,6 +17,7 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static org.neo4j.driver.Values.parameters;
+import org.commonmark.node.*;
@Slf4j
@SpringBootTest
@@ -153,13 +154,13 @@ class PdfQaServerApplicationTests {
@Test
public void metaDataTrainTest() {
- knowledgeGraphService.metaDataTrain(13);
+ knowledgeGraphService.metaDataTrain(15);
}
@Test
void generateGraphBaseTrainTest() {
- knowledgeGraphService.generateGraphBaseTrain(13);
+ knowledgeGraphService.generateGraphBaseTrain(14);
}