diff --git a/src/main/java/com/supervision/chat/client/CustomMultipartFile.java b/src/main/java/com/supervision/chat/client/CustomMultipartFile.java new file mode 100644 index 0000000..f7d2838 --- /dev/null +++ b/src/main/java/com/supervision/chat/client/CustomMultipartFile.java @@ -0,0 +1,75 @@ +package com.supervision.chat.client; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.FileCopyUtils; +import org.springframework.web.multipart.MultipartFile; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; + +public class CustomMultipartFile implements MultipartFile { + + private final String name; + private final String originalFilename; + @Nullable + private final String contentType; + private final byte[] content; + + public CustomMultipartFile(String name, @Nullable byte[] content) { + this(name, name, (String)null, (byte[])content); + } + + public CustomMultipartFile(String name, InputStream contentStream) throws IOException { + this(name, name, (String)null, (byte[]) FileCopyUtils.copyToByteArray(contentStream)); + } + + public CustomMultipartFile(String name, @Nullable String originalFilename, @Nullable String contentType, @Nullable byte[] content) { + Assert.hasLength(name, "Name must not be empty"); + this.name = name; + this.originalFilename = originalFilename != null ? originalFilename : ""; + this.contentType = contentType; + this.content = content != null ? content : new byte[0]; + } + + public CustomMultipartFile(String name, @Nullable String originalFilename, @Nullable String contentType, InputStream contentStream) throws IOException { + this(name, originalFilename, contentType, FileCopyUtils.copyToByteArray(contentStream)); + } + + public String getName() { + return this.name; + } + + @NonNull + public String getOriginalFilename() { + return this.originalFilename; + } + + @Nullable + public String getContentType() { + return this.contentType; + } + + public boolean isEmpty() { + return this.content.length == 0; + } + + public long getSize() { + return (long)this.content.length; + } + + public byte[] getBytes() throws IOException { + return this.content; + } + + public InputStream getInputStream() throws IOException { + return new ByteArrayInputStream(this.content); + } + + public void transferTo(File dest) throws IOException, IllegalStateException { + FileCopyUtils.copy(this.content, dest); + } +} diff --git a/src/main/java/com/supervision/chat/client/LangChainChatService.java b/src/main/java/com/supervision/chat/client/LangChainChatService.java index d7e56bf..0f0f8fc 100644 --- a/src/main/java/com/supervision/chat/client/LangChainChatService.java +++ b/src/main/java/com/supervision/chat/client/LangChainChatService.java @@ -1,6 +1,7 @@ package com.supervision.chat.client; import com.supervision.chat.client.dto.CreateBaseDTO; +import com.supervision.chat.client.dto.DeleteFileDTO; import com.supervision.chat.client.dto.LangChainChatRes; import org.springframework.core.io.Resource; import org.springframework.http.MediaType; @@ -14,20 +15,47 @@ import org.springframework.web.service.annotation.PostExchange; @HttpExchange public interface LangChainChatService { + /** + * 创建知识库 + * @param createBaseDTO 知识库对象 + * @return 结果 + */ @PostExchange(url = "create_knowledge_base", contentType = MediaType.APPLICATION_JSON_VALUE) - LangChainChatRes chat(@RequestBody CreateBaseDTO createBaseDTO); + LangChainChatRes createBase(@RequestBody CreateBaseDTO createBaseDTO); + /** + * 上传文档接口 + * @param knowledge_base_name 需要上传的知识库库名 + * @param files 文件,multipartFile + * @param text_splitter_type 问讯笔录 + * @param to_vector_store true 固定值 + * @param override false 固定值 + * @param not_refresh_vs_cache false 固定值 + * @param chunk_size 250 固定值 + * @param chunk_overlap 50 固定值 + * @param zh_title_enhance false 固定值 + * @param docs {"test.txt":[{"page_content":"custom doc","metadata":{},"type":"Document"}]} 固定值 + * @return 调用的结果 + */ @PostExchange(url = "upload_docs", contentType = MediaType.MULTIPART_FORM_DATA_VALUE) - void uploadFile(@RequestPart String knowledge_base_name, - @RequestPart Resource files, - @RequestPart String to_vector_store, - @RequestPart String override, - @RequestPart String not_refresh_vs_cache, - @RequestPart Integer chunk_size, - @RequestPart Integer chunk_overlap, - @RequestPart String zh_title_enhance, - @RequestPart String text_splitter_type, - @RequestPart String docs); + LangChainChatRes uploadFile(@RequestPart String knowledge_base_name, + @RequestPart MultipartFile files, + @RequestPart String text_splitter_type, + @RequestPart String to_vector_store, + @RequestPart String override, + @RequestPart String not_refresh_vs_cache, + @RequestPart Integer chunk_size, + @RequestPart Integer chunk_overlap, + @RequestPart String zh_title_enhance, + @RequestPart String docs); + + /** + * 删除文件 + * @param deleteFileDTO 删除的对象 + * @return 返回结果 + */ + @PostExchange(url = "delete_docs", contentType = MediaType.APPLICATION_JSON_VALUE) + LangChainChatRes deleteFile(@RequestBody DeleteFileDTO deleteFileDTO); } diff --git a/src/main/java/com/supervision/chat/client/dto/DeleteFileDTO.java b/src/main/java/com/supervision/chat/client/dto/DeleteFileDTO.java new file mode 100644 index 0000000..0b9fb7a --- /dev/null +++ b/src/main/java/com/supervision/chat/client/dto/DeleteFileDTO.java @@ -0,0 +1,29 @@ +package com.supervision.chat.client.dto; + +import lombok.Data; + +import java.util.ArrayList; +import java.util.List; + +@Data +public class DeleteFileDTO { + + private String knowledge_base_name; + + private List file_names; + + private Boolean delete_content = false; + + private Boolean not_refresh_vs_cache = false; + + public static DeleteFileDTO create(String knowledge_base_name, String file_name) { + DeleteFileDTO deleteFileDTO = new DeleteFileDTO(); + deleteFileDTO.setKnowledge_base_name(knowledge_base_name); + List file_names = new ArrayList<>(); + file_names.add(file_name); + deleteFileDTO.setFile_names(file_names); + return deleteFileDTO; + } + + +} diff --git a/src/main/java/com/supervision/chat/client/dto/LangChainChatRes.java b/src/main/java/com/supervision/chat/client/dto/LangChainChatRes.java index 130513f..7e9f614 100644 --- a/src/main/java/com/supervision/chat/client/dto/LangChainChatRes.java +++ b/src/main/java/com/supervision/chat/client/dto/LangChainChatRes.java @@ -9,5 +9,5 @@ public class LangChainChatRes { private String msg; - private String data; + private Object data; } diff --git a/src/main/java/com/supervision/chat/controller/TestController.java b/src/main/java/com/supervision/chat/controller/TestController.java index 43a5616..efc6442 100644 --- a/src/main/java/com/supervision/chat/controller/TestController.java +++ b/src/main/java/com/supervision/chat/controller/TestController.java @@ -1,14 +1,22 @@ package com.supervision.chat.controller; +import cn.hutool.core.io.FileUtil; import cn.hutool.json.JSONUtil; +import com.supervision.chat.client.CustomMultipartFile; import com.supervision.chat.client.LangChainChatService; import com.supervision.chat.client.dto.CreateBaseDTO; +import com.supervision.chat.client.dto.DeleteFileDTO; import com.supervision.chat.client.dto.LangChainChatRes; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.core.io.InputStreamResource; +import org.springframework.core.io.Resource; +import org.springframework.mock.web.MockMultipartFile; import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; import org.springframework.web.service.annotation.PostExchange; +import java.io.IOException; import java.util.concurrent.Executors; @Slf4j @@ -20,11 +28,33 @@ public class TestController { private final LangChainChatService langChainChatClient; @GetMapping("test") - public void test(){ + public void test() { CreateBaseDTO createBaseDTO = new CreateBaseDTO(); createBaseDTO.setKnowledge_base_name("11111111"); - LangChainChatRes chat = langChainChatClient.chat(createBaseDTO); + LangChainChatRes chat = langChainChatClient.createBase(createBaseDTO); log.info(JSONUtil.toJsonStr(chat)); } + @PostMapping("uploadFile") + public void testUploadFile(@RequestPart("file") MultipartFile file) throws IOException { + CustomMultipartFile mockMultipartFile = new CustomMultipartFile(file.getOriginalFilename(), file.getInputStream()); + LangChainChatRes langChainChatRes = langChainChatClient.uploadFile("11111111", + mockMultipartFile, + "问讯笔录", + "true", + "false", + "false", + 250, + 50, + "false", + "{\"test.txt\":[{\"page_content\":\"custom doc\",\"metadata\":{},\"type\":\"Document\"}]}"); + log.info(JSONUtil.toJsonStr(langChainChatRes)); + } + + @GetMapping("deleteFile") + public void testDeleteFile(String knowledgeBaseName, String fileName) { + LangChainChatRes langChainChatRes = langChainChatClient.deleteFile(DeleteFileDTO.create(knowledgeBaseName, fileName)); + log.info(JSONUtil.toJsonStr(langChainChatRes)); + } + } diff --git a/src/main/java/com/supervision/demo/controller/ExampleChatController.java b/src/main/java/com/supervision/demo/controller/ExampleChatController.java index 4ebdf12..c22f1a7 100644 --- a/src/main/java/com/supervision/demo/controller/ExampleChatController.java +++ b/src/main/java/com/supervision/demo/controller/ExampleChatController.java @@ -14,6 +14,7 @@ import com.supervision.utils.RecordRegexUtil; import com.supervision.utils.WordReadUtil; import lombok.extern.slf4j.Slf4j; import org.apache.ibatis.annotations.Param; +import org.json.JSONException; import org.json.JSONObject; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; @@ -92,7 +93,7 @@ public class ExampleChatController { """; @GetMapping("exampleChat") - public void exampleChat() { + public void exampleChat() throws JSONException { File file = FileUtil.file("E:\\jc\\宁夏\\Fw_裴金禄\\裴金禄第一次.docx"); String context = WordReadUtil.readWord(file.getPath()); List qaList = RecordRegexUtil.recordRegex(context, "裴金禄"); diff --git a/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java index d8854ee..e21ac48 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelCaseServiceImpl.java @@ -11,6 +11,9 @@ import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.chat.client.LangChainChatService; +import com.supervision.chat.client.dto.CreateBaseDTO; +import com.supervision.chat.client.dto.LangChainChatRes; import com.supervision.common.domain.R; import com.supervision.common.enums.ResultStatusEnum; import com.supervision.common.exception.CustomException; @@ -57,6 +60,8 @@ public class ModelCaseServiceImpl extends ServiceImpl 0) { - // TODO 这里需要调用知识库的接口,去保存知识库 return R.okMsg("保存成功"); } else { return R.fail("保存失败"); @@ -161,7 +173,6 @@ public class ModelCaseServiceImpl extends ServiceImpl del(String id) { ModelCase modelCase = modelCaseMapper.selectById(id); diff --git a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java index 6e34dfe..786d1fa 100644 --- a/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java @@ -6,9 +6,14 @@ import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.toolkit.Wrappers; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.chat.client.CustomMultipartFile; +import com.supervision.chat.client.LangChainChatService; +import com.supervision.chat.client.dto.DeleteFileDTO; +import com.supervision.chat.client.dto.LangChainChatRes; import com.supervision.common.utils.IPages; import com.supervision.common.utils.ListUtils; import com.supervision.common.utils.StringUtils; +import com.supervision.config.BusinessException; import com.supervision.minio.domain.MinioFile; import com.supervision.minio.mapper.MinioFileMapper; import com.supervision.minio.service.MinioService; @@ -25,6 +30,7 @@ import com.supervision.utils.WordReadUtil; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.mock.web.MockMultipartFile; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -64,6 +70,8 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl qaList = RecordRegexUtil.recordRegex(context, record.getName()); List splitList = new ArrayList<>(); @@ -135,6 +159,8 @@ public class NoteRecordSplitServiceImpl extends ServiceImpl