diff --git a/src/main/java/com/supervision/config/TransactionManagerConfig.java b/src/main/java/com/supervision/config/TransactionManagerConfig.java index 0ad1a6f..8cdfc54 100644 --- a/src/main/java/com/supervision/config/TransactionManagerConfig.java +++ b/src/main/java/com/supervision/config/TransactionManagerConfig.java @@ -26,4 +26,12 @@ public class TransactionManagerConfig { //可以设置其他事务管理器属性 return transactionManager; } + + @Bean("testTransactionManager") + public DataSourceTransactionManager testTransactionManager(DataSource dataSource) { + DataSourceTransactionManager transactionManager = new DataSourceTransactionManager(); + transactionManager.setDataSource(dataSource); + //可以设置其他事务管理器属性 + return transactionManager; + } } \ No newline at end of file diff --git a/src/main/java/com/supervision/police/domain/NotePrompt.java b/src/main/java/com/supervision/police/domain/NotePrompt.java index 922a54c..faa92b3 100644 --- a/src/main/java/com/supervision/police/domain/NotePrompt.java +++ b/src/main/java/com/supervision/police/domain/NotePrompt.java @@ -32,6 +32,8 @@ public class NotePrompt implements Serializable { private String startEntityType; + private String relType; + private String endEntityType; /** diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index 088a7eb..730b331 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -21,6 +21,7 @@ import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StopWatch; import java.time.LocalDateTime; @@ -52,6 +53,7 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { @Async + @Transactional(transactionManager = "testTransactionManager",rollbackFor = Exception.class) public void extractTripleInfo(String caseId, String name, String recordId) { // 首先获取所有切分后的笔录 List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) @@ -150,7 +152,9 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); // TODO 这里,如果已经生成了图谱,怎么办? // 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 - tripleInfoService.saveBatch(tripleInfos); + for (TripleInfo tripleInfo : tripleInfos) { + tripleInfoService.save(tripleInfo); + } } if (CollUtil.isNotEmpty(futures)) { // 将任务标记为成功 diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index d9e063e..e664935 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.police.service.impl; +import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; @@ -116,6 +117,9 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { + if (StrUtil.isBlank(recordId)){ + throw new RuntimeException("笔录ID不能为空"); + } boolean taskStatus = taskExtractStatusCheck(caseId, recordId); // 如果校验结果为false,则说明需要进行提取三元组操作 if (!taskStatus) { @@ -130,7 +134,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); + Optional caseTaskRecordOpt = caseTaskRecordService.lambdaQuery() + .eq(CaseTaskRecord::getType, 2).eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); if (caseTaskRecordOpt.isEmpty()) { CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord(); newCaseTaskRecord.setType(2); @@ -167,19 +172,100 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl { this.recordId = recordId; } + + /** + * 三元组提取任务:从给定对话中根据给定实体类型和关系提取对应关系的三元组。 + * 给定的头实体类型为"{headEntityType}";给定的尾实体类型为"{tailEntityType}",给定的关系为"{relation}"。 + * 请仔细分析以下的文本内容,精准找出符合给定关系且头尾实体类型相符的三元组,并进行提取。如果没有识别给定的三元组关系,请返回json:{"result":[]}。 + * --- + * 为您提供一个示例供学习: + * 给定三元组类型为:头实体类型:"行为人"关系:"伪造",尾实体类型:"合同" + * 办案警官问:描述一下事情的经过。 行为人小明答:我做了一份假的购房合同。 + * 本示例中应提取给定关系为"伪造"的三元组,则最终应提取的三元组为{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]}。 + * --- + * 需要分析提取的QA对如下: + * {question} + * {answer} + * --- + * 在提取三元组时,请务必严格遵循以下要求: + * 1. 精准理解需要分析的QA文本的含义,确保提取的信息准确无误、合理恰当。 + * 2. 只提取给定的实体类型和关系,不要提取给定关系和实体之外的三元组。 + * 3. 尽量遵循常见的语义和逻辑规则,杜绝过度解读或不合理的关系推断。 + * 4. 例子仅供参考,不要简单地返回示例中的结果。 + * 5. 提取之后,再检查一遍,提取的关系和实体是否与给定关系和实体类型对应 + * 返回格式为必须为以下的json格式: + * {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]} + */ @Override public TripleInfo call() { try { @@ -58,44 +78,59 @@ public class TripleExtractThread implements Callable { // 分析三元组 stopWatch.start(); HashMap paramMap = new HashMap<>(); - paramMap.put("qaRecord", question + answer); + paramMap.put("headEntityType", prompt.getStartEntityType()); + paramMap.put("relation", prompt.getRelType()); + paramMap.put("tailEntityType", prompt.getEndEntityType()); + paramMap.put("question", question); + paramMap.put("answer", answer); Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap))); - log.info("开始分析:"); ChatResponse call = chatClient.call(ask); stopWatch.stop(); - log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); String content = call.getResult().getOutput().getContent(); - log.info("分析的结果是:{}", content); + log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content); // 获取从提示词中提取到的三元组信息 - JSONObject jsonObject = new JSONObject(content); - // 修改,经测试,一次提取多个三元组效果较差,改成一次只提取一个三元组 - //JSONArray threeInfo = jsonObject.getJSONArray("result"); - //for (int i = 0; i < threeInfo.length(); i++) { - //JSONObject object = threeInfo.getJSONObject(i); - String entity = jsonObject.getString("主体"); - String relation = jsonObject.getString("关系"); - String value = jsonObject.getString("客体"); - // 类型信息从notePrompt对象中获取 - // String startNodeType = object.getString("startNodeType"); - // String endNodeType = object.getString("endNodeType"); - // 去空,如果存在任何的空值,则忽略 -// if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { -// continue; -// } - if (StrUtil.hasEmpty(entity, relation, value)) { - log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", entity, relation, value); + TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class); + if (ObjectUtil.isEmpty(extractResult) || extractResult.result.isEmpty()) { + log.info("提取三元组信息为空,忽略"); return null; } - // 构建三元组信息 - return new TripleInfo(entity, relation, value, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType()); - //} + for (TripleExtractNode tripleExtractNode : extractResult.getResult()) { + TripleEntity headEntity = tripleExtractNode.getHeadEntity(); + TripleEntity tailEntity = tripleExtractNode.getTailEntity(); + String relation = tripleExtractNode.getRelation(); + if (StrUtil.hasEmpty(headEntity.getName(), relation, tailEntity.getName())) { + log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", headEntity.getName(), relation, tailEntity.getName()); + return null; + } + // 构建三元组信息 + return new TripleInfo(headEntity.getName(), tailEntity.getName(), relation, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType()); + } } catch (Exception e) { log.error("提取三元组出现错误", e); } return null; } + @Data + public static class TripleExtractResult { + private List result; + + } + + @Data + public static class TripleExtractNode { + private TripleEntity headEntity; + private String relation; + private TripleEntity tailEntity; + } + + @Data + public static class TripleEntity { + private String name; + private String type; + } + } diff --git a/src/main/resources/application-dev.yml b/src/main/resources/application-dev.yml index 353f40a..49a0fa1 100644 --- a/src/main/resources/application-dev.yml +++ b/src/main/resources/application-dev.yml @@ -4,11 +4,13 @@ spring: ai: # 文档地址 https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/api/chat/ollama-chat.html ollama: +# base-url: http://192.168.10.70:12434 base-url: http://192.168.10.70:11434 # base-url: http://124.220.94.55:8060 chat: enabled: true options: + #model: qwen2:7b model: llama3-chinese:8b # model: qwen2:72b # 控制模型在请求后加载到内存中的时间(稍微长一点的时间,避免重复加载浪费性能,加快处理速度)