Merge branch 'dev_1.0.0' into ocr_branch

topo_dev
xueqingkun 9 months ago
commit 4279a55f59

@ -3,7 +3,11 @@ package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil; import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.*; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.supervision.police.domain.CasePerson;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.service.*; import com.supervision.police.service.*;
import com.supervision.thread.TripleExtractTask; import com.supervision.thread.TripleExtractTask;
import com.supervision.thread.TripleExtractTaskPool; import com.supervision.thread.TripleExtractTaskPool;
@ -11,11 +15,13 @@ import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.*; import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.function.Consumer; import java.util.function.Consumer;
@ -33,6 +39,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
private final CasePersonService casePersonService;
@Autowired @Autowired
private NoteRecordSplitService noteRecordSplitService; private NoteRecordSplitService noteRecordSplitService;
@ -54,61 +62,65 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
List<NotePrompt> notePromptList = notePromptService.listPromptBySplitId(recordSplitId); List<NotePrompt> notePromptList = notePromptService.listPromptBySplitId(recordSplitId);
if (CollUtil.isEmpty(notePromptList)){ if (CollUtil.isEmpty(notePromptList)) {
log.warn("extractTripleInfo:笔录片段:{},笔录分类:{} 不属于任何提示词,不进行后续操作...", recordSplit.getId(),recordSplit.getRecordType()); log.warn("extractTripleInfo:笔录片段:{},笔录分类:{} 不属于任何提示词,不进行后续操作...", recordSplit.getId(), recordSplit.getRecordType());
return; return;
} }
QueryWrapper<CasePerson> wrapper = new QueryWrapper<>();
wrapper.eq("case_id", caseId);
wrapper.eq("case_actor_flag", 1);
CasePerson mainActor = casePersonService.getOne(wrapper);
List<TripleExtractTask> taskList = notePromptList.stream() List<TripleExtractTask> taskList = notePromptList.stream()
.filter(prompt -> StrUtil.isNotBlank(prompt.getPrompt())) .filter(prompt -> StrUtil.isNotBlank(prompt.getPrompt()))
.peek(prompt -> { .peek(prompt -> {
caseTaskRecordService.taskCountIncrement(caseId, recordSplit.getNoteRecordId()); caseTaskRecordService.taskCountIncrement(caseId, recordSplit.getNoteRecordId());
log.info("extractTripleInfo:三元组抽取任务数量加1,笔录片段id:{}",prompt.getId()); log.info("extractTripleInfo:三元组抽取任务数量加1,笔录片段id:{}", prompt.getId());
}) })
.map(prompt -> new TripleExtractTask(chatClient, prompt, recordSplit,postExtractTriple())).toList(); .map(prompt -> new TripleExtractTask(chatClient, prompt, recordSplit, postExtractTriple(), mainActor)).toList();
if (CollUtil.isEmpty(taskList)){ if (CollUtil.isEmpty(taskList)) {
log.info("extractTripleInfo:笔录片段:{} 没有可用的提示词,不提交任何任务...", recordSplit.getId()); log.info("extractTripleInfo:笔录片段:{} 没有可用的提示词,不提交任何任务...", recordSplit.getId());
return; return;
} }
List<TripleInfo> tripleInfos = new ArrayList<>(); List<TripleInfo> tripleInfos = new ArrayList<>();
try { try {
log.info("extractTripleInfo:笔录片段:{}抽取任务成功提交{}个任务....",recordSplitId,taskList.size()); log.info("extractTripleInfo:笔录片段:{}抽取任务成功提交{}个任务....", recordSplitId, taskList.size());
List<Future<TripleInfo>> futures = TripleExtractTaskPool.executor.invokeAll(taskList); List<Future<TripleInfo>> futures = TripleExtractTaskPool.executor.invokeAll(taskList);
for (Future<TripleInfo> future : futures) { for (Future<TripleInfo> future : futures) {
try { try {
TripleInfo tripleInfo = future.get(); TripleInfo tripleInfo = future.get();
if (Objects.nonNull(tripleInfo)){ if (Objects.nonNull(tripleInfo)) {
tripleInfos.add(tripleInfo); tripleInfos.add(tripleInfo);
} }
} catch (ExecutionException e) { } catch (ExecutionException e) {
log.error("extractTripleInfo:笔录片段:{}三元组提取任务执行失败...",recordSplitId,e); log.error("extractTripleInfo:笔录片段:{}三元组提取任务执行失败...", recordSplitId, e);
} }
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
log.error("extractTripleInfo:笔录片段:{}三元组提取任务提交失败...",recordSplitId,e); log.error("extractTripleInfo:笔录片段:{}三元组提取任务提交失败...", recordSplitId, e);
} }
// 如果有提取到三元组信息 // 如果有提取到三元组信息
if (CollUtil.isNotEmpty(tripleInfos)) { if (CollUtil.isNotEmpty(tripleInfos)) {
for (TripleInfo tripleInfo : tripleInfos) { for (TripleInfo tripleInfo : tripleInfos) {
log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,三元组信息入库:{}", recordSplitId,JSONUtil.toJsonStr(tripleInfo)); log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,三元组信息入库:{}", recordSplitId, JSONUtil.toJsonStr(tripleInfo));
tripleInfoService.save(tripleInfo); tripleInfoService.save(tripleInfo);
} }
}else { } else {
log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,未提取到任何三元组信息...", recordSplitId); log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,未提取到任何三元组信息...", recordSplitId);
} }
} }
private Consumer<NoteRecordSplit> postExtractTriple() { private Consumer<NoteRecordSplit> postExtractTriple() {
return (recordSplit) -> { return (recordSplit) -> {
try{ try {
caseTaskRecordService.finishCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId()); caseTaskRecordService.finishCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId());
log.info("postExtractTriple:抽取任务完成数量加1,笔录片段id:{}",recordSplit.getId()); log.info("postExtractTriple:抽取任务完成数量加1,笔录片段id:{}", recordSplit.getId());
}catch (Exception e){ } catch (Exception e) {
log.error("postExtractTriple:笔录片段:{} 抽取任务执行后更新任务状态失败...",recordSplit.getId(),e); log.error("postExtractTriple:笔录片段:{} 抽取任务执行后更新任务状态失败...", recordSplit.getId(), e);
} }
}; };

@ -117,12 +117,17 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMappe
if (caseTaskRecordOpt.isEmpty()) { if (caseTaskRecordOpt.isEmpty()) {
log.info("recordProcessTaskStatusCheck:recordId:{}未查询到任务记录, 新建任务记录...",recordId);
CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord(); CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord();
newCaseTaskRecord.setCaseId(caseId); newCaseTaskRecord.setCaseId(caseId);
newCaseTaskRecord.setRecordId(recordId); newCaseTaskRecord.setRecordId(recordId);
newCaseTaskRecord.setStatus(1);
newCaseTaskRecord.setStatus(splitSize > 0 ? 1 : 2);
newCaseTaskRecord.setSubmitTime(LocalDateTime.now()); newCaseTaskRecord.setSubmitTime(LocalDateTime.now());
return caseTaskRecordService.save(newCaseTaskRecord); caseTaskRecordService.save(newCaseTaskRecord);
return newCaseTaskRecord.getStatus().equals(1);
} }
if (0 == splitSize) { if (0 == splitSize) {

@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.CasePerson;
import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.TripleInfo; import com.supervision.police.domain.TripleInfo;
@ -22,6 +23,7 @@ import java.util.function.Consumer;
@Slf4j @Slf4j
public class TripleExtractTask implements Callable<TripleInfo> { public class TripleExtractTask implements Callable<TripleInfo> {
private static final String HEAD_ENTITY_TYPE_ACTOR = "行为人";
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
@ -31,12 +33,15 @@ public class TripleExtractTask implements Callable<TripleInfo> {
private final Consumer<NoteRecordSplit> consumer; private final Consumer<NoteRecordSplit> consumer;
private final CasePerson mainActor;
public TripleExtractTask(OllamaChatClient chatClient, NotePrompt prompt, NoteRecordSplit noteRecordSplit, Consumer<NoteRecordSplit> consumer) {
public TripleExtractTask(OllamaChatClient chatClient, NotePrompt prompt, NoteRecordSplit noteRecordSplit, Consumer<NoteRecordSplit> consumer, CasePerson mainActor) {
this.chatClient = chatClient; this.chatClient = chatClient;
this.noteRecordSplit = noteRecordSplit; this.noteRecordSplit = noteRecordSplit;
this.prompt = prompt; this.prompt = prompt;
this.consumer = consumer; this.consumer = consumer;
this.mainActor = mainActor;
} }
@ -48,7 +53,7 @@ public class TripleExtractTask implements Callable<TripleInfo> {
* 3. * 3.
* 4. !! * 4. !!
* 5. ,, * 5. ,,
* * <p>
* "{headEntityType}";"{tailEntityType}","{relation}" * "{headEntityType}";"{tailEntityType}","{relation}"
* json:{"result":[]} * json:{"result":[]}
* --- * ---
@ -61,7 +66,7 @@ public class TripleExtractTask implements Callable<TripleInfo> {
* {question} * {question}
* {answer} * {answer}
* --- * ---
* * <p>
* json: * json:
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]} * {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
*/ */
@ -99,13 +104,13 @@ public class TripleExtractTask implements Callable<TripleInfo> {
} }
} catch (Exception e) { } catch (Exception e) {
log.error("提取三元组出现错误", e); log.error("提取三元组出现错误", e);
}finally { } finally {
consumer.accept(noteRecordSplit); consumer.accept(noteRecordSplit);
} }
return null; return null;
} }
private TripleRecord chat4Triple(NotePrompt prompt,NoteRecordSplit noteRecordSplit) { private TripleRecord chat4Triple(NotePrompt prompt, NoteRecordSplit noteRecordSplit) {
StopWatch stopWatch = new StopWatch(); StopWatch stopWatch = new StopWatch();
// 分析三元组 // 分析三元组
@ -116,6 +121,9 @@ public class TripleExtractTask implements Callable<TripleInfo> {
paramMap.put("tailEntityType", prompt.getEndEntityType()); paramMap.put("tailEntityType", prompt.getEndEntityType());
paramMap.put("question", noteRecordSplit.getQuestion()); paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer()); paramMap.put("answer", noteRecordSplit.getAnswer());
if (mainActor != null && HEAD_ENTITY_TYPE_ACTOR.equals(prompt.getStartEntityType())) {
paramMap.put("requirement", "当前案件的行为人是" + mainActor.getName() + ",只尝试提取" + mainActor.getName() + "为头结点的三元组。");
}
String format = StrUtil.format(prompt.getPrompt(), paramMap); String format = StrUtil.format(prompt.getPrompt(), paramMap);
log.info("提示词内容:{}", format); log.info("提示词内容:{}", format);
@ -123,11 +131,12 @@ public class TripleExtractTask implements Callable<TripleInfo> {
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format))); ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
stopWatch.stop(); stopWatch.stop();
String content = call.getResult().getOutput().getContent(); String content = call.getResult().getOutput().getContent();
log.info("问题:{}耗时:{},三元组提取结果是:{}",noteRecordSplit.getQuestion(),stopWatch.getTotalTimeSeconds(), content); log.info("问题:{}耗时:{},三元组提取结果是:{}", noteRecordSplit.getQuestion(), stopWatch.getTotalTimeSeconds(), content);
return new TripleRecord(format, content); return new TripleRecord(format, content);
} }
record TripleRecord(String question, String answer){} record TripleRecord(String question, String answer) {
}
@Data @Data
public static class TripleExtractResult { public static class TripleExtractResult {

Loading…
Cancel
Save