Merge remote-tracking branch 'origin/dev_1.0.0' into dev_1.0.0
commit
068cb6cf06
@ -0,0 +1,40 @@
|
||||
package com.supervision.prompt;
|
||||
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.supervision.common.domain.R;
|
||||
import com.supervision.config.BusinessException;
|
||||
import com.supervision.prompt.service.PromptTestService;
|
||||
import com.supervision.prompt.vo.ParamTestDTO;
|
||||
import com.supervision.prompt.vo.PromptRequestVO;
|
||||
import com.supervision.prompt.vo.ResultResVO;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("prompt")
|
||||
@RequiredArgsConstructor
|
||||
public class PromptTestController {
|
||||
|
||||
private final PromptTestService promptTestService;
|
||||
|
||||
@PostMapping("test")
|
||||
public R<Boolean> test(@RequestBody PromptRequestVO promptRequestVO) throws IOException {
|
||||
return R.ok(promptTestService.test(promptRequestVO));
|
||||
}
|
||||
|
||||
@GetMapping("queryResult")
|
||||
public R<ResultResVO> queryResult(String uid){
|
||||
if (StrUtil.isBlank(uid)){
|
||||
throw new BusinessException("uid不能为空");
|
||||
}
|
||||
return R.ok(promptTestService.queryResult(uid));
|
||||
}
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
package com.supervision.prompt.service;
|
||||
|
||||
import com.supervision.prompt.vo.PromptRequestVO;
|
||||
import com.supervision.prompt.vo.ResultResVO;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.multipart.MultipartFile;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
public interface PromptTestService {
|
||||
|
||||
boolean test(PromptRequestVO promptRequestVO) throws IOException;
|
||||
|
||||
ResultResVO queryResult(String uid);
|
||||
}
|
@ -0,0 +1,212 @@
|
||||
package com.supervision.prompt.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.collection.ConcurrentHashSet;
|
||||
import cn.hutool.core.thread.ThreadUtil;
|
||||
import cn.hutool.core.util.ObjectUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONObjectIter;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.supervision.config.BusinessException;
|
||||
import com.supervision.minio.domain.MinioFile;
|
||||
import com.supervision.minio.service.MinioService;
|
||||
import com.supervision.prompt.service.PromptTestService;
|
||||
import com.supervision.prompt.thread.PromptTestThread;
|
||||
import com.supervision.prompt.vo.ParamTestDTO;
|
||||
import com.supervision.prompt.vo.PromptRequestVO;
|
||||
import com.supervision.prompt.vo.ExecResultDTO;
|
||||
import com.supervision.prompt.vo.ResultResVO;
|
||||
import com.supervision.springaidemo.dto.QARecordNodeDTO;
|
||||
import com.supervision.springaidemo.util.RecordRegexUtil;
|
||||
import com.supervision.springaidemo.util.WordReadUtil;
|
||||
import lombok.Data;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.poi.ss.formula.functions.Rank;
|
||||
import org.springframework.ai.ollama.OllamaChatClient;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.math.BigDecimal;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@Slf4j
|
||||
public class PromptTestServiceImpl implements PromptTestService {
|
||||
|
||||
private final OllamaChatClient ollamaChatClient;
|
||||
|
||||
private final MinioService minioService;
|
||||
|
||||
private static final Map<String, ConcurrentLinkedDeque<ExecResultDTO>> resultMap = new ConcurrentHashMap<>();
|
||||
|
||||
private static final Set<String> statusSet = new ConcurrentHashSet<>();
|
||||
|
||||
private static final Map<String, ExecutorService> taskExecutorMap = new ConcurrentHashMap<>();
|
||||
|
||||
private static final Map<String, ExecutorService> checkExecutorMap = new ConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public boolean test(PromptRequestVO promptRequestVO) throws IOException {
|
||||
if (StrUtil.isBlank(promptRequestVO.getUid())) {
|
||||
throw new BusinessException("uid不能为空");
|
||||
}
|
||||
if (promptRequestVO.getTaskType() == 1) {
|
||||
if (StrUtil.isBlank(promptRequestVO.getFileId())) {
|
||||
throw new BusinessException("文件ID不能为空");
|
||||
}
|
||||
MinioFile minioFile = minioService.getMinioFile(promptRequestVO.getFileId());
|
||||
InputStream inputStream = null;
|
||||
try {
|
||||
inputStream = minioService.getObjectInputStream(minioFile);
|
||||
} catch (Exception e) {
|
||||
throw new BusinessException("从minio中获取文件失败:{}" + e.getMessage());
|
||||
}
|
||||
// 结果集合中设置初始值
|
||||
resultMap.put(promptRequestVO.getUid(), new ConcurrentLinkedDeque<>());
|
||||
// 首先对笔录进行分割操作
|
||||
String context = WordReadUtil.readWord(inputStream);
|
||||
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(context, promptRequestVO.getPersonName());
|
||||
// 校验,模板中必须存在参数question,answer
|
||||
if (StrUtil.isBlank(promptRequestVO.getPrompt())) {
|
||||
throw new BusinessException("模板内容不能为空");
|
||||
}
|
||||
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{question}")) {
|
||||
throw new BusinessException("模板内容不正确,未找到占位符{question}");
|
||||
}
|
||||
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{answer}")) {
|
||||
throw new BusinessException("模板内容不正确,未找到占位符{answer}");
|
||||
}
|
||||
for (ParamTestDTO paramTestDTO : promptRequestVO.getParamList()) {
|
||||
if (StrUtil.isBlank(paramTestDTO.getKey())) {
|
||||
throw new BusinessException("占位符的键不能为空");
|
||||
}
|
||||
if (StrUtil.isBlank(paramTestDTO.getValue())) {
|
||||
throw new BusinessException("占位符的值不能为空");
|
||||
}
|
||||
if (!StrUtil.contains(promptRequestVO.getPrompt(), "{" + paramTestDTO.getKey() + "}")) {
|
||||
throw new BusinessException("模板内容不正确,未找到占位符{" + paramTestDTO.getKey() + "}");
|
||||
}
|
||||
}
|
||||
// 创建一个线程池
|
||||
ExecutorService testExecutor = ThreadUtil.newFixedExecutor(5, Integer.MAX_VALUE, "testExtract", false);
|
||||
taskExecutorMap.put(promptRequestVO.getUid(), testExecutor);
|
||||
List<Future<ExecResultDTO>> futures = new ArrayList<>();
|
||||
for (QARecordNodeDTO qaRecordNodeDTO : qaList) {
|
||||
List<ParamTestDTO> paramList = promptRequestVO.getParamList();
|
||||
Map<String, String> paramMap = paramList.stream().collect(Collectors.toMap(ParamTestDTO::getKey, ParamTestDTO::getValue));
|
||||
paramMap.put("question", qaRecordNodeDTO.getQuestion());
|
||||
paramMap.put("answer", qaRecordNodeDTO.getAnswer());
|
||||
// 这里提交给大模型
|
||||
PromptTestThread result = new PromptTestThread(ollamaChatClient, promptRequestVO.getPrompt(), paramMap);
|
||||
Future<ExecResultDTO> submit = testExecutor.submit(result);
|
||||
futures.add(submit);
|
||||
}
|
||||
ExecutorService checkExecutor = ThreadUtil.newFixedExecutor(2, Integer.MAX_VALUE, "checkExtract", false);
|
||||
checkExecutorMap.put(promptRequestVO.getUid(), checkExecutor);
|
||||
CheckThread checkThread = new CheckThread(futures, promptRequestVO.getUid());
|
||||
checkExecutor.submit(checkThread);
|
||||
} else if (promptRequestVO.getTaskType() == 2) {
|
||||
ExecutorService executorService = taskExecutorMap.get(promptRequestVO.getUid());
|
||||
executorService.shutdown();
|
||||
ExecutorService checkService = checkExecutorMap.get(promptRequestVO.getUid());
|
||||
checkService.shutdown();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static class CheckThread extends Thread {
|
||||
|
||||
private final List<Future<ExecResultDTO>> futures;
|
||||
|
||||
private final String uid;
|
||||
|
||||
public CheckThread(List<Future<ExecResultDTO>> futures, String uid) {
|
||||
this.futures = futures;
|
||||
this.uid = uid;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
log.info("休眠5秒,5秒之后再去查询三元组的结果,需要查询的任务数量为:{}", futures.size());
|
||||
Thread.sleep(1000 * 5);
|
||||
} catch (Exception e) {
|
||||
log.error(e.getMessage(), e);
|
||||
}
|
||||
// 计数器
|
||||
AtomicInteger atomicInteger = new AtomicInteger(0);
|
||||
while (futures.size() > 0) {
|
||||
Iterator<Future<ExecResultDTO>> iterator = futures.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
Future<ExecResultDTO> future = iterator.next();
|
||||
try {
|
||||
// 如果提取到结果,且不为空,就进行保存
|
||||
if (future.isDone()) {
|
||||
ExecResultDTO result = future.get();
|
||||
if (ObjectUtil.isNotEmpty(result) && StrUtil.isNotBlank(result.getResult())) {
|
||||
ConcurrentLinkedDeque<ExecResultDTO> strings = resultMap.get(uid);
|
||||
// 从尾部添加
|
||||
strings.offerLast(result);
|
||||
}
|
||||
iterator.remove();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.info("从线程中获取任务失败");
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
try {
|
||||
int currentCount = atomicInteger.incrementAndGet();
|
||||
if (currentCount > 1000) {
|
||||
log.info("任务执行超时,遍历任务已执行:{}次,任务还剩余:{}个,不再继续执行", currentCount, futures.size());
|
||||
// 将还在执行的线程中断
|
||||
futures.forEach(future -> {
|
||||
future.cancel(true);
|
||||
});
|
||||
break;
|
||||
}
|
||||
log.info("已检查{}遍,任务剩余{}个,休眠5s后继续检查", currentCount, futures.size());
|
||||
Thread.sleep(1000 * 5);
|
||||
} catch (Exception e) {
|
||||
log.error(e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
log.info("执行完毕");
|
||||
statusSet.add(uid);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ResultResVO queryResult(String uid) {
|
||||
ConcurrentLinkedDeque<ExecResultDTO> strings = resultMap.get(uid);
|
||||
List<ExecResultDTO> result = new ArrayList<>();
|
||||
if (CollUtil.isNotEmpty(strings)) {
|
||||
while (true) {
|
||||
ExecResultDTO s = strings.pollFirst();
|
||||
if (ObjectUtil.isEmpty(s)) {
|
||||
break;
|
||||
} else {
|
||||
result.add(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
ResultResVO resultResVO = new ResultResVO();
|
||||
resultResVO.setDataList(result);
|
||||
if (statusSet.contains(uid)) {
|
||||
resultResVO.setStatus(2);
|
||||
} else {
|
||||
resultResVO.setStatus(1);
|
||||
}
|
||||
return resultResVO;
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package com.supervision.prompt.thread;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.util.ObjectUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.json.JSONUtil;
|
||||
import com.supervision.police.domain.NotePrompt;
|
||||
import com.supervision.police.domain.TripleInfo;
|
||||
import com.supervision.prompt.vo.ExecResultDTO;
|
||||
import com.supervision.thread.TripleExtractThread;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.ChatResponse;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.ollama.OllamaChatClient;
|
||||
import org.springframework.util.StopWatch;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.Callable;
|
||||
|
||||
@Slf4j
|
||||
public class PromptTestThread implements Callable<ExecResultDTO> {
|
||||
|
||||
private final OllamaChatClient chatClient;
|
||||
|
||||
private final Map<String, String> paramMap;
|
||||
|
||||
private final String prompt;
|
||||
|
||||
public PromptTestThread(OllamaChatClient chatClient, String prompt, Map<String, String> paramMap) {
|
||||
this.chatClient = chatClient;
|
||||
this.paramMap = paramMap;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecResultDTO call() {
|
||||
try {
|
||||
StopWatch stopWatch = new StopWatch();
|
||||
// 分析三元组
|
||||
stopWatch.start();
|
||||
String format = StrUtil.format(prompt, paramMap);
|
||||
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
|
||||
stopWatch.stop();
|
||||
String content = call.getResult().getOutput().getContent();
|
||||
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
|
||||
ExecResultDTO resultDTO = new ExecResultDTO();
|
||||
resultDTO.setResult(content);
|
||||
resultDTO.setSubmitPrompt(format);
|
||||
resultDTO.setQuestion(paramMap.get("question"));
|
||||
resultDTO.setAnswer(paramMap.get("answer"));
|
||||
|
||||
return resultDTO;
|
||||
} catch (Exception e) {
|
||||
log.error("提取三元组出现错误", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package com.supervision.prompt.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ExecResultDTO {
|
||||
|
||||
private String result;
|
||||
|
||||
private String submitPrompt;
|
||||
|
||||
private String question;
|
||||
|
||||
private String answer;
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package com.supervision.prompt.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ParamTestDTO {
|
||||
|
||||
private String key;
|
||||
|
||||
private String value;
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package com.supervision.prompt.vo;
|
||||
|
||||
import io.swagger.annotations.ApiModelProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class PromptRequestVO {
|
||||
|
||||
private Integer taskType;
|
||||
|
||||
private String fileId;
|
||||
|
||||
private List<ParamTestDTO> paramList;
|
||||
|
||||
private String prompt;
|
||||
|
||||
private String personName;
|
||||
|
||||
@ApiModelProperty("前端生成的唯一ID")
|
||||
private String uid;
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package com.supervision.prompt.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ResultResVO {
|
||||
|
||||
private List<ExecResultDTO> dataList;
|
||||
/*
|
||||
* 状态 1进行中,2已完成
|
||||
*/
|
||||
private Integer status;
|
||||
}
|
Loading…
Reference in New Issue