You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
virtual-patient/virtual-patient-common/src/main/java/com/supervision/util/AiChatUtil.java

127 lines
4.4 KiB
Java

package com.supervision.util;
import cn.hutool.core.thread.ThreadUtil;
import cn.hutool.json.JSONObject;
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.ollama.api.OllamaOptions;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.*;
@Slf4j
public class AiChatUtil {
private static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 5, "chat", new ThreadPoolExecutor.CallerRunsPolicy());
private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class);
/**
* 单轮对话
*
* @param chat 对话的内容
* @return jsonObject
*/
public static Optional<JSONObject> chat(String chat) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
return Optional.of(JSONUtil.parseObj(submit.get()));
} catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败", e);
}
return Optional.empty();
}
/**
* 单轮对话
*
* @param chat 对话的内容
* @return jsonObject
*/
public static Optional<JSONObject> chatWithRandom(String chat, Integer seed) {
OllamaOptions ollamaOptions = new OllamaOptions();
ollamaOptions.setSeed(seed);
Prompt prompt = new Prompt(List.of(new UserMessage(chat)), ollamaOptions);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
return Optional.of(JSONUtil.parseObj(submit.get()));
} catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败");
}
return Optional.empty();
}
/**
* 支持多轮对话,自定义消息
*
* @param messageList 消息列表
* @return jsonObject
*/
public static Optional<JSONObject> chat(List<Message> messageList) {
Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
return Optional.of(JSONUtil.parseObj(submit.get()));
} catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败", e);
}
return Optional.empty();
}
/**
* 支持序列化的方式
*
* @param messageList 消息列表
* @param clazz 需要序列化的对象
* @param <T> 需要序列化的对象的泛型
* @return 对应对象类型, 不支持列表类型
*/
public static <T> Optional<T> chat(List<Message> messageList, Class<T> clazz) {
Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
String s = submit.get();
return Optional.ofNullable(JSONUtil.toBean(s, clazz));
} catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败", e);
}
return Optional.empty();
}
/**
* 支持序列化的方式的对话
*
* @param chat 对话的消息
* @param clazz 需要序列化的对象
* @param <T> 需要序列化的对象的泛型
* @return 对应对象类型, 不支持列表类型
*/
public static <T> Optional<T> chat(String chat, Class<T> clazz) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
String s = submit.get();
return Optional.ofNullable(JSONUtil.toBean(s, clazz));
} catch (ExecutionException | InterruptedException e) {
log.error("调用大模型生成失败", e);
}
return Optional.empty();
}
private record ChatTask(OllamaChatClient chatClient, Prompt prompt) implements Callable<String> {
@Override
public String call() {
ChatResponse call = chatClient.call(prompt);
return call.getResult().getOutput().getContent();
}
}
}