diff --git a/pom.xml b/pom.xml index bd1ebf9..913df1f 100644 --- a/pom.xml +++ b/pom.xml @@ -109,6 +109,10 @@ 0.12.5 runtime + + org.springframework.ai + spring-ai-starter-model-ollama + diff --git a/src/main/java/com/supervision/config/ExceptionHandlerConfig.java b/src/main/java/com/supervision/config/ExceptionHandlerConfig.java index c3fe1ed..2857afd 100644 --- a/src/main/java/com/supervision/config/ExceptionHandlerConfig.java +++ b/src/main/java/com/supervision/config/ExceptionHandlerConfig.java @@ -3,6 +3,7 @@ package com.supervision.config; import com.supervision.constant.ResultStatusEnum; import com.supervision.dto.R; import com.supervision.exception.BusinessException; +import com.supervision.exception.UnauthorizedException; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; import org.springframework.web.bind.annotation.ExceptionHandler; @@ -44,6 +45,15 @@ public class ExceptionHandlerConfig { return R.fail(511, exception.getMessage()); } + @ExceptionHandler(UnauthorizedException.class) + public R unauthorizedExceptionResponse(UnauthorizedException exception) { + log.error("=========运行异常=========>>>"); + log.error(exception.getMessage(), exception); + log.error("<<<=========运行异常========="); + + return R.fail(400, exception.getMessage()); + } + @ExceptionHandler(RuntimeException.class) public R manualValidationExceptionResponse(RuntimeException exception) { log.error("=========运行异常=========>>>"); diff --git a/src/main/java/com/supervision/config/SecurityConfig.java b/src/main/java/com/supervision/config/SecurityConfig.java index d0d4884..0ce1a23 100644 --- a/src/main/java/com/supervision/config/SecurityConfig.java +++ b/src/main/java/com/supervision/config/SecurityConfig.java @@ -36,7 +36,7 @@ public class SecurityConfig { http .csrf(AbstractHttpConfigurer::disable) // 禁用CSRF .authorizeHttpRequests(auth -> auth - .requestMatchers("/auth/**").permitAll() + .requestMatchers("/auth/**","/agent/streamChat").permitAll() .anyRequest().authenticated() ) .sessionManagement(session -> session diff --git a/src/main/java/com/supervision/constant/PromptTemplate.java b/src/main/java/com/supervision/constant/PromptTemplate.java new file mode 100644 index 0000000..ec5b6b0 --- /dev/null +++ b/src/main/java/com/supervision/constant/PromptTemplate.java @@ -0,0 +1,15 @@ +package com.supervision.constant; + +public class PromptTemplate { + + + public static final String GENERAL_INDUSTRY_TEMPLATE = """ + 你是一个通用行业的智能体,擅长处理各种行业的任务。请根据用户的需求,提供专业的建议和解决方案。 + 用户需求:{query} + """; + + public static final String HEART_GUIDE_TEMPLATE = """ + 你是一个心灵导师,擅长提供情感支持和心理指导。请根据用户的需求,提供温暖和关怀的建议。 + 用户需求:{query} + """; +} diff --git a/src/main/java/com/supervision/controller/AgentController.java b/src/main/java/com/supervision/controller/AgentController.java new file mode 100644 index 0000000..f03511b --- /dev/null +++ b/src/main/java/com/supervision/controller/AgentController.java @@ -0,0 +1,46 @@ +package com.supervision.controller; + +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.supervision.domain.UserDetail; +import com.supervision.dto.AgentChatReqDTO; +import com.supervision.dto.AgentInfoDTO; +import com.supervision.dto.ChatResponseDTO; +import com.supervision.dto.R; +import com.supervision.service.AgentService; +import com.supervision.util.UserUtil; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.*; +import reactor.core.publisher.Flux; + +@Slf4j +@RestController +@RequestMapping("/agent") +@RequiredArgsConstructor +public class AgentController { + + private final AgentService agentService; + + /** + * 流式智能体问答 + * @param agentChatReqDTO agentChatReqDTO + * @return + */ + @PostMapping(value = "/streamChat", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux streamChat(@RequestBody AgentChatReqDTO agentChatReqDTO) { + UserDetail userDetail = UserUtil.currentUser(); + agentChatReqDTO.setUserId(userDetail.getUserId()); + return agentService.streamChat(agentChatReqDTO); + } + + @GetMapping("/pageList") + public R> pageList(@RequestParam (name = "page", required = false,defaultValue = "1") Integer page, + @RequestParam (name = "pageSize", required = false,defaultValue = "10") Integer pageSize) { + + IPage paged = agentService.pageList(page, pageSize); + return R.ok(paged); + } + + +} diff --git a/src/main/java/com/supervision/controller/AuthController.java b/src/main/java/com/supervision/controller/AuthController.java new file mode 100644 index 0000000..f9c0fa1 --- /dev/null +++ b/src/main/java/com/supervision/controller/AuthController.java @@ -0,0 +1,26 @@ +package com.supervision.controller; + +import cn.hutool.core.lang.Assert; +import com.supervision.dto.R; +import com.supervision.service.AuthService; +import com.supervision.vo.LoginReqVo; +import lombok.RequiredArgsConstructor; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@RestController +@RequestMapping("/auth") +@RequiredArgsConstructor +public class AuthController { + + private final AuthService authService; + + @PostMapping("/login") + public R login(@RequestBody LoginReqVo loginReqVo) { + Assert.notEmpty(loginReqVo.getUsername(), "用户名不能为空"); + String token = authService.login(loginReqVo.getUsername(), loginReqVo.getPassword()); + return R.ok(token); + } +} diff --git a/src/main/java/com/supervision/domain/AgentDialogueLog.java b/src/main/java/com/supervision/domain/AgentDialogueLog.java index 0fe9335..d5a0d42 100644 --- a/src/main/java/com/supervision/domain/AgentDialogueLog.java +++ b/src/main/java/com/supervision/domain/AgentDialogueLog.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -59,12 +57,14 @@ public class AgentDialogueLog implements Serializable { /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/AgentInfo.java b/src/main/java/com/supervision/domain/AgentInfo.java index b99d220..3b6c9b5 100644 --- a/src/main/java/com/supervision/domain/AgentInfo.java +++ b/src/main/java/com/supervision/domain/AgentInfo.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -44,12 +42,14 @@ public class AgentInfo implements Serializable { /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/DigitalHuman.java b/src/main/java/com/supervision/domain/DigitalHuman.java index 5c2cc4b..f61c8a1 100644 --- a/src/main/java/com/supervision/domain/DigitalHuman.java +++ b/src/main/java/com/supervision/domain/DigitalHuman.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -60,12 +58,14 @@ public class DigitalHuman implements Serializable { /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/DigitalHumanDialogueLog.java b/src/main/java/com/supervision/domain/DigitalHumanDialogueLog.java index 2dfff20..87d9f16 100644 --- a/src/main/java/com/supervision/domain/DigitalHumanDialogueLog.java +++ b/src/main/java/com/supervision/domain/DigitalHumanDialogueLog.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -54,12 +52,14 @@ public class DigitalHumanDialogueLog implements Serializable { /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/SysByteArray.java b/src/main/java/com/supervision/domain/SysByteArray.java index 7d18666..2e3c204 100644 --- a/src/main/java/com/supervision/domain/SysByteArray.java +++ b/src/main/java/com/supervision/domain/SysByteArray.java @@ -1,11 +1,8 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -24,17 +21,19 @@ public class SysByteArray implements Serializable { /** * 二进制数据 */ - private Object bytes; + private byte[] bytes; /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/SysUser.java b/src/main/java/com/supervision/domain/SysUser.java index 7295f88..b8e5729 100644 --- a/src/main/java/com/supervision/domain/SysUser.java +++ b/src/main/java/com/supervision/domain/SysUser.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -21,11 +19,6 @@ public class SysUser implements Serializable { @TableId private String id; - /** - * 应用id - */ - private String appId; - /** * 用户名 */ @@ -42,24 +35,16 @@ public class SysUser implements Serializable { private String status; /** - * - */ - private String createUserId; - - /** - * + * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; - - /** - * - */ - private String updateUserId; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/domain/UserDetail.java b/src/main/java/com/supervision/domain/UserDetail.java new file mode 100644 index 0000000..b45aa51 --- /dev/null +++ b/src/main/java/com/supervision/domain/UserDetail.java @@ -0,0 +1,32 @@ +package com.supervision.domain; + +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.userdetails.User; + +import java.util.Collection; + +public class UserDetail extends User { + private String userId; + + + public UserDetail(String userId ,String username, String password, Collection authorities) { + super(username, password, authorities); + this.userId = userId; + } + + public UserDetail(String username, String password, Collection authorities) { + super(username, password, authorities); + } + + public UserDetail(String username, String password, boolean enabled, boolean accountNonExpired, boolean credentialsNonExpired, boolean accountNonLocked, Collection authorities) { + super(username, password, enabled, accountNonExpired, credentialsNonExpired, accountNonLocked, authorities); + } + + public void setUserId(String userId) { + this.userId = userId; + } + + public String getUserId() { + return userId; + } +} diff --git a/src/main/java/com/supervision/domain/VoiceInfo.java b/src/main/java/com/supervision/domain/VoiceInfo.java index 62fb546..c677a90 100644 --- a/src/main/java/com/supervision/domain/VoiceInfo.java +++ b/src/main/java/com/supervision/domain/VoiceInfo.java @@ -1,11 +1,9 @@ package com.supervision.domain; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; + import java.io.Serializable; -import java.util.Date; +import java.time.LocalDateTime; import lombok.Data; /** @@ -34,12 +32,14 @@ public class VoiceInfo implements Serializable { /** * 创建时间 */ - private Date createTime; + @TableField(fill = FieldFill.INSERT) + private LocalDateTime createTime; /** * 更新时间 */ - private Date updateTime; + @TableField(fill = FieldFill.INSERT_UPDATE) + private LocalDateTime updateTime; @TableField(exist = false) private static final long serialVersionUID = 1L; diff --git a/src/main/java/com/supervision/dto/AgentChatReqDTO.java b/src/main/java/com/supervision/dto/AgentChatReqDTO.java new file mode 100644 index 0000000..30429bd --- /dev/null +++ b/src/main/java/com/supervision/dto/AgentChatReqDTO.java @@ -0,0 +1,27 @@ +package com.supervision.dto; + +import lombok.Data; + +@Data +public class AgentChatReqDTO { + + /** + * 智能体id + */ + private String agentId; + + /** + * 问题 + */ + private String query; + + /** + * 对话id + */ + private String conversationId; + + /** + * 用户id + */ + private String userId; +} diff --git a/src/main/java/com/supervision/dto/AgentDialogueLogDTO.java b/src/main/java/com/supervision/dto/AgentDialogueLogDTO.java new file mode 100644 index 0000000..4408cdb --- /dev/null +++ b/src/main/java/com/supervision/dto/AgentDialogueLogDTO.java @@ -0,0 +1,87 @@ +package com.supervision.dto; + +import cn.hutool.core.lang.UUID; +import cn.hutool.core.util.StrUtil; +import com.supervision.domain.AgentDialogueLog; +import lombok.Data; +import java.time.Duration; +import java.time.LocalDateTime; + +@Data +public class AgentDialogueLogDTO { + + /** + * 对话日志id + */ + private String id; + + /** + * 用户id + */ + private String userId; + + /** + * 对话id + */ + private String conversationId; + + /** + * 用户输入 + */ + private String userInput; + + /** + * 系统输出 + */ + private String systemOut; + + /** + * 耗时 + */ + private Long timeCost; + + /** + * 回答内容类型 0:正常回答 1:异常回答 + */ + private Integer answerType = 0; + + /** + * 智能体id + */ + private String agentId; + + /** + * 问题创建时间 + */ + private LocalDateTime createTime; + + public AgentDialogueLogDTO() { + } + + public AgentDialogueLogDTO(AgentChatReqDTO agentChatReqDTO) { + this.userId = agentChatReqDTO.getUserId(); + if (StrUtil.isEmpty(agentChatReqDTO.getConversationId())){ + this.conversationId = UUID.fastUUID().toString();// 生成一个新的对话ID + } else { + this.conversationId = agentChatReqDTO.getConversationId(); + } + this.agentId = agentChatReqDTO.getAgentId(); + this.userInput = agentChatReqDTO.getQuery(); + this.createTime = LocalDateTime.now(); + } + + public AgentDialogueLog toDialogueLog(){ + AgentDialogueLog dialogueLog = new AgentDialogueLog(); + dialogueLog.setId(this.id); + dialogueLog.setUserId(this.userId); + dialogueLog.setAgentId(this.agentId); + dialogueLog.setAnswerType(this.answerType); + dialogueLog.setConversationId(this.conversationId); + dialogueLog.setUserInput(this.userInput); + dialogueLog.setSystemOut(this.systemOut); + if (null != this.createTime){ + dialogueLog.setTimeCost(Duration.between(this.createTime, LocalDateTime.now()).toMillis()); + } + return dialogueLog; + } +} diff --git a/src/main/java/com/supervision/dto/AgentInfoDTO.java b/src/main/java/com/supervision/dto/AgentInfoDTO.java new file mode 100644 index 0000000..082b762 --- /dev/null +++ b/src/main/java/com/supervision/dto/AgentInfoDTO.java @@ -0,0 +1,44 @@ +package com.supervision.dto; + +import com.supervision.domain.AgentInfo; +import lombok.Data; + +@Data +public class AgentInfoDTO { + + /** + * 智能体id + */ + private String id; + + /** + * 智能体code + */ + private String agentCode; + + /** + * 智能体名称 + */ + private String agentName; + + /** + * 智能体描述 + */ + private String agentDesc; + + /** + * 对接状态 0:未对接 1:已对接 + */ + private String adapterStatus; + + public AgentInfoDTO() { + } + + public AgentInfoDTO(AgentInfo agentInfo) { + this.id = agentInfo.getId(); + this.agentCode = agentInfo.getAgentCode(); + this.agentName = agentInfo.getAgentName(); + this.agentDesc = agentInfo.getAgentDesc(); + this.adapterStatus = agentInfo.getAdapterStatus(); + } +} diff --git a/src/main/java/com/supervision/dto/ChatResponseDTO.java b/src/main/java/com/supervision/dto/ChatResponseDTO.java new file mode 100644 index 0000000..f9b67de --- /dev/null +++ b/src/main/java/com/supervision/dto/ChatResponseDTO.java @@ -0,0 +1,22 @@ +package com.supervision.dto; + +import lombok.Data; + +@Data +public class ChatResponseDTO { + + /** + * 事件类型 message, message_end + */ + private String event; + + /** + * 消息对话id + */ + private String conversationId; + + /** + * 回复内容 + */ + private String answer; +} diff --git a/src/main/java/com/supervision/exception/UnauthorizedException.java b/src/main/java/com/supervision/exception/UnauthorizedException.java new file mode 100644 index 0000000..5fbd232 --- /dev/null +++ b/src/main/java/com/supervision/exception/UnauthorizedException.java @@ -0,0 +1,15 @@ +package com.supervision.exception; + +public class UnauthorizedException extends RuntimeException { + public UnauthorizedException() { + super("用户未登录"); + } + + public UnauthorizedException(String message) { + super(message); + } + + public UnauthorizedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/supervision/service/AgentDialogueLogService.java b/src/main/java/com/supervision/service/AgentDialogueLogService.java index 788bf28..8e9cdb7 100644 --- a/src/main/java/com/supervision/service/AgentDialogueLogService.java +++ b/src/main/java/com/supervision/service/AgentDialogueLogService.java @@ -1,5 +1,7 @@ package com.supervision.service; +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.supervision.domain.AgentDialogueLog; import com.baomidou.mybatisplus.extension.service.IService; @@ -10,4 +12,11 @@ import com.baomidou.mybatisplus.extension.service.IService; */ public interface AgentDialogueLogService extends IService { + /** + * 分页查询智能体对话日志 + * @param page 分页参数 + * @param conversationId 对话ID + * @return + */ + IPage pageList(Page page, String conversationId); } diff --git a/src/main/java/com/supervision/service/AgentService.java b/src/main/java/com/supervision/service/AgentService.java new file mode 100644 index 0000000..6e174ab --- /dev/null +++ b/src/main/java/com/supervision/service/AgentService.java @@ -0,0 +1,25 @@ +package com.supervision.service; + +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.supervision.dto.AgentChatReqDTO; +import com.supervision.dto.AgentInfoDTO; +import com.supervision.dto.ChatResponseDTO; +import reactor.core.publisher.Flux; + +public interface AgentService { + + /** + * 流式智能体问答 + * @param agentChatReqDTO agentChatReqDTO + * @return + */ + Flux streamChat(AgentChatReqDTO agentChatReqDTO); + + /** + * 分页查询智能体信息 + * @param page + * @param pageSize + * @return + */ + IPage pageList(Integer page, Integer pageSize); +} diff --git a/src/main/java/com/supervision/service/AiCallService.java b/src/main/java/com/supervision/service/AiCallService.java new file mode 100644 index 0000000..1dc46f4 --- /dev/null +++ b/src/main/java/com/supervision/service/AiCallService.java @@ -0,0 +1,23 @@ +package com.supervision.service; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.Embedding; +import reactor.core.publisher.Flux; + +import java.util.List; + +/** + * @description: AI调用服务 + */ +public interface AiCallService { + + + String call(String prompt); + + Flux stream(Prompt prompt); + + Embedding embedding(String text); + + List embedding(List texts); +} diff --git a/src/main/java/com/supervision/service/AuthService.java b/src/main/java/com/supervision/service/AuthService.java new file mode 100644 index 0000000..16ea83f --- /dev/null +++ b/src/main/java/com/supervision/service/AuthService.java @@ -0,0 +1,7 @@ +package com.supervision.service; + +public interface AuthService { + + String login(String username, String password); + +} diff --git a/src/main/java/com/supervision/service/ThinkTagState.java b/src/main/java/com/supervision/service/ThinkTagState.java new file mode 100644 index 0000000..99498c0 --- /dev/null +++ b/src/main/java/com/supervision/service/ThinkTagState.java @@ -0,0 +1,54 @@ +package com.supervision.service; + +import org.springframework.ai.chat.model.ChatResponse; + +public class ThinkTagState { + private boolean inThink = false; + private final StringBuilder buffer = new StringBuilder(); + + private ChatResponse chatResponse; + + public ThinkTagState() { + System.out.println("ThinkTagState initialized"); + } + + public ThinkTagState process(String chunk) { + StringBuilder output = new StringBuilder(); + String remaining = chunk; + + while (!remaining.isEmpty()) { + if (!inThink) { + int startIdx = remaining.indexOf(""); + if (startIdx == -1) { + output.append(remaining); + break; + } + output.append(remaining.substring(0, startIdx)); + inThink = true; + remaining = remaining.substring(startIdx + "".length()); + } else { + int endIdx = remaining.indexOf(""); + if (endIdx == -1) { + break; // 等待后续分块 + } + inThink = false; + remaining = remaining.substring(endIdx + "".length()); + } + } + + buffer.append(output); + return this; + } + + public String getFilteredText() { + return buffer.toString(); + } + + public ChatResponse getChatResponse() { + return chatResponse; + } + + public void setChatResponse(ChatResponse chatResponse) { + this.chatResponse = chatResponse; + } +} diff --git a/src/main/java/com/supervision/service/impl/AgentDialogueLogServiceImpl.java b/src/main/java/com/supervision/service/impl/AgentDialogueLogServiceImpl.java index 50849c2..2836dd9 100644 --- a/src/main/java/com/supervision/service/impl/AgentDialogueLogServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/AgentDialogueLogServiceImpl.java @@ -1,5 +1,7 @@ package com.supervision.service.impl; +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.supervision.domain.AgentDialogueLog; import com.supervision.service.AgentDialogueLogService; @@ -15,6 +17,13 @@ import org.springframework.stereotype.Service; public class AgentDialogueLogServiceImpl extends ServiceImpl implements AgentDialogueLogService{ + @Override + public IPage pageList(Page page, String conversationId) { + return this.lambdaQuery() + .eq(AgentDialogueLog::getConversationId, conversationId) + .orderByDesc(AgentDialogueLog::getCreateTime) + .page(page); + } } diff --git a/src/main/java/com/supervision/service/impl/AgentInfoServiceImpl.java b/src/main/java/com/supervision/service/impl/AgentInfoServiceImpl.java index f3ee58a..994ffc6 100644 --- a/src/main/java/com/supervision/service/impl/AgentInfoServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/AgentInfoServiceImpl.java @@ -2,6 +2,7 @@ package com.supervision.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.supervision.domain.AgentInfo; +import com.supervision.dto.AgentDialogueLogDTO; import com.supervision.service.AgentInfoService; import com.supervision.mapper.AgentInfoMapper; import org.springframework.stereotype.Service; diff --git a/src/main/java/com/supervision/service/impl/AgentServiceImpl.java b/src/main/java/com/supervision/service/impl/AgentServiceImpl.java new file mode 100644 index 0000000..50c23bf --- /dev/null +++ b/src/main/java/com/supervision/service/impl/AgentServiceImpl.java @@ -0,0 +1,118 @@ +package com.supervision.service.impl; + +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; +import cn.hutool.core.lang.UUID; +import cn.hutool.core.util.StrUtil; +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import com.supervision.constant.PromptTemplate; +import com.supervision.domain.AgentDialogueLog; +import com.supervision.domain.AgentInfo; +import com.supervision.dto.AgentChatReqDTO; +import com.supervision.dto.AgentDialogueLogDTO; +import com.supervision.dto.AgentInfoDTO; +import com.supervision.dto.ChatResponseDTO; +import com.supervision.service.*; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.ChatResponseMetadata; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +@Slf4j +@Service +@RequiredArgsConstructor +public class AgentServiceImpl implements AgentService { + + private final AiCallService aiCallService; + + private final AgentInfoService agentInfoService; + + private final AgentDialogueLogService agentDialogueLogService; + @Override + public Flux streamChat(AgentChatReqDTO agentChatReqDTO) { + + // 查询历史对话(记忆功能) + List messages = buildHistoryMessages(agentChatReqDTO.getConversationId(), 10); + // 构建当前对话 + Message message = buildQueryMessage(agentChatReqDTO.getQuery(), agentChatReqDTO.getAgentId()); + Assert.notNull(message, "Agent not found for ID: " + agentChatReqDTO.getAgentId()); + messages.add(message); + AgentDialogueLogDTO dialogueLogDTO = new AgentDialogueLogDTO(agentChatReqDTO); + StringBuilder aiResponseBuilder = new StringBuilder(); + Flux map = aiCallService.stream(new Prompt(messages)) + .scan(new ThinkTagState(), (state, response) -> { + String chunk = response.getResult().getOutput().getText(); + state.setChatResponse(response); + return state.process(chunk); + }) + .filter(state -> StrUtil.isNotEmpty(state.getFilteredText())) + .map(response -> { + ChatResponseDTO responseDTO = new ChatResponseDTO(); + responseDTO.setAnswer(response.getChatResponse().getResult().getOutput().getText()); + responseDTO.setConversationId(StrUtil.isNotEmpty(agentChatReqDTO.getConversationId()) ? agentChatReqDTO.getConversationId() : UUID.fastUUID().toString()); + ChatResponseMetadata metadata = response.getChatResponse().getMetadata(); + Boolean done = metadata.get("done"); + responseDTO.setEvent(done ? "message_end" : "message"); + dialogueLogDTO.setConversationId(responseDTO.getConversationId()); + aiResponseBuilder.append(responseDTO.getAnswer()); + return responseDTO; + }).doOnComplete(() -> { + dialogueLogDTO.setSystemOut(aiResponseBuilder.toString()); + agentDialogueLogService.save(dialogueLogDTO.toDialogueLog()); + }).doOnError(e -> { + log.error("Error during AI chat stream: {}", e.getMessage()); + dialogueLogDTO.setSystemOut(aiResponseBuilder.toString()); + dialogueLogDTO.setAnswerType(1); + agentDialogueLogService.save(dialogueLogDTO.toDialogueLog()); + }); + return map; + } + + @Override + public IPage pageList(Integer page, Integer pageSize) { + IPage paged = agentInfoService.page(Page.of(page, pageSize)); + return paged.convert(AgentInfoDTO::new); + } + + private Message buildQueryMessage(String query,String agentId) { + if (StrUtil.equals(agentId, "1")) { + String template = PromptTemplate.GENERAL_INDUSTRY_TEMPLATE; + return new UserMessage(StrUtil.format(template, Map.of("query",query))); + }else if (StrUtil.equals(agentId, "2")) { + String template = PromptTemplate.HEART_GUIDE_TEMPLATE; + return new UserMessage(StrUtil.format(template, Map.of("query",query))); + } + return null; + } + private List buildHistoryMessages(String conversationId,int historySize) { + List messages = new ArrayList<>(); + if (StrUtil.isEmpty(conversationId)){ + return messages; + } + IPage dialogueLogs = agentDialogueLogService.pageList(Page.of(1, historySize), conversationId); + dialogueLogs.getRecords().sort((o1, o2) -> - o2.getCreateTime().compareTo(o1.getCreateTime())); + if (CollUtil.isNotEmpty(dialogueLogs.getRecords())) { + for (AgentDialogueLog dialogueLog : dialogueLogs.getRecords()) { + if (dialogueLog.getAnswerType() == 1){ + continue; + } + if (StrUtil.isNotEmpty(dialogueLog.getUserInput()) && historySize > 0) { + messages.add(new UserMessage(dialogueLog.getUserInput())); + String result = dialogueLog.getSystemOut().replaceAll("(?is)]*>(.*?)", "").trim(); + messages.add(new AssistantMessage(result)); + historySize--; + } + } + } + return messages; + } +} diff --git a/src/main/java/com/supervision/service/impl/AuthServiceImpl.java b/src/main/java/com/supervision/service/impl/AuthServiceImpl.java new file mode 100644 index 0000000..b78bd51 --- /dev/null +++ b/src/main/java/com/supervision/service/impl/AuthServiceImpl.java @@ -0,0 +1,38 @@ +package com.supervision.service.impl; + +import com.supervision.service.AuthService; +import com.supervision.util.JwtUtils; +import lombok.RequiredArgsConstructor; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.authentication.DisabledException; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.userdetails.UserDetails; +import org.springframework.stereotype.Service; + +@Service +@RequiredArgsConstructor +public class AuthServiceImpl implements AuthService { + + private final AuthenticationManager authenticationManager; + private final JwtUtils jwtUtils; + + /** + * 登录认证,返回JWT Token + */ + @Override + public String login(String username, String password) { + try { + Authentication authentication = authenticationManager.authenticate( + new UsernamePasswordAuthenticationToken(username, password) + ); + UserDetails userDetails = (UserDetails) authentication.getPrincipal(); + return jwtUtils.generateToken(userDetails); + } catch (BadCredentialsException e) { + throw new RuntimeException("用户名或密码错误"); + } catch (DisabledException e) { + throw new RuntimeException("用户已被禁用"); + } + } +} diff --git a/src/main/java/com/supervision/service/impl/OllamaCallServiceImpl.java b/src/main/java/com/supervision/service/impl/OllamaCallServiceImpl.java new file mode 100644 index 0000000..d6a5658 --- /dev/null +++ b/src/main/java/com/supervision/service/impl/OllamaCallServiceImpl.java @@ -0,0 +1,48 @@ +package com.supervision.service.impl; + +import com.supervision.service.AiCallService; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.ollama.OllamaChatModel; +import org.springframework.ai.ollama.OllamaEmbeddingModel; +import org.springframework.stereotype.Service; +import reactor.core.publisher.Flux; + +import java.util.List; + +@Slf4j +@Service +@RequiredArgsConstructor +public class OllamaCallServiceImpl implements AiCallService { + + private final OllamaChatModel ollamaChatModel; + + private final OllamaEmbeddingModel embeddingModel; + @Override + public String call(String prompt) { + + return ollamaChatModel.call(prompt); + } + + @Override + public Flux stream(Prompt prompt) { + return ollamaChatModel.stream(prompt); + } + + public Embedding embedding(String text) { + + EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(List.of(text),null)); + return embeddingResponse.getResult(); + } + + @Override + public List embedding(List texts) { + EmbeddingResponse embeddingResponse = embeddingModel.call(new EmbeddingRequest(texts,null)); + return embeddingResponse.getResults(); + } +} diff --git a/src/main/java/com/supervision/service/impl/UserDetailsServiceImpl.java b/src/main/java/com/supervision/service/impl/UserDetailsServiceImpl.java index 55436fb..5d3b71c 100644 --- a/src/main/java/com/supervision/service/impl/UserDetailsServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/UserDetailsServiceImpl.java @@ -1,11 +1,11 @@ package com.supervision.service.impl; import com.supervision.domain.SysUser; +import com.supervision.domain.UserDetail; import com.supervision.service.SysUserService; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.core.userdetails.UserDetailsService; import org.springframework.security.core.userdetails.UsernameNotFoundException; @@ -33,6 +33,6 @@ public class UserDetailsServiceImpl implements UserDetailsService { if (sysUser == null) { throw new UsernameNotFoundException("用户不存在: " + username); } - return new User(sysUser.getUserName(), sysUser.getPassword(), authorities); + return new UserDetail(sysUser.getId(),sysUser.getUserName(), sysUser.getPassword(), authorities); } } diff --git a/src/main/java/com/supervision/util/UserUtil.java b/src/main/java/com/supervision/util/UserUtil.java new file mode 100644 index 0000000..6fb0591 --- /dev/null +++ b/src/main/java/com/supervision/util/UserUtil.java @@ -0,0 +1,23 @@ +package com.supervision.util; + +import com.supervision.domain.UserDetail; +import com.supervision.exception.UnauthorizedException; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; + +public class UserUtil { + + public static UserDetail currentUser() { + SecurityContext context = SecurityContextHolder.getContext(); + if (null == context) { + throw new UnauthorizedException("未登录或登录已过期,请重新登录"); + } + Authentication authentication = context.getAuthentication(); + Object principal = authentication.getPrincipal(); + if ("anonymousUser".equals(principal)) { + throw new UnauthorizedException("未登录或登录已过期,请重新登录"); + } + return (UserDetail) authentication.getPrincipal(); + } +} diff --git a/src/main/java/com/supervision/vo/LoginReqVo.java b/src/main/java/com/supervision/vo/LoginReqVo.java new file mode 100644 index 0000000..0f3f2be --- /dev/null +++ b/src/main/java/com/supervision/vo/LoginReqVo.java @@ -0,0 +1,11 @@ +package com.supervision.vo; + +import lombok.Data; + +@Data +public class LoginReqVo { + + private String username; + + private String password; +} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index babaa9e..7ed20fd 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -15,10 +15,30 @@ spring: multipart: max-file-size: 10MB max-request-size: 100MB + ai: + openai: + baseUrl: https://api.deepseek.com + apiKey: sk-0b2c506c47e74594b5361c0f6844fd25 + chat: + options: + model: deepseek-chat + ollama: + baseUrl: http://192.168.10.70:11434 + chat: + model: qwen3:30b-a3b + #model: qwen3:32b + options: + max_tokens: 51200 + top_p: 0.9 + top_k: 40 + temperature: 0.7 + timeout: 180000 + embedding: + model: dengcao/Qwen3-Embedding-0.6B:F16 mybatis-plus: mapper-locations: classpath*:mapper/*.xml configuration: log-impl: org.apache.ibatis.logging.stdout.StdOutImpl jwt: secret: "DlHaPUePiN6MyvpMpsMq/t6swzMHqtrRFd2YnofKz4k=" # JWT密钥 使用官方推荐方式生成 Base64.getEncoder().encodeToString(Keys.secretKeyFor(SignatureAlgorithm.HS256).getEncoded()); - expiration: 3600000 # 1小时:3600000 1天:86400000 \ No newline at end of file + expiration: 86400000 # 1小时:3600000 1天:86400000 \ No newline at end of file diff --git a/src/main/resources/mapper/SysByteArrayMapper.xml b/src/main/resources/mapper/SysByteArrayMapper.xml index 3fbc650..05c6e60 100644 --- a/src/main/resources/mapper/SysByteArrayMapper.xml +++ b/src/main/resources/mapper/SysByteArrayMapper.xml @@ -6,7 +6,7 @@ - + diff --git a/src/main/resources/mapper/SysUserMapper.xml b/src/main/resources/mapper/SysUserMapper.xml index f843e1a..f69fe43 100644 --- a/src/main/resources/mapper/SysUserMapper.xml +++ b/src/main/resources/mapper/SysUserMapper.xml @@ -6,19 +6,16 @@ - - - - id,app_id,user_name, - password,status,create_user_id, - create_time,update_time,update_user_id + id,user_name, + password,status, + create_time,update_time diff --git a/src/test/java/com/supervision/PlatformApplicationTest.java b/src/test/java/com/supervision/PlatformApplicationTest.java new file mode 100644 index 0000000..410b146 --- /dev/null +++ b/src/test/java/com/supervision/PlatformApplicationTest.java @@ -0,0 +1,29 @@ +package com.supervision; + +import com.supervision.domain.SysByteArray; +import com.supervision.service.SysByteArrayService; +import lombok.extern.slf4j.Slf4j; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +@Slf4j +@SpringBootTest +public class PlatformApplicationTest { + + @Autowired + private SysByteArrayService sysByteArrayService; + @Test + void saveByteArray() { + SysByteArray sysByteArray = new SysByteArray(); + sysByteArray.setBytes(new byte[]{1,2,3}); + boolean save = sysByteArrayService.save(sysByteArray); + System.out.println("保存结果: " + save); + } + + @Test + void queryByteArray() { + SysByteArray sysByteArray = sysByteArrayService.getById("1945676008645058562"); + log.info("查询结果: {}", sysByteArray); + } +}