diff --git a/src/main/java/com/supervision/chat/client/LangChainChatConfig.java b/src/main/java/com/supervision/chat/client/LangChainChatConfig.java index c2e7dab..60e3456 100644 --- a/src/main/java/com/supervision/chat/client/LangChainChatConfig.java +++ b/src/main/java/com/supervision/chat/client/LangChainChatConfig.java @@ -1,21 +1,14 @@ package com.supervision.chat.client; -import cn.hutool.core.util.ArrayUtil; -import org.checkerframework.checker.units.qual.C; +import com.supervision.chat.client.dto.ChatResConverter; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; import org.springframework.web.client.RestTemplate; -import org.springframework.web.client.support.RestClientAdapter; import org.springframework.web.client.support.RestTemplateAdapter; import org.springframework.web.service.invoker.HttpServiceProxyFactory; import org.springframework.web.util.DefaultUriBuilderFactory; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.concurrent.CompletableFuture; - @Configuration public class LangChainChatConfig { @@ -26,18 +19,18 @@ public class LangChainChatConfig { @Bean public LangChainChatService langChainChatClient() { -// RestClientAdapter adapter = RestClientAdapter.create(restClient); -// HttpServiceProxyFactory factory = HttpServiceProxyFactory.builderFor(adapter).build(); -// -// return factory.createClient(LangChainChatService.class); - RestTemplate restTemplate = new RestTemplate(); restTemplate.setUriTemplateHandler(new DefaultUriBuilderFactory(LangChainChatClientUrl)); + + restTemplate.getMessageConverters().add(new ChatResConverter()); + RestTemplateAdapter adapter = RestTemplateAdapter.create(restTemplate); HttpServiceProxyFactory factory = HttpServiceProxyFactory.builderFor(adapter).build(); - return factory.createClient(LangChainChatService.class); } + + + } diff --git a/src/main/java/com/supervision/chat/client/LangChainChatService.java b/src/main/java/com/supervision/chat/client/LangChainChatService.java index da5b252..aa7ea8e 100644 --- a/src/main/java/com/supervision/chat/client/LangChainChatService.java +++ b/src/main/java/com/supervision/chat/client/LangChainChatService.java @@ -3,9 +3,13 @@ package com.supervision.chat.client; import com.supervision.chat.client.dto.CreateBaseDTO; import com.supervision.chat.client.dto.DeleteFileDTO; import com.supervision.chat.client.dto.LangChainChatRes; -import org.springframework.core.io.Resource; +import com.supervision.chat.client.dto.chat.ChatReqDTO; +import com.supervision.chat.client.dto.chat.ChatResDTO; +import com.supervision.common.domain.R; import org.springframework.http.MediaType; -import org.springframework.web.bind.annotation.*; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RequestPart; import org.springframework.web.multipart.MultipartFile; import org.springframework.web.service.annotation.GetExchange; import org.springframework.web.service.annotation.HttpExchange; @@ -19,7 +23,7 @@ public interface LangChainChatService { * @param createBaseDTO 知识库对象 * @return 结果 */ - @PostExchange(url = "create_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) + @PostExchange(url = "/knowledge_base/create_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) LangChainChatRes createBase(@RequestBody CreateBaseDTO createBaseDTO); /** @@ -36,7 +40,7 @@ public interface LangChainChatService { * @param docs {"test.txt":[{"page_content":"custom doc","metadata":{},"type":"Document"}]} 固定值 * @return 调用的结果 */ - @PostExchange(url = "upload_docs", contentType = MediaType.MULTIPART_FORM_DATA_VALUE) + @PostExchange(url = "/knowledge_base/upload_docs", contentType = MediaType.MULTIPART_FORM_DATA_VALUE) LangChainChatRes uploadFile(@RequestPart String knowledge_base_name, @RequestPart MultipartFile files, @RequestPart String text_splitter_type, @@ -53,14 +57,21 @@ public interface LangChainChatService { * @param deleteFileDTO 删除的对象 * @return 返回结果 */ - @PostExchange(url = "delete_docs", contentType = MediaType.APPLICATION_JSON_VALUE) + @PostExchange(url = "/knowledge_base/delete_docs", contentType = MediaType.APPLICATION_JSON_VALUE) LangChainChatRes deleteFile(@RequestBody DeleteFileDTO deleteFileDTO); - @GetExchange(url = "list_files") + @GetExchange(url = "/knowledge_base/list_files") LangChainChatRes queryFileList(@RequestParam String knowledge_base_name); - @PostExchange(url = "/delete_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) + @PostExchange(url = "/knowledge_base/delete_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) LangChainChatRes deleteBase(@RequestBody String knowledge_base_name); + @PostExchange(url = "/comDictionary/queryByType", contentType = MediaType.APPLICATION_JSON_VALUE) + R<?> findDictionaryListByType(); + + + @PostExchange(url = "/chat/knowledge_base_chat", contentType = MediaType.APPLICATION_JSON_VALUE) + ChatResDTO chat(@RequestBody ChatReqDTO chatReq); + } diff --git a/src/main/java/com/supervision/chat/client/dto/ChatResConverter.java b/src/main/java/com/supervision/chat/client/dto/ChatResConverter.java new file mode 100644 index 0000000..649d81e --- /dev/null +++ b/src/main/java/com/supervision/chat/client/dto/ChatResConverter.java @@ -0,0 +1,57 @@ +package com.supervision.chat.client.dto; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.supervision.chat.client.dto.chat.ChatResDTO; +import org.springframework.http.HttpInputMessage; +import org.springframework.http.HttpOutputMessage; +import org.springframework.http.MediaType; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.util.StreamUtils; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +/** + * @description 这个转换类写的很粗糙,只进行了解析处理,后续有需要再继续处理 + * @param + */ +public class ChatResConverter implements HttpMessageConverter { + + private final ObjectMapper objectMapper = new ObjectMapper(); + @Override + public boolean canRead(Class clazz, MediaType mediaType) { + + return clazz.equals(ChatResDTO.class); + } + + @Override + public boolean canWrite(Class clazz, MediaType mediaType) { + return false; + } + + @Override + public List<MediaType> getSupportedMediaTypes() { + // text/event-stream + return List.of(new MediaType("text", "event-stream")); + } + + @Override + public Object read(Class clazz, HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException { + String body = StreamUtils.copyToString(inputMessage.getBody(), StandardCharsets.UTF_8); + + // 去除 "data:" 前缀 + if (body.startsWith("data:")) { + body = body.substring(5).trim(); + } + // 将去除前缀后的内容解析为 ResponseData 对象 + return objectMapper.readValue(body, ChatResDTO.class); + } + + @Override + public void write(Object o, MediaType contentType, HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException { + + } +} diff --git a/src/main/java/com/supervision/chat/client/dto/chat/ChatReqDTO.java b/src/main/java/com/supervision/chat/client/dto/chat/ChatReqDTO.java new file mode 100644 index 0000000..b8f7e9b --- /dev/null +++ b/src/main/java/com/supervision/chat/client/dto/chat/ChatReqDTO.java @@ -0,0 +1,63 @@ +package com.supervision.chat.client.dto.chat; + +import com.supervision.police.vo.ChatReqVO; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import java.util.List; + +@Data +public class ChatReqDTO { + + @Schema(description = "用户输入的文本") + private String query; + + @Schema(description = "知识库名称") + private String knowledge_base_name; + + @Schema(description = "匹配向量数") + private int top_k; + + @Schema(description = "知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右") + private int score_threshold; + + @Schema(description = "历史会话") + private List<History> history; + + @Schema(description = "流式输出") + private boolean stream; + + @Schema(description = "LLM 模型名称") + private String model_name; + + @Schema(description = "LLM 模型温度") + private double temperature; + + @Schema(description = "限制LLM生成Token数量,默认None代表模型最大值") + private int max_tokens; + + @Schema(description = "使用的prompt模板名称") + private String prompt_name; + + + private static ChatReqDTO defaultChatReq(){ + ChatReqDTO chatReq = new ChatReqDTO(); + chatReq.setQuery("裴金禄出生日期"); + chatReq.setKnowledge_base_name("裴金禄笔录-薛庆坤测试"); + chatReq.setTop_k(3); + chatReq.setScore_threshold(1); + chatReq.setStream(false); + chatReq.setModel_name("openai-api"); + chatReq.setTemperature(0.7); + chatReq.setMax_tokens(0); + chatReq.setPrompt_name("default"); + return chatReq; + } + + public static ChatReqDTO create(String query, String knowledge_base_name,List<History> history){ + ChatReqDTO chatReq = defaultChatReq(); + chatReq.setQuery(query); + chatReq.setKnowledge_base_name(knowledge_base_name); + return chatReq; + } + +} diff --git a/src/main/java/com/supervision/chat/client/dto/chat/ChatResDTO.java b/src/main/java/com/supervision/chat/client/dto/chat/ChatResDTO.java new file mode 100644 index 0000000..8927c5b --- /dev/null +++ b/src/main/java/com/supervision/chat/client/dto/chat/ChatResDTO.java @@ -0,0 +1,13 @@ +package com.supervision.chat.client.dto.chat; + +import lombok.Data; + +import java.util.List; + +@Data +public class ChatResDTO { + + private String answer; + + private List<String> docs; +} diff --git a/src/main/java/com/supervision/chat/client/dto/chat/History.java b/src/main/java/com/supervision/chat/client/dto/chat/History.java new file mode 100644 index 0000000..a3eb3b3 --- /dev/null +++ b/src/main/java/com/supervision/chat/client/dto/chat/History.java @@ -0,0 +1,12 @@ +package com.supervision.chat.client.dto.chat; + +import lombok.Data; + +@Data +public class History { + + private String role; + + private String content; + +} \ No newline at end of file diff --git a/src/main/java/com/supervision/police/controller/ChatController.java b/src/main/java/com/supervision/police/controller/ChatController.java new file mode 100644 index 0000000..06fbc9f --- /dev/null +++ b/src/main/java/com/supervision/police/controller/ChatController.java @@ -0,0 +1,31 @@ +package com.supervision.police.controller; + +import com.supervision.common.domain.R; +import com.supervision.police.service.ChatService; +import com.supervision.police.vo.ChatReqVO; +import com.supervision.police.vo.ChatResVO; +import io.swagger.v3.oas.annotations.tags.Tag; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +@Slf4j +@RestController +@Tag(name = "对话接口") +@RequiredArgsConstructor +@RequestMapping("/robot") +public class ChatController { + + private final ChatService chatService; + + @PostMapping("/chat") + public R<ChatResVO> chat(@RequestBody ChatReqVO chatReqVO) { + + ChatResVO chatResVO = chatService.chat(chatReqVO); + return R.ok(chatResVO); + } + +} diff --git a/src/main/java/com/supervision/police/controller/ModelCaseController.java b/src/main/java/com/supervision/police/controller/ModelCaseController.java index 6ff08b9..a71c2e6 100644 --- a/src/main/java/com/supervision/police/controller/ModelCaseController.java +++ b/src/main/java/com/supervision/police/controller/ModelCaseController.java @@ -7,6 +7,7 @@ import com.supervision.police.domain.ModelCase; import com.supervision.police.dto.ModelCaseBase; import com.supervision.police.dto.ModelCaseDTO; import com.supervision.police.service.ModelCaseService; +import com.supervision.police.vo.ModelCaseVO; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import lombok.RequiredArgsConstructor; @@ -33,7 +34,7 @@ public class ModelCaseController { */ @Operation(summary = "查询案件列表") @PostMapping("/queryList") - public R<IPage<ModelCaseDTO>> queryList(@RequestBody ModelCase modelCase, + public R<IPage<ModelCaseDTO>> queryList(@RequestBody ModelCaseVO modelCase, @RequestParam(required = false, defaultValue = "1") Integer page, @RequestParam(required = false, defaultValue = "20") Integer size) { IPage<ModelCaseDTO> modelCaseDTOIPage = modelCaseService.queryList(modelCase, page, size); diff --git a/src/main/java/com/supervision/police/mapper/ModelCaseMapper.java b/src/main/java/com/supervision/police/mapper/ModelCaseMapper.java index 9ec79b4..ea6df73 100644 --- a/src/main/java/com/supervision/police/mapper/ModelCaseMapper.java +++ b/src/main/java/com/supervision/police/mapper/ModelCaseMapper.java @@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.core.metadata.IPage; import com.supervision.police.domain.ModelCase; import com.supervision.police.dto.AtomicIndexDTO; import com.supervision.police.dto.IndexDetail; +import com.supervision.police.vo.ModelCaseVO; import org.apache.ibatis.annotations.Param; import java.util.List; @@ -17,7 +18,7 @@ import java.util.List; */ public interface ModelCaseMapper extends BaseMapper<ModelCase> { - IPage<ModelCase> selectAll(IPage<ModelCase> iPage, ModelCase modelCase); + IPage<ModelCase> selectAll(IPage<ModelCaseVO> iPage, ModelCaseVO modelCase); int selectMaxIndex(); diff --git a/src/main/java/com/supervision/police/service/ChatService.java b/src/main/java/com/supervision/police/service/ChatService.java new file mode 100644 index 0000000..3b0fc2a --- /dev/null +++ b/src/main/java/com/supervision/police/service/ChatService.java @@ -0,0 +1,8 @@ +package com.supervision.police.service; + +import com.supervision.police.vo.ChatReqVO; +import com.supervision.police.vo.ChatResVO; + +public interface ChatService { + ChatResVO chat(ChatReqVO chatReqVO); +} diff --git a/src/main/java/com/supervision/police/service/ModelCaseService.java b/src/main/java/com/supervision/police/service/ModelCaseService.java index bba2184..b012697 100644 --- a/src/main/java/com/supervision/police/service/ModelCaseService.java +++ b/src/main/java/com/supervision/police/service/ModelCaseService.java @@ -7,6 +7,7 @@ import com.supervision.police.domain.CasePerson; import com.supervision.police.domain.ModelCase; import com.supervision.police.dto.ModelCaseBase; import com.supervision.police.dto.ModelCaseDTO; +import com.supervision.police.vo.ModelCaseVO; import org.springframework.web.multipart.MultipartFile; import java.util.List; @@ -19,7 +20,7 @@ import java.util.List; */ public interface ModelCaseService extends IService<ModelCase> { - IPage<ModelCaseDTO> queryList(ModelCase modelCase, Integer page, Integer size); + IPage<ModelCaseDTO> queryList(ModelCaseVO modelCase, Integer page, Integer size); R<?> checkCaseNo(String caseNo,String caseId); diff --git a/src/main/java/com/supervision/police/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ChatServiceImpl.java new file mode 100644 index 0000000..3ebf301 --- /dev/null +++ b/src/main/java/com/supervision/police/service/impl/ChatServiceImpl.java @@ -0,0 +1,46 @@ +package com.supervision.police.service.impl; + +import cn.hutool.core.lang.Assert; +import com.supervision.chat.client.LangChainChatService; +import com.supervision.chat.client.dto.chat.ChatReqDTO; +import com.supervision.chat.client.dto.chat.ChatResDTO; +import com.supervision.police.domain.ModelCase; +import com.supervision.police.service.ChatService; +import com.supervision.police.service.ModelCaseService; +import com.supervision.police.vo.ChatReqVO; +import com.supervision.police.vo.ChatResVO; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +@RequiredArgsConstructor +public class ChatServiceImpl implements ChatService { + + private final LangChainChatService langChainChatService; + + + private final ModelCaseService modelCaseService; + + + @Override + public ChatResVO chat(ChatReqVO chatReqVO) { + + + Assert.notEmpty(chatReqVO.getQuery(), "query 不能为空"); + Assert.notEmpty(chatReqVO.getCaseId(), "caseId 不能为空"); + + ModelCase modelCase = modelCaseService.getById(chatReqVO.getCaseId()); + Assert.notNull(modelCase, "案件不存在"); + + log.info("chat: caseNo:{},query{}", modelCase.getCaseNo(), chatReqVO.getQuery()); + + ChatResDTO chat = langChainChatService.chat( + ChatReqDTO.create(chatReqVO.getQuery(), modelCase.getCaseNo(),chatReqVO.getHistory())); + + log.info("chat: caseNo:{},query{},answer:{}", modelCase.getCaseNo(), chatReqVO.getQuery(),chat.getAnswer()); + + return new ChatResVO(chat); + } +} diff --git a/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java index 75526fd..19ad6ac 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java @@ -32,6 +32,7 @@ import com.supervision.police.service.CasePersonService; import com.supervision.police.service.CaseStatusManageService; import com.supervision.police.service.ComDictionaryService; import com.supervision.police.service.ModelCaseService; +import com.supervision.police.vo.ModelCaseVO; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -72,7 +73,7 @@ public class ModelCaseServiceImpl extends ServiceImpl<ModelCaseMapper, ModelCase * @return */ @Override - public IPage<ModelCaseDTO> queryList(ModelCase modelCase, Integer page, Integer size) { + public IPage<ModelCaseDTO> queryList(ModelCaseVO modelCase, Integer page, Integer size) { IPage<ModelCase> modelCaseIPage = modelCaseMapper.selectAll(Page.of(page, size), modelCase); if (CollUtil.isEmpty(modelCaseIPage.getRecords())) { diff --git a/src/main/java/com/supervision/police/vo/ChatReqVO.java b/src/main/java/com/supervision/police/vo/ChatReqVO.java new file mode 100644 index 0000000..a67ed1f --- /dev/null +++ b/src/main/java/com/supervision/police/vo/ChatReqVO.java @@ -0,0 +1,20 @@ +package com.supervision.police.vo; + +import com.supervision.chat.client.dto.chat.History; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.util.List; + +@Data +public class ChatReqVO { + + @Schema(description = "案件id") + private String caseId; + + @Schema(description = "用户输入的文本") + private String query; + + @Schema(description = "历史会话") + private List<History> history; +} diff --git a/src/main/java/com/supervision/police/vo/ChatResVO.java b/src/main/java/com/supervision/police/vo/ChatResVO.java new file mode 100644 index 0000000..5e60f0b --- /dev/null +++ b/src/main/java/com/supervision/police/vo/ChatResVO.java @@ -0,0 +1,27 @@ +package com.supervision.police.vo; + +import com.supervision.chat.client.dto.chat.ChatResDTO; +import lombok.Data; + +import java.util.List; +import java.util.Objects; + +@Data +public class ChatResVO { + + private String answer; + + private List<String> docs; + + + public ChatResVO() { + } + + public ChatResVO(ChatResDTO chatResDTO) { + if (Objects.isNull(chatResDTO)){ + return; + } + this.answer = chatResDTO.getAnswer(); + this.docs = chatResDTO.getDocs(); + } +} diff --git a/src/main/java/com/supervision/police/vo/ModelCaseVO.java b/src/main/java/com/supervision/police/vo/ModelCaseVO.java new file mode 100644 index 0000000..501cdfb --- /dev/null +++ b/src/main/java/com/supervision/police/vo/ModelCaseVO.java @@ -0,0 +1,45 @@ +package com.supervision.police.vo; + +import com.fasterxml.jackson.annotation.JsonFormat; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.time.LocalDateTime; +import java.util.List; + + +@Data +public class ModelCaseVO { + + @Schema(description = "主键") + private String id; + + + @Schema(description = "案件编号") + private String caseNo; + + + @Schema(description = "案件名称") + private String caseName; + + + @Schema(description = "案件类型") + private String caseType; + + @Schema(description = "认定结果") + private List<String> identifyResult; + + @Schema(description = "行为人") + private String lawActor; + + + @Schema(description = "当事人") + private String lawParty; + + @JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") + private LocalDateTime updateStartTime; + + @JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") + private LocalDateTime updateEndTime; + +} diff --git a/src/main/resources/application-dev.yml b/src/main/resources/application-dev.yml index ff24234..5b6f462 100644 --- a/src/main/resources/application-dev.yml +++ b/src/main/resources/application-dev.yml @@ -77,5 +77,4 @@ logging: org.springframework.ai: TRACE langChain-chat: - url: http://113.128.242.110:7861/knowledge_base/ -# url: http://192.168.10.27:8097/fu-hsi-server/ \ No newline at end of file + url: http://113.128.242.110:7861 diff --git a/src/main/resources/application-test.yml b/src/main/resources/application-test.yml index 1357c2b..43f18c6 100644 --- a/src/main/resources/application-test.yml +++ b/src/main/resources/application-test.yml @@ -66,4 +66,4 @@ logging: org.springframework.ai: TRACE langChain-chat: - url: http://113.128.242.110:7861/knowledge_base/ \ No newline at end of file + url: http://113.128.242.110:7861 \ No newline at end of file diff --git a/src/main/resources/mapper/ModelCaseMapper.xml b/src/main/resources/mapper/ModelCaseMapper.xml index 56d08c4..717187b 100644 --- a/src/main/resources/mapper/ModelCaseMapper.xml +++ b/src/main/resources/mapper/ModelCaseMapper.xml @@ -17,7 +17,10 @@ and FIND_IN_SET(#{modelCase.caseType}, case_type) > 0 </if> <if test="modelCase.identifyResult != null and modelCase.identifyResult != ''"> - and identify_result = #{modelCase.identifyResult} + and identify_result IN + <foreach item="item" collection="modelCase.identifyResult" open="(" separator="," close=")"> + #{item} + </foreach> </if> <if test="modelCase.lawActor != null and modelCase.lawActor != ''"> and law_actor like concat('%', #{modelCase.lawActor}, '%') diff --git a/src/test/java/com/supervision/demo/ModelIndexTest.java b/src/test/java/com/supervision/demo/ModelIndexTest.java index 1f1163f..e600b40 100644 --- a/src/test/java/com/supervision/demo/ModelIndexTest.java +++ b/src/test/java/com/supervision/demo/ModelIndexTest.java @@ -7,6 +7,9 @@ import cn.hutool.json.JSONUtil; import cn.hutool.poi.excel.ExcelReader; import cn.hutool.poi.excel.ExcelUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; +import com.supervision.chat.client.LangChainChatService; +import com.supervision.chat.client.dto.chat.ChatReqDTO; +import com.supervision.chat.client.dto.chat.ChatResDTO; import com.supervision.police.domain.CasePerson; import com.supervision.police.domain.ModelAtomicIndex; import com.supervision.police.domain.ModelIndex; @@ -75,6 +78,44 @@ public class ModelIndexTest { } } + @Autowired + private LangChainChatService langChainChatService; + + @Test + public void test1(){ + + /** + * { + * "query": "裴金禄出生日期", + * "knowledge_base_name": "裴金禄笔录-薛庆坤测试", + * "top_k": 3, + * "score_threshold": 1, + * "history": [ + * + * ], + * "stream": false, + * "model_name": "openai-api", + * "temperature": 0.7, + * "max_tokens": 0, + * "prompt_name": "default" + * } + */ + ChatReqDTO chatReq = new ChatReqDTO(); + chatReq.setQuery("裴金禄出生日期"); + chatReq.setKnowledge_base_name("裴金禄笔录-薛庆坤测试"); + chatReq.setTop_k(3); + chatReq.setScore_threshold(1); + chatReq.setStream(false); + chatReq.setModel_name("openai-api"); + chatReq.setTemperature(0.7); + chatReq.setMax_tokens(0); + chatReq.setPrompt_name("default"); + ChatResDTO chat = langChainChatService.chat(chatReq); + + + System.out.println(chat); + } + /** * 生成案件人员脚本 */