From c00a2c2da0d1dc3874aa0a4121c4aed067d6648d Mon Sep 17 00:00:00 2001 From: liu Date: Tue, 16 Jul 2024 16:05:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../police/controller/RecordController.java | 58 ++++++++++++------- .../service/ModelRecordTypeService.java | 7 ++- .../police/service/RecordService.java | 7 ++- .../impl/ModelRecordTypeServiceImpl.java | 55 +++++++++--------- .../service/impl/RecordServiceImpl.java | 37 +++++------- 5 files changed, 83 insertions(+), 81 deletions(-) diff --git a/src/main/java/com/supervision/police/controller/RecordController.java b/src/main/java/com/supervision/police/controller/RecordController.java index 53f531f..55f2ae9 100644 --- a/src/main/java/com/supervision/police/controller/RecordController.java +++ b/src/main/java/com/supervision/police/controller/RecordController.java @@ -1,12 +1,16 @@ package com.supervision.police.controller; import com.supervision.common.domain.R; +import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NoteRecords; +import com.supervision.police.domain.TripleInfo; import com.supervision.police.dto.ListDTO; import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.RecordService; +import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -14,31 +18,32 @@ import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.List; +import java.util.Map; @RestController @Slf4j @RequestMapping("record") @ApiOperation(value = "笔录接口") +@RequiredArgsConstructor public class RecordController { - @Autowired - public ModelRecordTypeService modelRecordTypeService; + public final ModelRecordTypeService modelRecordTypeService; - @Autowired - public RecordService recordService; + public final RecordService recordService; /** * 查询笔录类型 + * * @param name * @param page * @param size * @return */ @GetMapping("queryType") - public R queryType(@RequestParam(required = false) String name, - @RequestParam(required = false, defaultValue = "1") Integer page, - @RequestParam(required = false, defaultValue = "20") Integer size) { - return modelRecordTypeService.queryType(name, page, size); + public R> queryType(@RequestParam(required = false) String name, + @RequestParam(required = false, defaultValue = "1") Integer page, + @RequestParam(required = false, defaultValue = "20") Integer size) { + return R.ok(modelRecordTypeService.queryType(name, page, size)); } // @PostMapping("saveType") @@ -48,6 +53,7 @@ public class RecordController { /** * 保存提示词 + * * @param prompt * @return */ @@ -58,6 +64,7 @@ public class RecordController { /** * 删除提示词 + * * @param prompt * @return */ @@ -68,56 +75,63 @@ public class RecordController { /** * 获取案件三元组信息 + * * @param caseId * @param name * @param recordId * @return */ @GetMapping("/getThreeInfo") - public R getThreeInfo(@RequestParam String caseId, - @RequestParam String name, - @RequestParam(required = false) String recordId) { - return modelRecordTypeService.getThreeInfo(caseId, name, recordId); + @ApiOperation("获取笔录的三元组信息") + public R> getThreeInfo(@RequestParam String caseId, + @RequestParam String name, + @RequestParam(required = false) String recordId) { + return R.ok(modelRecordTypeService.getThreeInfo(caseId, name, recordId)); } + @ApiOperation("将三元组信息保存到知识图谱") @PostMapping("/addNeo4j") - public R addNeo4j(@RequestBody ListDTO list) { - return modelRecordTypeService.addNeo4j(list.getIds()); + public R addNeo4j(@RequestBody ListDTO list) { + return R.ok(modelRecordTypeService.addNeo4j(list.getIds())); } /** * 上传笔录, 修改 + * * @param records * @return */ @PostMapping("/addOrUpdRecords") - public R uploadRecords(NoteRecords records, - @RequestPart("file") List fileList) throws IOException { - return recordService.uploadRecords(records, fileList); + public R uploadRecords(NoteRecords records, + @RequestPart("file") List fileList) throws IOException { + return R.ok(recordService.uploadRecords(records, fileList)); } /** * 查询笔录,按姓名为父目录 + * * @param noteRecords * @param page * @param size * @return */ @PostMapping("/queryRecords") - public R queryRecords(@RequestBody NoteRecords noteRecords, - @RequestParam(required = false, defaultValue = "1") Integer page, - @RequestParam(required = false, defaultValue = "20") Integer size) { - return recordService.queryRecords(noteRecords, page, size); + public R> queryRecords(@RequestBody NoteRecords noteRecords, + @RequestParam(required = false, defaultValue = "1") Integer page, + @RequestParam(required = false, defaultValue = "20") Integer size) { + return R.ok(recordService.queryRecords(noteRecords, page, size)); } /** * 删除 + * * @param id * @return */ @PostMapping("/delRecords") public R delRecords(@RequestParam String id) { - return recordService.delRecords(id); + recordService.delRecords(id); + return R.ok(); } diff --git a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java index ce7777d..51a799d 100644 --- a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java +++ b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java @@ -4,12 +4,13 @@ import com.baomidou.mybatisplus.extension.service.IService; import com.supervision.common.domain.R; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NotePrompt; +import com.supervision.police.domain.TripleInfo; import java.util.List; public interface ModelRecordTypeService extends IService { - R queryType(String name, Integer page, Integer size); + List queryType(String name, Integer page, Integer size); ModelRecordType queryByName(String content); @@ -19,7 +20,7 @@ public interface ModelRecordTypeService extends IService { R delPrompt(NotePrompt prompt); - R getThreeInfo(String caseId, String name, String recordId); + List getThreeInfo(String caseId, String name, String recordId); - R addNeo4j(List ids); + String addNeo4j(List ids); } diff --git a/src/main/java/com/supervision/police/service/RecordService.java b/src/main/java/com/supervision/police/service/RecordService.java index a0fe5a9..d7e5afd 100644 --- a/src/main/java/com/supervision/police/service/RecordService.java +++ b/src/main/java/com/supervision/police/service/RecordService.java @@ -8,13 +8,14 @@ import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.List; +import java.util.Map; public interface RecordService extends IService { - R uploadRecords(NoteRecords records, List fileList) throws IOException; + String uploadRecords(NoteRecords records, List fileList) throws IOException; - R queryRecords(NoteRecords noteRecords, Integer page, Integer size); + Map queryRecords(NoteRecords noteRecords, Integer page, Integer size); - R delRecords(String id); + void delRecords(String id); } diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index 1b1a7f8..7572dc7 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -17,6 +17,7 @@ import com.supervision.police.mapper.NoteRecordMapper; import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.mapper.TripleInfoMapper; import com.supervision.police.service.ModelRecordTypeService; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONArray; import org.json.JSONObject; @@ -24,6 +25,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; @@ -36,31 +38,23 @@ import java.util.List; @Slf4j @Service +@RequiredArgsConstructor public class ModelRecordTypeServiceImpl extends ServiceImpl implements ModelRecordTypeService { - @Autowired - private ModelRecordTypeMapper modelRecordTypeMapper; + private final ModelRecordTypeMapper modelRecordTypeMapper; - @Autowired - private NoteRecordMapper noteRecordMapper; + private final NoteRecordMapper noteRecordMapper; - @Autowired - private NotePromptMapper notePromptMapper; + private final NotePromptMapper notePromptMapper; - private final OllamaChatClient chatClient; - @Autowired - public ModelRecordTypeServiceImpl(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } + private final TripleInfoMapper tripleInfoMapper; - @Autowired - private TripleInfoMapper tripleInfoMapper; + private final Neo4jService neo4jService; - @Autowired - private Neo4jService neo4jService; + private final OllamaChatClient chatClient; @Override - public R queryType(String name, Integer page, Integer size) { + public List queryType(String name, Integer page, Integer size) { // IPage iPage = new Page<>(page, size); // iPage = modelRecordTypeMapper.selectByName(iPage, name); // return R.ok(IPages.buildDataMap(iPage)); @@ -74,7 +68,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl prompts = notePromptMapper.queryPrompt(modelRecordType.getId()); modelRecordType.setPrompts(prompts); } - return R.ok(list); + return list; // return R.ok(IPages.buildDataMap(iPage)); } @@ -127,8 +121,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { + private List extractThreeInfo(String caseId, String name, String recordId) { List records = noteRecordMapper.selectRecord(caseId, name, recordId); List tripleInfos = new ArrayList<>(); for (NoteRecord record : records) { @@ -138,17 +131,15 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl messages = new ArrayList<>(List.of(new SystemMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer()))); - Prompt ask = new Prompt(messages); StopWatch stopWatch = new StopWatch(); + // 分析三元组 + Prompt ask = new Prompt(new UserMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer())); stopWatch.start(); log.info("开始分析:"); ChatResponse call = chatClient.call(ask); stopWatch.stop(); log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - Generation result = call.getResult(); - String content = result.getOutput().getContent(); + String content = call.getResult().getOutput().getContent(); log.info("分析的结果是:{}", content); JSONObject jsonObject = new JSONObject(content); JSONArray threeInfo = jsonObject.getJSONArray("result"); @@ -169,15 +160,21 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { + // TODO 这里应该改成异步的形式,通过异步的形式来进行提取三元组信息,不能每次点击就跑一遍 + return extractThreeInfo(caseId, name, recordId); } @Override - public R addNeo4j(List ids) { + public String addNeo4j(List ids) { List tripleInfos = tripleInfoMapper.selectByIds(ids); int i = 0; for (TripleInfo tripleInfo : tripleInfos) { @@ -210,9 +207,9 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl implements RecordService { - @Autowired - private NoteRecordMapper noteRecordMapper; + private final NoteRecordMapper noteRecordMapper; - @Autowired - private NoteRecordsMapper noteRecordsMapper; + private final NoteRecordsMapper noteRecordsMapper; - @Autowired - private MinioService minioService; + private final MinioService minioService; - @Autowired - private ModelCaseMapper modelCaseMapper; + private final ModelCaseMapper modelCaseMapper; - @Autowired - private MinioFileMapper minioFileMapper; + private final MinioFileMapper minioFileMapper; private final OllamaChatClient chatClient; - - @Autowired - public RecordServiceImpl(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } - - @Autowired private ModelRecordTypeMapper modelRecordTypeMapper; @Override // @Transactional(rollbackFor = Exception.class) - public R uploadRecords(NoteRecords records, List fileList) throws IOException { + public String uploadRecords(NoteRecords records, List fileList) throws IOException { //上传文件,获取文件ids List fileIds = new ArrayList<>(); for (MultipartFile file : fileList) { @@ -180,14 +170,14 @@ public class RecordServiceImpl extends ServiceImpl // } } } - return R.okMsg("保存成功"); + return "保存成功"; } else { - return R.fail("保存笔录失败"); + return "保存笔录失败"; } } @Override - public R queryRecords(NoteRecords noteRecords, Integer page, Integer size) { + public Map queryRecords(NoteRecords noteRecords, Integer page, Integer size) { LambdaQueryWrapper wrapper = Wrappers.lambdaQuery(); wrapper.like(StringUtils.isNotEmpty(noteRecords.getName()), NoteRecords::getName, noteRecords.getName()) .eq(NoteRecords::getCaseId, noteRecords.getCaseId()) @@ -212,11 +202,11 @@ public class RecordServiceImpl extends ServiceImpl } } } - return R.ok(IPages.buildDataMap(pager, res.size())); + return IPages.buildDataMap(pager, res.size()); } @Override - public R delRecords(String id) { + public void delRecords(String id) { NoteRecords noteRecords = noteRecordsMapper.selectById(id); noteRecords.setDataStatus(StringUtils.getUUID()); noteRecordsMapper.updateById(noteRecords); @@ -230,7 +220,6 @@ public class RecordServiceImpl extends ServiceImpl minioService.delFile(fileId); } } - return R.ok(); } }