1. 添加对话接口

2. 修复bug
topo_dev
xueqingkun 10 months ago
parent 0e8acc7aa8
commit 0a00a61fcb

@ -1,21 +1,14 @@
package com.supervision.chat.client; package com.supervision.chat.client;
import cn.hutool.core.util.ArrayUtil; import com.supervision.chat.client.dto.ChatResConverter;
import org.checkerframework.checker.units.qual.C;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.support.RestClientAdapter;
import org.springframework.web.client.support.RestTemplateAdapter; import org.springframework.web.client.support.RestTemplateAdapter;
import org.springframework.web.service.invoker.HttpServiceProxyFactory; import org.springframework.web.service.invoker.HttpServiceProxyFactory;
import org.springframework.web.util.DefaultUriBuilderFactory; import org.springframework.web.util.DefaultUriBuilderFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.CompletableFuture;
@Configuration @Configuration
public class LangChainChatConfig { public class LangChainChatConfig {
@ -26,18 +19,18 @@ public class LangChainChatConfig {
@Bean @Bean
public LangChainChatService langChainChatClient() { public LangChainChatService langChainChatClient() {
// RestClientAdapter adapter = RestClientAdapter.create(restClient);
// HttpServiceProxyFactory factory = HttpServiceProxyFactory.builderFor(adapter).build();
//
// return factory.createClient(LangChainChatService.class);
RestTemplate restTemplate = new RestTemplate(); RestTemplate restTemplate = new RestTemplate();
restTemplate.setUriTemplateHandler(new DefaultUriBuilderFactory(LangChainChatClientUrl)); restTemplate.setUriTemplateHandler(new DefaultUriBuilderFactory(LangChainChatClientUrl));
restTemplate.getMessageConverters().add(new ChatResConverter());
RestTemplateAdapter adapter = RestTemplateAdapter.create(restTemplate); RestTemplateAdapter adapter = RestTemplateAdapter.create(restTemplate);
HttpServiceProxyFactory factory = HttpServiceProxyFactory.builderFor(adapter).build(); HttpServiceProxyFactory factory = HttpServiceProxyFactory.builderFor(adapter).build();
return factory.createClient(LangChainChatService.class); return factory.createClient(LangChainChatService.class);
} }
} }

@ -3,9 +3,13 @@ package com.supervision.chat.client;
import com.supervision.chat.client.dto.CreateBaseDTO; import com.supervision.chat.client.dto.CreateBaseDTO;
import com.supervision.chat.client.dto.DeleteFileDTO; import com.supervision.chat.client.dto.DeleteFileDTO;
import com.supervision.chat.client.dto.LangChainChatRes; import com.supervision.chat.client.dto.LangChainChatRes;
import org.springframework.core.io.Resource; import com.supervision.chat.client.dto.chat.ChatReqDTO;
import com.supervision.chat.client.dto.chat.ChatResDTO;
import com.supervision.common.domain.R;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RequestPart;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.service.annotation.GetExchange; import org.springframework.web.service.annotation.GetExchange;
import org.springframework.web.service.annotation.HttpExchange; import org.springframework.web.service.annotation.HttpExchange;
@ -19,7 +23,7 @@ public interface LangChainChatService {
* @param createBaseDTO * @param createBaseDTO
* @return * @return
*/ */
@PostExchange(url = "create_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) @PostExchange(url = "/knowledge_base/create_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE)
LangChainChatRes createBase(@RequestBody CreateBaseDTO createBaseDTO); LangChainChatRes createBase(@RequestBody CreateBaseDTO createBaseDTO);
/** /**
@ -36,7 +40,7 @@ public interface LangChainChatService {
* @param docs {"test.txt":[{"page_content":"custom doc","metadata":{},"type":"Document"}]} * @param docs {"test.txt":[{"page_content":"custom doc","metadata":{},"type":"Document"}]}
* @return * @return
*/ */
@PostExchange(url = "upload_docs", contentType = MediaType.MULTIPART_FORM_DATA_VALUE) @PostExchange(url = "/knowledge_base/upload_docs", contentType = MediaType.MULTIPART_FORM_DATA_VALUE)
LangChainChatRes uploadFile(@RequestPart String knowledge_base_name, LangChainChatRes uploadFile(@RequestPart String knowledge_base_name,
@RequestPart MultipartFile files, @RequestPart MultipartFile files,
@RequestPart String text_splitter_type, @RequestPart String text_splitter_type,
@ -53,14 +57,21 @@ public interface LangChainChatService {
* @param deleteFileDTO * @param deleteFileDTO
* @return * @return
*/ */
@PostExchange(url = "delete_docs", contentType = MediaType.APPLICATION_JSON_VALUE) @PostExchange(url = "/knowledge_base/delete_docs", contentType = MediaType.APPLICATION_JSON_VALUE)
LangChainChatRes deleteFile(@RequestBody DeleteFileDTO deleteFileDTO); LangChainChatRes deleteFile(@RequestBody DeleteFileDTO deleteFileDTO);
@GetExchange(url = "list_files") @GetExchange(url = "/knowledge_base/list_files")
LangChainChatRes queryFileList(@RequestParam String knowledge_base_name); LangChainChatRes queryFileList(@RequestParam String knowledge_base_name);
@PostExchange(url = "/delete_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) @PostExchange(url = "/knowledge_base/delete_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE)
LangChainChatRes deleteBase(@RequestBody String knowledge_base_name); LangChainChatRes deleteBase(@RequestBody String knowledge_base_name);
@PostExchange(url = "/comDictionary/queryByType", contentType = MediaType.APPLICATION_JSON_VALUE)
R<?> findDictionaryListByType();
@PostExchange(url = "/chat/knowledge_base_chat", contentType = MediaType.APPLICATION_JSON_VALUE)
ChatResDTO chat(@RequestBody ChatReqDTO chatReq);
} }

@ -0,0 +1,57 @@
package com.supervision.chat.client.dto;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.supervision.chat.client.dto.chat.ChatResDTO;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.HttpMessageNotReadableException;
import org.springframework.http.converter.HttpMessageNotWritableException;
import org.springframework.util.StreamUtils;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
/**
* @description
* @param
*/
public class ChatResConverter implements HttpMessageConverter {
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
public boolean canRead(Class clazz, MediaType mediaType) {
return clazz.equals(ChatResDTO.class);
}
@Override
public boolean canWrite(Class clazz, MediaType mediaType) {
return false;
}
@Override
public List<MediaType> getSupportedMediaTypes() {
// text/event-stream
return List.of(new MediaType("text", "event-stream"));
}
@Override
public Object read(Class clazz, HttpInputMessage inputMessage) throws IOException, HttpMessageNotReadableException {
String body = StreamUtils.copyToString(inputMessage.getBody(), StandardCharsets.UTF_8);
// 去除 "data:" 前缀
if (body.startsWith("data:")) {
body = body.substring(5).trim();
}
// 将去除前缀后的内容解析为 ResponseData 对象
return objectMapper.readValue(body, ChatResDTO.class);
}
@Override
public void write(Object o, MediaType contentType, HttpOutputMessage outputMessage) throws IOException, HttpMessageNotWritableException {
}
}

@ -0,0 +1,63 @@
package com.supervision.chat.client.dto.chat;
import com.supervision.police.vo.ChatReqVO;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
@Data
public class ChatReqDTO {
@Schema(description = "用户输入的文本")
private String query;
@Schema(description = "知识库名称")
private String knowledge_base_name;
@Schema(description = "匹配向量数")
private int top_k;
@Schema(description = "知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右")
private int score_threshold;
@Schema(description = "历史会话")
private List<History> history;
@Schema(description = "流式输出")
private boolean stream;
@Schema(description = "LLM 模型名称")
private String model_name;
@Schema(description = "LLM 模型温度")
private double temperature;
@Schema(description = "限制LLM生成Token数量默认None代表模型最大值")
private int max_tokens;
@Schema(description = "使用的prompt模板名称")
private String prompt_name;
private static ChatReqDTO defaultChatReq(){
ChatReqDTO chatReq = new ChatReqDTO();
chatReq.setQuery("裴金禄出生日期");
chatReq.setKnowledge_base_name("裴金禄笔录-薛庆坤测试");
chatReq.setTop_k(3);
chatReq.setScore_threshold(1);
chatReq.setStream(false);
chatReq.setModel_name("openai-api");
chatReq.setTemperature(0.7);
chatReq.setMax_tokens(0);
chatReq.setPrompt_name("default");
return chatReq;
}
public static ChatReqDTO create(String query, String knowledge_base_name,List<History> history){
ChatReqDTO chatReq = defaultChatReq();
chatReq.setQuery(query);
chatReq.setKnowledge_base_name(knowledge_base_name);
return chatReq;
}
}

@ -0,0 +1,13 @@
package com.supervision.chat.client.dto.chat;
import lombok.Data;
import java.util.List;
@Data
public class ChatResDTO {
private String answer;
private List<String> docs;
}

@ -0,0 +1,12 @@
package com.supervision.chat.client.dto.chat;
import lombok.Data;
@Data
public class History {
private String role;
private String content;
}

@ -0,0 +1,31 @@
package com.supervision.police.controller;
import com.supervision.common.domain.R;
import com.supervision.police.service.ChatService;
import com.supervision.police.vo.ChatReqVO;
import com.supervision.police.vo.ChatResVO;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@Slf4j
@RestController
@Tag(name = "对话接口")
@RequiredArgsConstructor
@RequestMapping("/robot")
public class ChatController {
private final ChatService chatService;
@PostMapping("/chat")
public R<ChatResVO> chat(@RequestBody ChatReqVO chatReqVO) {
ChatResVO chatResVO = chatService.chat(chatReqVO);
return R.ok(chatResVO);
}
}

@ -7,6 +7,7 @@ import com.supervision.police.domain.ModelCase;
import com.supervision.police.dto.ModelCaseBase; import com.supervision.police.dto.ModelCaseBase;
import com.supervision.police.dto.ModelCaseDTO; import com.supervision.police.dto.ModelCaseDTO;
import com.supervision.police.service.ModelCaseService; import com.supervision.police.service.ModelCaseService;
import com.supervision.police.vo.ModelCaseVO;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
@ -33,7 +34,7 @@ public class ModelCaseController {
*/ */
@Operation(summary = "查询案件列表") @Operation(summary = "查询案件列表")
@PostMapping("/queryList") @PostMapping("/queryList")
public R<IPage<ModelCaseDTO>> queryList(@RequestBody ModelCase modelCase, public R<IPage<ModelCaseDTO>> queryList(@RequestBody ModelCaseVO modelCase,
@RequestParam(required = false, defaultValue = "1") Integer page, @RequestParam(required = false, defaultValue = "1") Integer page,
@RequestParam(required = false, defaultValue = "20") Integer size) { @RequestParam(required = false, defaultValue = "20") Integer size) {
IPage<ModelCaseDTO> modelCaseDTOIPage = modelCaseService.queryList(modelCase, page, size); IPage<ModelCaseDTO> modelCaseDTOIPage = modelCaseService.queryList(modelCase, page, size);

@ -5,6 +5,7 @@ import com.baomidou.mybatisplus.core.metadata.IPage;
import com.supervision.police.domain.ModelCase; import com.supervision.police.domain.ModelCase;
import com.supervision.police.dto.AtomicIndexDTO; import com.supervision.police.dto.AtomicIndexDTO;
import com.supervision.police.dto.IndexDetail; import com.supervision.police.dto.IndexDetail;
import com.supervision.police.vo.ModelCaseVO;
import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Param;
import java.util.List; import java.util.List;
@ -17,7 +18,7 @@ import java.util.List;
*/ */
public interface ModelCaseMapper extends BaseMapper<ModelCase> { public interface ModelCaseMapper extends BaseMapper<ModelCase> {
IPage<ModelCase> selectAll(IPage<ModelCase> iPage, ModelCase modelCase); IPage<ModelCase> selectAll(IPage<ModelCaseVO> iPage, ModelCaseVO modelCase);
int selectMaxIndex(); int selectMaxIndex();

@ -0,0 +1,8 @@
package com.supervision.police.service;
import com.supervision.police.vo.ChatReqVO;
import com.supervision.police.vo.ChatResVO;
public interface ChatService {
ChatResVO chat(ChatReqVO chatReqVO);
}

@ -7,6 +7,7 @@ import com.supervision.police.domain.CasePerson;
import com.supervision.police.domain.ModelCase; import com.supervision.police.domain.ModelCase;
import com.supervision.police.dto.ModelCaseBase; import com.supervision.police.dto.ModelCaseBase;
import com.supervision.police.dto.ModelCaseDTO; import com.supervision.police.dto.ModelCaseDTO;
import com.supervision.police.vo.ModelCaseVO;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import java.util.List; import java.util.List;
@ -19,7 +20,7 @@ import java.util.List;
*/ */
public interface ModelCaseService extends IService<ModelCase> { public interface ModelCaseService extends IService<ModelCase> {
IPage<ModelCaseDTO> queryList(ModelCase modelCase, Integer page, Integer size); IPage<ModelCaseDTO> queryList(ModelCaseVO modelCase, Integer page, Integer size);
R<?> checkCaseNo(String caseNo,String caseId); R<?> checkCaseNo(String caseNo,String caseId);

@ -0,0 +1,46 @@
package com.supervision.police.service.impl;
import cn.hutool.core.lang.Assert;
import com.supervision.chat.client.LangChainChatService;
import com.supervision.chat.client.dto.chat.ChatReqDTO;
import com.supervision.chat.client.dto.chat.ChatResDTO;
import com.supervision.police.domain.ModelCase;
import com.supervision.police.service.ChatService;
import com.supervision.police.service.ModelCaseService;
import com.supervision.police.vo.ChatReqVO;
import com.supervision.police.vo.ChatResVO;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements ChatService {
private final LangChainChatService langChainChatService;
private final ModelCaseService modelCaseService;
@Override
public ChatResVO chat(ChatReqVO chatReqVO) {
Assert.notEmpty(chatReqVO.getQuery(), "query 不能为空");
Assert.notEmpty(chatReqVO.getCaseId(), "caseId 不能为空");
ModelCase modelCase = modelCaseService.getById(chatReqVO.getCaseId());
Assert.notNull(modelCase, "案件不存在");
log.info("chat: caseNo:{},query{}", modelCase.getCaseNo(), chatReqVO.getQuery());
ChatResDTO chat = langChainChatService.chat(
ChatReqDTO.create(chatReqVO.getQuery(), modelCase.getCaseNo(),chatReqVO.getHistory()));
log.info("chat: caseNo:{},query{},answer:{}", modelCase.getCaseNo(), chatReqVO.getQuery(),chat.getAnswer());
return new ChatResVO(chat);
}
}

@ -32,6 +32,7 @@ import com.supervision.police.service.CasePersonService;
import com.supervision.police.service.CaseStatusManageService; import com.supervision.police.service.CaseStatusManageService;
import com.supervision.police.service.ComDictionaryService; import com.supervision.police.service.ComDictionaryService;
import com.supervision.police.service.ModelCaseService; import com.supervision.police.service.ModelCaseService;
import com.supervision.police.vo.ModelCaseVO;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -72,7 +73,7 @@ public class ModelCaseServiceImpl extends ServiceImpl<ModelCaseMapper, ModelCase
* @return * @return
*/ */
@Override @Override
public IPage<ModelCaseDTO> queryList(ModelCase modelCase, Integer page, Integer size) { public IPage<ModelCaseDTO> queryList(ModelCaseVO modelCase, Integer page, Integer size) {
IPage<ModelCase> modelCaseIPage = modelCaseMapper.selectAll(Page.of(page, size), modelCase); IPage<ModelCase> modelCaseIPage = modelCaseMapper.selectAll(Page.of(page, size), modelCase);
if (CollUtil.isEmpty(modelCaseIPage.getRecords())) { if (CollUtil.isEmpty(modelCaseIPage.getRecords())) {

@ -0,0 +1,20 @@
package com.supervision.police.vo;
import com.supervision.chat.client.dto.chat.History;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
@Data
public class ChatReqVO {
@Schema(description = "案件id")
private String caseId;
@Schema(description = "用户输入的文本")
private String query;
@Schema(description = "历史会话")
private List<History> history;
}

@ -0,0 +1,27 @@
package com.supervision.police.vo;
import com.supervision.chat.client.dto.chat.ChatResDTO;
import lombok.Data;
import java.util.List;
import java.util.Objects;
@Data
public class ChatResVO {
private String answer;
private List<String> docs;
public ChatResVO() {
}
public ChatResVO(ChatResDTO chatResDTO) {
if (Objects.isNull(chatResDTO)){
return;
}
this.answer = chatResDTO.getAnswer();
this.docs = chatResDTO.getDocs();
}
}

@ -0,0 +1,45 @@
package com.supervision.police.vo;
import com.fasterxml.jackson.annotation.JsonFormat;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
@Data
public class ModelCaseVO {
@Schema(description = "主键")
private String id;
@Schema(description = "案件编号")
private String caseNo;
@Schema(description = "案件名称")
private String caseName;
@Schema(description = "案件类型")
private String caseType;
@Schema(description = "认定结果")
private List<String> identifyResult;
@Schema(description = "行为人")
private String lawActor;
@Schema(description = "当事人")
private String lawParty;
@JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8")
private LocalDateTime updateStartTime;
@JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8")
private LocalDateTime updateEndTime;
}

@ -77,5 +77,4 @@ logging:
org.springframework.ai: TRACE org.springframework.ai: TRACE
langChain-chat: langChain-chat:
url: http://113.128.242.110:7861/knowledge_base/ url: http://113.128.242.110:7861
# url: http://192.168.10.27:8097/fu-hsi-server/

@ -66,4 +66,4 @@ logging:
org.springframework.ai: TRACE org.springframework.ai: TRACE
langChain-chat: langChain-chat:
url: http://113.128.242.110:7861/knowledge_base/ url: http://113.128.242.110:7861

@ -17,7 +17,10 @@
and FIND_IN_SET(#{modelCase.caseType}, case_type) > 0 and FIND_IN_SET(#{modelCase.caseType}, case_type) > 0
</if> </if>
<if test="modelCase.identifyResult != null and modelCase.identifyResult != ''"> <if test="modelCase.identifyResult != null and modelCase.identifyResult != ''">
and identify_result = #{modelCase.identifyResult} and identify_result IN
<foreach item="item" collection="modelCase.identifyResult" open="(" separator="," close=")">
#{item}
</foreach>
</if> </if>
<if test="modelCase.lawActor != null and modelCase.lawActor != ''"> <if test="modelCase.lawActor != null and modelCase.lawActor != ''">
and law_actor like concat('%', #{modelCase.lawActor}, '%') and law_actor like concat('%', #{modelCase.lawActor}, '%')

@ -7,6 +7,9 @@ import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelReader; import cn.hutool.poi.excel.ExcelReader;
import cn.hutool.poi.excel.ExcelUtil; import cn.hutool.poi.excel.ExcelUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.supervision.chat.client.LangChainChatService;
import com.supervision.chat.client.dto.chat.ChatReqDTO;
import com.supervision.chat.client.dto.chat.ChatResDTO;
import com.supervision.police.domain.CasePerson; import com.supervision.police.domain.CasePerson;
import com.supervision.police.domain.ModelAtomicIndex; import com.supervision.police.domain.ModelAtomicIndex;
import com.supervision.police.domain.ModelIndex; import com.supervision.police.domain.ModelIndex;
@ -75,6 +78,44 @@ public class ModelIndexTest {
} }
} }
@Autowired
private LangChainChatService langChainChatService;
@Test
public void test1(){
/**
* {
* "query": "裴金禄出生日期",
* "knowledge_base_name": "裴金禄笔录-薛庆坤测试",
* "top_k": 3,
* "score_threshold": 1,
* "history": [
*
* ],
* "stream": false,
* "model_name": "openai-api",
* "temperature": 0.7,
* "max_tokens": 0,
* "prompt_name": "default"
* }
*/
ChatReqDTO chatReq = new ChatReqDTO();
chatReq.setQuery("裴金禄出生日期");
chatReq.setKnowledge_base_name("裴金禄笔录-薛庆坤测试");
chatReq.setTop_k(3);
chatReq.setScore_threshold(1);
chatReq.setStream(false);
chatReq.setModel_name("openai-api");
chatReq.setTemperature(0.7);
chatReq.setMax_tokens(0);
chatReq.setPrompt_name("default");
ChatResDTO chat = langChainChatService.chat(chatReq);
System.out.println(chat);
}
/** /**
* *
*/ */

Loading…
Cancel
Save