From 0cbee971023ab54ef6f9155087256b8d28ea06d6 Mon Sep 17 00:00:00 2001
From: liu <liujiatong112@163.com>
Date: Thu, 23 May 2024 11:06:00 +0800
Subject: [PATCH] demo

---
 .../nxllm/thread/RunCheckThread.java          | 91 +++++++++++++++++++
 .../nxllm/thread/RunCheckThreadPool.java      | 10 ++
 2 files changed, 101 insertions(+)
 create mode 100644 src/main/java/com/supervision/nxllm/thread/RunCheckThread.java
 create mode 100644 src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java

diff --git a/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java b/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java
new file mode 100644
index 0000000..c307c30
--- /dev/null
+++ b/src/main/java/com/supervision/nxllm/thread/RunCheckThread.java
@@ -0,0 +1,91 @@
+package com.supervision.nxllm.thread;
+
+import cn.hutool.core.util.StrUtil;
+import cn.hutool.json.JSONUtil;
+import com.supervision.springaidemo.domain.ModelMetric;
+import com.supervision.springaidemo.domain.NoteCheckRecord;
+import com.supervision.springaidemo.dto.MetricResultDTO;
+import com.supervision.springaidemo.service.NoteCheckRecordService;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.ai.chat.ChatResponse;
+import org.springframework.ai.chat.Generation;
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.ollama.OllamaChatClient;
+import org.springframework.util.StopWatch;
+
+@Slf4j
+public class RunCheckThread implements Runnable {
+
+    private final String caseName;
+
+    private final OllamaChatClient chatClient;
+
+    private final NoteCheckRecordService noteCheckRecordService;
+
+    private final Prompt prompt;
+
+    private final String fileName;
+
+    private final String format;
+
+    private final String systemPrompt;
+
+    private final ModelMetric modelMetric;
+
+
+    private Integer count;
+
+    public RunCheckThread(String caseName, OllamaChatClient chatClient, NoteCheckRecordService noteCheckRecordService, Prompt prompt, String fileName, String format, String systemPrompt, ModelMetric modelMetric, Integer count) {
+        this.caseName = caseName;
+        this.chatClient = chatClient;
+        this.noteCheckRecordService = noteCheckRecordService;
+        this.prompt = prompt;
+        this.fileName = fileName;
+        this.format = format;
+        this.systemPrompt = systemPrompt;
+        this.modelMetric = modelMetric;
+        this.count = count;
+    }
+
+    @Override
+    public void run() {
+        try {
+            StopWatch stopWatch = new StopWatch();
+            stopWatch.start();
+            log.info("开始分析:{}",fileName);
+            ChatResponse call = chatClient.call(prompt);
+            stopWatch.stop();
+            log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
+            Generation result = call.getResult();
+
+            String content = result.getOutput().getContent();
+            log.info("分析的结果是:{}", content);
+            MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class);
+            // 如果为空,则再跑一次,最多跑5次
+            if (StrUtil.isBlank(metricResultDTO.getResult())) {
+                if (count > 5) {
+                    log.info("{}的{}结果超过5次,不再继续跑了", fileName, modelMetric);
+                } else {
+                    log.info("{}的{}结果为空,当前跑了{}次,重新提交,再跑一次", fileName, modelMetric, count);
+                    Integer newCount = count++;
+                    RunCheckThread runCheck = new RunCheckThread(caseName, chatClient, noteCheckRecordService, prompt, fileName, format, systemPrompt, modelMetric, newCount);
+                    RunCheckThreadPool.chatExecutor.submit(runCheck);
+                }
+            } else {
+                NoteCheckRecord noteCheckRecord = new NoteCheckRecord();
+                noteCheckRecord.setCaseName(caseName);
+                noteCheckRecord.setNoteName(fileName);
+                noteCheckRecord.setMetricCode(modelMetric.getMetricCode());
+                noteCheckRecord.setMetricName(modelMetric.getMetricName());
+                noteCheckRecord.setSystemPrompt(systemPrompt);
+                noteCheckRecord.setPrompt(format);
+                noteCheckRecord.setResult(metricResultDTO.getResult());
+                noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext());
+                noteCheckRecord.setReason(metricResultDTO.getReason());
+                noteCheckRecordService.save(noteCheckRecord);
+            }
+        } catch (Exception e) {
+            log.error("出现错误", e);
+        }
+    }
+}
diff --git a/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java b/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java
new file mode 100644
index 0000000..21e48fd
--- /dev/null
+++ b/src/main/java/com/supervision/nxllm/thread/RunCheckThreadPool.java
@@ -0,0 +1,10 @@
+package com.supervision.nxllm.thread;
+
+import cn.hutool.core.thread.ThreadUtil;
+
+import java.util.concurrent.ExecutorService;
+
+public class RunCheckThreadPool {
+
+    public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "chat", false);
+}