提示词配置代码优化

topo_dev
liu 9 months ago
parent 79472437ec
commit 9277af78de

@ -27,11 +27,4 @@ public class TransactionManagerConfig {
return transactionManager; return transactionManager;
} }
@Bean("testTransactionManager")
public DataSourceTransactionManager testTransactionManager(DataSource dataSource) {
DataSourceTransactionManager transactionManager = new DataSourceTransactionManager();
transactionManager.setDataSource(dataSource);
//可以设置其他事务管理器属性
return transactionManager;
}
} }

@ -6,6 +6,7 @@ import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonFormat;
import com.supervision.police.dto.TripleInfoDTO;
import lombok.Data; import lombok.Data;
import java.io.Serializable; import java.io.Serializable;
@ -44,6 +45,7 @@ public class ModelRecordType implements Serializable {
@TableField(exist = false) @TableField(exist = false)
private List<NoteRecordSplit> records; private List<NoteRecordSplit> records;
/** /**
* ID * ID
*/ */

@ -5,10 +5,12 @@ import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName; import com.baomidou.mybatisplus.annotation.TableName;
import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonFormat;
import com.supervision.police.dto.TripleInfoDTO;
import lombok.Data; import lombok.Data;
import java.io.Serializable; import java.io.Serializable;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.List;
@TableName(value = "note_prompt") @TableName(value = "note_prompt")
@Data @Data
@ -32,10 +34,20 @@ public class NotePrompt implements Serializable {
private String startEntityType; private String startEntityType;
private String startEntityTemplate;
private String relType; private String relType;
private String relTemplate;
private String endEntityType; private String endEntityType;
private String endEntityTemplate;
@TableField(exist = false)
private List<TripleInfoDTO> tripleList;
/** /**
* ID * ID
*/ */

@ -0,0 +1,24 @@
package com.supervision.police.dto;
import lombok.Data;
@Data
public class TripleInfoDTO {
/**
*
*/
private String type;
/**
*
*/
private String templateName;
/**
*
*/
private String value;
}

@ -53,7 +53,7 @@ public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
@Async @Async
@Transactional(transactionManager = "testTransactionManager",rollbackFor = Exception.class) @Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class)
public void extractTripleInfo(String caseId, String name, String recordId) { public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录 // 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId) List<NoteRecordSplit> recordSplitList = noteRecordSplitService.lambdaQuery().eq(StrUtil.isNotBlank(recordId), NoteRecordSplit::getNoteRecordsId, recordId)

@ -4,12 +4,14 @@ import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils; import com.alibaba.druid.util.StringUtils;
import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.common.domain.R; import com.supervision.common.domain.R;
import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.CaseNode;
import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.domain.Rel;
import com.supervision.neo4j.service.Neo4jService; import com.supervision.neo4j.service.Neo4jService;
import com.supervision.police.domain.*; import com.supervision.police.domain.*;
import com.supervision.police.dto.TripleInfoDTO;
import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.service.*; import com.supervision.police.service.*;
@ -21,9 +23,11 @@ import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional; import java.util.Optional;
@ -57,13 +61,37 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
//笔录内容 //笔录内容
List<NoteRecordSplit> noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType()); List<NoteRecordSplit> noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType());
modelRecordType.setRecords(noteRecords); modelRecordType.setRecords(noteRecords);
// grideOptions
//提示词 //提示词
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list(); List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list();
for (NotePrompt prompt : prompts) {
prompt.setTripleList(buildTripleInfo(prompt));
}
modelRecordType.setPrompts(prompts); modelRecordType.setPrompts(prompts);
} }
return list; return list;
} }
private List<TripleInfoDTO> buildTripleInfo(NotePrompt notePrompt) {
List<TripleInfoDTO> list = new ArrayList<>();
TripleInfoDTO dto = new TripleInfoDTO();
dto.setType("头节点");
dto.setTemplateName(notePrompt.getStartEntityTemplate());
dto.setValue(notePrompt.getStartEntityType());
list.add(dto);
TripleInfoDTO dto1 = new TripleInfoDTO();
dto1.setType("关系");
dto1.setTemplateName(notePrompt.getRelTemplate());
dto1.setValue(notePrompt.getRelType());
list.add(dto1);
TripleInfoDTO dto2 = new TripleInfoDTO();
dto2.setType("尾节点");
dto2.setTemplateName(notePrompt.getEndEntityTemplate());
dto2.setValue(notePrompt.getEndEntityType());
list.add(dto2);
return list;
}
@Override @Override
public ModelRecordType queryByName(String content) { public ModelRecordType queryByName(String content) {
Wrapper<ModelRecordType> wrapper = new QueryWrapper<ModelRecordType>().eq("record_type", content); Wrapper<ModelRecordType> wrapper = new QueryWrapper<ModelRecordType>().eq("record_type", content);
@ -88,6 +116,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
} }
@Override @Override
@Transactional(transactionManager = "dataSourceTransactionManager",rollbackFor = Exception.class)
public R<?> addOrUpdPrompt(NotePrompt prompt) { public R<?> addOrUpdPrompt(NotePrompt prompt) {
int i = 0; int i = 0;
boolean save; boolean save;
@ -96,6 +125,24 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
} else { } else {
save = notePromptService.updateById(prompt); save = notePromptService.updateById(prompt);
} }
// 更新类型字段
List<TripleInfoDTO> tripleList = prompt.getTripleList();
for (TripleInfoDTO dto : tripleList) {
if ("头节点".equals(dto.getType())) {
notePromptService.lambdaUpdate().set(NotePrompt::getStartEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getStartEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}else if ("关系".equals(dto.getType())){
notePromptService.lambdaUpdate().set(NotePrompt::getRelTemplate, dto.getTemplateName())
.set(NotePrompt::getRelType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}else if ("尾节点".equals(dto.getType())){
notePromptService.lambdaUpdate().set(NotePrompt::getEndEntityTemplate, dto.getTemplateName())
.set(NotePrompt::getEndEntityType, dto.getValue())
.eq(NotePrompt::getId, prompt.getId()).update();
}
}
if (save) { if (save) {
return R.ok("保存成功"); return R.ok("保存成功");
} else { } else {
@ -117,7 +164,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) { public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) {
if (StrUtil.isBlank(recordId)){ if (StrUtil.isBlank(recordId)) {
throw new RuntimeException("笔录ID不能为空"); throw new RuntimeException("笔录ID不能为空");
} }
boolean taskStatus = taskExtractStatusCheck(caseId, recordId); boolean taskStatus = taskExtractStatusCheck(caseId, recordId);

Loading…
Cancel
Save