提交代码优化

topo_dev
liu 9 months ago
parent f0ff638a94
commit 18d8c0f16c

@ -3,7 +3,9 @@ package com.supervision;
import org.mybatis.spring.annotation.MapperScan; import org.mybatis.spring.annotation.MapperScan;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
@EnableAsync
@MapperScan(basePackages = {"com.supervision.**.mapper"}) @MapperScan(basePackages = {"com.supervision.**.mapper"})
@SpringBootApplication(scanBasePackages = {"com.supervision.**"}) @SpringBootApplication(scanBasePackages = {"com.supervision.**"})
public class SpringAiDemoApplication { public class SpringAiDemoApplication {

@ -89,10 +89,10 @@ public class Neo4jController {
return neo4jService.getNode(picType, caseId); return neo4jService.getNode(picType, caseId);
} }
@GetMapping("/test") // @GetMapping("/test")
public R<?> test() { // public R<?> test() {
return neo4jService.test(); // return neo4jService.test();
} // }
@ApiOperation("构建抽象图谱") @ApiOperation("构建抽象图谱")
@GetMapping("createAbstractGraph") @GetMapping("createAbstractGraph")
@ -106,4 +106,10 @@ public class Neo4jController {
neo4jService.deleteAbstractGraph(); neo4jService.deleteAbstractGraph();
} }
@ApiOperation("mock测试数据")
@GetMapping("mockTestGraph")
public void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId, String caseId) {
neo4jService.mockTestGraph(path, sheetName, recordId, recordSplitId, caseId);
}
} }

@ -18,9 +18,9 @@ public class CaseNode {
private String nodeType; private String nodeType;
private String recordId; private String recordSplitId;
private String recordsId; private String recordId;
private String caseId; private String caseId;
@ -37,20 +37,20 @@ public class CaseNode {
this.name = name; this.name = name;
} }
public CaseNode(String name, String nodeType, String recordId, String recordsId, String caseId, String picType) { public CaseNode(String name, String nodeType, String recordSplitId, String recordId, String caseId, String picType) {
this.name = name; this.name = name;
this.nodeType = nodeType; this.nodeType = nodeType;
this.recordSplitId = recordSplitId;
this.recordId = recordId; this.recordId = recordId;
this.recordsId = recordsId;
this.caseId = caseId; this.caseId = caseId;
this.picType = picType; this.picType = picType;
} }
public CaseNode(Long id, String name, String nodeType, String recordId, String caseId, String picType) { public CaseNode(Long id, String name, String nodeType, String recordSplitId, String caseId, String picType) {
this.id = id; this.id = id;
this.name = name; this.name = name;
this.nodeType = nodeType; this.nodeType = nodeType;
this.recordId = recordId; this.recordSplitId = recordSplitId;
this.caseId = caseId; this.caseId = caseId;
this.picType = picType; this.picType = picType;
} }

@ -0,0 +1,19 @@
package com.supervision.neo4j.dto;
import lombok.Data;
@Data
public class WebRelDTO {
private long source;
private long target;
private String name;
public WebRelDTO(long source, long target, String name) {
this.source = source;
this.target = target;
this.name = name;
}
}

@ -21,6 +21,7 @@ public interface Neo4jService {
CaseNode findById(Long id); CaseNode findById(Long id);
List<CaseNode> findByName(String caseId, String recordId, String nodeType, String name, String picType); List<CaseNode> findByName(String caseId, String recordId, String nodeType, String name, String picType);
CaseNode findOneByName(String caseId, String recordId, String nodeType, String name, String picType); CaseNode findOneByName(String caseId, String recordId, String nodeType, String name, String picType);
Rel findRelation(Rel rel); Rel findRelation(Rel rel);
@ -29,9 +30,11 @@ public interface Neo4jService {
R<?> getNode(String picType, String caseId); R<?> getNode(String picType, String caseId);
R<?> test(); // R<?> test();
void deleteAbstractGraph(); void deleteAbstractGraph();
void createAbstractGraph(String path,String sheetName); void createAbstractGraph(String path, String sheetName);
void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId,String caseId);
} }

@ -7,6 +7,7 @@ import com.supervision.common.domain.R;
import com.supervision.common.utils.StringUtils; import com.supervision.common.utils.StringUtils;
import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.CaseNode;
import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.domain.Rel;
import com.supervision.neo4j.dto.WebRelDTO;
import com.supervision.neo4j.service.Neo4jService; import com.supervision.neo4j.service.Neo4jService;
import com.supervision.neo4j.utils.Neo4jUtils; import com.supervision.neo4j.utils.Neo4jUtils;
import lombok.Data; import lombok.Data;
@ -39,7 +40,7 @@ public class Neo4jServiceImpl implements Neo4jService {
if (StringUtils.isEmpty(caseNode.getName()) || StringUtils.isEmpty(caseNode.getNodeType())) { if (StringUtils.isEmpty(caseNode.getName()) || StringUtils.isEmpty(caseNode.getNodeType())) {
throw new RuntimeException("未传节点名称或节点类型或图谱类型!"); throw new RuntimeException("未传节点名称或节点类型或图谱类型!");
} }
List<CaseNode> byName = findByName(caseNode.getCaseId(), caseNode.getRecordsId(), caseNode.getNodeType(), caseNode.getName(), caseNode.getPicType()); List<CaseNode> byName = findByName(caseNode.getCaseId(), caseNode.getRecordId(), caseNode.getNodeType(), caseNode.getName(), caseNode.getPicType());
if (byName != null && !byName.isEmpty()) { if (byName != null && !byName.isEmpty()) {
throw new RuntimeException("名称已存在!"); throw new RuntimeException("名称已存在!");
} }
@ -50,14 +51,14 @@ public class Neo4jServiceImpl implements Neo4jService {
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
cql.append("CREATE (n:").append(caseNode.getNodeType()).append("{name:$name"); cql.append("CREATE (n:").append(caseNode.getNodeType()).append("{name:$name");
params.put("name", caseNode.getName()); params.put("name", caseNode.getName());
if (StringUtils.isNotEmpty(caseNode.getRecordId())) {
cql.append(", recordSplitId:$recordSplitId");
params.put("recordSplitId", caseNode.getRecordSplitId());
}
if (StringUtils.isNotEmpty(caseNode.getRecordId())) { if (StringUtils.isNotEmpty(caseNode.getRecordId())) {
cql.append(", recordId:$recordId"); cql.append(", recordId:$recordId");
params.put("recordId", caseNode.getRecordId()); params.put("recordId", caseNode.getRecordId());
} }
if (StringUtils.isNotEmpty(caseNode.getRecordsId())) {
cql.append(", recordsId:$recordsId");
params.put("recordsId", caseNode.getRecordsId());
}
if (StringUtils.isNotEmpty(caseNode.getCaseId())) { if (StringUtils.isNotEmpty(caseNode.getCaseId())) {
cql.append(", caseId:$caseId"); cql.append(", caseId:$caseId");
params.put("caseId", caseNode.getCaseId()); params.put("caseId", caseNode.getCaseId());
@ -134,7 +135,7 @@ public class Neo4jServiceImpl implements Neo4jService {
} }
@Override @Override
public List<CaseNode> findByName(String caseId, String recordsId, String nodeType, String name, String picType) { public List<CaseNode> findByName(String caseId, String recordId, String nodeType, String name, String picType) {
List<CaseNode> list = new ArrayList<>(); List<CaseNode> list = new ArrayList<>();
try { try {
Session session = driver.session(); Session session = driver.session();
@ -149,9 +150,9 @@ public class Neo4jServiceImpl implements Neo4jService {
cql.append(" and n.caseId = "); cql.append(" and n.caseId = ");
cql.append(caseId); cql.append(caseId);
} }
if (StringUtils.isNotEmpty(recordsId)) { if (StringUtils.isNotEmpty(recordId)) {
cql.append(" and n.recordsId = "); cql.append(" and n.recordId = ");
cql.append(recordsId); cql.append(recordId);
} }
if (StringUtils.isNotEmpty(name)) { if (StringUtils.isNotEmpty(name)) {
cql.append(" and n.name = '"); cql.append(" and n.name = '");
@ -172,7 +173,7 @@ public class Neo4jServiceImpl implements Neo4jService {
} }
@Override @Override
public CaseNode findOneByName(String caseId, String recordsId, String nodeType, String name, String picType) { public CaseNode findOneByName(String caseId, String recordId, String nodeType, String name, String picType) {
CaseNode node = null; CaseNode node = null;
try { try {
Session session = driver.session(); Session session = driver.session();
@ -188,9 +189,9 @@ public class Neo4jServiceImpl implements Neo4jService {
cql.append(" and n.caseId = $caseId"); cql.append(" and n.caseId = $caseId");
params.put("caseId", caseId); params.put("caseId", caseId);
} }
if (StringUtils.isNotEmpty(recordsId)) { if (StringUtils.isNotEmpty(recordId)) {
cql.append(" and n.recordsId = $recordsId"); cql.append(" and n.recordId = $recordId");
params.put("recordsId", recordsId); params.put("recordId", recordId);
} }
if (StringUtils.isNotEmpty(name)) { if (StringUtils.isNotEmpty(name)) {
cql.append(" and n.name = $name"); cql.append(" and n.name = $name");
@ -255,8 +256,8 @@ public class Neo4jServiceImpl implements Neo4jService {
@Override @Override
public R<?> getNode(String picType, String caseId) { public R<?> getNode(String picType, String caseId) {
Map<String, Object> map = new HashMap<>(); Map<String, Object> map = new HashMap<>();
List<Rel> list = new ArrayList<>(); List<WebRelDTO> list = new ArrayList<>();
List<Map<String, String>> nodes = new ArrayList<>(); List<Map<String, Object>> nodes = new ArrayList<>();
try { try {
Session session = driver.session(); Session session = driver.session();
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();
@ -266,23 +267,23 @@ public class Neo4jServiceImpl implements Neo4jService {
" RETURN id(rel) as id, n.name as source, id(n) as sourceId, type(rel) as name, r.name as target, id(r) as targetId", params); " RETURN id(rel) as id, n.name as source, id(n) as sourceId, type(rel) as name, r.name as target, id(r) as targetId", params);
while (run.hasNext()) { while (run.hasNext()) {
Record record = run.next(); Record record = run.next();
long id = record.get("id").asLong(); //long id = record.get("id").asLong();
String source = record.get("source").asString(); //String source = record.get("source").asString();
long sourceId = record.get("sourceId").asLong(); long sourceId = record.get("sourceId").asLong();
String name = record.get("name").asString(); String name = record.get("name").asString();
String target = record.get("target").asString(); //String target = record.get("target").asString();
long targetId = record.get("targetId").asLong(); long targetId = record.get("targetId").asLong();
list.add(new Rel(id, source, sourceId, name, target, targetId)); list.add(new WebRelDTO(sourceId, targetId, name));
} }
Result node = session.run("MATCH (n) where n.picType = $picType and n.caseId = $caseId RETURN id(n) as id, n.name as name", params); Result node = session.run("MATCH (n) where n.picType = $picType and n.caseId = $caseId RETURN id(n) as id, n.name as name", params);
while (node.hasNext()) { while (node.hasNext()) {
Record record = node.next(); Record record = node.next();
String name = record.get("name").asString(); String name = record.get("name").asString();
long idlong = record.get("id").asLong(); long idLong = record.get("id").asLong();
Map<String, String> nodeMap = new HashMap<>(); Map<String, Object> nodeMap = new HashMap<>();
nodeMap.put("name", name); nodeMap.put("name", name);
nodeMap.put("entityName", name); nodeMap.put("entityName", name);
// nodeMap.put("id", idlong + ""); nodeMap.put("id", idLong);
nodes.add(nodeMap); nodes.add(nodeMap);
} }
} catch (Exception e) { } catch (Exception e) {
@ -293,28 +294,29 @@ public class Neo4jServiceImpl implements Neo4jService {
return R.ok(map); return R.ok(map);
} }
@Override
public R<?> test() {
Session session = driver.session();
Map<String, Object> params = new HashMap<>();
params.put("lawActor", "行为人");
params.put("lawParty", "aaaaaa");
Result run = session.run("MATCH (m:LawActor), (n:FictionalOrgan) where m.name=$lawActor OPTIONAL MATCH (m)-[r:`冒充`]->(n) RETURN id(m) as startId, id(n) as endId, id(r) as relId, m.recordId as recordId, m.recordsId as recordsId", params);
while (run.hasNext()) {
Record record = run.next();
String id = Neo4jUtils.valueTransportString(record.get("startId")); // @Override
String endId = Neo4jUtils.valueTransportString(record.get("endId")); // public R<?> test() {
String relId = Neo4jUtils.valueTransportString(record.get("relId")); // Session session = driver.session();
System.out.println("************" + id); // Map<String, Object> params = new HashMap<>();
System.out.println("************" + endId); // params.put("lawActor", "行为人");
System.out.println("************" + relId); // params.put("lawParty", "aaaaaa");
} // Result run = session.run("MATCH (m:LawActor), (n:FictionalOrgan) where m.name=$lawActor OPTIONAL MATCH (m)-[r:`冒充`]->(n) RETURN id(m) as startId, id(n) as endId, id(r) as relId, m.recordId as recordId, m.recordsId as recordsId", params);
return R.ok("222"); // while (run.hasNext()) {
} // Record record = run.next();
//
// String id = Neo4jUtils.valueTransportString(record.get("startId"));
// String endId = Neo4jUtils.valueTransportString(record.get("endId"));
// String relId = Neo4jUtils.valueTransportString(record.get("relId"));
// System.out.println("************" + id);
// System.out.println("************" + endId);
// System.out.println("************" + relId);
// }
// return R.ok("222");
// }
@Override @Override
public void createAbstractGraph(String path,String sheetName) { public void createAbstractGraph(String path, String sheetName) {
// 首先从数据库中读到数据 // 首先从数据库中读到数据
ExcelReader reader = ExcelUtil.getReader(path, sheetName); ExcelReader reader = ExcelUtil.getReader(path, sheetName);
List<AbstractGraphExcelHeader> abstractGraphExcelHeaders = reader.readAll(AbstractGraphExcelHeader.class); List<AbstractGraphExcelHeader> abstractGraphExcelHeaders = reader.readAll(AbstractGraphExcelHeader.class);
@ -382,4 +384,45 @@ public class Neo4jServiceImpl implements Neo4jService {
private String relation; private String relation;
private String to; private String to;
} }
@Data
private static class MockDataGraphExcelHeader {
private String fromType;
private String from;
private String relation;
private String to;
private String toType;
}
@Override
public void mockTestGraph(String path, String sheetName, String recordId, String recordSplitId, String caseId) {
// 首先从数据库中读到数据
ExcelReader reader = ExcelUtil.getReader(path, sheetName);
List<MockDataGraphExcelHeader> mockDataGraphExcelList = reader.readAll(MockDataGraphExcelHeader.class);
Map<String, CaseNode> nodeMap = new HashMap<>();
Map<String, Rel> relMap = new HashMap<>();
for (MockDataGraphExcelHeader mockData : mockDataGraphExcelList) {
// from
if (!nodeMap.containsKey(mockData.getFrom())) {
CaseNode caseNode = new CaseNode(mockData.getFrom(), mockData.getFromType(), recordSplitId, recordId, caseId, "1");
log.info("点:{}插入成功", mockData.getFrom());
CaseNode save = save(caseNode);
nodeMap.put(mockData.getFrom(), save);
}
// to
if (!nodeMap.containsKey(mockData.getTo())) {
CaseNode caseNode = new CaseNode(mockData.getTo(), mockData.getToType(), recordSplitId, recordId, caseId, "1");
CaseNode save = save(caseNode);
log.info("点:{}插入成功", mockData.getTo());
nodeMap.put(mockData.getTo(), save);
}
// relation
if (!relMap.containsKey(mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo())) {
Rel rel = new Rel(nodeMap.get(mockData.getFrom()).getId(), mockData.getRelation(), nodeMap.get(mockData.getTo()).getId(), "1");
saveRelation(rel);
log.info("关系:{}插入成功", (mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo()));
relMap.put(mockData.getFrom() + "->" + mockData.getRelation() + "->" + mockData.getTo(), rel);
}
}
}
} }

@ -7,7 +7,7 @@ import com.supervision.police.domain.NoteRecord;
import com.supervision.police.domain.TripleInfo; import com.supervision.police.domain.TripleInfo;
import com.supervision.police.dto.ListDTO; import com.supervision.police.dto.ListDTO;
import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.RecordService; import com.supervision.police.service.NoteRecordSplitService;
import io.swagger.annotations.ApiOperation; import io.swagger.annotations.ApiOperation;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -27,7 +27,7 @@ public class RecordController {
public final ModelRecordTypeService modelRecordTypeService; public final ModelRecordTypeService modelRecordTypeService;
public final RecordService recordService; public final NoteRecordSplitService noteRecordSplitService;
/** /**
* *
@ -107,7 +107,7 @@ public class RecordController {
@PostMapping("/addOrUpdRecords") @PostMapping("/addOrUpdRecords")
public R<String> uploadRecords(NoteRecord records, public R<String> uploadRecords(NoteRecord records,
@RequestPart("file") List<MultipartFile> fileList) throws IOException { @RequestPart("file") List<MultipartFile> fileList) throws IOException {
return R.ok(recordService.uploadRecords(records, fileList)); return R.ok(noteRecordSplitService.uploadRecords(records, fileList));
} }
/** /**
@ -122,7 +122,7 @@ public class RecordController {
public R<Map<String, Object>> queryRecords(@RequestBody NoteRecord noteRecord, public R<Map<String, Object>> queryRecords(@RequestBody NoteRecord noteRecord,
@RequestParam(required = false, defaultValue = "1") Integer page, @RequestParam(required = false, defaultValue = "1") Integer page,
@RequestParam(required = false, defaultValue = "20") Integer size) { @RequestParam(required = false, defaultValue = "20") Integer size) {
return R.ok(recordService.queryRecords(noteRecord, page, size)); return R.ok(noteRecordSplitService.queryRecords(noteRecord, page, size));
} }
/** /**
@ -133,7 +133,7 @@ public class RecordController {
*/ */
@PostMapping("/delRecords") @PostMapping("/delRecords")
public R<?> delRecords(@RequestParam String id) { public R<?> delRecords(@RequestParam String id) {
recordService.delRecords(id); noteRecordSplitService.delRecords(id);
return R.ok(); return R.ok();
} }

@ -36,19 +36,14 @@ public class TripleInfo implements Serializable {
*/ */
private String relation; private String relation;
/** private String caseId;
* id
*/
private String noteRecordId;
@TableField(exist = false) private String recordId;
private String noteRecordsId;
/** /**
* id * id
*/ */
@TableField(exist = false) private String recordSplitId;
private String caseId;
/** /**
* *
@ -74,7 +69,7 @@ public class TripleInfo implements Serializable {
* *
*/ */
@TableField(fill = FieldFill.INSERT_UPDATE) @TableField(fill = FieldFill.INSERT_UPDATE)
@JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
private LocalDateTime createTime; private LocalDateTime createTime;
/** /**
@ -87,7 +82,7 @@ public class TripleInfo implements Serializable {
* *
*/ */
@TableField(fill = FieldFill.INSERT_UPDATE) @TableField(fill = FieldFill.INSERT_UPDATE)
@JsonFormat(pattern="yyyy-MM-dd HH:mm:ss",timezone = "GMT+8") @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
private LocalDateTime updateTime; private LocalDateTime updateTime;
@TableField(exist = false) @TableField(exist = false)
@ -97,11 +92,13 @@ public class TripleInfo implements Serializable {
} }
// todo // todo
public TripleInfo(String startNode, String endNode, String relation, String noteRecordId, LocalDateTime createTime, String startNodeType, String endNodeType) { public TripleInfo(String startNode, String endNode, String relation,String caseId, String recordId, String recordSplitId, LocalDateTime createTime, String startNodeType, String endNodeType) {
this.startNode = startNode; this.startNode = startNode;
this.endNode = endNode; this.endNode = endNode;
this.relation = relation; this.relation = relation;
this.noteRecordId = noteRecordId; this.caseId = caseId;
this.recordId = recordId;
this.recordSplitId = recordSplitId;
this.createTime = createTime; this.createTime = createTime;
this.startNodeType = startNodeType; this.startNodeType = startNodeType;
this.endNodeType = endNodeType; this.endNodeType = endNodeType;

@ -8,6 +8,5 @@ import java.util.List;
public interface NotePromptMapper extends BaseMapper<NotePrompt> { public interface NotePromptMapper extends BaseMapper<NotePrompt> {
List<NotePrompt> queryPrompt(@Param("typeId") String typeId);
} }

@ -8,6 +8,5 @@ import java.util.List;
public interface TripleInfoMapper extends BaseMapper<TripleInfo> { public interface TripleInfoMapper extends BaseMapper<TripleInfo> {
List<TripleInfo> selectByIds(@Param("ids") List<String> ids);
} }

@ -0,0 +1,6 @@
package com.supervision.police.service;
public interface ExtractTripleInfoService {
void extractTripleInfo(String caseId, String name, String recordId);
}

@ -1,7 +1,11 @@
package com.supervision.police.service; package com.supervision.police.service;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.springaidemo.domain.NoteCheckRecord; import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.baomidou.mybatisplus.extension.service.IService; import com.baomidou.mybatisplus.extension.service.IService;
import org.apache.ibatis.annotations.Param;
import java.util.List;
/** /**
* @author flevance * @author flevance
@ -10,4 +14,6 @@ import com.baomidou.mybatisplus.extension.service.IService;
*/ */
public interface NoteCheckRecordService extends IService<NoteCheckRecord> { public interface NoteCheckRecordService extends IService<NoteCheckRecord> {
} }

@ -0,0 +1,7 @@
package com.supervision.police.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.police.domain.NotePrompt;
public interface NotePromptService extends IService<NotePrompt> {
}

@ -9,7 +9,7 @@ import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
public interface RecordService extends IService<NoteRecordSplit> { public interface NoteRecordSplitService extends IService<NoteRecordSplit> {
String uploadRecords(NoteRecord records, List<MultipartFile> fileList) throws IOException; String uploadRecords(NoteRecord records, List<MultipartFile> fileList) throws IOException;

@ -0,0 +1,7 @@
package com.supervision.police.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.supervision.police.domain.TripleInfo;
public interface TripleInfoService extends IService<TripleInfo> {
}

@ -0,0 +1,100 @@
package com.supervision.police.service.impl;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.TripleInfoMapper;
import com.supervision.police.service.*;
import com.supervision.thread.TripleExtractThread;
import com.supervision.thread.TripleExtractThreadPool;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
@Slf4j
@Service
@RequiredArgsConstructor
public class ExtractTripleInfoServiceImpl implements ExtractTripleInfoService {
private final NoteRecordSplitMapper noteRecordSplitMapper;
private final NotePromptService notePromptService;
private final TripleInfoService tripleInfoService;
private final OllamaChatClient chatClient;
@Async
public void extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId);
List<TripleInfo> tripleInfos = new ArrayList<>();
List<Future<TripleInfo>> futures = new ArrayList<>();
// 对切分后的笔录进行遍历
for (NoteRecordSplit recordSplit : recordSplitList) {
// 根据笔录类型找到所有的提取三元组的提示词
List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, recordSplit.getRecordTypeId()).list();
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
continue;
}
try {
log.info("提交任务到线程池中进行三元组提取");
Future<TripleInfo> submit = TripleExtractThreadPool.chatExecutor.submit(new TripleExtractThread(chatClient, caseId, recordId, recordSplit.getId(), prompt.getPrompt(), recordSplit.getQuestion(), recordSplit.getAnswer()));
futures.add(submit);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
while (futures.size() > 0) {
Iterator<Future<TripleInfo>> iterator = futures.iterator();
while (iterator.hasNext()) {
Future<TripleInfo> future = iterator.next();
try {
// 如果提取到结果,且不为空,就进行保存
if (future.isDone()) {
TripleInfo tripleInfo = future.get();
if (tripleInfo != null) {
tripleInfos.add(tripleInfo);
}
iterator.remove();
}
} catch (Exception e) {
log.info("从线程中获取任务失败");
iterator.remove();
}
}
try {
log.info("检查一遍,休眠1s后继续检查");
Thread.sleep(1000);
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
// 首先清除
tripleInfoService.lambdaUpdate().eq(TripleInfo::getRecordId, recordId).remove();
// 首先要把这个笔录已经提取过的三元组记录删除掉,删除掉之后才可以重新提取
tripleInfoService.saveBatch(tripleInfos);
}
}

@ -1,6 +1,5 @@
package com.supervision.police.service.impl; package com.supervision.police.service.impl;
import cn.hutool.core.util.StrUtil;
import com.alibaba.druid.util.StringUtils; import com.alibaba.druid.util.StringUtils;
import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
@ -9,19 +8,12 @@ import com.supervision.common.domain.R;
import com.supervision.neo4j.domain.CaseNode; import com.supervision.neo4j.domain.CaseNode;
import com.supervision.neo4j.domain.Rel; import com.supervision.neo4j.domain.Rel;
import com.supervision.neo4j.service.Neo4jService; import com.supervision.neo4j.service.Neo4jService;
import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.*;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.NotePromptMapper; import com.supervision.police.service.*;
import com.supervision.police.mapper.TripleInfoMapper;
import com.supervision.police.service.ModelRecordTypeService;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
@ -30,8 +22,8 @@ import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch; import org.springframework.util.StopWatch;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
@Slf4j @Slf4j
@Service @Service
@ -42,19 +34,20 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
private final NoteRecordSplitMapper noteRecordSplitMapper; private final NoteRecordSplitMapper noteRecordSplitMapper;
private final NotePromptMapper notePromptMapper; private final NotePromptService notePromptService;
private final TripleInfoMapper tripleInfoMapper; private final TripleInfoService tripleInfoService;
private final Neo4jService neo4jService; private final Neo4jService neo4jService;
private final OllamaChatClient chatClient; private final OllamaChatClient chatClient;
private final CaseTaskRecordService caseTaskRecordService;
private final ExtractTripleInfoService extractTripleInfo;
@Override @Override
public List<ModelRecordType> queryType(String name, Integer page, Integer size) { public List<ModelRecordType> queryType(String name, Integer page, Integer size) {
// IPage<ModelRecordType> iPage = new Page<>(page, size);
// iPage = modelRecordTypeMapper.selectByName(iPage, name);
// return R.ok(IPages.buildDataMap(iPage));
List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name); List<ModelRecordType> list = modelRecordTypeMapper.selectByName(name);
for (ModelRecordType modelRecordType : list) { for (ModelRecordType modelRecordType : list) {
@ -62,11 +55,10 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
List<NoteRecordSplit> noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType()); List<NoteRecordSplit> noteRecords = noteRecordSplitMapper.selectByRecordType(modelRecordType.getRecordType());
modelRecordType.setRecords(noteRecords); modelRecordType.setRecords(noteRecords);
//提示词 //提示词
List<NotePrompt> prompts = notePromptMapper.queryPrompt(modelRecordType.getId()); List<NotePrompt> prompts = notePromptService.lambdaQuery().eq(NotePrompt::getTypeId, modelRecordType.getId()).list();
modelRecordType.setPrompts(prompts); modelRecordType.setPrompts(prompts);
} }
return list; return list;
// return R.ok(IPages.buildDataMap(iPage));
} }
@Override @Override
@ -95,12 +87,13 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public R<?> addOrUpdPrompt(NotePrompt prompt) { public R<?> addOrUpdPrompt(NotePrompt prompt) {
int i = 0; int i = 0;
boolean save;
if (StringUtils.isEmpty(prompt.getId())) { if (StringUtils.isEmpty(prompt.getId())) {
i = notePromptMapper.insert(prompt); save = notePromptService.save(prompt);
} else { } else {
i = notePromptMapper.updateById(prompt); save = notePromptService.updateById(prompt);
} }
if (i > 0) { if (save) {
return R.ok("保存成功"); return R.ok("保存成功");
} else { } else {
return R.fail("保存失败"); return R.fail("保存失败");
@ -110,8 +103,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public R<?> delPrompt(NotePrompt prompt) { public R<?> delPrompt(NotePrompt prompt) {
String id = prompt.getId(); String id = prompt.getId();
int i = notePromptMapper.deleteById(id); boolean removeById = notePromptService.removeById(id);
if (i > 0) { if (removeById) {
return R.ok("删除成功"); return R.ok("删除成功");
} else { } else {
return R.fail("删除失败"); return R.fail("删除失败");
@ -121,8 +114,52 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
@Override @Override
public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) { public List<TripleInfo> getThreeInfo(String caseId, String name, String recordId) {
// TODO 这里应该改成异步的形式,通过异步的形式来进行提取三元组信息,不能每次点击就跑一遍 boolean taskStatus = taskExtractStatusCheck(caseId, recordId);
return extractTripleInfo(caseId, name, recordId); // 如果校验结果为false,则说明需要进行提取三元组操作
if (!taskStatus) {
extractTripleInfo.extractTripleInfo(caseId, name, recordId);
}
// 这里进行查询
return tripleInfoService.lambdaQuery().eq(TripleInfo::getRecordId, recordId).list();
}
/**
* ,,,,
*/
private boolean taskExtractStatusCheck(String caseId, String recordId) {
// 首先查询是否存在任务,如果不存在,就新建
Optional<CaseTaskRecord> caseTaskRecordOpt = caseTaskRecordService.lambdaQuery().eq(CaseTaskRecord::getCaseId, caseId).eq(CaseTaskRecord::getRecordId, recordId).oneOpt();
if (caseTaskRecordOpt.isEmpty()) {
CaseTaskRecord newCaseTaskRecord = new CaseTaskRecord();
newCaseTaskRecord.setCaseId(caseId);
newCaseTaskRecord.setRecordId(recordId);
newCaseTaskRecord.setStatus(1);
newCaseTaskRecord.setSubmitTime(LocalDateTime.now());
caseTaskRecordService.save(newCaseTaskRecord);
return false;
} else {
// 如果存在,则校验时间是否已经超过1天,如果超过了1天还没有执行完毕,就重新提交这个任务
CaseTaskRecord caseTaskRecord = caseTaskRecordOpt.get();
if (caseTaskRecordOpt.get().getStatus() == 1 && LocalDateTime.now().isAfter(caseTaskRecord.getSubmitTime().plusDays(1))) {
// 如果已经超过1天,则重新提交任务
caseTaskRecord.setStatus(1);
caseTaskRecord.setSubmitTime(LocalDateTime.now());
caseTaskRecordService.updateById(caseTaskRecord);
return false;
} else if (caseTaskRecordOpt.get().getStatus() == 2) {
return true;
} else if (caseTaskRecordOpt.get().getStatus() == 3) {
caseTaskRecord.setStatus(1);
caseTaskRecord.setSubmitTime(LocalDateTime.now());
caseTaskRecordService.updateById(caseTaskRecord);
return false;
} else {
// 如果没有超过1天,则返回正在执行中
throw new RuntimeException("任务正在执行中");
}
}
} }
@Override @Override
@ -152,95 +189,45 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
String content = call.getResult().getOutput().getContent(); String content = call.getResult().getOutput().getContent();
log.info("分析的结果是:{}", content); log.info("分析的结果是:{}", content);
String oldPrompt = """ String oldPrompt = """
1. 1.
{result {result
[ [
{startNodeType:'LawActor',entity:'',endNodeType:'FictionalOrgan',property:'',value:'' }, {startNodeType:'LawActor',entity:'',endNodeType:'FictionalOrgan',property:'',value:'' },
{startNodeType:'LawActor',entity:'',endNodeType:'Seal',property:'',value:''}, {startNodeType:'LawActor',entity:'',endNodeType:'Seal',property:'',value:''},
{startNodeType:'LawActor',entity:'',endNodeType:'BusinessLicense',property:'',value:''} {startNodeType:'LawActor',entity:'',endNodeType:'BusinessLicense',property:'',value:''}
] ]
} }
2. 2.
3.{ "result": [] } 3.{ "result": [] }
1便QQ 1便QQ
{"result":[{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"FictionalOrgan","property":"兰州胜利机械租赁有限公司","value":"冒充"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"Seal","property":"兰州胜利机械租赁有限公司合同专用章","value":"伪造"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"Seal","property":"中铁北京局集团有限公司合同专用章","value":"伪造"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"BusinessLicense","property":"兰州胜利机械租赁有限公司营业执照","value":"伪造"}]} {"result":[{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"FictionalOrgan","property":"兰州胜利机械租赁有限公司","value":"冒充"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"Seal","property":"兰州胜利机械租赁有限公司合同专用章","value":"伪造"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"Seal","property":"中铁北京局集团有限公司合同专用章","value":"伪造"},{"startNodeType":"LawActor","entity":"裴金禄","endNodeType":"BusinessLicense","property":"兰州胜利机械租赁有限公司营业执照","value":"伪造"}]}
2:{ "result": [] } 2:{ "result": [] }
"""; """;
} }
private List<TripleInfo> extractTripleInfo(String caseId, String name, String recordId) {
// 首先获取所有切分后的笔录
List<NoteRecordSplit> recordSplitList = noteRecordSplitMapper.selectRecord(caseId, name, recordId);
List<TripleInfo> tripleInfos = new ArrayList<>();
// 对切分后的笔录进行遍历
for (NoteRecordSplit record : recordSplitList) {
// 根据笔录类型找到所有的提取三元组的提示词
List<NotePrompt> prompts = notePromptMapper.queryPrompt(record.getRecordTypeId());
// 遍历提示词进行提取
for (NotePrompt prompt : prompts) {
if (StringUtils.isEmpty(prompt.getPrompt())) {
continue;
}
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
Prompt ask = new Prompt(new UserMessage(prompt.getPrompt() + record.getQuestion() + record.getAnswer()));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content);
JSONArray threeInfo = jsonObject.getJSONArray("result");
for (int i = 0; i < threeInfo.length(); i++) {
JSONObject object = threeInfo.getJSONObject(i);
String startNodeType = object.getString("startNodeType");
String entity = object.getString("entity");
String endNodeType = object.getString("endNodeType");
String property = object.getString("property");
String value = object.getString("value");
// 去空,如果存在任何的空值,则忽略
if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
continue;
}
// 将三元组信息进行保存操作
TripleInfo tripleInfo = new TripleInfo(entity, property, value, record.getId(), LocalDateTime.now(), startNodeType, endNodeType);
tripleInfoMapper.insert(tripleInfo);
tripleInfos.add(tripleInfo);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
}
return tripleInfos;
}
@Override @Override
public String addNeo4j(List<String> ids) { public String addNeo4j(List<String> ids) {
List<TripleInfo> tripleInfos = tripleInfoMapper.selectByIds(ids); List<TripleInfo> tripleInfos = tripleInfoService.listByIds(ids);
int i = 0; int i = 0;
for (TripleInfo tripleInfo : tripleInfos) { for (TripleInfo tripleInfo : tripleInfos) {
try { try {
//开始节点 //开始节点
String start = tripleInfo.getStartNode(); String start = tripleInfo.getStartNode();
// 首先看是否已经存在了,如果已经存在了,就不添加了 // 首先看是否已经存在了,如果已经存在了,就不添加了
CaseNode startNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getNoteRecordsId(), tripleInfo.getStartNodeType(), start, "1"); CaseNode startNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getRecordId(), tripleInfo.getStartNodeType(), start, "1");
if (startNode == null) { if (startNode == null) {
startNode = new CaseNode(start, tripleInfo.getStartNodeType(), tripleInfo.getNoteRecordId(), tripleInfo.getNoteRecordsId(), tripleInfo.getCaseId(), "1"); startNode = new CaseNode(start, tripleInfo.getStartNodeType(), tripleInfo.getRecordSplitId(), tripleInfo.getRecordId(), tripleInfo.getCaseId(), "1");
CaseNode save = neo4jService.save(startNode); CaseNode save = neo4jService.save(startNode);
startNode.setId(save.getId()); startNode.setId(save.getId());
} }
//结束节点 //结束节点
String end = tripleInfo.getEndNode(); String end = tripleInfo.getEndNode();
CaseNode endNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getNoteRecordsId(), tripleInfo.getEndNodeType(), end, "1"); CaseNode endNode = neo4jService.findOneByName(tripleInfo.getCaseId(), tripleInfo.getRecordId(), tripleInfo.getEndNodeType(), end, "1");
if (endNode == null) { if (endNode == null) {
endNode = new CaseNode(end, tripleInfo.getEndNodeType(), tripleInfo.getNoteRecordId(), tripleInfo.getNoteRecordsId(), tripleInfo.getCaseId(), "1"); endNode = new CaseNode(end, tripleInfo.getEndNodeType(), tripleInfo.getRecordSplitId(), tripleInfo.getRecordId(), tripleInfo.getCaseId(), "1");
CaseNode save = neo4jService.save(endNode); CaseNode save = neo4jService.save(endNode);
endNode.setId(save.getId()); endNode.setId(save.getId());
} }
@ -251,8 +238,8 @@ public class ModelRecordTypeServiceImpl extends ServiceImpl<ModelRecordTypeMappe
neo4jService.saveRelation(rel); neo4jService.saveRelation(rel);
} }
tripleInfo.setAddNeo4j("1"); tripleInfo.setAddNeo4j("1");
int j = tripleInfoMapper.updateById(tripleInfo); boolean updateResult = tripleInfoService.updateById(tripleInfo);
if (j > 0) { if (updateResult) {
i++; i++;
} }
// TODO 重复添加的OK了,删除的呢? // TODO 重复添加的OK了,删除的呢?

@ -1,11 +1,14 @@
package com.supervision.police.service.impl; package com.supervision.police.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.springaidemo.domain.NoteCheckRecord; import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.police.service.NoteCheckRecordService; import com.supervision.police.service.NoteCheckRecordService;
import com.supervision.police.mapper.NoteCheckRecordMapper; import com.supervision.police.mapper.NoteCheckRecordMapper;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List;
/** /**
* @author flevance * @author flevance
* @description note_check_record()Service * @description note_check_record()Service
@ -15,6 +18,7 @@ import org.springframework.stereotype.Service;
public class NoteCheckRecordServiceImpl extends ServiceImpl<NoteCheckRecordMapper, NoteCheckRecord> public class NoteCheckRecordServiceImpl extends ServiceImpl<NoteCheckRecordMapper, NoteCheckRecord>
implements NoteCheckRecordService{ implements NoteCheckRecordService{
} }

@ -0,0 +1,12 @@
package com.supervision.police.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.mapper.NotePromptMapper;
import com.supervision.police.service.NotePromptService;
import org.springframework.stereotype.Service;
@Service
public class NotePromptServiceImpl extends ServiceImpl<NotePromptMapper, NotePrompt> implements NotePromptService {
}

@ -18,7 +18,7 @@ import com.supervision.police.mapper.ModelCaseMapper;
import com.supervision.police.mapper.ModelRecordTypeMapper; import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.mapper.NoteRecordMapper; import com.supervision.police.mapper.NoteRecordMapper;
import com.supervision.police.service.RecordService; import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.springaidemo.dto.QARecordNodeDTO; import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil; import com.supervision.springaidemo.util.WordReadUtil;
@ -41,7 +41,7 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
@Service @Service
@RequiredArgsConstructor @RequiredArgsConstructor
public class RecordServiceImpl extends ServiceImpl<NoteRecordSplitMapper, NoteRecordSplit> implements RecordService { public class NoteRecordSplitServiceImpl extends ServiceImpl<NoteRecordSplitMapper, NoteRecordSplit> implements NoteRecordSplitService {
private final NoteRecordSplitMapper noteRecordSplitMapper; private final NoteRecordSplitMapper noteRecordSplitMapper;

@ -0,0 +1,12 @@
package com.supervision.police.service.impl;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.supervision.police.domain.TripleInfo;
import com.supervision.police.mapper.TripleInfoMapper;
import com.supervision.police.service.TripleInfoService;
import org.springframework.stereotype.Service;
@Service
public class TripleInfoServiceImpl extends ServiceImpl<TripleInfoMapper, TripleInfo> implements TripleInfoService {
}

@ -1,336 +0,0 @@
package com.supervision.springaidemo.controller;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.springaidemo.service.ModelMetricService;
import com.supervision.police.service.NoteCheckRecordService;
import com.supervision.springaidemo.thread.RunCheckThread;
import com.supervision.springaidemo.thread.RunCheckThreadPool;
import com.supervision.springaidemo.util.WordReadUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.BufferedReader;
import java.io.File;
import java.util.*;
@RestController
@Slf4j
public class ChatController {
private final OllamaChatClient chatClient;
@Autowired
private ModelMetricService modelMetricService;
@Autowired
private NoteCheckRecordService noteCheckRecordService;
// 使用多线程进行提交
@Autowired
public ChatController(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@GetMapping("/ai/chat")
public void generate() {
String template = """
["XXX给我转来的钱。","我收到过XXX给我转来的钱","我通过银行给XXX转了钱"]
,;10,10!
json:{"behavior":[""],"victim":[""]}
""";
Prompt prompt = new Prompt(List.of(new UserMessage(template)));
ChatResponse call = chatClient.call(prompt);
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info(content);
}
private Message buildMessage(Map<String, Object> param) {
String messageTemplate = """
---
---
:
---
198064222613519
202203233西85使33退22020诿3
403
2035122
20诿2022
西
84211
3707818J
178
20220116016
---
,,:,
1.:{metricName}
2.:true({metricTrueDesc})/false({metricFalseDesc}),true/false
3.:,()
4.:,,
json, JSONvalue,:
---
{"metricName":"指标名称", "result":"结论", "originalContext":"笔录对应原话","reason":"原因"}
---
""";
String format = StrUtil.format(messageTemplate, param);
return new UserMessage(format);
}
@GetMapping("/ai/run")
public void run() {
var list = modelMetricService.list();
for (ModelMetric modelMetric : list) {
Map<String, Object> param = new HashMap<>();
param.put("metricTrueDesc", modelMetric.getMetricTrueDesc());
param.put("metricFalseDesc", modelMetric.getMetricFalseDesc());
param.put("metricName", modelMetric.getMetricName());
Message message = buildMessage(param);
Prompt prompt = new Prompt(List.of(new SystemMessage("所有的回复以简体中文回答。请以step by step的方式进行。step1:理解笔录设计人员的身份信息;step2:根据笔录的内容分析案件之间的逻辑关系和关联;step3:判断给定的指标是否满足。step4:根据要求的给定格式进行回复"), message));
log.info("prompt:{}", prompt.toString());
ChatResponse call = chatClient.call(prompt);
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info(content);
}
}
private static final String template = """
step by step,
step1:;
---
:{actionUserNameList}
{suspectUserNameList}
{victimUserNameList}
{witnessNameList}
---
step2:;
,"问","答"{noteUserName}:
---
{context}
---
step3::
:{metricName}
:
({metricTrueDesc}),true;
({metricFalseDesc}),false;
,,empty
step4:,:
1.:true/false/empty
2.:,()true,!
3.:,,true/false,!
step5:json, JSONvalue,:
---
{"result":"结论", "originalContext":"笔录对应原话","reason":"原因"}
---
,!
""";
@GetMapping("runNoteCheck")
public void runNoteCheck() {
HashMap<String, String> map = new HashMap<>();
map.put("杨学明", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/受害人杨学明询问笔录.docx");
map.put("朱文泽", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/受害人朱文泽询问笔录.docx");
map.put("陈恩明", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人陈恩明讯问笔录.docx");
map.put("武桂清1", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人武桂清讯问笔录1.docx");
map.put("武桂清2", "/Users/flevance/Desktop/宁夏审讯大模型/陈恩明合同诈骗/嫌疑人武桂清讯问笔录2.docx");
for (Map.Entry<String, String> entry : map.entrySet()) {
String context = WordReadUtil.readWord(entry.getValue());
List<String> actionUserNameList = new ArrayList<>();
actionUserNameList.add("陈恩明");
actionUserNameList.add("武桂清");
List<String> victimUserNameList = new ArrayList<>();
victimUserNameList.add("漫旭昌");
victimUserNameList.add("杨学明");
victimUserNameList.add("朱文泽");
var list = modelMetricService.list();
for (ModelMetric modelMetric : list) {
// 没有跑过的,才继续跑
Long count = noteCheckRecordService.lambdaQuery().eq(NoteCheckRecord::getPersonName, entry.getKey()).eq(NoteCheckRecord::getMetricCode, modelMetric.getMetricCode()).count();
if (count < 1) {
Map<String, Object> param = new HashMap<>();
param.put("actionUserNameList", CollUtil.join(actionUserNameList, ";"));
param.put("victimUserNameList", CollUtil.join(victimUserNameList, ";"));
param.put("witnessNameList", "无");
param.put("noteUserName", entry.getKey());
param.put("context", context);
param.put("metricName", modelMetric.getMetricName());
param.put("metricTrueDesc", modelMetric.getMetricTrueDesc());
param.put("metricFalseDesc", modelMetric.getMetricFalseDesc());
String format = StrUtil.format(template, param);
Message message = new UserMessage(format);
Prompt prompt = new Prompt(List.of(new SystemMessage("所有的回复以简体中文回答。请以step by step的方式进行。step1:理解笔录设计人员的身份信息;step2:根据笔录的内容分析案件之间的逻辑关系和关联;step3:根据给定的指标,提取出来可能涉及的笔录内容文本;step4:根据该笔录内容判断给定的指标是否满足。step5:根据要求的给定格式进行回复"), message));
log.info("prompt:{}", prompt);
ChatResponse call = chatClient.call(prompt);
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info(content);
MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class);
NoteCheckRecord noteCheckRecord = new NoteCheckRecord();
noteCheckRecord.setPersonName(entry.getKey());
noteCheckRecord.setNoteName(FileUtil.getName(entry.getValue()));
noteCheckRecord.setType(entry.getKey().contains("询问") ? "询问" : "讯问");
noteCheckRecord.setMetricCode(modelMetric.getMetricCode());
noteCheckRecord.setMetricName(modelMetric.getMetricName());
noteCheckRecord.setResult(metricResultDTO.getResult());
noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext());
noteCheckRecord.setReason(metricResultDTO.getReason());
noteCheckRecordService.save(noteCheckRecord);
}
}
}
}
/**
* word
*/
@GetMapping("runCheck")
public void runCheck() {
// List<String> metricCodeList = ListUtil.list(false, "RZ010", "RZ019", "RZ020", "RZ022");
// 行为人
List<String> actionUserNameList = ListUtil.list(false, "裴金禄");
// 犯罪嫌疑人
List<String> suspectUserNameList = ListUtil.list(false, "裴金禄", "景涛", "李世怀", "万学宝");
// 受害人
List<String> victimUserNameList = ListUtil.list(false, "董金才", "吕加国", "吕志仓");
// 证人
List<String> witnessNameList = ListUtil.list(false, "白鹏", "丁建华", "雷建贵", "雷建明", "李泽懿", "王存良", "王开阔", "吴尚军", "杨正福", "叶魁伍", "赵景宝");
// 获取目录下的所有笔录信息
List<File> files = FileUtil.loopFiles("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/");
for (File file : files) {
// 只跑裴金禄的笔录
log.info("开始分析:{}的笔录", file.getName());
String context = WordReadUtil.readWord(file.getPath());
// List<ModelMetric> list = modelMetricService.lambdaQuery().in(ModelMetric::getMetricCode, metricCodeList).list();
List<ModelMetric> list = modelMetricService.list();
for (ModelMetric modelMetric : list) {
Map<String, Object> param = new HashMap<>();
param.put("actionUserNameList", CollUtil.join(actionUserNameList, ";"));
param.put("suspectUserNameList", CollUtil.join(suspectUserNameList, ";"));
param.put("victimUserNameList", CollUtil.join(victimUserNameList, ";"));
param.put("witnessNameList", CollUtil.join(witnessNameList, ";"));
param.put("context", context);
param.put("metricName", modelMetric.getMetricName());
param.put("metricTrueDesc", modelMetric.getMetricTrueDesc());
param.put("metricFalseDesc", modelMetric.getMetricFalseDesc());
String format = StrUtil.format(template, param);
List<Message> userMessageList = new ArrayList<>();
log.info("开始提交分析,prompt长度为:{}", format.length());
// 如果超过8000字,就进行截断,每次以6000字进行提交
if (format.length() > 8000) {
log.info("分段提交");
for (String s : StrUtil.split(format, 6000)) {
userMessageList.add(new UserMessage(s));
userMessageList.add(new AssistantMessage("继续"));
}
userMessageList.remove(userMessageList.size() - 1);
} else {
userMessageList.add(new UserMessage(format));
}
String systemPrompt = """
,,,,,
""";
List<Message> messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt)));
messages.addAll(userMessageList);
Prompt prompt = new Prompt(messages);
RunCheckThread runCheck = new RunCheckThread("裴金禄第五次",chatClient, noteCheckRecordService, prompt, file.getName(), format, systemPrompt, modelMetric, 0);
RunCheckThreadPool.chatExecutor.submit(runCheck);
}
}
}
@GetMapping("testLongText")
public void testLongText() {
StringBuilder stringBuilder = new StringBuilder();
BufferedReader utf8Reader = FileUtil.getUtf8Reader("/Users/flevance/Desktop/宁夏审讯大模型/了不起的盖茨比 .txt");
utf8Reader.lines().forEach(stringBuilder::append);
String template = """
,:
---
{context}
---
,8,..json
""";
String systemPrompt = """
,,.
""";
Map<String, Object> param = new HashMap<>();
log.info("大小是:{}", stringBuilder.length());
param.put("context", stringBuilder.toString());
String format = StrUtil.format(template, param);
List<Message> userMessageList = new ArrayList<>();
for (String s : StrUtil.split(format, 6000)) {
userMessageList.add(new UserMessage(s));
userMessageList.add(new AssistantMessage("继续"));
}
userMessageList.remove(userMessageList.size() - 1);
List<Message> messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt)));
messages.addAll(userMessageList);
Prompt prompt = new Prompt(messages);
ChatResponse call = chatClient.call(prompt);
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info("分析的结果是:{}", content);
}
}

@ -8,7 +8,7 @@ import com.supervision.police.mapper.ModelRecordTypeMapper;
import com.supervision.police.mapper.NoteRecordSplitMapper; import com.supervision.police.mapper.NoteRecordSplitMapper;
import com.supervision.police.service.ModelRecordTypeService; import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.springaidemo.service.ModelMetricService; import com.supervision.springaidemo.service.ModelMetricService;
import com.supervision.police.service.RecordService; import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.NoteCheckRecordService; import com.supervision.police.service.NoteCheckRecordService;
import com.supervision.springaidemo.util.RecordRegexUtil; import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil; import com.supervision.springaidemo.util.WordReadUtil;
@ -58,7 +58,7 @@ public class ExampleChatController {
} }
@Autowired @Autowired
private RecordService recordService; private NoteRecordSplitService noteRecordSplitService;
@Autowired @Autowired
private ModelRecordTypeService modelRecordTypeService; private ModelRecordTypeService modelRecordTypeService;
@ -133,7 +133,7 @@ public class ExampleChatController {
//保存笔录 //保存笔录
noteRecord.setRecordType(type); noteRecord.setRecordType(type);
recordService.save(noteRecord); noteRecordSplitService.save(noteRecord);
ModelRecordType exist = modelRecordTypeService.queryByName(type); ModelRecordType exist = modelRecordTypeService.queryByName(type);
if (exist == null) { if (exist == null) {
@ -169,8 +169,8 @@ public class ExampleChatController {
// } // }
// messages.addAll(userMessageList); // messages.addAll(userMessageList);
// //
// RunCheckThread runCheck = new RunCheckThread("裴金禄尝试正则来做", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0); // TripleExtractThread runCheck = new TripleExtractThread("裴金禄尝试正则来做", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0);
// RunCheckThreadPool.chatExecutor.submit(runCheck); // TripleExtractThreadPool.chatExecutor.submit(runCheck);
// } // }
// } // }
@ -231,7 +231,7 @@ public class ExampleChatController {
@GetMapping("test1") @GetMapping("test1")
public void test2(@Param("id") String id) { public void test2(@Param("id") String id) {
NoteRecordSplit noteRecord = recordService.getById(id); NoteRecordSplit noteRecord = noteRecordSplitService.getById(id);
String question = noteRecord.getQuestion(); String question = noteRecord.getQuestion();
String answer = noteRecord.getAnswer(); String answer = noteRecord.getAnswer();
String test = "请从以下对话中提取所有关于" + noteRecord.getRecordType() + "的所有三元组"; String test = "请从以下对话中提取所有关于" + noteRecord.getRecordType() + "的所有三元组";

@ -1,106 +0,0 @@
package com.supervision.springaidemo.controller;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.dto.QARecordNodeDTO;
import com.supervision.springaidemo.service.ModelMetricService;
import com.supervision.police.service.NoteCheckRecordService;
import com.supervision.springaidemo.thread.RunCheckThread;
import com.supervision.springaidemo.thread.RunCheckThreadPool;
import com.supervision.springaidemo.util.RecordRegexUtil;
import com.supervision.springaidemo.util.WordReadUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* ,2
* ,,()
*
*/
@RestController
@Slf4j
@RequestMapping("newTestChat")
public class NewTestController {
private final OllamaChatClient chatClient;
@Autowired
private ModelMetricService modelMetricService;
@Autowired
private NoteCheckRecordService noteCheckRecordService;
@Autowired
public NewTestController(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
private static final String template = """
---
1:
---
---
{context}
---
,,
json
---
{"match":"true/false(相关回复true,不相关回复false)","summary":"如果有关联,则对该笔录进行总结"}
---
,!!
""";
@GetMapping("newTestChat")
public void newTestChat() {
// 只查入罪指标
ModelMetric modelMetric = modelMetricService.lambdaQuery().eq(ModelMetric::getMetricCode, "RZ007").one();
File file = FileUtil.file("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/裴金禄第一次.docx");
String context = WordReadUtil.readWord(file.getPath());
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(context, "裴金禄");
for (QARecordNodeDTO qaRecordNodeDTO : qaList) {
String systemPrompt = """
,
""";
List<Message> messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt)));
Map<String, Object> param = new HashMap<>();
param.put("context", qaRecordNodeDTO.toString());
String format = StrUtil.format(template, param);
List<Message> userMessageList = new ArrayList<>();
if (format.length() > 8000) {
log.info("分段提交");
for (String s : StrUtil.split(format, 6000)) {
userMessageList.add(new UserMessage(s));
userMessageList.add(new AssistantMessage("继续"));
}
userMessageList.remove(userMessageList.size() - 1);
} else {
userMessageList.add(new UserMessage(format));
}
messages.addAll(userMessageList);
RunCheckThread runCheck = new RunCheckThread("尝试分段提取-第一段找相似", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 5);
RunCheckThreadPool.chatExecutor.submit(runCheck);
}
}
}

@ -1,118 +0,0 @@
package com.supervision.springaidemo.controller;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.StrUtil;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.service.ModelMetricService;
import com.supervision.police.service.NoteCheckRecordService;
import com.supervision.springaidemo.thread.RunCheckThread;
import com.supervision.springaidemo.thread.RunCheckThreadPool;
import com.supervision.springaidemo.util.WordReadUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* controller
*/
@RestController
@Slf4j
@RequestMapping("rzChat")
public class RzChatController {
private static final String step1Template = """
step by step,
step1:;
,"问","答"{noteUserName}:
---
{context}
---
step2::
:
:{metricDetailTemplate},true;
::{metricDetailTemplate},false;
,,,empty
step3:,:
1.:true/false/empty
2.:,()true,!
3.:,,true/false,!
step4:json, JSONvalue,:
---
{"result":"结论", "originalContext":"笔录对应原话","reason":"原因"}
---
,!
""";
private final OllamaChatClient chatClient;
@Autowired
private ModelMetricService modelMetricService;
@Autowired
private NoteCheckRecordService noteCheckRecordService;
@Autowired
public RzChatController(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@GetMapping("extract")
public void extract() throws InterruptedException {
List<File> files = FileUtil.loopFiles("/Users/flevance/Desktop/宁夏审讯大模型/裴金禄/行为人和受害人/");
for (File file : files) {
String context = WordReadUtil.readWord(file.getPath());
// 只查入罪指标
List<ModelMetric> list = modelMetricService.lambdaQuery().likeRight(ModelMetric::getMetricCode, "RZ").list();
for (ModelMetric modelMetric : list) {
String systemPrompt = """
,,,,,
""";
List<Message> messages = new ArrayList<>(List.of(new SystemMessage(systemPrompt)));
Map<String, Object> param = new HashMap<>();
param.put("metricDetailTemplate", StrUtil.format(modelMetric.getMetricDetailTemplate(), MapUtil.of("action", "裴金禄")));
param.put("noteUserName", "裴金禄");
param.put("context", context);
String format = StrUtil.format(step1Template, param);
List<Message> userMessageList = new ArrayList<>();
if (format.length() > 8000) {
log.info("分段提交");
for (String s : StrUtil.split(format, 6000)) {
userMessageList.add(new UserMessage(s));
userMessageList.add(new AssistantMessage("继续"));
}
userMessageList.remove(userMessageList.size() - 1);
} else {
userMessageList.add(new UserMessage(format));
}
messages.addAll(userMessageList);
RunCheckThread runCheck = new RunCheckThread("裴金禄尝试通过直接定义模板", chatClient, noteCheckRecordService, new Prompt(messages), FileUtil.getName(file), format, systemPrompt, modelMetric, 0);
RunCheckThreadPool.chatExecutor.submit(runCheck);
}
}
}
}

@ -1,91 +0,0 @@
package com.supervision.springaidemo.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.police.service.NoteCheckRecordService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
@Slf4j
public class RunCheckThread implements Runnable {
private final String caseName;
private final OllamaChatClient chatClient;
private final NoteCheckRecordService noteCheckRecordService;
private final Prompt prompt;
private final String fileName;
private final String format;
private final String systemPrompt;
private final ModelMetric modelMetric;
private Integer count;
public RunCheckThread(String caseName, OllamaChatClient chatClient, NoteCheckRecordService noteCheckRecordService, Prompt prompt, String fileName, String format, String systemPrompt, ModelMetric modelMetric, Integer count) {
this.caseName = caseName;
this.chatClient = chatClient;
this.noteCheckRecordService = noteCheckRecordService;
this.prompt = prompt;
this.fileName = fileName;
this.format = format;
this.systemPrompt = systemPrompt;
this.modelMetric = modelMetric;
this.count = count;
}
@Override
public void run() {
try {
StopWatch stopWatch = new StopWatch();
stopWatch.start();
log.info("开始分析:{}",fileName);
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info("分析的结果是:{}", content);
MetricResultDTO metricResultDTO = JSONUtil.toBean(content, MetricResultDTO.class);
// 如果为空,则再跑一次,最多跑5次
if (StrUtil.isBlank(metricResultDTO.getResult())) {
if (count > 5) {
log.info("{}的{}结果超过5次,不再继续跑了", fileName, modelMetric);
} else {
log.info("{}的{}结果为空,当前跑了{}次,重新提交,再跑一次", fileName, modelMetric, count);
Integer newCount = count++;
RunCheckThread runCheck = new RunCheckThread(caseName, chatClient, noteCheckRecordService, prompt, fileName, format, systemPrompt, modelMetric, newCount);
RunCheckThreadPool.chatExecutor.submit(runCheck);
}
} else {
NoteCheckRecord noteCheckRecord = new NoteCheckRecord();
noteCheckRecord.setCaseName(caseName);
noteCheckRecord.setNoteName(fileName);
noteCheckRecord.setMetricCode(modelMetric.getMetricCode());
noteCheckRecord.setMetricName(modelMetric.getMetricName());
noteCheckRecord.setSystemPrompt(systemPrompt);
noteCheckRecord.setPrompt(format);
noteCheckRecord.setResult(metricResultDTO.getResult());
noteCheckRecord.setOriginalContext(metricResultDTO.getOriginalContext());
noteCheckRecord.setReason(metricResultDTO.getReason());
noteCheckRecordService.save(noteCheckRecord);
}
} catch (Exception e) {
log.error("出现错误", e);
}
}
}

@ -0,0 +1,88 @@
package com.supervision.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.TripleInfo;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.police.service.NoteCheckRecordService;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.concurrent.Callable;
@Slf4j
public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final String prompt;
private final String question;
private final String answer;
private final String recordSplitId;
private final String caseId;
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) {
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
this.prompt = prompt;
this.recordSplitId = recordSplitId;
this.caseId = caseId;
this.recordId = recordId;
}
@Override
public TripleInfo call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
Prompt ask = new Prompt(new UserMessage(prompt + question + answer));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content);
JSONArray threeInfo = jsonObject.getJSONArray("result");
for (int i = 0; i < threeInfo.length(); i++) {
JSONObject object = threeInfo.getJSONObject(i);
String startNodeType = object.getString("startNodeType");
String entity = object.getString("entity");
String endNodeType = object.getString("endNodeType");
String property = object.getString("property");
String value = object.getString("value");
// 去空,如果存在任何的空值,则忽略
if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
continue;
}
// 构建三元组信息
return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType);
}
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
}

@ -1,10 +1,10 @@
package com.supervision.springaidemo.thread; package com.supervision.thread;
import cn.hutool.core.thread.ThreadUtil; import cn.hutool.core.thread.ThreadUtil;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
public class RunCheckThreadPool { public class TripleExtractThreadPool {
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "chat", false); public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(20, Integer.MAX_VALUE, "tripleExtract", false);
} }

@ -3,8 +3,4 @@
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd"> "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.police.mapper.NotePromptMapper"> <mapper namespace="com.supervision.police.mapper.NotePromptMapper">
<select id="queryPrompt" resultType="com.supervision.police.domain.NotePrompt">
select * from note_prompt
where type_id = #{typeId}
</select>
</mapper> </mapper>

@ -3,17 +3,5 @@
PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd"> "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.supervision.police.mapper.TripleInfoMapper"> <mapper namespace="com.supervision.police.mapper.TripleInfoMapper">
<select id="selectByIds" resultType="com.supervision.police.domain.TripleInfo">
select nr2.case_id as caseId, nr2.id as noteRecordsId, ti.*
from triple_info ti
left join note_record_split nr on ti.note_record_id = nr.id
left join note_record nr2 on nr.note_records_id = nr2.id
where 1 = 1
<if test="ids != null and ids.size > 0">
and ti.id in
<foreach collection="ids" item="id" open="(" close=")" separator=",">
#{id}
</foreach>
</if>
</select>
</mapper> </mapper>
Loading…
Cancel
Save