You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/main/java/com/supervision/police/service/impl/NotePromptServiceImpl.java

179 lines
8.3 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.police.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.common.constant.NotePromptConstants;
import com.supervision.common.utils.StringUtils;
import com.supervision.demo.dto.QARecordNodeDTO;
import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.service.MinioService;
import com.supervision.police.domain.*;
import com.supervision.police.dto.LLMExtractDto;
import com.supervision.police.dto.NotePromptDTO;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.service.*;
import com.supervision.utils.RecordRegexUtil;
import com.supervision.utils.WordReadUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Service
@RequiredArgsConstructor
public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePrompt> implements NotePromptService {
private final NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Autowired
private NotePromptMapper notePromptMapper;
private final NotePromptTypeRelService notePromptTypeRelService;
@Autowired
private MinioService minioService;
@Autowired
private LLMExtractService llmExtractService;
@Autowired
private ChatClient chatClient;
@Override
public List<NotePrompt> listPromptBySplitId(String recordSplitId) {
List<NotePrompt> notePromptList = new ArrayList<>();
// 首先获取所有切分后的笔录
Optional<NoteRecordSplit> optById = noteRecordSplitService.getOptById(recordSplitId);
if (optById.isEmpty()) {
log.warn("listPromptBySplitId根据笔录片段id{}未找到切分笔录数据...", recordSplitId);
return notePromptList;
}
NoteRecordSplit recordSplit = optById.get();
String recordType = recordSplit.getRecordType();
if (StrUtil.isBlank(recordType)) {
log.info("listPromptBySplitId:笔录片段:{} 不属于任何分类...", recordSplit.getId());
}
// 获取所有的分类
List<ModelRecordType> allTypeList = modelRecordTypeService.list();
Map<String, String> allTypeMap = allTypeList.stream().collect(Collectors.toMap(ModelRecordType::getRecordType, ModelRecordType::getId, (k1, k2) -> k1));
// 对切分后的笔录进行遍历
String[] split = recordType.split(";");
for (String typeName : split) {
String typeId = allTypeMap.get(typeName);
if (StrUtil.isBlank(typeId)) {
log.info("listPromptBySplitId:笔录片段id:{} typeName:{}未在全局分类中找到数据...", recordSplit.getId(), typeName);
continue;
}
// 根据笔录类型找到所有的提取三元组的提示词
// 一个提示词可能关联多个类型,要进行拆分操作
List<NotePromptTypeRel> promptTypeRelList = notePromptTypeRelService.lambdaQuery().eq(NotePromptTypeRel::getTypeId, typeId).select(NotePromptTypeRel::getPromptId).list();
if (CollUtil.isEmpty(promptTypeRelList)) {
log.info("listPromptBySplitId:笔录片段:{}根据typeId:{},typeName:{},未找到对应的提示词信息...", recordSplit.getId(), typeId, typeName);
continue;
}
List<String> promptIdList = promptTypeRelList.stream().map(NotePromptTypeRel::getPromptId).toList();
List<NotePrompt> list = super.lambdaQuery().in(NotePrompt::getId, promptIdList).list();
if (CollUtil.isEmpty(list)) {
log.info("listPromptBySplitId:根据 promptIdList:{},未找到对应的提示词信息...", CollUtil.join(promptIdList, ","));
continue;
}
notePromptList.addAll(list);
}
return notePromptList;
}
@Override
public IPage<NotePromptDTO> listPrompt(int page, int size, NotePrompt notePrompt) {
return notePromptMapper.selectNotePromptWithMatchNum(new Page<>(page, size), notePrompt);
}
@Override
public String promptDebugging(NotePromptDTO notePromptDTO) {
String result = "";
String text = notePromptDTO.getText();
String fileId = notePromptDTO.getFileId();
if (StringUtils.isNotEmpty(fileId)) {
MinioFile minioFile = minioService.getMinioFile(fileId);
if (minioFile != null) {
if (minioFile.getFilename().endsWith(".doc") || minioFile.getFilename().endsWith(".docx")) {
log.info("当前文件名为:{}当做word文件处理...", minioFile.getFilename());
text = WordReadUtil.readWord(minioService.getObjectInputStream(minioFile));
} else if (minioFile.getFilename().endsWith(".txt")) {
log.info("当前文件名为:{}当做txt文件处理...", minioFile.getFilename());
InputStream inputStream = minioService.getObjectInputStream(minioFile);
try (inputStream) {
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
text = reader.lines().collect(Collectors.joining("\n"));
} catch (IOException e) {
log.error("读取文件内容失败", e);
}
}
}
}
if (StringUtils.isEmpty(text)) {
log.info("未上传文件且调试文本为空...");
return result;
}
String type = notePromptDTO.getType();
if (NotePromptConstants.TYPE_GRAPH_REASONING.equals(type)) {
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(text, "");
log.info("拆分问答对:{}", qaList.size());
if (qaList.isEmpty()) {
return "未找到问答对";
}
QARecordNodeDTO qaRecordNodeDTO = qaList.get(0);
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", notePromptDTO.getStartEntityType());
paramMap.put("relation", notePromptDTO.getRelType());
paramMap.put("tailEntityType", notePromptDTO.getEndEntityType());
paramMap.put("question", qaRecordNodeDTO.getQuestion());
paramMap.put("answer", qaRecordNodeDTO.getAnswer());
paramMap.put("requirement", "");
log.info("开始尝试提取三元组:{}-{}-{}", notePromptDTO.getStartEntityType(), notePromptDTO.getRelType(), notePromptDTO.getEndEntityType());
String format = StrUtil.format(notePromptDTO.getPrompt(), paramMap);
log.info("提示词内容:{}", format);
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
String content = call.getResult().getOutput().getContent();
log.info("三元组提取结果:{}", content);
result = content;
} else if (NotePromptConstants.TYPE_STRUCTURAL_REASONING.equals(type)) {
LLMExtractDto llmExtractDto = new LLMExtractDto();
llmExtractDto.setText(text);
llmExtractDto.setPrompt(notePromptDTO.getPrompt());
llmExtractDto.setExtractAttributes(notePromptDTO.getExtractAttributes());
List<LLMExtractDto> llmExtractDtos = llmExtractService.extractAttribute(Collections.singletonList(llmExtractDto));
if (CollUtil.isNotEmpty(llmExtractDtos)) {
result = llmExtractDtos.get(0).toString();
}
} else {
log.info("未找到对应的调试类型...【{}】", type);
}
return result;
}
}