diff --git a/src/main/java/com/supervision/demo/controller/ExampleChatController.java b/src/main/java/com/supervision/demo/controller/ExampleChatController.java index 12d2348..7b7f14f 100644 --- a/src/main/java/com/supervision/demo/controller/ExampleChatController.java +++ b/src/main/java/com/supervision/demo/controller/ExampleChatController.java @@ -1,7 +1,11 @@ package com.supervision.demo.controller; +import cn.hutool.core.date.DateTime; +import cn.hutool.core.date.DateUnit; +import cn.hutool.core.date.DateUtil; import cn.hutool.core.io.FileUtil; import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.supervision.demo.service.ModelMetricService; import com.supervision.police.domain.ModelRecordType; import com.supervision.police.domain.NoteRecordSplit; @@ -257,7 +261,7 @@ public class ExampleChatController { } - @PostMapping("chat") + @PostMapping("/chat") public String chat(@RequestBody Map messages) { Prompt prompt = null; @@ -267,15 +271,17 @@ public class ExampleChatController { prompt = new Prompt(List.of(new AssistantMessage(messages.get("assistantMessage")), new UserMessage(messages.get("userMessage")))); } + DateTime start = DateUtil.date(); ChatResponse call = chatClient.call(prompt); Generation result = call.getResult(); - + long between = DateUtil.between(DateUtil.date(), start, DateUnit.MS); String content = result.getOutput().getContent(); log.info("分析的结果是:{}", content); if (StrUtil.isBlank(content)){ content = "{}"; } - return content; + + return JSONUtil.toJsonStr(Map.of("result", content, "time", between+"毫秒")); } } diff --git a/src/test/java/com/supervision/demo/FuHsiApplicationTests.java b/src/test/java/com/supervision/demo/FuHsiApplicationTests.java index 91cf9c4..63ead25 100644 --- a/src/test/java/com/supervision/demo/FuHsiApplicationTests.java +++ b/src/test/java/com/supervision/demo/FuHsiApplicationTests.java @@ -1,21 +1,28 @@ package com.supervision.demo; import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.date.DateTime; +import cn.hutool.core.date.DateUnit; +import cn.hutool.core.date.DateUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; +import cn.hutool.poi.excel.ExcelUtil; +import cn.hutool.poi.excel.ExcelWriter; import com.supervision.common.domain.R; import com.supervision.neo4j.controller.Neo4jController; import com.supervision.neo4j.domain.CaseNode; import com.supervision.demo.controller.ExampleChatController; import com.supervision.police.domain.ModelRecordType; +import com.supervision.police.domain.NotePrompt; import com.supervision.police.domain.NoteRecordSplit; import com.supervision.police.dto.RetrieveReqDTO; import com.supervision.police.dto.RetrieveResDTO; -import com.supervision.police.service.ModelRecordTypeService; -import com.supervision.police.service.NoteRecordSplitService; -import com.supervision.police.service.OCRService; -import com.supervision.police.service.RecordSplitClassifyService; +import com.supervision.police.service.*; import com.supervision.thread.RecordSplitClassifyTask; +import com.supervision.thread.TripleExtractTask; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; @@ -30,6 +37,11 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; @@ -172,10 +184,115 @@ public class FuHsiApplicationTests { @Autowired private OCRService ocrService; @Test - public void test1() { + public void retrieveTest() { String question = "银川市公安局金凤区分局\\r\\n拘留通知书\\r\\n(副本)\\r\\n银金公(经侦)拘通字[2024]10017号\\r\\n梁玉峰家属\\r\\n根据《中华人民共和国刑事诉讼法》第八十二条第(一)项之规定\\r\\n罪\\r\\n将涉嫌合同诈骗\\r\\n我局已于2024年01月16日10时\\r\\n银川市看守所\\r\\n的\\r\\n梁玉峰\\r\\n刑事拘留,现羁押在\\r\\n市\\r\\n二0二\\r\\n年名\\r\\n本通知书已收到。\\r\\n年月\\r\\n日时\\r\\n被拘留人家属:\\r\\n如未在拘留后24小时内通知被拘留人家属,注明原因:办集民警\\r\\n当天拔打梁天峰亲属果无辉毛话,对方电生无法接通\\r\\n办案人:\\r\\n2024年1月16日1时\\r\\n此联附卷"; RetrieveResDTO retrieve = ocrService.retrieve(new RetrieveReqDTO(question)); System.out.println(retrieve); } + + @Autowired + private NotePromptService notePromptService; + + + @Test + public void promptTest() throws InterruptedException { + + List listPromptBySplitId = notePromptService.listPromptBySplitId("1825358898516275202"); + + ArrayList notePrompts = new ArrayList<>(); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + notePrompts.addAll(listPromptBySplitId); + + NoteRecordSplit noteRecordSplit = new NoteRecordSplit(); + /*noteRecordSplit.setQuestion("办案警官问:你让马旭东帮你买车你们是如何协商的?"); + noteRecordSplit.setAnswer("裴金禄回答:我和马旭东认识十年时间了,我们关系比较好。我找到马旭东告诉他我有车要出售,这样要求把帮忙出售,我不出面。出售完毕之后我再给马旭东支付一些好处费用");*/ + noteRecordSplit.setQuestion("办案警官问:你和马旭东是如何协商好处费的?"); + noteRecordSplit.setAnswer("裴金禄回答:我没有协商过,我都是随心意给马旭东给的,"); + String mainActor = "裴金禄"; + + List tripleRecords = new ArrayList<>(); + AtomicLong atomicLong = new AtomicLong(); + CountDownLatch countDownLatch = new CountDownLatch(notePrompts.size()); + + ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 10000, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>()); + + DateTime start = DateUtil.date(); + for (NotePrompt notePrompt : notePrompts) { + + executor.submit(()->{ + + TripleRecord tripleRecord = chat4Triple(notePrompt, noteRecordSplit, mainActor,null); + atomicLong.getAndAdd(tripleRecord.getTime()); + countDownLatch.countDown(); + tripleRecords.add(tripleRecord); + }); + + + } + + + countDownLatch.await(); + long between = DateUtil.between(start, DateUtil.date(), DateUnit.SECOND); + log.info("实际总耗时:" + between); + ExcelWriter writer = ExcelUtil.getWriter("F:\\tmp\\1\\ollama-模型优化\\50服务器\\prompt.xlsx","速度对比-并发1-qwen2.5-32b"); + writer.write(tripleRecords); + + writer.close(); + log.info("所有线程总耗时:" + atomicLong.get()); + + } + + private TripleRecord chat4Triple(NotePrompt prompt, NoteRecordSplit noteRecordSplit,String mainActor,String type) { + + StopWatch stopWatch = new StopWatch(); + // 分析三元组 + stopWatch.start(); + HashMap paramMap = new HashMap<>(); + paramMap.put("headEntityType", prompt.getStartEntityType()); + paramMap.put("relation", prompt.getRelType()); + paramMap.put("tailEntityType", prompt.getEndEntityType()); + paramMap.put("question", noteRecordSplit.getQuestion()); + paramMap.put("answer", noteRecordSplit.getAnswer()); + //log.info("开始尝试提取三元组:{}-{}-{},mainActor:{}", prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), mainActor == null ? "" : mainActor); + if (mainActor != null && "行为人".equals(prompt.getStartEntityType())) { + paramMap.put("requirement", "当前案件的行为人是" + mainActor + ",只尝试提取" + mainActor + "为头结点的三元组。"); + } else { + paramMap.put("requirement", ""); + } + String format = StrUtil.format(prompt.getPrompt(), paramMap); + + //log.info("提示词内容:{}", format); + log.info("开始执行:chatClient.call 开始"); + ChatResponse call = chatClient.call(new Prompt(new UserMessage(format))); + log.info("开始执行:chatClient.call 结束"); + stopWatch.stop(); + String content = call.getResult().getOutput().getContent(); + //log.info("问题:{}耗时:{},三元组提取结果是:{}", noteRecordSplit.getQuestion(), stopWatch.getTotalTimeSeconds(), content); + return new TripleRecord(noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type, + prompt.getId(),prompt.getPrompt(),format, + content, stopWatch.getTotalTimeMillis()); + } + + @Data + @AllArgsConstructor + class TripleRecord{ + private String question; + private String answer; + private String type; + private String templateId; + private String template; + private String prompt; + private String result; + private long time; + } }