diff --git a/src/main/java/com/supervision/common/domain/R.java b/src/main/java/com/supervision/common/domain/R.java index 992df99..f15f7d0 100644 --- a/src/main/java/com/supervision/common/domain/R.java +++ b/src/main/java/com/supervision/common/domain/R.java @@ -21,10 +21,14 @@ public class R implements Serializable { public static final String TOTAL_COUNT = "total"; public static final String RESULT_LIST = "result"; - /** 成功 */ + /** + * 成功 + */ public static final int SUCCESS = Constants.SUCCESS; - /** 失败 */ + /** + * 失败 + */ public static final int FAIL = Constants.FAIL; private int code; @@ -37,6 +41,14 @@ public class R implements Serializable { return restResult(null, SUCCESS, null); } + public static R judgeResult(Boolean bo, String successMessage, String failMessage) { + if (bo) { + return restResult(null, SUCCESS, successMessage); + } else { + return restResult(null, FAIL, failMessage); + } + } + public static R okMsg(String msg) { return restResult(null, SUCCESS, msg); } @@ -69,7 +81,9 @@ public class R implements Serializable { return restResult(resultStatusEnum); } - public static R fail(ResultStatusEnum resultStatusEnum, T data) {return restResult(resultStatusEnum,data);} + public static R fail(ResultStatusEnum resultStatusEnum, T data) { + return restResult(resultStatusEnum, data); + } private static R restResult(ResultStatusEnum resultStatusEnum, T data) { R apiResult = new R<>(); @@ -96,7 +110,6 @@ public class R implements Serializable { } - public static Map buildDataMap(List list) { Map dataMap = new HashMap<>(); if (list == null) { diff --git a/src/main/java/com/supervision/neo4j/controller/Neo4jController.java b/src/main/java/com/supervision/neo4j/controller/Neo4jController.java index 1554864..f592460 100644 --- a/src/main/java/com/supervision/neo4j/controller/Neo4jController.java +++ b/src/main/java/com/supervision/neo4j/controller/Neo4jController.java @@ -65,7 +65,8 @@ public class Neo4jController { @PostMapping("/saveRelation") public R saveRelation(@RequestBody Rel rel) { - return neo4jService.saveRelation(rel); + Boolean result = neo4jService.saveRelation(rel); + return R.judgeResult(result, null, "保存失败"); } /*************************************************************************************/ diff --git a/src/main/java/com/supervision/neo4j/service/Neo4jService.java b/src/main/java/com/supervision/neo4j/service/Neo4jService.java index 7306722..9a5f6f9 100644 --- a/src/main/java/com/supervision/neo4j/service/Neo4jService.java +++ b/src/main/java/com/supervision/neo4j/service/Neo4jService.java @@ -23,7 +23,7 @@ public interface Neo4jService { Rel findRelation(Rel rel); - R saveRelation(Rel rel); + Boolean saveRelation(Rel rel); R getNode(String picType, String caseId); diff --git a/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java b/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java index 68da9d5..0961562 100644 --- a/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java +++ b/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java @@ -198,7 +198,7 @@ public class Neo4jServiceImpl implements Neo4jService { } @Override - public R saveRelation(Rel rel) { + public Boolean saveRelation(Rel rel) { Rel res = null; try { Session session = driver.session(); @@ -214,11 +214,7 @@ public class Neo4jServiceImpl implements Neo4jService { } catch (Exception e) { e.printStackTrace(); } - if (rel != null) { - return R.ok(rel); - } else { - return R.fail("保存失败"); - } + return rel != null; } @Override diff --git a/src/main/java/com/supervision/police/controller/RecordController.java b/src/main/java/com/supervision/police/controller/RecordController.java index 53f531f..55f2ae9 100644 --- a/src/main/java/com/supervision/police/controller/RecordController.java +++ b/src/main/java/com/supervision/police/controller/RecordController.java @@ -1,12 +1,16 @@ package com.supervision.police.controller; import com.supervision.common.domain.R; +import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NoteRecords; +import com.supervision.police.domain.TripleInfo; import com.supervision.police.dto.ListDTO; import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.RecordService; +import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; @@ -14,31 +18,32 @@ import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.List; +import java.util.Map; @RestController @Slf4j @RequestMapping("record") @ApiOperation(value = "笔录接口") +@RequiredArgsConstructor public class RecordController { - @Autowired - public ModelRecordTypeService modelRecordTypeService; + public final ModelRecordTypeService modelRecordTypeService; - @Autowired - public RecordService recordService; + public final RecordService recordService; /** * 查询笔录类型 + * * @param name * @param page * @param size * @return */ @GetMapping("queryType") - public R queryType(@RequestParam(required = false) String name, - @RequestParam(required = false, defaultValue = "1") Integer page, - @RequestParam(required = false, defaultValue = "20") Integer size) { - return modelRecordTypeService.queryType(name, page, size); + public R> queryType(@RequestParam(required = false) String name, + @RequestParam(required = false, defaultValue = "1") Integer page, + @RequestParam(required = false, defaultValue = "20") Integer size) { + return R.ok(modelRecordTypeService.queryType(name, page, size)); } // @PostMapping("saveType") @@ -48,6 +53,7 @@ public class RecordController { /** * 保存提示词 + * * @param prompt * @return */ @@ -58,6 +64,7 @@ public class RecordController { /** * 删除提示词 + * * @param prompt * @return */ @@ -68,56 +75,63 @@ public class RecordController { /** * 获取案件三元组信息 + * * @param caseId * @param name * @param recordId * @return */ @GetMapping("/getThreeInfo") - public R getThreeInfo(@RequestParam String caseId, - @RequestParam String name, - @RequestParam(required = false) String recordId) { - return modelRecordTypeService.getThreeInfo(caseId, name, recordId); + @ApiOperation("获取笔录的三元组信息") + public R> getThreeInfo(@RequestParam String caseId, + @RequestParam String name, + @RequestParam(required = false) String recordId) { + return R.ok(modelRecordTypeService.getThreeInfo(caseId, name, recordId)); } + @ApiOperation("将三元组信息保存到知识图谱") @PostMapping("/addNeo4j") - public R addNeo4j(@RequestBody ListDTO list) { - return modelRecordTypeService.addNeo4j(list.getIds()); + public R addNeo4j(@RequestBody ListDTO list) { + return R.ok(modelRecordTypeService.addNeo4j(list.getIds())); } /** * 上传笔录, 修改 + * * @param records * @return */ @PostMapping("/addOrUpdRecords") - public R uploadRecords(NoteRecords records, - @RequestPart("file") List fileList) throws IOException { - return recordService.uploadRecords(records, fileList); + public R uploadRecords(NoteRecords records, + @RequestPart("file") List fileList) throws IOException { + return R.ok(recordService.uploadRecords(records, fileList)); } /** * 查询笔录,按姓名为父目录 + * * @param noteRecords * @param page * @param size * @return */ @PostMapping("/queryRecords") - public R queryRecords(@RequestBody NoteRecords noteRecords, - @RequestParam(required = false, defaultValue = "1") Integer page, - @RequestParam(required = false, defaultValue = "20") Integer size) { - return recordService.queryRecords(noteRecords, page, size); + public R> queryRecords(@RequestBody NoteRecords noteRecords, + @RequestParam(required = false, defaultValue = "1") Integer page, + @RequestParam(required = false, defaultValue = "20") Integer size) { + return R.ok(recordService.queryRecords(noteRecords, page, size)); } /** * 删除 + * * @param id * @return */ @PostMapping("/delRecords") public R delRecords(@RequestParam String id) { - return recordService.delRecords(id); + recordService.delRecords(id); + return R.ok(); } diff --git a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java index ce7777d..51a799d 100644 --- a/src/main/java/com/supervision/police/service/ModelRecordTypeService.java +++ b/src/main/java/com/supervision/police/service/ModelRecordTypeService.java @@ -4,12 +4,13 @@ import com.baomidou.mybatisplus.extension.service.IService; import com.supervision.common.domain.R; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NotePrompt; +import com.supervision.police.domain.TripleInfo; import java.util.List; public interface ModelRecordTypeService extends IService { - R queryType(String name, Integer page, Integer size); + List queryType(String name, Integer page, Integer size); ModelRecordType queryByName(String content); @@ -19,7 +20,7 @@ public interface ModelRecordTypeService extends IService { R delPrompt(NotePrompt prompt); - R getThreeInfo(String caseId, String name, String recordId); + List getThreeInfo(String caseId, String name, String recordId); - R addNeo4j(List ids); + String addNeo4j(List ids); } diff --git a/src/main/java/com/supervision/police/service/RecordService.java b/src/main/java/com/supervision/police/service/RecordService.java index a0fe5a9..d7e5afd 100644 --- a/src/main/java/com/supervision/police/service/RecordService.java +++ b/src/main/java/com/supervision/police/service/RecordService.java @@ -8,13 +8,14 @@ import org.springframework.web.multipart.MultipartFile; import java.io.IOException; import java.util.List; +import java.util.Map; public interface RecordService extends IService { - R uploadRecords(NoteRecords records, List fileList) throws IOException; + String uploadRecords(NoteRecords records, List fileList) throws IOException; - R queryRecords(NoteRecords noteRecords, Integer page, Integer size); + Map queryRecords(NoteRecords noteRecords, Integer page, Integer size); - R delRecords(String id); + void delRecords(String id); } diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index e3cafef..4f7b17f 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.police.service.impl; +import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; @@ -17,6 +18,7 @@ import com.supervision.police.mapper.NoteRecordMapper; import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.mapper.TripleInfoMapper; import com.supervision.police.service.ModelRecordTypeService; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONArray; import org.json.JSONObject; @@ -24,6 +26,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; @@ -37,31 +40,23 @@ import java.util.List; @Slf4j @Service +@RequiredArgsConstructor public class ModelRecordTypeServiceImpl extends ServiceImpl implements ModelRecordTypeService { - @Autowired - private ModelRecordTypeMapper modelRecordTypeMapper; + private final ModelRecordTypeMapper modelRecordTypeMapper; - @Autowired - private NoteRecordMapper noteRecordMapper; + private final NoteRecordMapper noteRecordMapper; - @Autowired - private NotePromptMapper notePromptMapper; + private final NotePromptMapper notePromptMapper; - private final OllamaChatClient chatClient; - @Autowired - public ModelRecordTypeServiceImpl(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } + private final TripleInfoMapper tripleInfoMapper; - @Autowired - private TripleInfoMapper tripleInfoMapper; + private final Neo4jService neo4jService; - @Autowired - private Neo4jService neo4jService; + private final OllamaChatClient chatClient; @Override - public R queryType(String name, Integer page, Integer size) { + public List queryType(String name, Integer page, Integer size) { // IPage iPage = new Page<>(page, size); // iPage = modelRecordTypeMapper.selectByName(iPage, name); // return R.ok(IPages.buildDataMap(iPage)); @@ -75,7 +70,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl prompts = notePromptMapper.queryPrompt(modelRecordType.getId()); modelRecordType.setPrompts(prompts); } - return R.ok(list); + return list; // return R.ok(IPages.buildDataMap(iPage)); } @@ -128,8 +123,7 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { + private List extractTripleInfo(String caseId, String name, String recordId) { List records = noteRecordMapper.selectRecord(caseId, name, recordId); List tripleInfos = new ArrayList<>(); for (NoteRecord record : records) { @@ -139,17 +133,15 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl messages = new ArrayList<>(List.of(new SystemMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer()))); - Prompt ask = new Prompt(messages); StopWatch stopWatch = new StopWatch(); + // 分析三元组 + Prompt ask = new Prompt(new UserMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer())); stopWatch.start(); log.info("开始分析:"); ChatResponse call = chatClient.call(ask); stopWatch.stop(); log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - Generation result = call.getResult(); - String content = result.getOutput().getContent(); + String content = call.getResult().getOutput().getContent(); log.info("分析的结果是:{}", content); JSONObject jsonObject = new JSONObject(content); JSONArray threeInfo = jsonObject.getJSONArray("result"); @@ -160,9 +152,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { + // TODO 这里应该改成异步的形式,通过异步的形式来进行提取三元组信息,不能每次点击就跑一遍 + return extractTripleInfo(caseId, name, recordId); } @Override - public R addNeo4j(List ids) { + public String addNeo4j(List ids) { List tripleInfos = tripleInfoMapper.selectByIds(ids); int i = 0; for (TripleInfo tripleInfo : tripleInfos) { try { //开始节点 String start = tripleInfo.getStartNode(); + // 首先看是否已经存在了,如果已经存在了,就不添加了 CaseNode startNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getNoteRecordsId(), tripleInfo.getStartNodeType(), start, "1"); if (startNode == null) { startNode = new CaseNode(start, tripleInfo.getStartNodeType(), tripleInfo.getNoteRecordId(), tripleInfo.getNoteRecordsId(), tripleInfo.getCaseId(), "1"); @@ -203,17 +201,18 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl r = neo4jService.saveRelation(rel); + neo4jService.saveRelation(rel); } tripleInfo.setAddNeo4j("1"); int j = tripleInfoMapper.updateById(tripleInfo); if (j > 0) { i++; } + // TODO 重复添加的OK了,删除的呢? } catch (Exception e) { - e.printStackTrace(); + log.error(e.getMessage(), e); } } - return R.ok("成功插入" + i + "条信息"); + return ("成功插入" + i + "条信息"); } } diff --git a/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java b/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java index 1140ef2..5eb56bd 100644 --- a/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java @@ -24,6 +24,7 @@ import com.supervision.police.service.RecordService; import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.WordReadUtil; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.json.JSONObject; import org.springframework.ai.chat.ChatResponse; @@ -46,36 +47,25 @@ import java.util.stream.Collectors; @Slf4j @Service +@RequiredArgsConstructor public class RecordServiceImpl extends ServiceImpl implements RecordService { - @Autowired - private NoteRecordMapper noteRecordMapper; + private final NoteRecordMapper noteRecordMapper; - @Autowired - private NoteRecordsMapper noteRecordsMapper; + private final NoteRecordsMapper noteRecordsMapper; - @Autowired - private MinioService minioService; + private final MinioService minioService; - @Autowired - private ModelCaseMapper modelCaseMapper; + private final ModelCaseMapper modelCaseMapper; - @Autowired - private MinioFileMapper minioFileMapper; + private final MinioFileMapper minioFileMapper; private final OllamaChatClient chatClient; - - @Autowired - public RecordServiceImpl(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } - - @Autowired private ModelRecordTypeMapper modelRecordTypeMapper; @Override // @Transactional(rollbackFor = Exception.class) - public R uploadRecords(NoteRecords records, List fileList) throws IOException { + public String uploadRecords(NoteRecords records, List fileList) throws IOException { //上传文件,获取文件ids List fileIds = new ArrayList<>(); for (MultipartFile file : fileList) { @@ -181,14 +171,14 @@ public class RecordServiceImpl extends ServiceImpl // } } } - return R.okMsg("保存成功"); + return "保存成功"; } else { - return R.fail("保存笔录失败"); + return "保存笔录失败"; } } @Override - public R queryRecords(NoteRecords noteRecords, Integer page, Integer size) { + public Map queryRecords(NoteRecords noteRecords, Integer page, Integer size) { LambdaQueryWrapper wrapper = Wrappers.lambdaQuery(); wrapper.like(StringUtils.isNotEmpty(noteRecords.getName()), NoteRecords::getName, noteRecords.getName()) .eq(NoteRecords::getCaseId, noteRecords.getCaseId()) @@ -213,11 +203,11 @@ public class RecordServiceImpl extends ServiceImpl } } } - return R.ok(IPages.buildDataMap(pager, res.size())); + return IPages.buildDataMap(pager, res.size()); } @Override - public R delRecords(String id) { + public void delRecords(String id) { NoteRecords noteRecords = noteRecordsMapper.selectById(id); noteRecords.setDataStatus(StringUtils.getUUID()); noteRecordsMapper.updateById(noteRecords); @@ -231,7 +221,6 @@ public class RecordServiceImpl extends ServiceImpl minioService.delFile(fileId); } } - return R.ok(); } }