You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/test/java/com/supervision/demo/FuHsiApplicationTests.java

335 lines
15 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateTime;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelUtil;
import cn.hutool.poi.excel.ExcelWriter;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.dto.EvidenceDirectoryDTO;
import com.supervision.police.dto.RetrieveReqDTO;
import com.supervision.police.dto.RetrieveResDTO;
import com.supervision.police.service.*;
import com.supervision.thread.RecordSplitClassifyTask;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class FuHsiApplicationTests {
@Autowired
private OllamaChatClient ollamaChatClient;
@Autowired
private ExampleChatController exampleChatController;
@Autowired
private Neo4jController neo4jController;
@Test
public void contextLoads() {
}
@Test
public void test() {
// exampleChatController.test("1803663875373694977", "你现在是一个笔录分析人员,请用四个字描述一下下述内容属于哪种类型的对话?");
}
@Test
public void savePersion() {
CaseNode caseNode = new CaseNode();
caseNode.setName("自然人");
R<?> save = neo4jController.save(caseNode);
System.out.printf(save.toString());
}
private final OllamaChatClient chatClient;
@Autowired
public FuHsiApplicationTests(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Test
public void classificationTest() {
String recordId = "1824329325387304962";
final String NEW_TEMPLATE = """
分类任务: 将对话笔录文本进行分类。
目标: 将给定的对话分配到预定义的类别中
预定义类别为:{typeContext}。
说明: 提供一段对话笔记录文本,分类器应该识别出对话的主题,并将其归类到上述类别中的一个或多个。
---
示例输入:
办案警官问:你和上述的这些受害人签订协议之后是否实际履行合同?
行为人XXX回答我和他们签订合同就是为了骗他们相信我我就是伪造的一些假合同等我把他们的钱骗到手之后我就不会履行合同。
预期输出: {"result":[{"type":"","explain":"XXX"},{"type":"","explain":"XXX"}]}
---
任务要求:
1. 分类器应当准确地识别对话的主题,分类来自于预定义的类别。
2. 分类器应该实事求是按照对话进行分类,不要有过多的推测。
2. 如果一段对话笔记录包含多个主题,请选择最相关的类别,最多可选择三个分类。
3. 如果不涉及任何分类则回复{"result":[]}
---
以下为问答对内容:
{question}
{answer}
---
返回格式为json,字段名要严格一致:{"result":[{"type":"1","explain":""},{"type":"2","explain":""}]}
""";
final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
String type = "";
// 根据recordId查询所有的分割后的笔录
List<NoteRecordSplit> noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list();
for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) {
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer());
log.info("分析的结果是:{}", content);
RecordSplitClassifyTask.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitClassifyTask.TypeResultDTO.class);
List<RecordSplitClassifyTask.TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(RecordSplitClassifyTask.TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
log.info("最后结果 question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type);
}
}
@Autowired
private RecordSplitClassifyService recordSplitClassifyService;
@Test
public void classificationTest2() {
List<ModelRecordType> typeList = modelRecordTypeService.lambdaQuery().list();
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getId, "1824729361214418946").list();
recordSplitClassifyService.classify(typeList,noteRecordSplits);
log.info("分类结果:{}",noteRecordSplits);
}
@Autowired
private OCRService ocrService;
@Test
public void retrieveTest() {
String question = "银川市公安局金凤区分局\\r\\n拘留通知书\\r\\n副本\\r\\n银金公经侦拘通字[2024]10017号\\r\\n梁玉峰家属\\r\\n根据《中华人民共和国刑事诉讼法》第八十二条第一)项之规定\\r\\n罪\\r\\n将涉嫌合同诈骗\\r\\n我局已于2024年01月16日10时\\r\\n银川市看守所\\r\\n的\\r\\n梁玉峰\\r\\n刑事拘留现羁押在\\r\\n市\\r\\n二0二\\r\\n年名\\r\\n本通知书已收到。\\r\\n年月\\r\\n日时\\r\\n被拘留人家属\\r\\n如未在拘留后24小时内通知被拘留人家属注明原因办集民警\\r\\n当天拔打梁天峰亲属果无辉毛话对方电生无法接通\\r\\n办案人\\r\\n2024年1月16日1时\\r\\n此联附卷";
RetrieveResDTO retrieve = ocrService.retrieve(new RetrieveReqDTO(question));
System.out.println(retrieve);
}
@Autowired
private NotePromptService notePromptService;
@Test
public void promptTest() throws InterruptedException {
List<NotePrompt> listPromptBySplitId = notePromptService.listPromptBySplitId("1825358898516275202");
ArrayList<NotePrompt> notePrompts = new ArrayList<>();
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
NoteRecordSplit noteRecordSplit = new NoteRecordSplit();
/*noteRecordSplit.setQuestion("办案警官问:你让马旭东帮你买车你们是如何协商的?");
noteRecordSplit.setAnswer("裴金禄回答:我和马旭东认识十年时间了,我们关系比较好。我找到马旭东告诉他我有车要出售,这样要求把帮忙出售,我不出面。出售完毕之后我再给马旭东支付一些好处费用");*/
noteRecordSplit.setQuestion("办案警官问:你和马旭东是如何协商好处费的?");
noteRecordSplit.setAnswer("裴金禄回答:我没有协商过,我都是随心意给马旭东给的,");
String mainActor = "裴金禄";
List<TripleRecord> tripleRecords = new ArrayList<>();
AtomicLong atomicLong = new AtomicLong();
CountDownLatch countDownLatch = new CountDownLatch(notePrompts.size());
ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 10000, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
DateTime start = DateUtil.date();
for (NotePrompt notePrompt : notePrompts) {
executor.submit(()->{
TripleRecord tripleRecord = chat4Triple(notePrompt, noteRecordSplit, mainActor,null);
atomicLong.getAndAdd(tripleRecord.getTime());
countDownLatch.countDown();
tripleRecords.add(tripleRecord);
});
}
countDownLatch.await();
long between = DateUtil.between(start, DateUtil.date(), DateUnit.SECOND);
log.info("实际总耗时:" + between);
ExcelWriter writer = ExcelUtil.getWriter("F:\\tmp\\1\\ollama-模型优化\\50服务器\\prompt.xlsx","速度对比-并发1-qwen2.5-32b");
writer.write(tripleRecords);
writer.close();
log.info("所有线程总耗时:" + atomicLong.get());
}
private TripleRecord chat4Triple(NotePrompt prompt, NoteRecordSplit noteRecordSplit,String mainActor,String type) {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", prompt.getStartEntityType());
paramMap.put("relation", prompt.getRelType());
paramMap.put("tailEntityType", prompt.getEndEntityType());
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
//log.info("开始尝试提取三元组:{}-{}-{},mainActor:{}", prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), mainActor == null ? "" : mainActor);
if (mainActor != null && "行为人".equals(prompt.getStartEntityType())) {
paramMap.put("requirement", "当前案件的行为人是" + mainActor + ",只尝试提取" + mainActor + "为头结点的三元组。");
} else {
paramMap.put("requirement", "");
}
String format = StrUtil.format(prompt.getPrompt(), paramMap);
//log.info("提示词内容:{}", format);
log.info("开始执行:chatClient.call 开始");
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
log.info("开始执行:chatClient.call 结束");
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
//log.info("问题:{}耗时:{},三元组提取结果是:{}", noteRecordSplit.getQuestion(), stopWatch.getTotalTimeSeconds(), content);
return new TripleRecord(noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type,
prompt.getId(),prompt.getPrompt(),format,
content, stopWatch.getTotalTimeMillis());
}
@Autowired
private CaseEvidenceService caseEvidenceService;
@Test
public void ocrAndExtractTest() {
// EvidenceDirectoryDTO directoryDTO = new EvidenceDirectoryDTO();
// directoryDTO.setId("1845745238182703105");
// directoryDTO.setDirectoryName("书证");
// EvidenceDirectoryDTO directoryDTO1 = new EvidenceDirectoryDTO();
// directoryDTO1.setId("1845745256180461570");
// directoryDTO1.setDirectoryName("合同协议");
// directoryDTO1.setParentId("1845745238182703105");
// directoryDTO1.setFileIdList(List.of("1823958525387853825","1831232334119645185"));
// directoryDTO.setChild(List.of(directoryDTO1));
//
// EvidenceDirectoryDTO directoryDTO2 = new EvidenceDirectoryDTO();
// directoryDTO2.setId("2222");
// directoryDTO2.setDirectoryName("车辆合同协议");
// directoryDTO2.setParentId("1845745256180461570");
// directoryDTO2.setFileIdList(List.of("1833015941205143554","1833016215680397313","1833016621840019457"));
// directoryDTO1.setChild(List.of(directoryDTO2));
// caseEvidenceService.ocrAndExtract("1823955210189000706", List.of(directoryDTO));
}
@Test
public void initCaseDirectory() {
caseEvidenceService.initCaseEvidenceDirectory("1823955210189000706", "1");
System.out.println("222");
}
@Data
@AllArgsConstructor
class TripleRecord{
private String question;
private String answer;
private String type;
private String templateId;
private String template;
private String prompt;
private String result;
private long time;
}
@Test
void refreshCaseEvidenceTest() {
caseEvidenceService.refreshCaseEvidence();
System.out.println("执行完成");
}
}