From 0cbee971023ab54ef6f9155087256b8d28ea06d6 Mon Sep 17 00:00:00 2001 From: liu <liujiatong112@163.com> Date: Thu, 23 May 2024 11:06:00 +0800 Subject: [PATCH] demo --- .../nxllm/thread/RunCheckThread.java | 91 +++++++++++++++++++ .../nxllm/thread/RunCheckThreadPool.java | 10 ++ 2 files changed, 101 insertions(+) create mode 100644 src/main/java/com/supervision/nxllm/thread/RunCheckThread.java create mode 100644 src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java diff --git a/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java b/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java new file mode 100644 index 0000000..c307c30 --- /dev/null +++ b/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java @@ -0,0 +1,91 @@ +package com.supervision.nxllm.thread; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; +import com.supervision.springaidemo.domain.ModelMetric; +import com.supervision.springaidemo.domain.NoteCheckRecord; +import com.supervision.springaidemo.dto.MetricResultDTO; +import com.supervision.springaidemo.service.NoteCheckRecordService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.util.StopWatch; + +@Slf4j +public class RunCheckThread implements Runnable { + + private final String caseName; + + private final OllamaChatClient chatClient; + + private final NoteCheckRecordService noteCheckRecordService; + + private final Prompt prompt; + + private final String fileName; + + private final String format; + + private final String systemPrompt; + + private final ModelMetric modelMetric; + + + private Integer count; + + public RunCheckThread(String caseName, OllamaChatClient chatClient, NoteCheckRecordService noteCheckRecordService, Prompt prompt, String fileName, String format, String systemPrompt, ModelMetric modelMetric, Integer count) { + this.caseName = caseName; + this.chatClient = chatClient; + this.noteCheckRecordService = noteCheckRecordService; + this.prompt = prompt; + this.fileName = fileName; + this.format = format; + this.systemPrompt = systemPrompt; + this.modelMetric = modelMetric; + this.count = count; + } + + @Override + public void run() { + try { + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); + log.info("开始分析:{}",fileName); + ChatResponse call = chatClient.call(prompt); + stopWatch.stop(); + log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); + Generation result = call.getResult(); + + String content = result.getOutput().getContent(); + log.info("分析的结果是:{}", content); + MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class); + // 如果为空,则再跑一次,最多跑5次 + if (StrUtil.isBlank(metricResultDTO.getResult())) { + if (count > 5) { + log.info("{}的{}结果超过5次,不再继续跑了", fileName, modelMetric); + } else { + log.info("{}的{}结果为空,当前跑了{}次,重新提交,再跑一次", fileName, modelMetric, count); + Integer newCount = count++; + RunCheckThread runCheck = new RunCheckThread(caseName, chatClient, noteCheckRecordService, prompt, fileName, format, systemPrompt, modelMetric, newCount); + RunCheckThreadPool.chatExecutor.submit(runCheck); + } + } else { + NoteCheckRecord noteCheckRecord = new NoteCheckRecord(); + noteCheckRecord.setCaseName(caseName); + noteCheckRecord.setNoteName(fileName); + noteCheckRecord.setMetricCode(modelMetric.getMetricCode()); + noteCheckRecord.setMetricName(modelMetric.getMetricName()); + noteCheckRecord.setSystemPrompt(systemPrompt); + noteCheckRecord.setPrompt(format); + noteCheckRecord.setResult(metricResultDTO.getResult()); + noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext()); + noteCheckRecord.setReason(metricResultDTO.getReason()); + noteCheckRecordService.save(noteCheckRecord); + } + } catch (Exception e) { + log.error("出现错误", e); + } + } +} diff --git a/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java b/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java new file mode 100644 index 0000000..21e48fd --- /dev/null +++ b/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java @@ -0,0 +1,10 @@ +package com.supervision.nxllm.thread; + +import cn.hutool.core.thread.ThreadUtil; + +import java.util.concurrent.ExecutorService; + +public class RunCheckThreadPool { + + public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "chat", false); +}