You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/main/java/com/supervision/police/service/impl/NotePromptServiceImpl.java

212 lines
10 KiB
Java

package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.common.constant.NotePromptConstants;
import com.supervision.common.utils.StringUtils;
import com.supervision.demo.dto.QARecordNodeDTO;
import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.service.MinioService;
import com.supervision.police.domain.*;
import com.supervision.police.dto.LLMExtractDto;
import com.supervision.police.dto.NotePromptDTO;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.service.*;
import com.supervision.thread.TripleExtractTask;
import com.supervision.utils.RecordRegexUtil;
import com.supervision.utils.WordReadUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
import static com.supervision.police.service.impl.ModelRecordTypeServiceImpl.buildTripleInfo;
@Slf4j
@Service
@RequiredArgsConstructor
public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePrompt> implements NotePromptService {
private final NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Autowired
private NotePromptMapper notePromptMapper;
private final NotePromptTypeRelService notePromptTypeRelService;
@Autowired
private MinioService minioService;
@Autowired
private LLMExtractService llmExtractService;
@Autowired
private ChatClient chatClient;
@Autowired
private EvidenceCategoryService evidenceCategoryService;
@Override
public List<NotePrompt> listPromptBySplitId(String recordSplitId) {
List<NotePrompt> notePromptList = new ArrayList<>();
// 首先获取所有切分后的笔录
Optional<NoteRecordSplit> optById = noteRecordSplitService.getOptById(recordSplitId);
if (optById.isEmpty()) {
log.warn("listPromptBySplitId根据笔录片段id{}未找到切分笔录数据...", recordSplitId);
return notePromptList;
}
NoteRecordSplit recordSplit = optById.get();
String recordType = recordSplit.getRecordType();
if (StrUtil.isBlank(recordType)) {
log.info("listPromptBySplitId:笔录片段:{} 不属于任何分类...", recordSplit.getId());
}
// 获取所有的分类
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
// 对切分后的笔录进行遍历
String[] split = recordType.split(";");
for (String typeName : split) {
String typeId = allTypeMap.get(typeName);
if (StrUtil.isBlank(typeId)) {
log.info("listPromptBySplitId:笔录片段id:{} typeName:{}未在全局分类中找到数据...", recordSplit.getId(), typeName);
continue;
}
// 根据笔录类型找到所有的提取三元组的提示词
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePromptTypeRel> promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list();
if (CollUtil.isEmpty(promptTypeRelList)) {
log.info("listPromptBySplitId:笔录片段:{}根据typeId:{},typeName:{},未找到对应的提示词信息...", recordSplit.getId(), typeId, typeName);
continue;
}
List<String> promptIdList = promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).toList();
List<NotePrompt> list = super.lambdaQuery().in(NotePrompt::getId, promptIdList).list();
if (CollUtil.isEmpty(list)) {
log.info("listPromptBySplitId:根据 promptIdList:{},未找到对应的提示词信息...", CollUtil.join(promptIdList, ","));
continue;
}
notePromptList.addAll(list);
}
return notePromptList;
}
11 months ago
@Override
public IPage<NotePromptDTO> listPrompt(NotePromptDTO notePromptDTO) {
return notePromptMapper.selectNotePromptWithMatchNum(new Page<>(notePromptDTO.getPage(), notePromptDTO.getSize()), notePromptDTO);
11 months ago
}
@Override
public List promptDebugging(NotePromptDTO notePromptDTO) {
String text = notePromptDTO.getText();
String fileId = notePromptDTO.getFileId();
if (StringUtils.isNotEmpty(fileId)) {
MinioFile minioFile = minioService.getMinioFile(fileId);
if (minioFile != null) {
if (minioFile.getFilename().endsWith(".doc") || minioFile.getFilename().endsWith(".docx")) {
log.info("当前文件名为:{}当做word文件处理...", minioFile.getFilename());
text = WordReadUtil.readWord(minioService.getObjectInputStream(minioFile));
} else if (minioFile.getFilename().endsWith(".txt")) {
log.info("当前文件名为:{}当做txt文件处理...", minioFile.getFilename());
InputStream inputStream = minioService.getObjectInputStream(minioFile);
try (inputStream) {
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
text = reader.lines().collect(Collectors.joining("\n"));
} catch (IOException e) {
log.error("读取文件内容失败", e);
}
}
}
}
if (StringUtils.isEmpty(text)) {
log.info("未上传文件且调试文本为空...");
return null;
}
String type = notePromptDTO.getType();
if (NotePromptConstants.TYPE_GRAPH_REASONING.equals(type)) {
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(text, "");
log.info("拆分问答对:{}", qaList.size());
if (qaList.isEmpty()) {
return null;
}
QARecordNodeDTO qaRecordNodeDTO = qaList.get(0);
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", notePromptDTO.getStartEntityType());
paramMap.put("relation", notePromptDTO.getRelType());
paramMap.put("tailEntityType", notePromptDTO.getEndEntityType());
paramMap.put("question", qaRecordNodeDTO.getQuestion());
paramMap.put("answer", qaRecordNodeDTO.getAnswer());
paramMap.put("requirement", "");
log.info("开始尝试提取三元组:{}-{}-{}", notePromptDTO.getStartEntityType(), notePromptDTO.getRelType(), notePromptDTO.getEndEntityType());
String format = StrUtil.format(notePromptDTO.getPrompt(), paramMap);
log.info("提示词内容:{}", format);
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
String content = call.getResult().getOutput().getContent();
log.info("三元组提取结果:{}", content);
return JSONUtil.toBean(content, TripleExtractTask.TripleExtractResult.class).getResult();
} else if (NotePromptConstants.TYPE_STRUCTURAL_REASONING.equals(type)) {
LLMExtractDto llmExtractDto = new LLMExtractDto();
llmExtractDto.setText(text);
llmExtractDto.setPrompt(notePromptDTO.getPrompt());
llmExtractDto.setExtractAttributes(notePromptDTO.getExtractAttributes());
List<LLMExtractDto> llmExtractDtos = llmExtractService.extractAttribute(Collections.singletonList(llmExtractDto));
if (CollUtil.isNotEmpty(llmExtractDtos)) {
return llmExtractDtos.get(0).getExtractAttributes();
}
} else {
log.info("未找到对应的调试类型...【{}】", type);
}
return null;
}
@Override
public NotePromptDTO getById(String id) {
NotePrompt notePrompt = super.getById(id);
NotePromptDTO notePromptDTO = new NotePromptDTO();
BeanUtils.copyProperties(notePrompt, notePromptDTO);
notePromptDTO.setTripleList(buildTripleInfo(notePrompt));
//根据notePrompt的ID调用notePromptTypeRelService查询prompt_id相等的list
List<NotePromptTypeRel> notePromptTypeRels = notePromptTypeRelService.list(new LambdaQueryWrapper<NotePromptTypeRel>().eq(NotePromptTypeRel::getPromptId, notePrompt.getId()));
if (notePromptTypeRels != null && !notePromptTypeRels.isEmpty()) {
notePromptDTO.setTypeList(notePromptTypeRels.stream().map(NotePromptTypeRel::getTypeId).collect(Collectors.toList()));
}
String evidenceCategoryId = notePromptDTO.getEvidenceCategoryId();
if (StringUtils.isNotEmpty(evidenceCategoryId)) {
EvidenceCategory category = evidenceCategoryService.getById(evidenceCategoryId);
if (category != null) {
String parentId = category.getParentId();
notePromptDTO.setEvidenceCategoryIdList(Arrays.asList(parentId, evidenceCategoryId));
}
}
return notePromptDTO;
}
}