提示词配置代码优化bugfix

topo_dev
liu 9 months ago
parent 91b64b7c33
commit 694ca17dd7

@ -45,6 +45,11 @@ public class TripleInfo implements Serializable {
*/
private String recordSplitId;
/**
*
*/
private String submitPrompt;
/**
*
*/

@ -60,13 +60,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override
public List<TypeDTO> queryTypeListChoose() {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(null);
return list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toList());
return list.stream().map(e -> new TypeDTO(e.getId(), e.getRecordType())).collect(Collectors.toList());
}
@Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name);
Map<String, TypeDTO> modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(),e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity()));
Map<String, TypeDTO> modelRecordTypeMap = list.stream().map(e -> new TypeDTO(e.getId(), e.getRecordType())).collect(Collectors.toMap(TypeDTO::getId, Function.identity()));
// 获取类型和三元组提示词的关系表
List<NotePromptTypeRel> relList = notePromptTypeRelService.lambdaQuery().list();
Map<String, List<NotePromptTypeRel>> typeIdMap = relList.stream().collect(Collectors.groupingBy(NotePromptTypeRel::getTypeId));
@ -152,6 +152,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
int i = 0;
boolean save;
if (StringUtils.isEmpty(prompt.getId())) {
// 新增的时候,校验是否已经存在相同的三元组关系,如果已经存在了相同的三元组关系,不允许添加
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(),prompt.getEndEntityType(),null);
save = notePromptService.save(prompt);
// 新增prompt绑定的分类信息
for (String typeId : typeList) {
@ -161,6 +163,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
notePromptTypeRelService.save(rel);
}
} else {
checkHasSameTriple(prompt.getStartEntityType(), prompt.getRelType(),prompt.getEndEntityType(),prompt.getId());
save = notePromptService.updateById(prompt);
// 更新prompt绑定的分类信息
// 首先查询已经有的,如果都存在,就不变,如果数据库有,前端没有,就删除,如果前端有,数据库没有,就新增
@ -175,13 +178,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
notePromptTypeRelService.lambdaUpdate().in(NotePromptTypeRel::getTypeId, deleteIdList).eq(NotePromptTypeRel::getPromptId, prompt.getId()).remove();
}
// 新增(前端有数据库没有的)
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e ->{
frontRelIdList.stream().filter(id -> !existTypeList.contains(id)).forEach(e -> {
NotePromptTypeRel rel = new NotePromptTypeRel();
rel.setPromptId(prompt.getId());
rel.setTypeId(e);
notePromptTypeRelService.save(rel);
});
}else {
} else {
// 如果数据库里面没查到,直接新增,一般不会走这一步
for (String typeId : typeList) {
NotePromptTypeRel rel = new NotePromptTypeRel();
@ -218,6 +221,22 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
}
}
private void checkHasSameTriple(String startEntityType, String relType, String endEntityType,String promptId) {
List<NotePrompt> list = notePromptService.lambdaQuery().eq(NotePrompt::getStartEntityType, startEntityType)
.eq(NotePrompt::getRelType, relType).eq(NotePrompt::getEndEntityType, endEntityType).list();
if (CollUtil.isNotEmpty(list) ) {
if (StrUtil.isBlank(promptId)){
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
}else {
// 校验list查出来的是不是和promptId相等,如果不想等,也报错
if (!list.get(0).getId().equals(promptId)){
throw new RuntimeException("该三元组关系已经存在,请勿重复添加");
}
}
}
}
@Override
@Transactional(transactionManager = "dataSourceTransactionManager", rollbackFor = Exception.class)
public R<?> delPrompt(NotePrompt prompt) {

@ -87,7 +87,6 @@ public class RecordSplitTypeThread implements Callable<Boolean> {
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
private static final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
@Override

@ -52,25 +52,27 @@ public class TripleExtractThread implements Callable<TripleInfo> {
/**
* :
* :
*
* 1.
* 2.
* 3.
* 4. !!
* 5. ,,
*
* "{headEntityType}";"{tailEntityType}","{relation}"
* json:{"result":[]}
* ---
* :
* ::"行为人":"伪造",:"合同"
* : :
* "伪造",{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]}
* example:
*
* : :
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"小明"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "法国的高档化妆品"}}]}
* ---
* QA
*
* {question}
* {answer}
* ---
*
* 1. QA
* 2.
* 3.
* 4. ,
* 5. ,,
*
* json:
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
*/
@ -87,19 +89,20 @@ public class TripleExtractThread implements Callable<TripleInfo> {
paramMap.put("question", question);
paramMap.put("answer", answer);
String format = StrUtil.format(prompt.getPrompt(), paramMap);
// 对format进行切分,把前面两个---作为systemPrompt
String[] split = format.split("---");
SystemMessage systemMessage = new SystemMessage(split[0]);
UserMessage userMessage = new UserMessage(split[1]);
Prompt ask = new Prompt(List.of(systemMessage, userMessage));
ChatResponse call = chatClient.call(ask);
// // 对format进行切分,把前面两个---作为systemPrompt
// String[] split = format.split("---");
// SystemMessage systemMessage = new SystemMessage(split[0]);
// UserMessage userMessage = new UserMessage(split[1]);
// Prompt ask = new Prompt(List.of(systemMessage, userMessage));
// ChatResponse call = chatClient.call(ask);
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
// 获取从提示词中提取到的三元组信息
TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class);
if (ObjectUtil.isEmpty(extractResult) || CollUtil.isEmpty(extractResult.getResult())) {
log.info("提取三元组信息为空,忽略.提取的内容为:{}",content);
log.info("提取三元组信息为空,忽略.提取的内容为:{}", content);
return null;
}
for (TripleExtractNode tripleExtractNode : extractResult.getResult()) {
@ -118,6 +121,7 @@ public class TripleExtractThread implements Callable<TripleInfo> {
tripleInfo.setCaseId(caseId);
tripleInfo.setRecordId(recordId);
tripleInfo.setRecordSplitId(recordSplitId);
tripleInfo.setSubmitPrompt(format);
tripleInfo.setStartNodeType(prompt.getStartEntityType());
tripleInfo.setEndNodeType(prompt.getEndEntityType());
return tripleInfo;

Loading…
Cancel
Save