diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index e9dabab..997de46 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -3,7 +3,11 @@ package com.supervision.police.service.impl; import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; -import com.supervision.police.domain.*; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.supervision.police.domain.CasePerson; +import com.supervision.police.domain.NotePrompt; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.domain.TripleInfo; import com.supervision.police.service.*; import com.supervision.thread.TripleExtractTask; import com.supervision.thread.TripleExtractTaskPool; @@ -11,11 +15,13 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.function.Consumer; @@ -33,6 +39,8 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { private final OllamaChatClient chatClient; + private final CasePersonService casePersonService; + @Autowired private NoteRecordSplitService noteRecordSplitService; @@ -54,61 +62,65 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { List notePromptList = notePromptService.listPromptBySplitId(recordSplitId); - if (CollUtil.isEmpty(notePromptList)){ - log.warn("extractTripleInfo:笔录片段:{},笔录分类:{} 不属于任何提示词,不进行后续操作...", recordSplit.getId(),recordSplit.getRecordType()); + if (CollUtil.isEmpty(notePromptList)) { + log.warn("extractTripleInfo:笔录片段:{},笔录分类:{} 不属于任何提示词,不进行后续操作...", recordSplit.getId(), recordSplit.getRecordType()); return; } + QueryWrapper wrapper = new QueryWrapper<>(); + wrapper.eq("case_id", caseId); + wrapper.eq("case_actor_flag", 1); + CasePerson mainActor = casePersonService.getOne(wrapper); List taskList = notePromptList.stream() .filter(prompt -> StrUtil.isNotBlank(prompt.getPrompt())) .peek(prompt -> { caseTaskRecordService.taskCountIncrement(caseId, recordSplit.getNoteRecordId()); - log.info("extractTripleInfo:三元组抽取任务数量加1,笔录片段id:{}",prompt.getId()); + log.info("extractTripleInfo:三元组抽取任务数量加1,笔录片段id:{}", prompt.getId()); }) - .map(prompt -> new TripleExtractTask(chatClient, prompt, recordSplit,postExtractTriple())).toList(); + .map(prompt -> new TripleExtractTask(chatClient, prompt, recordSplit, postExtractTriple(), mainActor)).toList(); - if (CollUtil.isEmpty(taskList)){ + if (CollUtil.isEmpty(taskList)) { log.info("extractTripleInfo:笔录片段:{} 没有可用的提示词,不提交任何任务...", recordSplit.getId()); return; } List tripleInfos = new ArrayList<>(); try { - log.info("extractTripleInfo:笔录片段:{}抽取任务成功提交{}个任务....",recordSplitId,taskList.size()); + log.info("extractTripleInfo:笔录片段:{}抽取任务成功提交{}个任务....", recordSplitId, taskList.size()); List> futures = TripleExtractTaskPool.executor.invokeAll(taskList); for (Future future : futures) { try { TripleInfo tripleInfo = future.get(); - if (Objects.nonNull(tripleInfo)){ + if (Objects.nonNull(tripleInfo)) { tripleInfos.add(tripleInfo); } } catch (ExecutionException e) { - log.error("extractTripleInfo:笔录片段:{}三元组提取任务执行失败...",recordSplitId,e); + log.error("extractTripleInfo:笔录片段:{}三元组提取任务执行失败...", recordSplitId, e); } } } catch (InterruptedException e) { - log.error("extractTripleInfo:笔录片段:{}三元组提取任务提交失败...",recordSplitId,e); + log.error("extractTripleInfo:笔录片段:{}三元组提取任务提交失败...", recordSplitId, e); } // 如果有提取到三元组信息 if (CollUtil.isNotEmpty(tripleInfos)) { for (TripleInfo tripleInfo : tripleInfos) { - log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,三元组信息入库:{}", recordSplitId,JSONUtil.toJsonStr(tripleInfo)); + log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,三元组信息入库:{}", recordSplitId, JSONUtil.toJsonStr(tripleInfo)); tripleInfoService.save(tripleInfo); } - }else { + } else { log.info("extractTripleInfo:笔录片段:{}三元组提取任务执行结束...,未提取到任何三元组信息...", recordSplitId); } } private Consumer postExtractTriple() { return (recordSplit) -> { - try{ + try { caseTaskRecordService.finishCountIncrement(recordSplit.getCaseId(), recordSplit.getNoteRecordId()); - log.info("postExtractTriple:抽取任务完成数量加1,笔录片段id:{}",recordSplit.getId()); - }catch (Exception e){ - log.error("postExtractTriple:笔录片段:{} 抽取任务执行后更新任务状态失败...",recordSplit.getId(),e); + log.info("postExtractTriple:抽取任务完成数量加1,笔录片段id:{}", recordSplit.getId()); + } catch (Exception e) { + log.error("postExtractTriple:笔录片段:{} 抽取任务执行后更新任务状态失败...", recordSplit.getId(), e); } }; diff --git a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java index 1b87be3..22373dd 100644 --- a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java @@ -117,12 +117,17 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl 0 ? 1 : 2); newCaseTaskRecord.setSubmitTime(LocalDateTime.now()); - return caseTaskRecordService.save(newCaseTaskRecord); + caseTaskRecordService.save(newCaseTaskRecord); + + return newCaseTaskRecord.getStatus().equals(1); + } if (0 == splitSize) { diff --git a/src/main/java/com/supervision/thread/TripleExtractTask.java b/src/main/java/com/supervision/thread/TripleExtractTask.java index 3f489a6..fc7cbfd 100644 --- a/src/main/java/com/supervision/thread/TripleExtractTask.java +++ b/src/main/java/com/supervision/thread/TripleExtractTask.java @@ -4,6 +4,7 @@ import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; +import com.supervision.police.domain.CasePerson; import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.domain.TripleInfo; @@ -22,6 +23,7 @@ import java.util.function.Consumer; @Slf4j public class TripleExtractTask implements Callable { + private static final String HEAD_ENTITY_TYPE_ACTOR = "行为人"; private final OllamaChatClient chatClient; @@ -31,12 +33,15 @@ public class TripleExtractTask implements Callable { private final Consumer consumer; + private final CasePerson mainActor; - public TripleExtractTask(OllamaChatClient chatClient, NotePrompt prompt, NoteRecordSplit noteRecordSplit, Consumer consumer) { + + public TripleExtractTask(OllamaChatClient chatClient, NotePrompt prompt, NoteRecordSplit noteRecordSplit, Consumer consumer, CasePerson mainActor) { this.chatClient = chatClient; this.noteRecordSplit = noteRecordSplit; this.prompt = prompt; this.consumer = consumer; + this.mainActor = mainActor; } @@ -48,7 +53,7 @@ public class TripleExtractTask implements Callable { * 3. 尽量遵循常见的语义和逻辑规则,杜绝过度解读或不合理的关系推断。 * 4. 不要提取例子中的实体和关系,提取的结果一定来自需要分析的文本内容!! * 5. 提取之后,再检查一遍,提取的关系和实体是否与给定关系和实体类型对应 - * + *

* 给定的头实体类型为"{headEntityType}";给定的尾实体类型为"{tailEntityType}",给定的关系为"{relation}"。 * 请仔细分析以下的文本内容,精准找出符合给定关系且头尾实体类型相符的三元组,并进行提取。如果没有识别给定的三元组关系,请返回json:{"result":[]}。 * --- @@ -61,7 +66,7 @@ public class TripleExtractTask implements Callable { * {question} * {answer} * --- - * + *

* 返回格式为必须为以下的json格式: * {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]} */ @@ -99,13 +104,13 @@ public class TripleExtractTask implements Callable { } } catch (Exception e) { log.error("提取三元组出现错误", e); - }finally { + } finally { consumer.accept(noteRecordSplit); } return null; } - private TripleRecord chat4Triple(NotePrompt prompt,NoteRecordSplit noteRecordSplit) { + private TripleRecord chat4Triple(NotePrompt prompt, NoteRecordSplit noteRecordSplit) { StopWatch stopWatch = new StopWatch(); // 分析三元组 @@ -116,6 +121,9 @@ public class TripleExtractTask implements Callable { paramMap.put("tailEntityType", prompt.getEndEntityType()); paramMap.put("question", noteRecordSplit.getQuestion()); paramMap.put("answer", noteRecordSplit.getAnswer()); + if (mainActor != null && HEAD_ENTITY_TYPE_ACTOR.equals(prompt.getStartEntityType())) { + paramMap.put("requirement", "当前案件的行为人是" + mainActor.getName() + ",只尝试提取" + mainActor.getName() + "为头结点的三元组。"); + } String format = StrUtil.format(prompt.getPrompt(), paramMap); log.info("提示词内容:{}", format); @@ -123,11 +131,12 @@ public class TripleExtractTask implements Callable { ChatResponse call = chatClient.call(new Prompt(new UserMessage(format))); stopWatch.stop(); String content = call.getResult().getOutput().getContent(); - log.info("问题:{}耗时:{},三元组提取结果是:{}",noteRecordSplit.getQuestion(),stopWatch.getTotalTimeSeconds(), content); + log.info("问题:{}耗时:{},三元组提取结果是:{}", noteRecordSplit.getQuestion(), stopWatch.getTotalTimeSeconds(), content); return new TripleRecord(format, content); } - record TripleRecord(String question, String answer){} + record TripleRecord(String question, String answer) { + } @Data public static class TripleExtractResult {