main
liu 11 months ago
parent 54cff96d45
commit 0cbee97102

@ -0,0 +1,91 @@
package com.supervision.nxllm.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.springaidemo.service.NoteCheckRecordService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
@Slf4j
public class RunCheckThread implements Runnable {
private final String caseName;
private final OllamaChatClient chatClient;
private final NoteCheckRecordService noteCheckRecordService;
private final Prompt prompt;
private final String fileName;
private final String format;
private final String systemPrompt;
private final ModelMetric modelMetric;
private Integer count;
public RunCheckThread(String caseName, OllamaChatClient chatClient, NoteCheckRecordService noteCheckRecordService, Prompt prompt, String fileName, String format, String systemPrompt, ModelMetric modelMetric, Integer count) {
this.caseName = caseName;
this.chatClient = chatClient;
this.noteCheckRecordService = noteCheckRecordService;
this.prompt = prompt;
this.fileName = fileName;
this.format = format;
this.systemPrompt = systemPrompt;
this.modelMetric = modelMetric;
this.count = count;
}
@Override
public void run() {
try {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
log.info("开始分析:{}",fileName);
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info("分析的结果是:{}", content);
MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class);
// 如果为空,则再跑一次,最多跑5次
if (StrUtil.isBlank(metricResultDTO.getResult())) {
if (count > 5) {
log.info("{}的{}结果超过5次,不再继续跑了", fileName, modelMetric);
} else {
log.info("{}的{}结果为空,当前跑了{}次,重新提交,再跑一次", fileName, modelMetric, count);
Integer newCount = count++;
RunCheckThread runCheck = new RunCheckThread(caseName, chatClient, noteCheckRecordService, prompt, fileName, format, systemPrompt, modelMetric, newCount);
RunCheckThreadPool.chatExecutor.submit(runCheck);
}
} else {
NoteCheckRecord noteCheckRecord = new NoteCheckRecord();
noteCheckRecord.setCaseName(caseName);
noteCheckRecord.setNoteName(fileName);
noteCheckRecord.setMetricCode(modelMetric.getMetricCode());
noteCheckRecord.setMetricName(modelMetric.getMetricName());
noteCheckRecord.setSystemPrompt(systemPrompt);
noteCheckRecord.setPrompt(format);
noteCheckRecord.setResult(metricResultDTO.getResult());
noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext());
noteCheckRecord.setReason(metricResultDTO.getReason());
noteCheckRecordService.save(noteCheckRecord);
}
} catch (Exception e) {
log.error("出现错误", e);
}
}
}

@ -0,0 +1,10 @@
package com.supervision.nxllm.thread;
import cn.hutool.core.thread.ThreadUtil;
import java.util.concurrent.ExecutorService;
public class RunCheckThreadPool {
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "chat", false);
}
Loading…
Cancel
Save