diff --git a/src/main/java/com/supervision/SpringAiDemoApplication.java b/src/main/java/com/supervision/SpringAiDemoApplication.java index f4e0920..548e7ca 100644 --- a/src/main/java/com/supervision/SpringAiDemoApplication.java +++ b/src/main/java/com/supervision/SpringAiDemoApplication.java @@ -3,7 +3,9 @@ package com.supervision; import org.mybatis.spring.annotation.MapperScan; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.scheduling.annotation.EnableAsync; +@EnableAsync @MapperScan(basePackages = {"com.supervision.**.mapper"}) @SpringBootApplication(scanBasePackages = {"com.supervision.**"}) public class SpringAiDemoApplication { diff --git a/src/main/java/com/supervision/neo4j/controller/Neo4jController.java b/src/main/java/com/supervision/neo4j/controller/Neo4jController.java index 634d1ac..e7ca8e2 100644 --- a/src/main/java/com/supervision/neo4j/controller/Neo4jController.java +++ b/src/main/java/com/supervision/neo4j/controller/Neo4jController.java @@ -89,10 +89,10 @@ public class Neo4jController { return neo4jService.getNode(picType, caseId); } - @GetMapping("/test") - public R test() { - return neo4jService.test(); - } +// @GetMapping("/test") +// public R test() { +// return neo4jService.test(); +// } @ApiOperation("构建抽象图谱") @GetMapping("createAbstractGraph") @@ -106,4 +106,10 @@ public class Neo4jController { neo4jService.deleteAbstractGraph(); } + @ApiOperation("mock测试数据") + @GetMapping("mockTestGraph") + public void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId, String caseId) { + neo4jService.mockTestGraph(path, sheetName, recordId, recordSplitId, caseId); + } + } diff --git a/src/main/java/com/supervision/neo4j/domain/CaseNode.java b/src/main/java/com/supervision/neo4j/domain/CaseNode.java index 0e0e6ce..91af38c 100644 --- a/src/main/java/com/supervision/neo4j/domain/CaseNode.java +++ b/src/main/java/com/supervision/neo4j/domain/CaseNode.java @@ -18,9 +18,9 @@ public class CaseNode { private String nodeType; - private String recordId; + private String recordSplitId; - private String recordsId; + private String recordId; private String caseId; @@ -37,20 +37,20 @@ public class CaseNode { this.name = name; } - public CaseNode(String name, String nodeType, String recordId, String recordsId, String caseId, String picType) { + public CaseNode(String name, String nodeType, String recordSplitId, String recordId, String caseId, String picType) { this.name = name; this.nodeType = nodeType; + this.recordSplitId = recordSplitId; this.recordId = recordId; - this.recordsId = recordsId; this.caseId = caseId; this.picType = picType; } - public CaseNode(Long id, String name, String nodeType, String recordId, String caseId, String picType) { + public CaseNode(Long id, String name, String nodeType, String recordSplitId, String caseId, String picType) { this.id = id; this.name = name; this.nodeType = nodeType; - this.recordId = recordId; + this.recordSplitId = recordSplitId; this.caseId = caseId; this.picType = picType; } diff --git a/src/main/java/com/supervision/neo4j/dto/WebRelDTO.java b/src/main/java/com/supervision/neo4j/dto/WebRelDTO.java new file mode 100644 index 0000000..20c185f --- /dev/null +++ b/src/main/java/com/supervision/neo4j/dto/WebRelDTO.java @@ -0,0 +1,19 @@ +package com.supervision.neo4j.dto; + +import lombok.Data; + +@Data +public class WebRelDTO { + + private long source; + + private long target; + + private String name; + + public WebRelDTO(long source, long target, String name) { + this.source = source; + this.target = target; + this.name = name; + } +} diff --git a/src/main/java/com/supervision/neo4j/service/Neo4jService.java b/src/main/java/com/supervision/neo4j/service/Neo4jService.java index 46ea25f..e1e8f15 100644 --- a/src/main/java/com/supervision/neo4j/service/Neo4jService.java +++ b/src/main/java/com/supervision/neo4j/service/Neo4jService.java @@ -21,6 +21,7 @@ public interface Neo4jService { CaseNode findById(Long id); List findByName(String caseId, String recordId, String nodeType, String name, String picType); + CaseNode findOneByName(String caseId, String recordId, String nodeType, String name, String picType); Rel findRelation(Rel rel); @@ -29,9 +30,11 @@ public interface Neo4jService { R getNode(String picType, String caseId); - R test(); + // R test(); void deleteAbstractGraph(); - void createAbstractGraph(String path,String sheetName); + void createAbstractGraph(String path, String sheetName); + + void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId,String caseId); } diff --git a/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java b/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java index 73981ab..3d6a52e 100644 --- a/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java +++ b/src/main/java/com/supervision/neo4j/service/impl/Neo4jServiceImpl.java @@ -7,6 +7,7 @@ import com.supervision.common.domain.R; import com.supervision.common.utils.StringUtils; import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.Rel; +import com.supervision.neo4j.dto.WebRelDTO; import com.supervision.neo4j.service.Neo4jService; import com.supervision.neo4j.utils.Neo4jUtils; import lombok.Data; @@ -39,7 +40,7 @@ public class Neo4jServiceImpl implements Neo4jService { if (StringUtils.isEmpty(caseNode.getName()) || StringUtils.isEmpty(caseNode.getNodeType())) { throw new RuntimeException("未传节点名称或节点类型或图谱类型!"); } - List byName = findByName(caseNode.getCaseId(), caseNode.getRecordsId(), caseNode.getNodeType(), caseNode.getName(), caseNode.getPicType()); + List byName = findByName(caseNode.getCaseId(), caseNode.getRecordId(), caseNode.getNodeType(), caseNode.getName(), caseNode.getPicType()); if (byName != null && !byName.isEmpty()) { throw new RuntimeException("名称已存在!"); } @@ -50,14 +51,14 @@ public class Neo4jServiceImpl implements Neo4jService { Map params = new HashMap<>(); cql.append("CREATE (n:").append(caseNode.getNodeType()).append("{name:$name"); params.put("name", caseNode.getName()); + if (StringUtils.isNotEmpty(caseNode.getRecordId())) { + cql.append(", recordSplitId:$recordSplitId"); + params.put("recordSplitId", caseNode.getRecordSplitId()); + } if (StringUtils.isNotEmpty(caseNode.getRecordId())) { cql.append(", recordId:$recordId"); params.put("recordId", caseNode.getRecordId()); } - if (StringUtils.isNotEmpty(caseNode.getRecordsId())) { - cql.append(", recordsId:$recordsId"); - params.put("recordsId", caseNode.getRecordsId()); - } if (StringUtils.isNotEmpty(caseNode.getCaseId())) { cql.append(", caseId:$caseId"); params.put("caseId", caseNode.getCaseId()); @@ -134,7 +135,7 @@ public class Neo4jServiceImpl implements Neo4jService { } @Override - public List findByName(String caseId, String recordsId, String nodeType, String name, String picType) { + public List findByName(String caseId, String recordId, String nodeType, String name, String picType) { List list = new ArrayList<>(); try { Session session = driver.session(); @@ -149,9 +150,9 @@ public class Neo4jServiceImpl implements Neo4jService { cql.append(" and n.caseId = "); cql.append(caseId); } - if (StringUtils.isNotEmpty(recordsId)) { - cql.append(" and n.recordsId = "); - cql.append(recordsId); + if (StringUtils.isNotEmpty(recordId)) { + cql.append(" and n.recordId = "); + cql.append(recordId); } if (StringUtils.isNotEmpty(name)) { cql.append(" and n.name = '"); @@ -172,7 +173,7 @@ public class Neo4jServiceImpl implements Neo4jService { } @Override - public CaseNode findOneByName(String caseId, String recordsId, String nodeType, String name, String picType) { + public CaseNode findOneByName(String caseId, String recordId, String nodeType, String name, String picType) { CaseNode node = null; try { Session session = driver.session(); @@ -188,9 +189,9 @@ public class Neo4jServiceImpl implements Neo4jService { cql.append(" and n.caseId = $caseId"); params.put("caseId", caseId); } - if (StringUtils.isNotEmpty(recordsId)) { - cql.append(" and n.recordsId = $recordsId"); - params.put("recordsId", recordsId); + if (StringUtils.isNotEmpty(recordId)) { + cql.append(" and n.recordId = $recordId"); + params.put("recordId", recordId); } if (StringUtils.isNotEmpty(name)) { cql.append(" and n.name = $name"); @@ -255,8 +256,8 @@ public class Neo4jServiceImpl implements Neo4jService { @Override public R getNode(String picType, String caseId) { Map map = new HashMap<>(); - List list = new ArrayList<>(); - List> nodes = new ArrayList<>(); + List list = new ArrayList<>(); + List> nodes = new ArrayList<>(); try { Session session = driver.session(); Map params = new HashMap<>(); @@ -266,23 +267,23 @@ public class Neo4jServiceImpl implements Neo4jService { " RETURN id(rel) as id, n.name as source, id(n) as sourceId, type(rel) as name, r.name as target, id(r) as targetId", params); while (run.hasNext()) { Record record = run.next(); - long id = record.get("id").asLong(); - String source = record.get("source").asString(); + //long id = record.get("id").asLong(); + //String source = record.get("source").asString(); long sourceId = record.get("sourceId").asLong(); String name = record.get("name").asString(); - String target = record.get("target").asString(); + //String target = record.get("target").asString(); long targetId = record.get("targetId").asLong(); - list.add(new Rel(id, source, sourceId, name, target, targetId)); + list.add(new WebRelDTO(sourceId, targetId, name)); } Result node = session.run("MATCH (n) where n.picType = $picType and n.caseId = $caseId RETURN id(n) as id, n.name as name", params); while (node.hasNext()) { Record record = node.next(); String name = record.get("name").asString(); - long idlong = record.get("id").asLong(); - Map nodeMap = new HashMap<>(); + long idLong = record.get("id").asLong(); + Map nodeMap = new HashMap<>(); nodeMap.put("name", name); nodeMap.put("entityName", name); -// nodeMap.put("id", idlong + ""); + nodeMap.put("id", idLong); nodes.add(nodeMap); } } catch (Exception e) { @@ -293,28 +294,29 @@ public class Neo4jServiceImpl implements Neo4jService { return R.ok(map); } - @Override - public R test() { - Session session = driver.session(); - Map params = new HashMap<>(); - params.put("lawActor", "行为人"); - params.put("lawParty", "aaaaaa"); - Result run = session.run("MATCH (m:LawActor), (n:FictionalOrgan) where m.name=$lawActor OPTIONAL MATCH (m)-[r:`冒充`]->(n) RETURN id(m) as startId, id(n) as endId, id(r) as relId, m.recordId as recordId, m.recordsId as recordsId", params); - while (run.hasNext()) { - Record record = run.next(); - String id = Neo4jUtils.valueTransportString(record.get("startId")); - String endId = Neo4jUtils.valueTransportString(record.get("endId")); - String relId = Neo4jUtils.valueTransportString(record.get("relId")); - System.out.println("************" + id); - System.out.println("************" + endId); - System.out.println("************" + relId); - } - return R.ok("222"); - } +// @Override +// public R test() { +// Session session = driver.session(); +// Map params = new HashMap<>(); +// params.put("lawActor", "行为人"); +// params.put("lawParty", "aaaaaa"); +// Result run = session.run("MATCH (m:LawActor), (n:FictionalOrgan) where m.name=$lawActor OPTIONAL MATCH (m)-[r:`冒充`]->(n) RETURN id(m) as startId, id(n) as endId, id(r) as relId, m.recordId as recordId, m.recordsId as recordsId", params); +// while (run.hasNext()) { +// Record record = run.next(); +// +// String id = Neo4jUtils.valueTransportString(record.get("startId")); +// String endId = Neo4jUtils.valueTransportString(record.get("endId")); +// String relId = Neo4jUtils.valueTransportString(record.get("relId")); +// System.out.println("************" + id); +// System.out.println("************" + endId); +// System.out.println("************" + relId); +// } +// return R.ok("222"); +// } @Override - public void createAbstractGraph(String path,String sheetName) { + public void createAbstractGraph(String path, String sheetName) { // 首先从数据库中读到数据 ExcelReader reader = ExcelUtil.getReader(path, sheetName); List abstractGraphExcelHeaders = reader.readAll(AbstractGraphExcelHeader.class); @@ -382,4 +384,45 @@ public class Neo4jServiceImpl implements Neo4jService { private String relation; private String to; } + + @Data + private static class MockDataGraphExcelHeader { + private String fromType; + private String from; + private String relation; + private String to; + private String toType; + } + + @Override + public void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId, String caseId) { + // 首先从数据库中读到数据 + ExcelReader reader = ExcelUtil.getReader(path, sheetName); + List mockDataGraphExcelList = reader.readAll(MockDataGraphExcelHeader.class); + Map nodeMap = new HashMap<>(); + Map relMap = new HashMap<>(); + for (MockDataGraphExcelHeader mockData : mockDataGraphExcelList) { + // from + if (!nodeMap.containsKey(mockData.getFrom())) { + CaseNode caseNode = new CaseNode(mockData.getFrom(), mockData.getFromType(), recordSplitId, recordId, caseId, "1"); + log.info("点:{}插入成功", mockData.getFrom()); + CaseNode save = save(caseNode); + nodeMap.put(mockData.getFrom(), save); + } + // to + if (!nodeMap.containsKey(mockData.getTo())) { + CaseNode caseNode = new CaseNode(mockData.getTo(), mockData.getToType(), recordSplitId, recordId, caseId, "1"); + CaseNode save = save(caseNode); + log.info("点:{}插入成功", mockData.getTo()); + nodeMap.put(mockData.getTo(), save); + } + // relation + if (!relMap.containsKey(mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo())) { + Rel rel = new Rel(nodeMap.get(mockData.getFrom()).getId(), mockData.getRelation(), nodeMap.get(mockData.getTo()).getId(), "1"); + saveRelation(rel); + log.info("关系:{}插入成功", (mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo())); + relMap.put(mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo(), rel); + } + } + } } diff --git a/src/main/java/com/supervision/police/controller/RecordController.java b/src/main/java/com/supervision/police/controller/RecordController.java index 368cbfd..c19d86a 100644 --- a/src/main/java/com/supervision/police/controller/RecordController.java +++ b/src/main/java/com/supervision/police/controller/RecordController.java @@ -7,7 +7,7 @@ import com.supervision.police.domain.NoteRecord; import com.supervision.police.domain.TripleInfo; import com.supervision.police.dto.ListDTO; import com.supervision.police.service.ModelRecordTypeService; -import com.supervision.police.service.RecordService; +import com.supervision.police.service.NoteRecordSplitService; import io.swagger.annotations.ApiOperation; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -27,7 +27,7 @@ public class RecordController { public final ModelRecordTypeService modelRecordTypeService; - public final RecordService recordService; + public final NoteRecordSplitService noteRecordSplitService; /** * 查询笔录类型 @@ -107,7 +107,7 @@ public class RecordController { @PostMapping("/addOrUpdRecords") public R uploadRecords(NoteRecord records, @RequestPart("file") List fileList) throws IOException { - return R.ok(recordService.uploadRecords(records, fileList)); + return R.ok(noteRecordSplitService.uploadRecords(records, fileList)); } /** @@ -122,7 +122,7 @@ public class RecordController { public R> queryRecords(@RequestBody NoteRecord noteRecord, @RequestParam(required = false, defaultValue = "1") Integer page, @RequestParam(required = false, defaultValue = "20") Integer size) { - return R.ok(recordService.queryRecords(noteRecord, page, size)); + return R.ok(noteRecordSplitService.queryRecords(noteRecord, page, size)); } /** @@ -133,7 +133,7 @@ public class RecordController { */ @PostMapping("/delRecords") public R delRecords(@RequestParam String id) { - recordService.delRecords(id); + noteRecordSplitService.delRecords(id); return R.ok(); } diff --git a/src/main/java/com/supervision/police/domain/TripleInfo.java b/src/main/java/com/supervision/police/domain/TripleInfo.java index a123fbf..193a9c8 100644 --- a/src/main/java/com/supervision/police/domain/TripleInfo.java +++ b/src/main/java/com/supervision/police/domain/TripleInfo.java @@ -36,19 +36,14 @@ public class TripleInfo implements Serializable { */ private String relation; - /** - * 笔录片段id - */ - private String noteRecordId; + private String caseId; - @TableField(exist = false) - private String noteRecordsId; + private String recordId; /** - * 案件id + * 笔录片段id */ - @TableField(exist = false) - private String caseId; + private String recordSplitId; /** * 是否生成图谱 @@ -74,7 +69,7 @@ public class TripleInfo implements Serializable { * 创建时间 */ @TableField(fill = FieldFill.INSERT_UPDATE) - @JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8") private LocalDateTime createTime; /** @@ -87,7 +82,7 @@ public class TripleInfo implements Serializable { * 更新时间 */ @TableField(fill = FieldFill.INSERT_UPDATE) - @JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") + @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8") private LocalDateTime updateTime; @TableField(exist = false) @@ -97,11 +92,13 @@ public class TripleInfo implements Serializable { } // todo - public TripleInfo(String startNode, String endNode, String relation, String noteRecordId, LocalDateTime createTime, String startNodeType, String endNodeType) { + public TripleInfo(String startNode, String endNode, String relation,String caseId, String recordId, String recordSplitId, LocalDateTime createTime, String startNodeType, String endNodeType) { this.startNode = startNode; this.endNode = endNode; this.relation = relation; - this.noteRecordId = noteRecordId; + this.caseId = caseId; + this.recordId = recordId; + this.recordSplitId = recordSplitId; this.createTime = createTime; this.startNodeType = startNodeType; this.endNodeType = endNodeType; diff --git a/src/main/java/com/supervision/police/mapper/NotePromptMapper.java b/src/main/java/com/supervision/police/mapper/NotePromptMapper.java index ad857fb..5288516 100644 --- a/src/main/java/com/supervision/police/mapper/NotePromptMapper.java +++ b/src/main/java/com/supervision/police/mapper/NotePromptMapper.java @@ -8,6 +8,5 @@ import java.util.List; public interface NotePromptMapper extends BaseMapper { - List queryPrompt(@Param("typeId") String typeId); } diff --git a/src/main/java/com/supervision/police/mapper/TripleInfoMapper.java b/src/main/java/com/supervision/police/mapper/TripleInfoMapper.java index 604d297..95864dc 100644 --- a/src/main/java/com/supervision/police/mapper/TripleInfoMapper.java +++ b/src/main/java/com/supervision/police/mapper/TripleInfoMapper.java @@ -8,6 +8,5 @@ import java.util.List; public interface TripleInfoMapper extends BaseMapper { - List selectByIds(@Param("ids") List ids); } diff --git a/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java b/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java new file mode 100644 index 0000000..d9805e3 --- /dev/null +++ b/src/main/java/com/supervision/police/service/ExtractTripleInfoService.java @@ -0,0 +1,6 @@ +package com.supervision.police.service; + +public interface ExtractTripleInfoService { + + void extractTripleInfo(String caseId, String name, String recordId); +} diff --git a/src/main/java/com/supervision/police/service/NoteCheckRecordService.java b/src/main/java/com/supervision/police/service/NoteCheckRecordService.java index afe39e2..9a3c099 100644 --- a/src/main/java/com/supervision/police/service/NoteCheckRecordService.java +++ b/src/main/java/com/supervision/police/service/NoteCheckRecordService.java @@ -1,7 +1,11 @@ package com.supervision.police.service; +import com.supervision.police.domain.NoteRecordSplit; import com.supervision.springaidemo.domain.NoteCheckRecord; import com.baomidou.mybatisplus.extension.service.IService; +import org.apache.ibatis.annotations.Param; + +import java.util.List; /** * @author flevance @@ -10,4 +14,6 @@ import com.baomidou.mybatisplus.extension.service.IService; */ public interface NoteCheckRecordService extends IService { + + } diff --git a/src/main/java/com/supervision/police/service/NotePromptService.java b/src/main/java/com/supervision/police/service/NotePromptService.java new file mode 100644 index 0000000..d122688 --- /dev/null +++ b/src/main/java/com/supervision/police/service/NotePromptService.java @@ -0,0 +1,7 @@ +package com.supervision.police.service; + +import com.baomidou.mybatisplus.extension.service.IService; +import com.supervision.police.domain.NotePrompt; + +public interface NotePromptService extends IService { +} diff --git a/src/main/java/com/supervision/police/service/RecordService.java b/src/main/java/com/supervision/police/service/NoteRecordSplitService.java similarity index 87% rename from src/main/java/com/supervision/police/service/RecordService.java rename to src/main/java/com/supervision/police/service/NoteRecordSplitService.java index 67e1724..78984b4 100644 --- a/src/main/java/com/supervision/police/service/RecordService.java +++ b/src/main/java/com/supervision/police/service/NoteRecordSplitService.java @@ -9,7 +9,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; -public interface RecordService extends IService { +public interface NoteRecordSplitService extends IService { String uploadRecords(NoteRecord records, List fileList) throws IOException; diff --git a/src/main/java/com/supervision/police/service/TripleInfoService.java b/src/main/java/com/supervision/police/service/TripleInfoService.java new file mode 100644 index 0000000..29ba353 --- /dev/null +++ b/src/main/java/com/supervision/police/service/TripleInfoService.java @@ -0,0 +1,7 @@ +package com.supervision.police.service; + +import com.baomidou.mybatisplus.extension.service.IService; +import com.supervision.police.domain.TripleInfo; + +public interface TripleInfoService extends IService { +} diff --git a/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java new file mode 100644 index 0000000..db021af --- /dev/null +++ b/src/main/java/com/supervision/police/service/impl/ExtractTripleInfoServiceImpl.java @@ -0,0 +1,100 @@ +package com.supervision.police.service.impl; + +import cn.hutool.core.util.StrUtil; +import com.alibaba.druid.util.StringUtils; +import com.supervision.police.domain.NotePrompt; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.domain.TripleInfo; +import com.supervision.police.mapper.NotePromptMapper; +import com.supervision.police.mapper.NoteRecordSplitMapper; +import com.supervision.police.mapper.TripleInfoMapper; +import com.supervision.police.service.*; +import com.supervision.thread.TripleExtractThread; +import com.supervision.thread.TripleExtractThreadPool; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.util.StopWatch; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.Future; + +@Slf4j +@Service +@RequiredArgsConstructor +public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService { + + private final NoteRecordSplitMapper noteRecordSplitMapper; + + private final NotePromptService notePromptService; + + private final TripleInfoService tripleInfoService; + + private final OllamaChatClient chatClient; + + + @Async + public void extractTripleInfo(String caseId, String name, String recordId) { + // 首先获取所有切分后的笔录 + List recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId); + List tripleInfos = new ArrayList<>(); + List> futures = new ArrayList<>(); + // 对切分后的笔录进行遍历 + for (NoteRecordSplit recordSplit : recordSplitList) { + // 根据笔录类型找到所有的提取三元组的提示词 + List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, recordSplit.getRecordTypeId()).list(); + // 遍历提示词进行提取 + for (NotePrompt prompt : prompts) { + if (StringUtils.isEmpty(prompt.getPrompt())) { + continue; + } + try { + log.info("提交任务到线程池中进行三元组提取"); + Future submit = TripleExtractThreadPool.chatExecutor.submit(new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt.getPrompt(), recordSplit.getQuestion(), recordSplit.getAnswer())); + futures.add(submit); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } + } + while (futures.size() > 0) { + Iterator> iterator = futures.iterator(); + while (iterator.hasNext()) { + Future future = iterator.next(); + try { + // 如果提取到结果,且不为空,就进行保存 + if (future.isDone()) { + TripleInfo tripleInfo = future.get(); + if (tripleInfo != null) { + tripleInfos.add(tripleInfo); + } + iterator.remove(); + } + } catch (Exception e) { + log.info("从线程中获取任务失败"); + iterator.remove(); + } + } + try { + log.info("检查一遍,休眠1s后继续检查"); + Thread.sleep(1000); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } + // 首先清除 + tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove(); + // 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取 + tripleInfoService.saveBatch(tripleInfos); + } +} diff --git a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java index bcf45af..078b10e 100644 --- a/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/ModelRecordTypeServiceImpl.java @@ -1,6 +1,5 @@ package com.supervision.police.service.impl; -import cn.hutool.core.util.StrUtil; import com.alibaba.druid.util.StringUtils; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; @@ -9,19 +8,12 @@ import com.supervision.common.domain.R; import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.service.Neo4jService; -import com.supervision.police.domain.ModelRecordType; -import com.supervision.police.domain.NoteRecordSplit; -import com.supervision.police.domain.NotePrompt; -import com.supervision.police.domain.TripleInfo; +import com.supervision.police.domain.*; import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; -import com.supervision.police.mapper.NotePromptMapper; -import com.supervision.police.mapper.TripleInfoMapper; -import com.supervision.police.service.ModelRecordTypeService; +import com.supervision.police.service.*; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.json.JSONArray; -import org.json.JSONObject; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -30,8 +22,8 @@ import org.springframework.stereotype.Service; import org.springframework.util.StopWatch; import java.time.LocalDateTime; -import java.util.ArrayList; import java.util.List; +import java.util.Optional; @Slf4j @Service @@ -42,19 +34,20 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl queryType(String name, Integer page, Integer size) { -// IPage iPage = new Page<>(page, size); -// iPage = modelRecordTypeMapper.selectByName(iPage, name); -// return R.ok(IPages.buildDataMap(iPage)); List list = modelRecordTypeMapper.selectByName(name); for (ModelRecordType modelRecordType : list) { @@ -62,11 +55,10 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType()); modelRecordType.setRecords(noteRecords); //提示词 - List prompts = notePromptMapper.queryPrompt(modelRecordType.getId()); + List prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list(); modelRecordType.setPrompts(prompts); } return list; -// return R.ok(IPages.buildDataMap(iPage)); } @Override @@ -95,12 +87,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl addOrUpdPrompt(NotePrompt prompt) { int i = 0; + boolean save; if (StringUtils.isEmpty(prompt.getId())) { - i = notePromptMapper.insert(prompt); + save = notePromptService.save(prompt); } else { - i = notePromptMapper.updateById(prompt); + save = notePromptService.updateById(prompt); } - if (i > 0) { + if (save) { return R.ok("保存成功"); } else { return R.fail("保存失败"); @@ -110,8 +103,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl delPrompt(NotePrompt prompt) { String id = prompt.getId(); - int i = notePromptMapper.deleteById(id); - if (i > 0) { + boolean removeById = notePromptService.removeById(id); + if (removeById) { return R.ok("删除成功"); } else { return R.fail("删除失败"); @@ -121,8 +114,52 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl getThreeInfo(String caseId, String name, String recordId) { - // TODO 这里应该改成异步的形式,通过异步的形式来进行提取三元组信息,不能每次点击就跑一遍 - return extractTripleInfo(caseId, name, recordId); + boolean taskStatus = taskExtractStatusCheck(caseId, recordId); + // 如果校验结果为false,则说明需要进行提取三元组操作 + if (!taskStatus) { + extractTripleInfo.extractTripleInfo(caseId, name, recordId); + } + // 这里进行查询 + return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list(); + } + + /** + * 提取任务校验,校验是否已经存在相关的人物,如果存在相关的任务,就不再继续执行了,直接告诉任务正在执行中 + */ + private boolean taskExtractStatusCheck(String caseId, String recordId) { + // 首先查询是否存在任务,如果不存在,就新建 + Optional caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt(); + if (caseTaskRecordOpt.isEmpty()) { + CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord(); + newCaseTaskRecord.setCaseId(caseId); + newCaseTaskRecord.setRecordId(recordId); + newCaseTaskRecord.setStatus(1); + newCaseTaskRecord.setSubmitTime(LocalDateTime.now()); + caseTaskRecordService.save(newCaseTaskRecord); + return false; + } else { + + // 如果存在,则校验时间是否已经超过1天,如果超过了1天还没有执行完毕,就重新提交这个任务 + CaseTaskRecord caseTaskRecord = caseTaskRecordOpt.get(); + if (caseTaskRecordOpt.get().getStatus() == 1 && LocalDateTime.now().isAfter(caseTaskRecord.getSubmitTime().plusDays(1))) { + // 如果已经超过1天,则重新提交任务 + caseTaskRecord.setStatus(1); + caseTaskRecord.setSubmitTime(LocalDateTime.now()); + caseTaskRecordService.updateById(caseTaskRecord); + return false; + } else if (caseTaskRecordOpt.get().getStatus() == 2) { + return true; + } else if (caseTaskRecordOpt.get().getStatus() == 3) { + caseTaskRecord.setStatus(1); + caseTaskRecord.setSubmitTime(LocalDateTime.now()); + caseTaskRecordService.updateById(caseTaskRecord); + return false; + } else { + // 如果没有超过1天,则返回正在执行中 + throw new RuntimeException("任务正在执行中"); + + } + } } @Override @@ -152,95 +189,45 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl extractTripleInfo(String caseId, String name, String recordId) { - // 首先获取所有切分后的笔录 - List recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId); - List tripleInfos = new ArrayList<>(); - // 对切分后的笔录进行遍历 - for (NoteRecordSplit record : recordSplitList) { - // 根据笔录类型找到所有的提取三元组的提示词 - List prompts = notePromptMapper.queryPrompt(record.getRecordTypeId()); - // 遍历提示词进行提取 - for (NotePrompt prompt : prompts) { - if (StringUtils.isEmpty(prompt.getPrompt())) { - continue; - } - try { - StopWatch stopWatch = new StopWatch(); - // 分析三元组 - Prompt ask = new Prompt(new UserMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer())); - stopWatch.start(); - log.info("开始分析:"); - ChatResponse call = chatClient.call(ask); - stopWatch.stop(); - log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - String content = call.getResult().getOutput().getContent(); - log.info("分析的结果是:{}", content); - // 获取从提示词中提取到的三元组信息 - JSONObject jsonObject = new JSONObject(content); - JSONArray threeInfo = jsonObject.getJSONArray("result"); - for (int i = 0; i < threeInfo.length(); i++) { - JSONObject object = threeInfo.getJSONObject(i); - String startNodeType = object.getString("startNodeType"); - String entity = object.getString("entity"); - String endNodeType = object.getString("endNodeType"); - String property = object.getString("property"); - String value = object.getString("value"); - // 去空,如果存在任何的空值,则忽略 - if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { - continue; - } - // 将三元组信息进行保存操作 - TripleInfo tripleInfo = new TripleInfo(entity, property, value, record.getId(), LocalDateTime.now(), startNodeType, endNodeType); - tripleInfoMapper.insert(tripleInfo); - tripleInfos.add(tripleInfo); - } - } catch (Exception e) { - log.error(e.getMessage(), e); - } - } - } - return tripleInfos; - } @Override public String addNeo4j(List ids) { - List tripleInfos = tripleInfoMapper.selectByIds(ids); + List tripleInfos = tripleInfoService.listByIds(ids); int i = 0; for (TripleInfo tripleInfo : tripleInfos) { try { //开始节点 String start = tripleInfo.getStartNode(); // 首先看是否已经存在了,如果已经存在了,就不添加了 - CaseNode startNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getNoteRecordsId(), tripleInfo.getStartNodeType(), start, "1"); + CaseNode startNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getRecordId(), tripleInfo.getStartNodeType(), start, "1"); if (startNode == null) { - startNode = new CaseNode(start, tripleInfo.getStartNodeType(), tripleInfo.getNoteRecordId(), tripleInfo.getNoteRecordsId(), tripleInfo.getCaseId(), "1"); + startNode = new CaseNode(start, tripleInfo.getStartNodeType(), tripleInfo.getRecordSplitId(), tripleInfo.getRecordId(), tripleInfo.getCaseId(), "1"); CaseNode save = neo4jService.save(startNode); startNode.setId(save.getId()); } //结束节点 String end = tripleInfo.getEndNode(); - CaseNode endNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getNoteRecordsId(), tripleInfo.getEndNodeType(), end, "1"); + CaseNode endNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getRecordId(), tripleInfo.getEndNodeType(), end, "1"); if (endNode == null) { - endNode = new CaseNode(end, tripleInfo.getEndNodeType(), tripleInfo.getNoteRecordId(), tripleInfo.getNoteRecordsId(), tripleInfo.getCaseId(), "1"); + endNode = new CaseNode(end, tripleInfo.getEndNodeType(), tripleInfo.getRecordSplitId(), tripleInfo.getRecordId(), tripleInfo.getCaseId(), "1"); CaseNode save = neo4jService.save(endNode); endNode.setId(save.getId()); } @@ -251,8 +238,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl 0) { + boolean updateResult = tripleInfoService.updateById(tripleInfo); + if (updateResult) { i++; } // TODO 重复添加的OK了,删除的呢? diff --git a/src/main/java/com/supervision/police/service/impl/NoteCheckRecordServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteCheckRecordServiceImpl.java index 03c254c..e358880 100644 --- a/src/main/java/com/supervision/police/service/impl/NoteCheckRecordServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteCheckRecordServiceImpl.java @@ -1,11 +1,14 @@ package com.supervision.police.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.police.domain.NoteRecordSplit; import com.supervision.springaidemo.domain.NoteCheckRecord; import com.supervision.police.service.NoteCheckRecordService; import com.supervision.police.mapper.NoteCheckRecordMapper; import org.springframework.stereotype.Service; +import java.util.List; + /** * @author flevance * @description 针对表【note_check_record(案件执行验证结果)】的数据库操作Service实现 @@ -15,6 +18,7 @@ import org.springframework.stereotype.Service; public class NoteCheckRecordServiceImpl extends ServiceImpl implements NoteCheckRecordService{ + } diff --git a/src/main/java/com/supervision/police/service/impl/NotePromptServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NotePromptServiceImpl.java new file mode 100644 index 0000000..0ad8661 --- /dev/null +++ b/src/main/java/com/supervision/police/service/impl/NotePromptServiceImpl.java @@ -0,0 +1,12 @@ +package com.supervision.police.service.impl; + +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.police.domain.NotePrompt; +import com.supervision.police.mapper.NotePromptMapper; +import com.supervision.police.service.NotePromptService; +import org.springframework.stereotype.Service; + +@Service +public class NotePromptServiceImpl extends ServiceImpl implements NotePromptService { + +} diff --git a/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java similarity index 98% rename from src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java rename to src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java index 9013d45..013aa1f 100644 --- a/src/main/java/com/supervision/police/service/impl/RecordServiceImpl.java +++ b/src/main/java/com/supervision/police/service/impl/NoteRecordSplitServiceImpl.java @@ -18,7 +18,7 @@ import com.supervision.police.mapper.ModelCaseMapper; import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordMapper; -import com.supervision.police.service.RecordService; +import com.supervision.police.service.NoteRecordSplitService; import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.WordReadUtil; @@ -41,7 +41,7 @@ import java.util.stream.Collectors; @Slf4j @Service @RequiredArgsConstructor -public class RecordServiceImpl extends ServiceImpl implements RecordService { +public class NoteRecordSplitServiceImpl extends ServiceImpl implements NoteRecordSplitService { private final NoteRecordSplitMapper noteRecordSplitMapper; diff --git a/src/main/java/com/supervision/police/service/impl/TripleInfoServiceImpl.java b/src/main/java/com/supervision/police/service/impl/TripleInfoServiceImpl.java new file mode 100644 index 0000000..347f093 --- /dev/null +++ b/src/main/java/com/supervision/police/service/impl/TripleInfoServiceImpl.java @@ -0,0 +1,12 @@ +package com.supervision.police.service.impl; + +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import com.supervision.police.domain.TripleInfo; +import com.supervision.police.mapper.TripleInfoMapper; +import com.supervision.police.service.TripleInfoService; +import org.springframework.stereotype.Service; + +@Service +public class TripleInfoServiceImpl extends ServiceImpl implements TripleInfoService { + +} diff --git a/src/main/java/com/supervision/springaidemo/controller/ChatController.java b/src/main/java/com/supervision/springaidemo/controller/ChatController.java deleted file mode 100644 index 68c151c..0000000 --- a/src/main/java/com/supervision/springaidemo/controller/ChatController.java +++ /dev/null @@ -1,336 +0,0 @@ -package com.supervision.springaidemo.controller; - -import cn.hutool.core.collection.CollUtil; -import cn.hutool.core.collection.ListUtil; -import cn.hutool.core.io.FileUtil; -import cn.hutool.core.util.StrUtil; -import cn.hutool.json.JSONUtil; -import com.supervision.springaidemo.domain.ModelMetric; -import com.supervision.springaidemo.domain.NoteCheckRecord; -import com.supervision.springaidemo.dto.MetricResultDTO; -import com.supervision.springaidemo.service.ModelMetricService; -import com.supervision.police.service.NoteCheckRecordService; -import com.supervision.springaidemo.thread.RunCheckThread; -import com.supervision.springaidemo.thread.RunCheckThreadPool; -import com.supervision.springaidemo.util.WordReadUtil; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.Generation; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RestController; - -import java.io.BufferedReader; -import java.io.File; -import java.util.*; - -@RestController -@Slf4j -public class ChatController { - - private final OllamaChatClient chatClient; - - @Autowired - private ModelMetricService modelMetricService; - - @Autowired - private NoteCheckRecordService noteCheckRecordService; - - // 使用多线程进行提交 - - - @Autowired - public ChatController(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } - - @GetMapping("/ai/chat") - public void generate() { - String template = """ - 一般嫌疑人或者受害者在笔录中说什么话能对应到下面指标呢 - - 例子:指标:是否有金钱往来。回复:["XXX给我转来的钱。","我收到过XXX给我转来的钱","我通过银行给XXX转了钱"] - 现在请对下面这个指标进行举例,例子要尽可能覆盖更多的情况;既要有行为人可能承认的话10条,还要有受害者可能指认的话10条! - 指标:行为人在合同签订后支付了部分货款,并骗取全部货物后,在规定的期限内无正当理由不支付其余货款 - 回复json格式:{"behavior":[""],"victim":[""]} - """; - Prompt prompt = new Prompt(List.of(new UserMessage(template))); - ChatResponse call = chatClient.call(prompt); - Generation result = call.getResult(); - String content = result.getOutput().getContent(); - log.info(content); - } - - private Message buildMessage(Map param) { - String messageTemplate = """ - 以下是案件的身份信息 - --- - 行为人:金(李),吴 - 受害人:刘,尚,张 - --- - 以下是笔录的内容: - --- - 问:我们是经侦大队的民警(出示工作证件),现依法向你询问有关问题。根据刑事诉讼法的有关规定,你应当如实提供证据、证言,如果有意作伪证或者隐匿罪证的,要负法律责任。你明白吗?答:明白。 - 问:现向你宣读《证人诉讼权利义务告知书》(向当事人宣读《证人诉讼权利义务告知书》,并将《证人诉讼权利义务告知书》送交当事人),你对你的权利义务是否清楚?答:清楚了。 - 问:你有什么要求吗?答:没有。 - 问:你的个人情况?答:我叫张,曾用名无,男,1980年出生,汉族,高中文化程度,户籍所在地宁夏固原市,现住宁夏固原市,现在无工作,居民身份证号码642226,联系电话13519。 - 问:你是否是中国共产党党员或国家机关工作人员?答:都不是。 - 问:你认识刘吗?答:认识,我们是朋友。 - 问:刘于2022年03月23日在我局报称:他和你、尚3人想合作经营一家汽车公司,被一名男子金骗了,该金自称是汽车有限公司的工作人员,负责西北地区推销,与刘签订了《汽车租赁合同》,租赁8辆新能源汽车,并协商每辆汽车缴纳5万元的押金使用3年,3年后由汽车公司收回车辆,退还押金,后刘分2次给该公司公户内转账20万元前期押金,剩余20万元押金等车辆到固原后将车辆上牌,再进行支付。支付押金后,金与该公司以各种借口推诿,至今未将租赁车辆交给你们3人,是否属实?答:属实。 - 问:你是如何得知汽车公司的?答:是金给我介绍的。 - 问:金是如何认识的?答:金是我通过平台认识,之后我们成为微信好友。 - 问:金是如何给你介绍汽车的?答:我之前与金在平台认识,之后我们成为微信好友,金平常就在微信内给我推送关于汽车的模式和链接。我当时看这个生意可以盈利,就与金了解汽车。 - 问:你是如何给刘和尚介绍的?答:当时我与金了解了汽车,前期需要投入40万元,我能力有限,就将此事与刘和尚商议,我当时给他们说合作在固原开一家汽车运营公司(包括卖车业务、租车业务、还有新能源汽车销售业务),我们3人商议后达成了合作协议。 - 问:你与刘、尚是如何出资的?答:前期需要缴纳20万元的押金,我出资了3万元,尚出资了5万元,刘出资了12万元,我和尚的钱全部交给刘,是刘分2次给汽车公司转账的。 - 问:你们具体是谁与金联系对接业务的?答:具体尚与金联系对接业务的。 - 问:你们将押金转给该公司后,该公司是否将你们预定的车辆发给你们?答:我们将20万元的押金转给该公司后,该公司与金以各种借口推诿,一直未将我们预定的车辆发给我们。之后我们到你们公安局报案,经你们民警联系后,该公司与金于2022年给我们发了一辆电动车,我当时察看后,该车是一辆旧车,也不能正常从板车上开下,我就没有要。 - 问:是谁给你通知让你去接车的?答:当时是一个山东省电话号码给我打电话,让我去固原市高速公路南出口接车,我去看见是一辆旧车,我就询问车辆司机,该车是从什么地方拉来的,司机称是从西安市拉来的。 - 问:你们当时是谁与金签订的《汽车租赁合同》?答:是刘与金签订的。 - 问:你们除了签订《汽车租赁合同》,还签订什么协议了吗?答:我们还签订了一份《授权及扶持补充协议》。 - 问:签订的《汽车租赁合同》、《授权及扶持补充协议》是谁制作的?答:是金带来的。 - 问:签订合同时,你是否在场?现场还有谁?答:我在场,现场还有尚,还有与金一起的女子。 - 问:你们当时签订合同预定的都是什么牌子的车辆?答:总共8辆全新的(刚出厂未挂牌)汽车,分别是创维4辆,奔腾2辆,尼桑1辆,大众朗逸1辆。 - 问:你讲一下金的基本情况?答:该男子真名不叫金,我们报案后,才知道该男子叫李,是山东人,身份证号码是:37078,联系电话是18,微信号是J,其他什么情况我不清楚。 - 问:你讲一下与金一起来的女子的基本情况?答:我只知道该女子姓吴,金称是他们公司的行政经理,联系电话是178,其他什么情况我不知道。 - 问:该姓吴女子来固原具体都做了什么工作?答:她只是来固原给我们办理注册公司。 - 问:你还有什么需要补充说明的吗?答:2022年01月刘将剩余的16万元转给该公司账户内,01月金(李)在微信内建了个微信群,群内有我、刘、尚、金(李)、刘、吴6人,称刘是汽车公司负责给各地发车,吴现在已经不在该群内了。 - 问:你以上所讲的是否属实?答:属实。 - 问:以上笔录请你仔细阅看。如果记录有误请指出来,我们即给予更正。请你确认记录无误后再在笔录上逐页签名。答:好的。 - --- - - 现在需要你根据以上内容,进行判断并以简体中文输出下面的各项,注意:如果笔录里面有提到了存在相关证据,则你可以认为这些证据文件是真实存在的 - 1.指标名称:{metricName}。 - 2.结论:true({metricTrueDesc})/false({metricFalseDesc}),直接给我true/false。 - 3.笔录对应原话:从笔录的对话中,得到该结论的原文(一定是摘抄的原文)。 - 4.原因:分析得出该结论的原因,需明确说明为什么得到该结论,需要逻辑清晰完整。 - - 判断结果以json格式回复, JSON的value内容我给你的提示,在实际输出的时候不需要带上: - --- - {"metricName":"指标名称", "result":"结论", "originalContext":"笔录对应原话","reason":"原因"} - --- - """; - String format = StrUtil.format(messageTemplate, param); - return new UserMessage(format); - } - - - @GetMapping("/ai/run") - public void run() { - var list = modelMetricService.list(); - for (ModelMetric modelMetric : list) { - Map param = new HashMap<>(); - param.put("metricTrueDesc", modelMetric.getMetricTrueDesc()); - param.put("metricFalseDesc", modelMetric.getMetricFalseDesc()); - param.put("metricName", modelMetric.getMetricName()); - Message message = buildMessage(param); - Prompt prompt = new Prompt(List.of(new SystemMessage("所有的回复以简体中文回答。请以step by step的方式进行。step1:理解笔录设计人员的身份信息;step2:根据笔录的内容分析案件之间的逻辑关系和关联;step3:判断给定的指标是否满足。step4:根据要求的给定格式进行回复"), message)); - log.info("prompt:{}", prompt.toString()); - ChatResponse call = chatClient.call(prompt); - Generation result = call.getResult(); - String content = result.getOutput().getContent(); - log.info(content); - } - } - - private static final String template = """ - 我们现在需要以step by step的方式进行笔录的指标分析工作,得到最终的结果并返回。 - step1:理解下面人员身份信息; - --- - 行为人:{actionUserNameList} - 犯罪嫌疑人:{suspectUserNameList} - 受害人:{victimUserNameList} - 证人:{witnessNameList} - --- - - step2:分析笔录的内容; - 以下是笔录的内容,笔录中"问"是办案警官问,"答"是{noteUserName}回答: - --- - {context} - --- - - step3:现在给你指标以及指标的释义或例子: - 指标:{metricName} - 指标释义或例子及判断标准: - 如({metricTrueDesc}),则为true; - 如({metricFalseDesc}),则为false; - 如果笔录中,没有任何笔录内容涉及到该项指标,则为empty。 - - step4:现在需要你根据上面提供的所有信息,尽可能实事求是完成判断: - 1.判断结论:true/false/empty - 2.得到结论的笔录原话:从笔录的对话中,得到该结论的原文(一定是摘抄的原文且为中文)。如果结论为true,则必须要有原文佐证! - 3.得到结论的原因:分析得出该结论的原因,需明确说明为什么得到该结论,需要实事求是且为中文回复。如果结论为true/false,则必须有原因! - - - step5:必须以json格式回复, JSON的value内容我给你的提示,在实际输出的时候不需要带上: - --- - {"result":"结论", "originalContext":"笔录对应原话","reason":"原因"} - --- - 好了,现在可以回复了! - """; - - @GetMapping("runNoteCheck") - public void runNoteCheck() { - HashMap map = new HashMap<>(); - map.put("杨学明", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/受害人杨学明询问笔录.docx"); - map.put("朱文泽", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/受害人朱文泽询问笔录.docx"); - map.put("陈恩明", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人陈恩明讯问笔录.docx"); - map.put("武桂清1", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人武桂清讯问笔录1.docx"); - map.put("武桂清2", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人武桂清讯问笔录2.docx"); - for (Map.Entry entry : map.entrySet()) { - String context = WordReadUtil.readWord(entry.getValue()); - List actionUserNameList = new ArrayList<>(); - actionUserNameList.add("陈恩明"); - actionUserNameList.add("武桂清"); - List victimUserNameList = new ArrayList<>(); - victimUserNameList.add("漫旭昌"); - victimUserNameList.add("杨学明"); - victimUserNameList.add("朱文泽"); - var list = modelMetricService.list(); - for (ModelMetric modelMetric : list) { - // 没有跑过的,才继续跑 - Long count = noteCheckRecordService.lambdaQuery().eq(NoteCheckRecord::getPersonName, entry.getKey()).eq(NoteCheckRecord::getMetricCode, modelMetric.getMetricCode()).count(); - if (count < 1) { - Map param = new HashMap<>(); - param.put("actionUserNameList", CollUtil.join(actionUserNameList, ";")); - param.put("victimUserNameList", CollUtil.join(victimUserNameList, ";")); - param.put("witnessNameList", "无"); - param.put("noteUserName", entry.getKey()); - param.put("context", context); - param.put("metricName", modelMetric.getMetricName()); - param.put("metricTrueDesc", modelMetric.getMetricTrueDesc()); - param.put("metricFalseDesc", modelMetric.getMetricFalseDesc()); - String format = StrUtil.format(template, param); - Message message = new UserMessage(format); - Prompt prompt = new Prompt(List.of(new SystemMessage("所有的回复以简体中文回答。请以step by step的方式进行。step1:理解笔录设计人员的身份信息;step2:根据笔录的内容分析案件之间的逻辑关系和关联;step3:根据给定的指标,提取出来可能涉及的笔录内容文本;step4:根据该笔录内容判断给定的指标是否满足。step5:根据要求的给定格式进行回复"), message)); - log.info("prompt:{}", prompt); - ChatResponse call = chatClient.call(prompt); - Generation result = call.getResult(); - String content = result.getOutput().getContent(); - log.info(content); - MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class); - NoteCheckRecord noteCheckRecord = new NoteCheckRecord(); - noteCheckRecord.setPersonName(entry.getKey()); - noteCheckRecord.setNoteName(FileUtil.getName(entry.getValue())); - noteCheckRecord.setType(entry.getKey().contains("询问") ? "询问" : "讯问"); - noteCheckRecord.setMetricCode(modelMetric.getMetricCode()); - noteCheckRecord.setMetricName(modelMetric.getMetricName()); - noteCheckRecord.setResult(metricResultDTO.getResult()); - noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext()); - noteCheckRecord.setReason(metricResultDTO.getReason()); - noteCheckRecordService.save(noteCheckRecord); - } - - } - } - - } - - /** - * 从word中读取笔录 - */ - - @GetMapping("runCheck") - public void runCheck() { - - -// List metricCodeList = ListUtil.list(false, "RZ010", "RZ019", "RZ020", "RZ022"); - // 行为人 - List actionUserNameList = ListUtil.list(false, "裴金禄"); - // 犯罪嫌疑人 - List suspectUserNameList = ListUtil.list(false, "裴金禄", "景涛", "李世怀", "万学宝"); - // 受害人 - List victimUserNameList = ListUtil.list(false, "董金才", "吕加国", "吕志仓"); - // 证人 - List witnessNameList = ListUtil.list(false, "白鹏", "丁建华", "雷建贵", "雷建明", "李泽懿", "王存良", "王开阔", "吴尚军", "杨正福", "叶魁伍", "赵景宝"); - // 获取目录下的所有笔录信息 - List files = FileUtil.loopFiles("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/"); - for (File file : files) { - // 只跑裴金禄的笔录 - log.info("开始分析:{}的笔录", file.getName()); - String context = WordReadUtil.readWord(file.getPath()); -// List list = modelMetricService.lambdaQuery().in(ModelMetric::getMetricCode, metricCodeList).list(); - List list = modelMetricService.list(); - for (ModelMetric modelMetric : list) { - - Map param = new HashMap<>(); - param.put("actionUserNameList", CollUtil.join(actionUserNameList, ";")); - param.put("suspectUserNameList", CollUtil.join(suspectUserNameList, ";")); - param.put("victimUserNameList", CollUtil.join(victimUserNameList, ";")); - param.put("witnessNameList", CollUtil.join(witnessNameList, ";")); - param.put("context", context); - param.put("metricName", modelMetric.getMetricName()); - param.put("metricTrueDesc", modelMetric.getMetricTrueDesc()); - param.put("metricFalseDesc", modelMetric.getMetricFalseDesc()); - String format = StrUtil.format(template, param); - List userMessageList = new ArrayList<>(); - log.info("开始提交分析,prompt长度为:{}", format.length()); - // 如果超过8000字,就进行截断,每次以6000字进行提交 - if (format.length() > 8000) { - log.info("分段提交"); - for (String s : StrUtil.split(format, 6000)) { - userMessageList.add(new UserMessage(s)); - userMessageList.add(new AssistantMessage("继续")); - } - userMessageList.remove(userMessageList.size() - 1); - } else { - userMessageList.add(new UserMessage(format)); - } - String systemPrompt = """ - 你是一个善于分析办案笔录的模型,能够根据办案笔录中的当事人的回答内容,实事求是的判断给定指标是否满足。注意,仅根据笔录进行分析,不要做笔录之外的推断。笔录内容可能比较长,可能分多次提交给你。 - """; - List messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt))); - messages.addAll(userMessageList); - Prompt prompt = new Prompt(messages); - RunCheckThread runCheck = new RunCheckThread("裴金禄第五次",chatClient, noteCheckRecordService, prompt, file.getName(), format, systemPrompt, modelMetric, 0); - RunCheckThreadPool.chatExecutor.submit(runCheck); - } - } - } - - - @GetMapping("testLongText") - public void testLongText() { - StringBuilder stringBuilder = new StringBuilder(); - - BufferedReader utf8Reader = FileUtil.getUtf8Reader("/Users/flevance/Desktop/宁夏审讯大模型/了不起的盖茨比 .txt"); - utf8Reader.lines().forEach(stringBuilder::append); - String template = """ - 我现在给你一个小说,请你解析小说的内容: - --- - {context} - --- - 现在请你分析小说内容,讲讲第8章讲了什么内容,清晰的描述出来.请以中文进行回答.并以json的形式进行输出 - """; - String systemPrompt = """ - 你是一个善于归纳分析的大模型,我现在需要你来做小说内容提取。小说内容可能比较长,可能分多次提交给你。所有的回答都以中文的形式进行回答. - """; - Map param = new HashMap<>(); - log.info("大小是:{}", stringBuilder.length()); - param.put("context", stringBuilder.toString()); - String format = StrUtil.format(template, param); - List userMessageList = new ArrayList<>(); - for (String s : StrUtil.split(format, 6000)) { - userMessageList.add(new UserMessage(s)); - userMessageList.add(new AssistantMessage("继续")); - } - userMessageList.remove(userMessageList.size() - 1); - List messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt))); - messages.addAll(userMessageList); - Prompt prompt = new Prompt(messages); - - ChatResponse call = chatClient.call(prompt); - Generation result = call.getResult(); - - String content = result.getOutput().getContent(); - log.info("分析的结果是:{}", content); - - } -} - - diff --git a/src/main/java/com/supervision/springaidemo/controller/ExampleChatController.java b/src/main/java/com/supervision/springaidemo/controller/ExampleChatController.java index c0c1f97..b4a5eea 100644 --- a/src/main/java/com/supervision/springaidemo/controller/ExampleChatController.java +++ b/src/main/java/com/supervision/springaidemo/controller/ExampleChatController.java @@ -8,7 +8,7 @@ import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.service.ModelRecordTypeService; import com.supervision.springaidemo.service.ModelMetricService; -import com.supervision.police.service.RecordService; +import com.supervision.police.service.NoteRecordSplitService; import com.supervision.police.service.NoteCheckRecordService; import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.WordReadUtil; @@ -58,7 +58,7 @@ public class ExampleChatController { } @Autowired - private RecordService recordService; + private NoteRecordSplitService noteRecordSplitService; @Autowired private ModelRecordTypeService modelRecordTypeService; @@ -133,7 +133,7 @@ public class ExampleChatController { //保存笔录 noteRecord.setRecordType(type); - recordService.save(noteRecord); + noteRecordSplitService.save(noteRecord); ModelRecordType exist = modelRecordTypeService.queryByName(type); if (exist == null) { @@ -169,8 +169,8 @@ public class ExampleChatController { // } // messages.addAll(userMessageList); // -// RunCheckThread runCheck = new RunCheckThread("裴金禄尝试正则来做", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0); -// RunCheckThreadPool.chatExecutor.submit(runCheck); +// TripleExtractThread runCheck = new TripleExtractThread("裴金禄尝试正则来做", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0); +// TripleExtractThreadPool.chatExecutor.submit(runCheck); // } // } @@ -231,7 +231,7 @@ public class ExampleChatController { @GetMapping("test1") public void test2(@Param("id") String id) { - NoteRecordSplit noteRecord = recordService.getById(id); + NoteRecordSplit noteRecord = noteRecordSplitService.getById(id); String question = noteRecord.getQuestion(); String answer = noteRecord.getAnswer(); String test = "请从以下对话中提取所有关于" + noteRecord.getRecordType() + "的所有三元组"; diff --git a/src/main/java/com/supervision/springaidemo/controller/NewTestController.java b/src/main/java/com/supervision/springaidemo/controller/NewTestController.java deleted file mode 100644 index 0c9bf62..0000000 --- a/src/main/java/com/supervision/springaidemo/controller/NewTestController.java +++ /dev/null @@ -1,106 +0,0 @@ -package com.supervision.springaidemo.controller; - -import cn.hutool.core.io.FileUtil; -import cn.hutool.core.util.StrUtil; -import com.supervision.springaidemo.domain.ModelMetric; -import com.supervision.springaidemo.dto.QARecordNodeDTO; -import com.supervision.springaidemo.service.ModelMetricService; -import com.supervision.police.service.NoteCheckRecordService; -import com.supervision.springaidemo.thread.RunCheckThread; -import com.supervision.springaidemo.thread.RunCheckThreadPool; -import com.supervision.springaidemo.util.RecordRegexUtil; -import com.supervision.springaidemo.util.WordReadUtil; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -import java.io.File; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * 自己测试的方案,计划改成2阶段的形式 - * 第一阶段,首先根据提示词,从原文中获取笔录的内容(首先让大模型从原文中提取出来可能涉及的内容) - * 然后再尝试跑结果 - */ -@RestController -@Slf4j -@RequestMapping("newTestChat") -public class NewTestController { - - private final OllamaChatClient chatClient; - - @Autowired - private ModelMetricService modelMetricService; - - @Autowired - private NoteCheckRecordService noteCheckRecordService; - - @Autowired - public NewTestController(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } - - private static final String template = """ - 请根据以下关键词,从笔录中找到相关的原文段落: - --- - 关键词1:财物去向 - --- - 笔录内容如下: - --- - {context} - --- - 请判断笔录和关键词是否可能有关联,如有关联,则对这段话进行总结 - 必须以json数组格式回复 - --- - {"match":"true/false(相关回复true,不相关回复false)","summary":"如果有关联,则对该笔录进行总结"} - --- - 好了,现在可以回复了。不要胡编乱造!无中生有! - """; - - - @GetMapping("newTestChat") - public void newTestChat() { - // 只查入罪指标 - ModelMetric modelMetric = modelMetricService.lambdaQuery().eq(ModelMetric::getMetricCode, "RZ007").one(); - File file = FileUtil.file("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/裴金禄第一次.docx"); - String context = WordReadUtil.readWord(file.getPath()); - List qaList = RecordRegexUtil.recordRegex(context, "裴金禄"); - for (QARecordNodeDTO qaRecordNodeDTO : qaList) { - String systemPrompt = """ - 你是一个善于内容提取的模型,能够根据给定关键字从原文中找到可能匹配的原文段落。 - """; - List messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt))); - Map param = new HashMap<>(); - param.put("context", qaRecordNodeDTO.toString()); - String format = StrUtil.format(template, param); - List userMessageList = new ArrayList<>(); - if (format.length() > 8000) { - log.info("分段提交"); - for (String s : StrUtil.split(format, 6000)) { - userMessageList.add(new UserMessage(s)); - userMessageList.add(new AssistantMessage("继续")); - } - userMessageList.remove(userMessageList.size() - 1); - } else { - userMessageList.add(new UserMessage(format)); - } - messages.addAll(userMessageList); - - RunCheckThread runCheck = new RunCheckThread("尝试分段提取-第一段找相似", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 5); - RunCheckThreadPool.chatExecutor.submit(runCheck); - } - } - - -} diff --git a/src/main/java/com/supervision/springaidemo/controller/RzChatController.java b/src/main/java/com/supervision/springaidemo/controller/RzChatController.java deleted file mode 100644 index f16ebdc..0000000 --- a/src/main/java/com/supervision/springaidemo/controller/RzChatController.java +++ /dev/null @@ -1,118 +0,0 @@ -package com.supervision.springaidemo.controller; - -import cn.hutool.core.io.FileUtil; -import cn.hutool.core.map.MapUtil; -import cn.hutool.core.util.StrUtil; -import com.supervision.springaidemo.domain.ModelMetric; -import com.supervision.springaidemo.service.ModelMetricService; -import com.supervision.police.service.NoteCheckRecordService; -import com.supervision.springaidemo.thread.RunCheckThread; -import com.supervision.springaidemo.thread.RunCheckThreadPool; -import com.supervision.springaidemo.util.WordReadUtil; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; - -import java.io.File; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * 入罪指标controller - */ -@RestController -@Slf4j -@RequestMapping("rzChat") -public class RzChatController { - - private static final String step1Template = """ - 我们现在需要以step by step的方式进行笔录的指标分析工作,得到最终的结果并返回。 - - step1:分析笔录的内容; - 以下是笔录的内容,笔录中"问"是办案警官问,"答"是{noteUserName}回答: - --- - {context} - --- - - step2:现在给你指标以及指标的释义或例子: - 指标释义或例子及判断标准: - 如满足:{metricDetailTemplate},则为true; - 如果不满足::{metricDetailTemplate},则为false; - 如果笔录中,没有任何笔录内容涉及到该项指标,无法进行判断,则为empty。 - - step3:现在需要你根据上面提供的所有信息,尽可能实事求是完成判断: - 1.判断结论:true/false/empty - 2.得到结论的笔录原话:从笔录的对话中,得到该结论的原文(一定是摘抄的原文且为中文)。如果结论为true,则必须要有原文佐证! - 3.得到结论的原因:分析得出该结论的原因,需明确说明为什么得到该结论,需要实事求是且为中文回复。如果结论为true/false,则必须有原因! - - step4:必须以json格式回复, JSON的value内容我给你的提示,在实际输出的时候不需要带上: - --- - {"result":"结论", "originalContext":"笔录对应原话","reason":"原因"} - --- - 好了,现在可以回复了! - """; - - private final OllamaChatClient chatClient; - - @Autowired - private ModelMetricService modelMetricService; - - @Autowired - private NoteCheckRecordService noteCheckRecordService; - - @Autowired - public RzChatController(OllamaChatClient chatClient) { - this.chatClient = chatClient; - } - - @GetMapping("extract") - public void extract() throws InterruptedException { - List files = FileUtil.loopFiles("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/"); - for (File file : files) { - String context = WordReadUtil.readWord(file.getPath()); - // 只查入罪指标 - List list = modelMetricService.lambdaQuery().likeRight(ModelMetric::getMetricCode, "RZ").list(); - for (ModelMetric modelMetric : list) { - String systemPrompt = """ - 你是一个善于分析办案笔录的模型,能够根据办案笔录的回答内容,实事求是的判断给定指标是否满足。注意,仅根据笔录进行分析,不要做笔录之外的推断。笔录内容可能比较长,可能分多次提交给你。 - """; - List messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt))); - Map param = new HashMap<>(); - param.put("metricDetailTemplate", StrUtil.format(modelMetric.getMetricDetailTemplate(), MapUtil.of("action", "裴金禄"))); - param.put("noteUserName", "裴金禄"); - param.put("context", context); - String format = StrUtil.format(step1Template, param); - List userMessageList = new ArrayList<>(); - if (format.length() > 8000) { - log.info("分段提交"); - for (String s : StrUtil.split(format, 6000)) { - userMessageList.add(new UserMessage(s)); - userMessageList.add(new AssistantMessage("继续")); - } - userMessageList.remove(userMessageList.size() - 1); - } else { - userMessageList.add(new UserMessage(format)); - } - messages.addAll(userMessageList); - - RunCheckThread runCheck = new RunCheckThread("裴金禄尝试通过直接定义模板", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0); - RunCheckThreadPool.chatExecutor.submit(runCheck); - - - } - } - - - } - -} diff --git a/src/main/java/com/supervision/springaidemo/thread/RunCheckThread.java b/src/main/java/com/supervision/springaidemo/thread/RunCheckThread.java deleted file mode 100644 index a59f800..0000000 --- a/src/main/java/com/supervision/springaidemo/thread/RunCheckThread.java +++ /dev/null @@ -1,91 +0,0 @@ -package com.supervision.springaidemo.thread; - -import cn.hutool.core.util.StrUtil; -import cn.hutool.json.JSONUtil; -import com.supervision.springaidemo.domain.ModelMetric; -import com.supervision.springaidemo.domain.NoteCheckRecord; -import com.supervision.springaidemo.dto.MetricResultDTO; -import com.supervision.police.service.NoteCheckRecordService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.ollama.OllamaChatClient; -import org.springframework.util.StopWatch; - -@Slf4j -public class RunCheckThread implements Runnable { - - private final String caseName; - - private final OllamaChatClient chatClient; - - private final NoteCheckRecordService noteCheckRecordService; - - private final Prompt prompt; - - private final String fileName; - - private final String format; - - private final String systemPrompt; - - private final ModelMetric modelMetric; - - - private Integer count; - - public RunCheckThread(String caseName, OllamaChatClient chatClient, NoteCheckRecordService noteCheckRecordService, Prompt prompt, String fileName, String format, String systemPrompt, ModelMetric modelMetric, Integer count) { - this.caseName = caseName; - this.chatClient = chatClient; - this.noteCheckRecordService = noteCheckRecordService; - this.prompt = prompt; - this.fileName = fileName; - this.format = format; - this.systemPrompt = systemPrompt; - this.modelMetric = modelMetric; - this.count = count; - } - - @Override - public void run() { - try { - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - log.info("开始分析:{}",fileName); - ChatResponse call = chatClient.call(prompt); - stopWatch.stop(); - log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - Generation result = call.getResult(); - - String content = result.getOutput().getContent(); - log.info("分析的结果是:{}", content); - MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class); - // 如果为空,则再跑一次,最多跑5次 - if (StrUtil.isBlank(metricResultDTO.getResult())) { - if (count > 5) { - log.info("{}的{}结果超过5次,不再继续跑了", fileName, modelMetric); - } else { - log.info("{}的{}结果为空,当前跑了{}次,重新提交,再跑一次", fileName, modelMetric, count); - Integer newCount = count++; - RunCheckThread runCheck = new RunCheckThread(caseName, chatClient, noteCheckRecordService, prompt, fileName, format, systemPrompt, modelMetric, newCount); - RunCheckThreadPool.chatExecutor.submit(runCheck); - } - } else { - NoteCheckRecord noteCheckRecord = new NoteCheckRecord(); - noteCheckRecord.setCaseName(caseName); - noteCheckRecord.setNoteName(fileName); - noteCheckRecord.setMetricCode(modelMetric.getMetricCode()); - noteCheckRecord.setMetricName(modelMetric.getMetricName()); - noteCheckRecord.setSystemPrompt(systemPrompt); - noteCheckRecord.setPrompt(format); - noteCheckRecord.setResult(metricResultDTO.getResult()); - noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext()); - noteCheckRecord.setReason(metricResultDTO.getReason()); - noteCheckRecordService.save(noteCheckRecord); - } - } catch (Exception e) { - log.error("出现错误", e); - } - } -} diff --git a/src/main/java/com/supervision/thread/TripleExtractThread.java b/src/main/java/com/supervision/thread/TripleExtractThread.java new file mode 100644 index 0000000..7b1c038 --- /dev/null +++ b/src/main/java/com/supervision/thread/TripleExtractThread.java @@ -0,0 +1,88 @@ +package com.supervision.thread; + +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; +import com.supervision.police.domain.TripleInfo; +import com.supervision.springaidemo.domain.ModelMetric; +import com.supervision.springaidemo.domain.NoteCheckRecord; +import com.supervision.springaidemo.dto.MetricResultDTO; +import com.supervision.police.service.NoteCheckRecordService; +import lombok.extern.slf4j.Slf4j; +import org.json.JSONArray; +import org.json.JSONObject; +import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.ollama.OllamaChatClient; +import org.springframework.util.StopWatch; + +import java.time.LocalDateTime; +import java.util.concurrent.Callable; + +@Slf4j +public class TripleExtractThread implements Callable { + + private final OllamaChatClient chatClient; + + private final String prompt; + + private final String question; + + private final String answer; + + private final String recordSplitId; + + private final String caseId; + + private final String recordId; + + + public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) { + this.question = question; + this.chatClient = chatClient; + this.answer = answer; + this.prompt = prompt; + this.recordSplitId = recordSplitId; + this.caseId = caseId; + this.recordId = recordId; + } + + @Override + public TripleInfo call() { + try { + StopWatch stopWatch = new StopWatch(); + // 分析三元组 + Prompt ask = new Prompt(new UserMessage(prompt + question + answer)); + stopWatch.start(); + log.info("开始分析:"); + ChatResponse call = chatClient.call(ask); + stopWatch.stop(); + log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); + String content = call.getResult().getOutput().getContent(); + log.info("分析的结果是:{}", content); + // 获取从提示词中提取到的三元组信息 + JSONObject jsonObject = new JSONObject(content); + JSONArray threeInfo = jsonObject.getJSONArray("result"); + for (int i = 0; i < threeInfo.length(); i++) { + JSONObject object = threeInfo.getJSONObject(i); + String startNodeType = object.getString("startNodeType"); + String entity = object.getString("entity"); + String endNodeType = object.getString("endNodeType"); + String property = object.getString("property"); + String value = object.getString("value"); + // 去空,如果存在任何的空值,则忽略 + if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) { + continue; + } + // 构建三元组信息 + return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType); + } + } catch (Exception e) { + log.error("提取三元组出现错误", e); + } + return null; + } + + +} diff --git a/src/main/java/com/supervision/springaidemo/thread/RunCheckThreadPool.java b/src/main/java/com/supervision/thread/TripleExtractThreadPool.java similarity index 53% rename from src/main/java/com/supervision/springaidemo/thread/RunCheckThreadPool.java rename to src/main/java/com/supervision/thread/TripleExtractThreadPool.java index e7231b9..b75ae5f 100644 --- a/src/main/java/com/supervision/springaidemo/thread/RunCheckThreadPool.java +++ b/src/main/java/com/supervision/thread/TripleExtractThreadPool.java @@ -1,10 +1,10 @@ -package com.supervision.springaidemo.thread; +package com.supervision.thread; import cn.hutool.core.thread.ThreadUtil; import java.util.concurrent.ExecutorService; -public class RunCheckThreadPool { +public class TripleExtractThreadPool { - public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "chat", false); + public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "tripleExtract", false); } diff --git a/src/main/resources/mapper/NotePromptMapper.xml b/src/main/resources/mapper/NotePromptMapper.xml index 0d722f1..493ecbb 100644 --- a/src/main/resources/mapper/NotePromptMapper.xml +++ b/src/main/resources/mapper/NotePromptMapper.xml @@ -3,8 +3,4 @@ PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> - \ No newline at end of file diff --git a/src/main/resources/mapper/TripleInfoMapper.xml b/src/main/resources/mapper/TripleInfoMapper.xml index f190f91..e38dc44 100644 --- a/src/main/resources/mapper/TripleInfoMapper.xml +++ b/src/main/resources/mapper/TripleInfoMapper.xml @@ -3,17 +3,5 @@ PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> - + \ No newline at end of file