Merge remote-tracking branch 'origin/dev_1.0.0' into dev_1.0.0

topo_dev
xueqingkun 9 months ago
commit 068cb6cf06

@ -0,0 +1,40 @@
package com.supervision.prompt;
import cn.hutool.core.util.StrUtil;
import com.supervision.common.domain.R;
import com.supervision.config.BusinessException;
import com.supervision.prompt.service.PromptTestService;
import com.supervision.prompt.vo.ParamTestDTO;
import com.supervision.prompt.vo.PromptRequestVO;
import com.supervision.prompt.vo.ResultResVO;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.annotations.Param;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.List;
@Slf4j
@RestController
@RequestMapping("prompt")
@RequiredArgsConstructor
public class PromptTestController {
private final PromptTestService promptTestService;
@PostMapping("test")
public R<Boolean> test(@RequestBody PromptRequestVO promptRequestVO) throws IOException {
return R.ok(promptTestService.test(promptRequestVO));
}
@GetMapping("queryResult")
public R<ResultResVO> queryResult(String uid){
if (StrUtil.isBlank(uid)){
throw new BusinessException("uid不能为空");
}
return R.ok(promptTestService.queryResult(uid));
}
}

@ -0,0 +1,18 @@
package com.supervision.prompt.service;
import com.supervision.prompt.vo.PromptRequestVO;
import com.supervision.prompt.vo.ResultResVO;
import org.apache.ibatis.annotations.Param;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.List;
public interface PromptTestService {
boolean test(PromptRequestVO promptRequestVO) throws IOException;
ResultResVO queryResult(String uid);
}

@ -0,0 +1,212 @@
package com.supervision.prompt.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ConcurrentHashSet;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONObjectIter;
import cn.hutool.json.JSONUtil;
import com.supervision.config.BusinessException;
import com.supervision.minio.domain.MinioFile;
import com.supervision.minio.service.MinioService;
import com.supervision.prompt.service.PromptTestService;
import com.supervision.prompt.thread.PromptTestThread;
import com.supervision.prompt.vo.ParamTestDTO;
import com.supervision.prompt.vo.PromptRequestVO;
import com.supervision.prompt.vo.ExecResultDTO;
import com.supervision.prompt.vo.ResultResVO;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.poi.ss.formula.functions.Rank;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.stereotype.Service;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@Service
@RequiredArgsConstructor
@Slf4j
public class PromptTestServiceImpl implements PromptTestService {
private final OllamaChatClient ollamaChatClient;
private final MinioService minioService;
private static final Map<String, ConcurrentLinkedDeque<ExecResultDTO>> resultMap = new ConcurrentHashMap<>();
private static final Set<String> statusSet = new ConcurrentHashSet<>();
private static final Map<String, ExecutorService> taskExecutorMap = new ConcurrentHashMap<>();
private static final Map<String, ExecutorService> checkExecutorMap = new ConcurrentHashMap<>();
@Override
public boolean test(PromptRequestVO promptRequestVO) throws IOException {
if (StrUtil.isBlank(promptRequestVO.getUid())) {
throw new BusinessException("uid不能为空");
}
if (promptRequestVO.getTaskType() == 1) {
if (StrUtil.isBlank(promptRequestVO.getFileId())) {
throw new BusinessException("文件ID不能为空");
}
MinioFile minioFile = minioService.getMinioFile(promptRequestVO.getFileId());
InputStream inputStream = null;
try {
inputStream = minioService.getObjectInputStream(minioFile);
} catch (Exception e) {
throw new BusinessException("从minio中获取文件失败:{}" + e.getMessage());
}
// 结果集合中设置初始值
resultMap.put(promptRequestVO.getUid(), new ConcurrentLinkedDeque<>());
// 首先对笔录进行分割操作
String context = WordReadUtil.readWord(inputStream);
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(context, promptRequestVO.getPersonName());
// 校验,模板中必须存在参数question,answer
if (StrUtil.isBlank(promptRequestVO.getPrompt())) {
throw new BusinessException("模板内容不能为空");
}
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{question}")) {
throw new BusinessException("模板内容不正确,未找到占位符{question}");
}
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{answer}")) {
throw new BusinessException("模板内容不正确,未找到占位符{answer}");
}
for (ParamTestDTO paramTestDTO : promptRequestVO.getParamList()) {
if (StrUtil.isBlank(paramTestDTO.getKey())) {
throw new BusinessException("占位符的键不能为空");
}
if (StrUtil.isBlank(paramTestDTO.getValue())) {
throw new BusinessException("占位符的值不能为空");
}
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{" + paramTestDTO.getKey() + "}")) {
throw new BusinessException("模板内容不正确,未找到占位符{" + paramTestDTO.getKey() + "}");
}
}
// 创建一个线程池
ExecutorService testExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "testExtract", false);
taskExecutorMap.put(promptRequestVO.getUid(), testExecutor);
List<Future<ExecResultDTO>> futures = new ArrayList<>();
for (QARecordNodeDTO qaRecordNodeDTO : qaList) {
List<ParamTestDTO> paramList = promptRequestVO.getParamList();
Map<String, String> paramMap = paramList.stream().collect(Collectors.toMap(ParamTestDTO::getKey, ParamTestDTO::getValue));
paramMap.put("question", qaRecordNodeDTO.getQuestion());
paramMap.put("answer", qaRecordNodeDTO.getAnswer());
// 这里提交给大模型
PromptTestThread result = new PromptTestThread(ollamaChatClient, promptRequestVO.getPrompt(), paramMap);
Future<ExecResultDTO> submit = testExecutor.submit(result);
futures.add(submit);
}
ExecutorService checkExecutor = ThreadUtil.newFixedExecutor(2, Integer.MAX_VALUE, "checkExtract", false);
checkExecutorMap.put(promptRequestVO.getUid(), checkExecutor);
CheckThread checkThread = new CheckThread(futures, promptRequestVO.getUid());
checkExecutor.submit(checkThread);
} else if (promptRequestVO.getTaskType() == 2) {
ExecutorService executorService = taskExecutorMap.get(promptRequestVO.getUid());
executorService.shutdown();
ExecutorService checkService = checkExecutorMap.get(promptRequestVO.getUid());
checkService.shutdown();
}
return true;
}
private static class CheckThread extends Thread {
private final List<Future<ExecResultDTO>> futures;
private final String uid;
public CheckThread(List<Future<ExecResultDTO>> futures, String uid) {
this.futures = futures;
this.uid = uid;
}
@Override
public void run() {
try {
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
// 计数器
AtomicInteger atomicInteger = new AtomicInteger(0);
while (futures.size() > 0) {
Iterator<Future<ExecResultDTO>> iterator = futures.iterator();
while (iterator.hasNext()) {
Future<ExecResultDTO> future = iterator.next();
try {
// 如果提取到结果,且不为空,就进行保存
if (future.isDone()) {
ExecResultDTO result = future.get();
if (ObjectUtil.isNotEmpty(result) && StrUtil.isNotBlank(result.getResult())) {
ConcurrentLinkedDeque<ExecResultDTO> strings = resultMap.get(uid);
// 从尾部添加
strings.offerLast(result);
}
iterator.remove();
}
} catch (Exception e) {
log.info("从线程中获取任务失败");
iterator.remove();
}
}
try {
int currentCount = atomicInteger.incrementAndGet();
if (currentCount > 1000) {
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
// 将还在执行的线程中断
futures.forEach(future -> {
future.cancel(true);
});
break;
}
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
Thread.sleep(1000 * 5);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
log.info("执行完毕");
statusSet.add(uid);
}
}
@Override
public ResultResVO queryResult(String uid) {
ConcurrentLinkedDeque<ExecResultDTO> strings = resultMap.get(uid);
List<ExecResultDTO> result = new ArrayList<>();
if (CollUtil.isNotEmpty(strings)) {
while (true) {
ExecResultDTO s = strings.pollFirst();
if (ObjectUtil.isEmpty(s)) {
break;
} else {
result.add(s);
}
}
}
ResultResVO resultResVO = new ResultResVO();
resultResVO.setDataList(result);
if (statusSet.contains(uid)) {
resultResVO.setStatus(2);
} else {
resultResVO.setStatus(1);
}
return resultResVO;
}
}

@ -0,0 +1,63 @@
package com.supervision.prompt.thread;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo;
import com.supervision.prompt.vo.ExecResultDTO;
import com.supervision.thread.TripleExtractThread;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
@Slf4j
public class PromptTestThread implements Callable<ExecResultDTO> {
private final OllamaChatClient chatClient;
private final Map<String, String> paramMap;
private final String prompt;
public PromptTestThread(OllamaChatClient chatClient, String prompt, Map<String, String> paramMap) {
this.chatClient = chatClient;
this.paramMap = paramMap;
this.prompt = prompt;
}
@Override
public ExecResultDTO call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
String format = StrUtil.format(prompt, paramMap);
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
ExecResultDTO resultDTO = new ExecResultDTO();
resultDTO.setResult(content);
resultDTO.setSubmitPrompt(format);
resultDTO.setQuestion(paramMap.get("question"));
resultDTO.setAnswer(paramMap.get("answer"));
return resultDTO;
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
}

@ -0,0 +1,15 @@
package com.supervision.prompt.vo;
import lombok.Data;
@Data
public class ExecResultDTO {
private String result;
private String submitPrompt;
private String question;
private String answer;
}

@ -0,0 +1,11 @@
package com.supervision.prompt.vo;
import lombok.Data;
@Data
public class ParamTestDTO {
private String key;
private String value;
}

@ -0,0 +1,26 @@
package com.supervision.prompt.vo;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class PromptRequestVO {
private Integer taskType;
private String fileId;
private List<ParamTestDTO> paramList;
private String prompt;
private String personName;
@ApiModelProperty("前端生成的唯一ID")
private String uid;
}

@ -0,0 +1,15 @@
package com.supervision.prompt.vo;
import lombok.Data;
import java.util.List;
@Data
public class ResultResVO {
private List<ExecResultDTO> dataList;
/*
* 1,2
*/
private Integer status;
}
Loading…
Cancel
Save