提示词配置代码优化bugfix

topo_dev
liu
parent 91b64b7c33
commit 694ca17dd7

@ -45,6 +45,11 @@ public class TripleInfo implements Serializable {
*/ */
private String recordSplitId; private String recordSplitId;
/**
*
*/
private String submitPrompt;
/** /**
* *
*/ */

@ -60,13 +60,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public List<TypeDTO> queryTypeListChoose() { public List<TypeDTO> queryTypeListChoose() {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(null); List<ModelRecordType> list = modelRecordTypeMapper.selectByName(null);
return list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toList()); return list.stream().map(e -> new TypeDTO(e.getId(), e.getRecordType())).collect(Collectors.toList());
} }
@Override @Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) { public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name); List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name);
Map<String, TypeDTO> modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity())); Map<String, TypeDTO> modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(), e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity()));
// 获取类型和三元组提示词的关系表 // 获取类型和三元组提示词的关系表
List<NotePromptTypeRel> relList = notePromptTypeRelService.lambdaQuery().list(); List<NotePromptTypeRel> relList = notePromptTypeRelService.lambdaQuery().list();
Map<String, List<NotePromptTypeRel>> typeIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getTypeId)); Map<String, List<NotePromptTypeRel>> typeIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getTypeId));
@ -152,6 +152,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
int i = 0; int i = 0;
boolean save; boolean save;
if (StringUtils.isEmpty(prompt.getId())) { if (StringUtils.isEmpty(prompt.getId())) {
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(),prompt.getEndEntityType(),null);
save = notePromptService.save(prompt); save = notePromptService.save(prompt);
// 新增prompt绑定的分类信息 // 新增prompt绑定的分类信息
for (String typeId : typeList) { for (String typeId : typeList) {
@ -161,6 +163,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
notePromptTypeRelService.save(rel); notePromptTypeRelService.save(rel);
} }
} else { } else {
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(),prompt.getEndEntityType(),prompt.getId());
save = notePromptService.updateById(prompt); save = notePromptService.updateById(prompt);
// 更新prompt绑定的分类信息 // 更新prompt绑定的分类信息
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增 // 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
@ -175,13 +178,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove(); notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
} }
// 新增(前端有数据库没有的) // 新增(前端有数据库没有的)
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e ->{ frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
NotePromptTypeRel rel = new NotePromptTypeRel(); NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId()); rel.setPromptId(prompt.getId());
rel.setTypeId(e); rel.setTypeId(e);
notePromptTypeRelService.save(rel); notePromptTypeRelService.save(rel);
}); });
}else { } else {
// 如果数据库里面没查到,直接新增,一般不会走这一步 // 如果数据库里面没查到,直接新增,一般不会走这一步
for (String typeId : typeList) { for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel(); NotePromptTypeRel rel = new NotePromptTypeRel();
@ -218,6 +221,22 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
} }
} }
private void checkHasSameTriple(String startEntityType, String relType, String endEntityType,String promptId) {
List<NotePrompt> list = notePromptService.lambdaQuery().eq(NotePrompt::getStartEntityType, startEntityType)
.eq(NotePrompt::getRelType, relType).eq(NotePrompt::getEndEntityType, endEntityType).list();
if (CollUtil.isNotEmpty(list) ) {
if (StrUtil.isBlank(promptId)){
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
}else {
// 校验list查出来的是不是和promptId相等,如果不想等,也报错
if (!list.get(0).getId().equals(promptId)){
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
}
}
}
}
@Override @Override
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class) @Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public R<?> delPrompt(NotePrompt prompt) { public R<?> delPrompt(NotePrompt prompt) {

@ -87,7 +87,6 @@ public class RecordSplitTypeThread implements Callable<Boolean> {
--- ---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]} json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
"""; """;
private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}"; private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
@Override @Override

@ -52,25 +52,27 @@ public class TripleExtractThread implements Callable<TripleInfo> {
/** /**
* : * :
*
* 1.
* 2.
* 3.
* 4. !!
* 5. ,,
*
* "{headEntityType}";"{tailEntityType}","{relation}" * "{headEntityType}";"{tailEntityType}","{relation}"
* json:{"result":[]} * json:{"result":[]}
* --- * ---
* : * example:
* ::"行为人":"伪造",:"合同" *
* : : * : :
* "伪造",{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]} * {"result":[{"headEntity": {"type": "{headEntityType}","name":"小明"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "法国的高档化妆品"}}]}
* --- * ---
* QA *
* {question} * {question}
* {answer} * {answer}
* --- * ---
* *
* 1. QA
* 2.
* 3.
* 4. ,
* 5. ,,
* json: * json:
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]} * {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
*/ */
@ -87,19 +89,20 @@ public class TripleExtractThread implements Callable<TripleInfo> {
paramMap.put("question", question); paramMap.put("question", question);
paramMap.put("answer", answer); paramMap.put("answer", answer);
String format = StrUtil.format(prompt.getPrompt(), paramMap); String format = StrUtil.format(prompt.getPrompt(), paramMap);
// 对format进行切分,把前面两个---作为systemPrompt // // 对format进行切分,把前面两个---作为systemPrompt
String[] split = format.split("---"); // String[] split = format.split("---");
SystemMessage systemMessage = new SystemMessage(split[0]); // SystemMessage systemMessage = new SystemMessage(split[0]);
UserMessage userMessage = new UserMessage(split[1]); // UserMessage userMessage = new UserMessage(split[1]);
Prompt ask = new Prompt(List.of(systemMessage, userMessage)); // Prompt ask = new Prompt(List.of(systemMessage, userMessage));
ChatResponse call = chatClient.call(ask); // ChatResponse call = chatClient.call(ask);
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
stopWatch.stop(); stopWatch.stop();
String content = call.getResult().getOutput().getContent(); String content = call.getResult().getOutput().getContent();
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content); log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
// 获取从提示词中提取到的三元组信息 // 获取从提示词中提取到的三元组信息
TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class); TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class);
if (ObjectUtil.isEmpty(extractResult) || CollUtil.isEmpty(extractResult.getResult())) { if (ObjectUtil.isEmpty(extractResult) || CollUtil.isEmpty(extractResult.getResult())) {
log.info("提取三元组信息为空,忽略.提取的内容为:{}",content); log.info("提取三元组信息为空,忽略.提取的内容为:{}", content);
return null; return null;
} }
for (TripleExtractNode tripleExtractNode : extractResult.getResult()) { for (TripleExtractNode tripleExtractNode : extractResult.getResult()) {
@ -118,6 +121,7 @@ public class TripleExtractThread implements Callable<TripleInfo> {
tripleInfo.setCaseId(caseId); tripleInfo.setCaseId(caseId);
tripleInfo.setRecordId(recordId); tripleInfo.setRecordId(recordId);
tripleInfo.setRecordSplitId(recordSplitId); tripleInfo.setRecordSplitId(recordSplitId);
tripleInfo.setSubmitPrompt(format);
tripleInfo.setStartNodeType(prompt.getStartEntityType()); tripleInfo.setStartNodeType(prompt.getStartEntityType());
tripleInfo.setEndNodeType(prompt.getEndEntityType()); tripleInfo.setEndNodeType(prompt.getEndEntityType());
return tripleInfo; return tripleInfo;

Loading…
Cancel
Save