diff --git a/src/main/java/com/supervision/config/TransactionManagerConfig.java b/src/main/java/com/supervision/config/TransactionManagerConfig.java index 8cdfc54..1696726 100644 --- a/src/main/java/com/supervision/config/TransactionManagerConfig.java +++ b/src/main/java/com/supervision/config/TransactionManagerConfig.java @@ -27,11 +27,4 @@ public class TransactionManagerConfig { return transactionManager; } - @Bean("testTransactionManager") - public DataSourceTransactionManager testTransactionManager(DataSource dataSource) { - DataSourceTransactionManager transactionManager = new DataSourceTransactionManager(); - transactionManager.setDataSource(dataSource); - //可以设置其他事务管理器属性 - return transactionManager; - } } \ No newline at end of file diff --git a/src/main/java/com/supervision/police/domain/ModelRecordType.java b/src/main/java/com/supervision/police/domain/ModelRecordType.java index 795b943..048ae8b 100644 --- a/src/main/java/com/supervision/police/domain/ModelRecordType.java +++ b/src/main/java/com/supervision/police/domain/ModelRecordType.java @@ -6,6 +6,7 @@ import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.fasterxml.jackson.annotation.JsonFormat; +import com.supervision.police.dto.TripleInfoDTO; import lombok.Data; import java.io.Serializable; @@ -44,6 +45,7 @@ public class ModelRecordType implements Serializable { @TableField(exist = false) private List records; + /** * 创建人ID */ diff --git a/src/main/java/com/supervision/police/domain/NotePrompt.java b/src/main/java/com/supervision/police/domain/NotePrompt.java index faa92b3..96881f8 100644 --- a/src/main/java/com/supervision/police/domain/NotePrompt.java +++ b/src/main/java/com/supervision/police/domain/NotePrompt.java @@ -5,10 +5,12 @@ import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.fasterxml.jackson.annotation.JsonFormat; +import com.supervision.police.dto.TripleInfoDTO; import lombok.Data; import java.io.Serializable; import java.time.LocalDateTime; +import java.util.List; @TableName(value = "note_prompt") @Data @@ -32,10 +34,20 @@ public class NotePrompt implements Serializable { private String startEntityType; + private String startEntityTemplate; + private String relType; + private String relTemplate; + private String endEntityType; + private String endEntityTemplate; + + + @TableField(exist = false) + private List tripleList; + /** * 创建人ID */ diff --git a/src/main/java/com/supervision/police/dto/TripleInfoDTO.java b/src/main/java/com/supervision/police/dto/TripleInfoDTO.java new file mode 100644 index 0000000..d28320e --- /dev/null +++ b/src/main/java/com/supervision/police/dto/TripleInfoDTO.java @@ -0,0 +1,24 @@ +package com.supervision.police.dto; + +import lombok.Data; + +@Data +public class TripleInfoDTO { + + + /** + * 类型 + */ + private String type; + + /** + * 占位符名称 + */ + private String templateName; + + + /** + * 占位符值 + */ + private String value; +} diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java index 730b331..198f778 100644 --- a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -53,7 +53,7 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { @Async - @Transactional(transactionManager = "testTransactionManager",rollbackFor = Exception.class) + @Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class) public void extractTripleInfo(String caseId, String name, String recordId) { // 首先获取所有切分后的笔录 List recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index 94331ed..80fb67d 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -4,12 +4,14 @@ import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.supervision.common.domain.R; import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.service.Neo4jService; import com.supervision.police.domain.*; +import com.supervision.police.dto.TripleInfoDTO; import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.service.*; @@ -21,9 +23,11 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import org.springframework.util.StopWatch; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -57,13 +61,37 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType()); modelRecordType.setRecords(noteRecords); + // grideOptions //提示词 List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list(); + for (NotePrompt prompt : prompts) { + prompt.setTripleList(buildTripleInfo(prompt)); + } modelRecordType.setPrompts(prompts); } return list; } + private List buildTripleInfo(NotePrompt notePrompt) { + List list = new ArrayList<>(); + TripleInfoDTO dto = new TripleInfoDTO(); + dto.setType("头节点"); + dto.setTemplateName(notePrompt.getStartEntityTemplate()); + dto.setValue(notePrompt.getStartEntityType()); + list.add(dto); + TripleInfoDTO dto1 = new TripleInfoDTO(); + dto1.setType("关系"); + dto1.setTemplateName(notePrompt.getRelTemplate()); + dto1.setValue(notePrompt.getRelType()); + list.add(dto1); + TripleInfoDTO dto2 = new TripleInfoDTO(); + dto2.setType("尾节点"); + dto2.setTemplateName(notePrompt.getEndEntityTemplate()); + dto2.setValue(notePrompt.getEndEntityType()); + list.add(dto2); + return list; + } + @Override public ModelRecordType queryByName(String content) { Wrapper wrapper = new QueryWrapper().eq("record_type", content); @@ -88,6 +116,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl addOrUpdPrompt(NotePrompt prompt) { int i = 0; boolean save; @@ -96,6 +125,24 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl tripleList = prompt.getTripleList(); + for (TripleInfoDTO dto : tripleList) { + if ("头节点".equals(dto.getType())) { + notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName()) + .set(NotePrompt::getStartEntityType, dto.getValue()) + .eq(NotePrompt::getId, prompt.getId()).update(); + }else if ("关系".equals(dto.getType())){ + notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName()) + .set(NotePrompt::getRelType, dto.getValue()) + .eq(NotePrompt::getId, prompt.getId()).update(); + }else if ("尾节点".equals(dto.getType())){ + notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName()) + .set(NotePrompt::getEndEntityType, dto.getValue()) + .eq(NotePrompt::getId, prompt.getId()).update(); + } + } + if (save) { return R.ok("保存成功"); } else { @@ -117,7 +164,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { - if (StrUtil.isBlank(recordId)){ + if (StrUtil.isBlank(recordId)) { throw new RuntimeException("笔录ID不能为空"); } boolean taskStatus = taskExtractStatusCheck(caseId, recordId);