From aed68d87e3775e5bb12087f6e606bc48aaa9e32c Mon Sep 17 00:00:00 2001 From: liu Date: Tue, 30 Jul 2024 16:53:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E4=BA=A4=E4=B8=89=E5=85=83=E7=BB=84?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E6=B5=8B=E8=AF=95=E7=9A=84=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../prompt/PromptTestController.java | 40 ++++ .../prompt/service/PromptTestService.java | 18 ++ .../service/impl/PromptTestServiceImpl.java | 212 ++++++++++++++++++ .../prompt/thread/PromptTestThread.java | 63 ++++++ .../supervision/prompt/vo/ExecResultDTO.java | 15 ++ .../supervision/prompt/vo/ParamTestDTO.java | 11 + .../prompt/vo/PromptRequestVO.java | 26 +++ .../supervision/prompt/vo/ResultResVO.java | 15 ++ 8 files changed, 400 insertions(+) create mode 100644 src/main/java/com/supervision/prompt/PromptTestController.java create mode 100644 src/main/java/com/supervision/prompt/service/PromptTestService.java create mode 100644 src/main/java/com/supervision/prompt/service/impl/PromptTestServiceImpl.java create mode 100644 src/main/java/com/supervision/prompt/thread/PromptTestThread.java create mode 100644 src/main/java/com/supervision/prompt/vo/ExecResultDTO.java create mode 100644 src/main/java/com/supervision/prompt/vo/ParamTestDTO.java create mode 100644 src/main/java/com/supervision/prompt/vo/PromptRequestVO.java create mode 100644 src/main/java/com/supervision/prompt/vo/ResultResVO.java diff --git a/src/main/java/com/supervision/prompt/PromptTestController.java b/src/main/java/com/supervision/prompt/PromptTestController.java new file mode 100644 index 0000000..b873474 --- /dev/null +++ b/src/main/java/com/supervision/prompt/PromptTestController.java @@ -0,0 +1,40 @@ +package com.supervision.prompt; + +import cn.hutool.core.util.StrUtil; +import com.supervision.common.domain.R; +import com.supervision.config.BusinessException; +import com.supervision.prompt.service.PromptTestService; +import com.supervision.prompt.vo.ParamTestDTO; +import com.supervision.prompt.vo.PromptRequestVO; +import com.supervision.prompt.vo.ResultResVO; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.ibatis.annotations.Param; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.List; + +@Slf4j +@RestController +@RequestMapping("prompt") +@RequiredArgsConstructor +public class PromptTestController { + + private final PromptTestService promptTestService; + + @PostMapping("test") + public R test(@RequestBody PromptRequestVO promptRequestVO) throws IOException { + return R.ok(promptTestService.test(promptRequestVO)); + } + + @GetMapping("queryResult") + public R queryResult(String uid){ + if (StrUtil.isBlank(uid)){ + throw new BusinessException("uid不能为空"); + } + return R.ok(promptTestService.queryResult(uid)); + } +} \ No newline at end of file diff --git a/src/main/java/com/supervision/prompt/service/PromptTestService.java b/src/main/java/com/supervision/prompt/service/PromptTestService.java new file mode 100644 index 0000000..231c2ac --- /dev/null +++ b/src/main/java/com/supervision/prompt/service/PromptTestService.java @@ -0,0 +1,18 @@ +package com.supervision.prompt.service; + +import com.supervision.prompt.vo.PromptRequestVO; +import com.supervision.prompt.vo.ResultResVO; +import org.apache.ibatis.annotations.Param; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.io.IOException; +import java.util.List; + +public interface PromptTestService { + + boolean test(PromptRequestVO promptRequestVO) throws IOException; + + ResultResVO queryResult(String uid); +} diff --git a/src/main/java/com/supervision/prompt/service/impl/PromptTestServiceImpl.java b/src/main/java/com/supervision/prompt/service/impl/PromptTestServiceImpl.java new file mode 100644 index 0000000..3789b86 --- /dev/null +++ b/src/main/java/com/supervision/prompt/service/impl/PromptTestServiceImpl.java @@ -0,0 +1,212 @@ +package com.supervision.prompt.service.impl; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.collection.ConcurrentHashSet; +import cn.hutool.core.thread.ThreadUtil; +import cn.hutool.core.util.ObjectUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONObjectIter; +import cn.hutool.json.JSONUtil; +import com.supervision.config.BusinessException; +import com.supervision.minio.domain.MinioFile; +import com.supervision.minio.service.MinioService; +import com.supervision.prompt.service.PromptTestService; +import com.supervision.prompt.thread.PromptTestThread; +import com.supervision.prompt.vo.ParamTestDTO; +import com.supervision.prompt.vo.PromptRequestVO; +import com.supervision.prompt.vo.ExecResultDTO; +import com.supervision.prompt.vo.ResultResVO; +import com.supervision.springaidemo.dto.QARecordNodeDTO; +import com.supervision.springaidemo.util.RecordRegexUtil; +import com.supervision.springaidemo.util.WordReadUtil; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.apache.poi.ss.formula.functions.Rank; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; +import java.math.BigDecimal; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +@Service +@RequiredArgsConstructor +@Slf4j +public class PromptTestServiceImpl implements PromptTestService { + + private final OllamaChatClient ollamaChatClient; + + private final MinioService minioService; + + private static final Map> resultMap = new ConcurrentHashMap<>(); + + private static final Set statusSet = new ConcurrentHashSet<>(); + + private static final Map taskExecutorMap = new ConcurrentHashMap<>(); + + private static final Map checkExecutorMap = new ConcurrentHashMap<>(); + + @Override + public boolean test(PromptRequestVO promptRequestVO) throws IOException { + if (StrUtil.isBlank(promptRequestVO.getUid())) { + throw new BusinessException("uid不能为空"); + } + if (promptRequestVO.getTaskType() == 1) { + if (StrUtil.isBlank(promptRequestVO.getFileId())) { + throw new BusinessException("文件ID不能为空"); + } + MinioFile minioFile = minioService.getMinioFile(promptRequestVO.getFileId()); + InputStream inputStream = null; + try { + inputStream = minioService.getObjectInputStream(minioFile); + } catch (Exception e) { + throw new BusinessException("从minio中获取文件失败:{}" + e.getMessage()); + } + // 结果集合中设置初始值 + resultMap.put(promptRequestVO.getUid(), new ConcurrentLinkedDeque<>()); + // 首先对笔录进行分割操作 + String context = WordReadUtil.readWord(inputStream); + List qaList = RecordRegexUtil.recordRegex(context, promptRequestVO.getPersonName()); + // 校验,模板中必须存在参数question,answer + if (StrUtil.isBlank(promptRequestVO.getPrompt())) { + throw new BusinessException("模板内容不能为空"); + } + if (!StrUtil.contains(promptRequestVO.getPrompt(), "{question}")) { + throw new BusinessException("模板内容不正确,未找到占位符{question}"); + } + if (!StrUtil.contains(promptRequestVO.getPrompt(), "{answer}")) { + throw new BusinessException("模板内容不正确,未找到占位符{answer}"); + } + for (ParamTestDTO paramTestDTO : promptRequestVO.getParamList()) { + if (StrUtil.isBlank(paramTestDTO.getKey())) { + throw new BusinessException("占位符的键不能为空"); + } + if (StrUtil.isBlank(paramTestDTO.getValue())) { + throw new BusinessException("占位符的值不能为空"); + } + if (!StrUtil.contains(promptRequestVO.getPrompt(), "{" + paramTestDTO.getKey() + "}")) { + throw new BusinessException("模板内容不正确,未找到占位符{" + paramTestDTO.getKey() + "}"); + } + } + // 创建一个线程池 + ExecutorService testExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "testExtract", false); + taskExecutorMap.put(promptRequestVO.getUid(), testExecutor); + List> futures = new ArrayList<>(); + for (QARecordNodeDTO qaRecordNodeDTO : qaList) { + List paramList = promptRequestVO.getParamList(); + Map paramMap = paramList.stream().collect(Collectors.toMap(ParamTestDTO::getKey, ParamTestDTO::getValue)); + paramMap.put("question", qaRecordNodeDTO.getQuestion()); + paramMap.put("answer", qaRecordNodeDTO.getAnswer()); + // 这里提交给大模型 + PromptTestThread result = new PromptTestThread(ollamaChatClient, promptRequestVO.getPrompt(), paramMap); + Future submit = testExecutor.submit(result); + futures.add(submit); + } + ExecutorService checkExecutor = ThreadUtil.newFixedExecutor(2, Integer.MAX_VALUE, "checkExtract", false); + checkExecutorMap.put(promptRequestVO.getUid(), checkExecutor); + CheckThread checkThread = new CheckThread(futures, promptRequestVO.getUid()); + checkExecutor.submit(checkThread); + } else if (promptRequestVO.getTaskType() == 2) { + ExecutorService executorService = taskExecutorMap.get(promptRequestVO.getUid()); + executorService.shutdown(); + ExecutorService checkService = checkExecutorMap.get(promptRequestVO.getUid()); + checkService.shutdown(); + } + + return true; + } + + private static class CheckThread extends Thread { + + private final List> futures; + + private final String uid; + + public CheckThread(List> futures, String uid) { + this.futures = futures; + this.uid = uid; + } + + @Override + public void run() { + try { + log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size()); + Thread.sleep(1000 * 5); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + // 计数器 + AtomicInteger atomicInteger = new AtomicInteger(0); + while (futures.size() > 0) { + Iterator> iterator = futures.iterator(); + while (iterator.hasNext()) { + Future future = iterator.next(); + try { + // 如果提取到结果,且不为空,就进行保存 + if (future.isDone()) { + ExecResultDTO result = future.get(); + if (ObjectUtil.isNotEmpty(result) && StrUtil.isNotBlank(result.getResult())) { + ConcurrentLinkedDeque strings = resultMap.get(uid); + // 从尾部添加 + strings.offerLast(result); + } + iterator.remove(); + } + } catch (Exception e) { + log.info("从线程中获取任务失败"); + iterator.remove(); + } + } + try { + int currentCount = atomicInteger.incrementAndGet(); + if (currentCount > 1000) { + log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size()); + // 将还在执行的线程中断 + futures.forEach(future -> { + future.cancel(true); + }); + break; + } + log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size()); + Thread.sleep(1000 * 5); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } + log.info("执行完毕"); + statusSet.add(uid); + } + } + + @Override + public ResultResVO queryResult(String uid) { + ConcurrentLinkedDeque strings = resultMap.get(uid); + List result = new ArrayList<>(); + if (CollUtil.isNotEmpty(strings)) { + while (true) { + ExecResultDTO s = strings.pollFirst(); + if (ObjectUtil.isEmpty(s)) { + break; + } else { + result.add(s); + } + } + } + ResultResVO resultResVO = new ResultResVO(); + resultResVO.setDataList(result); + if (statusSet.contains(uid)) { + resultResVO.setStatus(2); + } else { + resultResVO.setStatus(1); + } + return resultResVO; + } + + +} diff --git a/src/main/java/com/supervision/prompt/thread/PromptTestThread.java b/src/main/java/com/supervision/prompt/thread/PromptTestThread.java new file mode 100644 index 0000000..781ff05 --- /dev/null +++ b/src/main/java/com/supervision/prompt/thread/PromptTestThread.java @@ -0,0 +1,63 @@ +package com.supervision.prompt.thread; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.ObjectUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; +import com.supervision.police.domain.NotePrompt; +import com.supervision.police.domain.TripleInfo; +import com.supervision.prompt.vo.ExecResultDTO; +import com.supervision.thread.TripleExtractThread; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.util.StopWatch; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; + +@Slf4j +public class PromptTestThread implements Callable { + + private final OllamaChatClient chatClient; + + private final Map paramMap; + + private final String prompt; + + public PromptTestThread(OllamaChatClient chatClient, String prompt, Map paramMap) { + this.chatClient = chatClient; + this.paramMap = paramMap; + this.prompt = prompt; + } + + @Override + public ExecResultDTO call() { + try { + StopWatch stopWatch = new StopWatch(); + // 分析三元组 + stopWatch.start(); + String format = StrUtil.format(prompt, paramMap); + ChatResponse call = chatClient.call(new Prompt(new UserMessage(format))); + stopWatch.stop(); + String content = call.getResult().getOutput().getContent(); + log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content); + ExecResultDTO resultDTO = new ExecResultDTO(); + resultDTO.setResult(content); + resultDTO.setSubmitPrompt(format); + resultDTO.setQuestion(paramMap.get("question")); + resultDTO.setAnswer(paramMap.get("answer")); + + return resultDTO; + } catch (Exception e) { + log.error("提取三元组出现错误", e); + } + return null; + } + +} diff --git a/src/main/java/com/supervision/prompt/vo/ExecResultDTO.java b/src/main/java/com/supervision/prompt/vo/ExecResultDTO.java new file mode 100644 index 0000000..8e5bfa1 --- /dev/null +++ b/src/main/java/com/supervision/prompt/vo/ExecResultDTO.java @@ -0,0 +1,15 @@ +package com.supervision.prompt.vo; + +import lombok.Data; + +@Data +public class ExecResultDTO { + + private String result; + + private String submitPrompt; + + private String question; + + private String answer; +} diff --git a/src/main/java/com/supervision/prompt/vo/ParamTestDTO.java b/src/main/java/com/supervision/prompt/vo/ParamTestDTO.java new file mode 100644 index 0000000..a9ad38f --- /dev/null +++ b/src/main/java/com/supervision/prompt/vo/ParamTestDTO.java @@ -0,0 +1,11 @@ +package com.supervision.prompt.vo; + +import lombok.Data; + +@Data +public class ParamTestDTO { + + private String key; + + private String value; +} diff --git a/src/main/java/com/supervision/prompt/vo/PromptRequestVO.java b/src/main/java/com/supervision/prompt/vo/PromptRequestVO.java new file mode 100644 index 0000000..2f9bd1a --- /dev/null +++ b/src/main/java/com/supervision/prompt/vo/PromptRequestVO.java @@ -0,0 +1,26 @@ +package com.supervision.prompt.vo; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +import java.util.List; +import java.util.Map; + +@Data +public class PromptRequestVO { + + private Integer taskType; + + private String fileId; + + private List paramList; + + private String prompt; + + private String personName; + + @ApiModelProperty("前端生成的唯一ID") + private String uid; + + +} diff --git a/src/main/java/com/supervision/prompt/vo/ResultResVO.java b/src/main/java/com/supervision/prompt/vo/ResultResVO.java new file mode 100644 index 0000000..32a1b56 --- /dev/null +++ b/src/main/java/com/supervision/prompt/vo/ResultResVO.java @@ -0,0 +1,15 @@ +package com.supervision.prompt.vo; + +import lombok.Data; + +import java.util.List; + +@Data +public class ResultResVO { + + private List dataList; + /* + * 状态 1进行中,2已完成 + */ + private Integer status; +}