添加智能对话功能代码
parent
797c3257ac
commit
969a7cf534
@ -0,0 +1,46 @@
|
||||
package com.supervision.controller;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.supervision.domain.UserDetail;
|
||||
import com.supervision.dto.AgentChatReqDTO;
|
||||
import com.supervision.dto.AgentInfoDTO;
|
||||
import com.supervision.dto.ChatResponseDTO;
|
||||
import com.supervision.dto.R;
|
||||
import com.supervision.service.AgentService;
|
||||
import com.supervision.util.UserUtil;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
@Slf4j
|
||||
@RestController
|
||||
@RequestMapping("/agent")
|
||||
@RequiredArgsConstructor
|
||||
public class AgentController {
|
||||
|
||||
private final AgentService agentService;
|
||||
|
||||
/**
|
||||
* 流式智能体问答
|
||||
* @param agentChatReqDTO agentChatReqDTO
|
||||
* @return
|
||||
*/
|
||||
@PostMapping(value = "/streamChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public Flux<ChatResponseDTO> streamChat(@RequestBody AgentChatReqDTO agentChatReqDTO) {
|
||||
UserDetail userDetail = UserUtil.currentUser();
|
||||
agentChatReqDTO.setUserId(userDetail.getUserId());
|
||||
return agentService.streamChat(agentChatReqDTO);
|
||||
}
|
||||
|
||||
@GetMapping("/pageList")
|
||||
public R<IPage<AgentInfoDTO>> pageList(@RequestParam (name = "page", required = false,defaultValue = "1") Integer page,
|
||||
@RequestParam (name = "pageSize", required = false,defaultValue = "10") Integer pageSize) {
|
||||
|
||||
IPage<AgentInfoDTO> paged = agentService.pageList(page, pageSize);
|
||||
return R.ok(paged);
|
||||
}
|
||||
|
||||
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
package com.supervision.controller;
|
||||
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import com.supervision.dto.R;
|
||||
import com.supervision.service.AuthService;
|
||||
import com.supervision.vo.LoginReqVo;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/auth")
|
||||
@RequiredArgsConstructor
|
||||
public class AuthController {
|
||||
|
||||
private final AuthService authService;
|
||||
|
||||
@PostMapping("/login")
|
||||
public R<String> login(@RequestBody LoginReqVo loginReqVo) {
|
||||
Assert.notEmpty(loginReqVo.getUsername(), "用户名不能为空");
|
||||
String token = authService.login(loginReqVo.getUsername(), loginReqVo.getPassword());
|
||||
return R.ok(token);
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
package com.supervision.domain;
|
||||
|
||||
import org.springframework.security.core.GrantedAuthority;
|
||||
import org.springframework.security.core.userdetails.User;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
public class UserDetail extends User {
|
||||
private String userId;
|
||||
|
||||
|
||||
public UserDetail(String userId ,String username, String password, Collection<? extends GrantedAuthority> authorities) {
|
||||
super(username, password, authorities);
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
public UserDetail(String username, String password, Collection<? extends GrantedAuthority> authorities) {
|
||||
super(username, password, authorities);
|
||||
}
|
||||
|
||||
public UserDetail(String username, String password, boolean enabled, boolean accountNonExpired, boolean credentialsNonExpired, boolean accountNonLocked, Collection<? extends GrantedAuthority> authorities) {
|
||||
super(username, password, enabled, accountNonExpired, credentialsNonExpired, accountNonLocked, authorities);
|
||||
}
|
||||
|
||||
public void setUserId(String userId) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
public String getUserId() {
|
||||
return userId;
|
||||
}
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
package com.supervision.dto;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AgentChatReqDTO {
|
||||
|
||||
/**
|
||||
* 智能体id
|
||||
*/
|
||||
private String agentId;
|
||||
|
||||
/**
|
||||
* 问题
|
||||
*/
|
||||
private String query;
|
||||
|
||||
/**
|
||||
* 对话id
|
||||
*/
|
||||
private String conversationId;
|
||||
|
||||
/**
|
||||
* 用户id
|
||||
*/
|
||||
private String userId;
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
package com.supervision.dto;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatResponseDTO {
|
||||
|
||||
/**
|
||||
* 事件类型 message, message_end
|
||||
*/
|
||||
private String event;
|
||||
|
||||
/**
|
||||
* 消息对话id
|
||||
*/
|
||||
private String conversationId;
|
||||
|
||||
/**
|
||||
* 回复内容
|
||||
*/
|
||||
private String answer;
|
||||
}
|
@ -0,0 +1,15 @@
|
||||
package com.supervision.exception;
|
||||
|
||||
public class UnauthorizedException extends RuntimeException {
|
||||
public UnauthorizedException() {
|
||||
super("用户未登录");
|
||||
}
|
||||
|
||||
public UnauthorizedException(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public UnauthorizedException(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
package com.supervision.service;
|
||||
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.supervision.dto.AgentChatReqDTO;
|
||||
import com.supervision.dto.AgentInfoDTO;
|
||||
import com.supervision.dto.ChatResponseDTO;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
public interface AgentService {
|
||||
|
||||
/**
|
||||
* 流式智能体问答
|
||||
* @param agentChatReqDTO agentChatReqDTO
|
||||
* @return
|
||||
*/
|
||||
Flux<ChatResponseDTO> streamChat(AgentChatReqDTO agentChatReqDTO);
|
||||
|
||||
/**
|
||||
* 分页查询智能体信息
|
||||
* @param page
|
||||
* @param pageSize
|
||||
* @return
|
||||
*/
|
||||
IPage<AgentInfoDTO> pageList(Integer page, Integer pageSize);
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package com.supervision.service;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* @description: AI调用服务
|
||||
*/
|
||||
public interface AiCallService {
|
||||
|
||||
|
||||
String call(String prompt);
|
||||
|
||||
Flux<ChatResponse> stream(Prompt prompt);
|
||||
|
||||
Embedding embedding(String text);
|
||||
|
||||
List<Embedding> embedding(List<String> texts);
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
package com.supervision.service;
|
||||
|
||||
public interface AuthService {
|
||||
|
||||
String login(String username, String password);
|
||||
|
||||
}
|
@ -0,0 +1,54 @@
|
||||
package com.supervision.service;
|
||||
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
|
||||
public class ThinkTagState {
|
||||
private boolean inThink = false;
|
||||
private final StringBuilder buffer = new StringBuilder();
|
||||
|
||||
private ChatResponse chatResponse;
|
||||
|
||||
public ThinkTagState() {
|
||||
System.out.println("ThinkTagState initialized");
|
||||
}
|
||||
|
||||
public ThinkTagState process(String chunk) {
|
||||
StringBuilder output = new StringBuilder();
|
||||
String remaining = chunk;
|
||||
|
||||
while (!remaining.isEmpty()) {
|
||||
if (!inThink) {
|
||||
int startIdx = remaining.indexOf("<think>");
|
||||
if (startIdx == -1) {
|
||||
output.append(remaining);
|
||||
break;
|
||||
}
|
||||
output.append(remaining.substring(0, startIdx));
|
||||
inThink = true;
|
||||
remaining = remaining.substring(startIdx + "<think>".length());
|
||||
} else {
|
||||
int endIdx = remaining.indexOf("</think>");
|
||||
if (endIdx == -1) {
|
||||
break; // 等待后续分块
|
||||
}
|
||||
inThink = false;
|
||||
remaining = remaining.substring(endIdx + "</think>".length());
|
||||
}
|
||||
}
|
||||
|
||||
buffer.append(output);
|
||||
return this;
|
||||
}
|
||||
|
||||
public String getFilteredText() {
|
||||
return buffer.toString();
|
||||
}
|
||||
|
||||
public ChatResponse getChatResponse() {
|
||||
return chatResponse;
|
||||
}
|
||||
|
||||
public void setChatResponse(ChatResponse chatResponse) {
|
||||
this.chatResponse = chatResponse;
|
||||
}
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
package com.supervision.service.impl;
|
||||
|
||||
import cn.hutool.core.collection.CollUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.lang.UUID;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import com.baomidou.mybatisplus.core.metadata.IPage;
|
||||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
||||
import com.supervision.constant.PromptTemplate;
|
||||
import com.supervision.domain.AgentDialogueLog;
|
||||
import com.supervision.domain.AgentInfo;
|
||||
import com.supervision.dto.AgentChatReqDTO;
|
||||
import com.supervision.dto.AgentDialogueLogDTO;
|
||||
import com.supervision.dto.AgentInfoDTO;
|
||||
import com.supervision.dto.ChatResponseDTO;
|
||||
import com.supervision.service.*;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.messages.AssistantMessage;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Flux;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class AgentServiceImpl implements AgentService {
|
||||
|
||||
private final AiCallService aiCallService;
|
||||
|
||||
private final AgentInfoService agentInfoService;
|
||||
|
||||
private final AgentDialogueLogService agentDialogueLogService;
|
||||
@Override
|
||||
public Flux<ChatResponseDTO> streamChat(AgentChatReqDTO agentChatReqDTO) {
|
||||
|
||||
// 查询历史对话(记忆功能)
|
||||
List<Message> messages = buildHistoryMessages(agentChatReqDTO.getConversationId(), 10);
|
||||
// 构建当前对话
|
||||
Message message = buildQueryMessage(agentChatReqDTO.getQuery(), agentChatReqDTO.getAgentId());
|
||||
Assert.notNull(message, "Agent not found for ID: " + agentChatReqDTO.getAgentId());
|
||||
messages.add(message);
|
||||
AgentDialogueLogDTO dialogueLogDTO = new AgentDialogueLogDTO(agentChatReqDTO);
|
||||
StringBuilder aiResponseBuilder = new StringBuilder();
|
||||
Flux<ChatResponseDTO> map = aiCallService.stream(new Prompt(messages))
|
||||
.scan(new ThinkTagState(), (state, response) -> {
|
||||
String chunk = response.getResult().getOutput().getText();
|
||||
state.setChatResponse(response);
|
||||
return state.process(chunk);
|
||||
})
|
||||
.filter(state -> StrUtil.isNotEmpty(state.getFilteredText()))
|
||||
.map(response -> {
|
||||
ChatResponseDTO responseDTO = new ChatResponseDTO();
|
||||
responseDTO.setAnswer(response.getChatResponse().getResult().getOutput().getText());
|
||||
responseDTO.setConversationId(StrUtil.isNotEmpty(agentChatReqDTO.getConversationId()) ? agentChatReqDTO.getConversationId() : UUID.fastUUID().toString());
|
||||
ChatResponseMetadata metadata = response.getChatResponse().getMetadata();
|
||||
Boolean done = metadata.get("done");
|
||||
responseDTO.setEvent(done ? "message_end" : "message");
|
||||
dialogueLogDTO.setConversationId(responseDTO.getConversationId());
|
||||
aiResponseBuilder.append(responseDTO.getAnswer());
|
||||
return responseDTO;
|
||||
}).doOnComplete(() -> {
|
||||
dialogueLogDTO.setSystemOut(aiResponseBuilder.toString());
|
||||
agentDialogueLogService.save(dialogueLogDTO.toDialogueLog());
|
||||
}).doOnError(e -> {
|
||||
log.error("Error during AI chat stream: {}", e.getMessage());
|
||||
dialogueLogDTO.setSystemOut(aiResponseBuilder.toString());
|
||||
dialogueLogDTO.setAnswerType(1);
|
||||
agentDialogueLogService.save(dialogueLogDTO.toDialogueLog());
|
||||
});
|
||||
return map;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IPage<AgentInfoDTO> pageList(Integer page, Integer pageSize) {
|
||||
IPage<AgentInfo> paged = agentInfoService.page(Page.of(page, pageSize));
|
||||
return paged.convert(AgentInfoDTO::new);
|
||||
}
|
||||
|
||||
private Message buildQueryMessage(String query,String agentId) {
|
||||
if (StrUtil.equals(agentId, "1")) {
|
||||
String template = PromptTemplate.GENERAL_INDUSTRY_TEMPLATE;
|
||||
return new UserMessage(StrUtil.format(template, Map.of("query",query)));
|
||||
}else if (StrUtil.equals(agentId, "2")) {
|
||||
String template = PromptTemplate.HEART_GUIDE_TEMPLATE;
|
||||
return new UserMessage(StrUtil.format(template, Map.of("query",query)));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
private List<Message> buildHistoryMessages(String conversationId,int historySize) {
|
||||
List<Message> messages = new ArrayList<>();
|
||||
if (StrUtil.isEmpty(conversationId)){
|
||||
return messages;
|
||||
}
|
||||
IPage<AgentDialogueLog> dialogueLogs = agentDialogueLogService.pageList(Page.of(1, historySize), conversationId);
|
||||
dialogueLogs.getRecords().sort((o1, o2) -> - o2.getCreateTime().compareTo(o1.getCreateTime()));
|
||||
if (CollUtil.isNotEmpty(dialogueLogs.getRecords())) {
|
||||
for (AgentDialogueLog dialogueLog : dialogueLogs.getRecords()) {
|
||||
if (dialogueLog.getAnswerType() == 1){
|
||||
continue;
|
||||
}
|
||||
if (StrUtil.isNotEmpty(dialogueLog.getUserInput()) && historySize > 0) {
|
||||
messages.add(new UserMessage(dialogueLog.getUserInput()));
|
||||
String result = dialogueLog.getSystemOut().replaceAll("(?is)<think\\b[^>]*>(.*?)</think>", "").trim();
|
||||
messages.add(new AssistantMessage(result));
|
||||
historySize--;
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
package com.supervision.service.impl;
|
||||
|
||||
import com.supervision.service.AiCallService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.embedding.Embedding;
|
||||
import org.springframework.ai.embedding.EmbeddingRequest;
|
||||
import org.springframework.ai.embedding.EmbeddingResponse;
|
||||
import org.springframework.ai.ollama.OllamaChatModel;
|
||||
import org.springframework.ai.ollama.OllamaEmbeddingModel;
|
||||
import org.springframework.stereotype.Service;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class OllamaCallServiceImpl implements AiCallService {
|
||||
|
||||
private final OllamaChatModel ollamaChatModel;
|
||||
|
||||
private final OllamaEmbeddingModel embeddingModel;
|
||||
@Override
|
||||
public String call(String prompt) {
|
||||
|
||||
return ollamaChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return ollamaChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
public Embedding embedding(String text) {
|
||||
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of(text),null));
|
||||
return embeddingResponse.getResult();
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Embedding> embedding(List<String> texts) {
|
||||
EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(texts,null));
|
||||
return embeddingResponse.getResults();
|
||||
}
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
package com.supervision.util;
|
||||
|
||||
import com.supervision.domain.UserDetail;
|
||||
import com.supervision.exception.UnauthorizedException;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContext;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
|
||||
public class UserUtil {
|
||||
|
||||
public static UserDetail currentUser() {
|
||||
SecurityContext context = SecurityContextHolder.getContext();
|
||||
if (null == context) {
|
||||
throw new UnauthorizedException("未登录或登录已过期,请重新登录");
|
||||
}
|
||||
Authentication authentication = context.getAuthentication();
|
||||
Object principal = authentication.getPrincipal();
|
||||
if ("anonymousUser".equals(principal)) {
|
||||
throw new UnauthorizedException("未登录或登录已过期,请重新登录");
|
||||
}
|
||||
return (UserDetail) authentication.getPrincipal();
|
||||
}
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package com.supervision.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LoginReqVo {
|
||||
|
||||
private String username;
|
||||
|
||||
private String password;
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
package com.supervision;
|
||||
|
||||
import com.supervision.domain.SysByteArray;
|
||||
import com.supervision.service.SysByteArrayService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
@Slf4j
|
||||
@SpringBootTest
|
||||
public class PlatformApplicationTest {
|
||||
|
||||
@Autowired
|
||||
private SysByteArrayService sysByteArrayService;
|
||||
@Test
|
||||
void saveByteArray() {
|
||||
SysByteArray sysByteArray = new SysByteArray();
|
||||
sysByteArray.setBytes(new byte[]{1,2,3});
|
||||
boolean save = sysByteArrayService.save(sysByteArray);
|
||||
System.out.println("保存结果: " + save);
|
||||
}
|
||||
|
||||
@Test
|
||||
void queryByteArray() {
|
||||
SysByteArray sysByteArray = sysByteArrayService.getById("1945676008645058562");
|
||||
log.info("查询结果: {}", sysByteArray);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue