1. 添加测试用例

topo_dev v1.0.0
xueqingkun 7 months ago
parent 020b335b8d
commit e772af4fb9

@ -1,7 +1,11 @@
package com.supervision.demo.controller;
import cn.hutool.core.date.DateTime;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.demo.service.ModelMetricService;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
@ -257,7 +261,7 @@ public class ExampleChatController {
}
@PostMapping("chat")
@PostMapping("/chat")
public String chat(@RequestBody Map<String,String> messages) {
Prompt prompt = null;
@ -267,15 +271,17 @@ public class ExampleChatController {
prompt = new Prompt(List.of(new AssistantMessage(messages.get("assistantMessage")), new UserMessage(messages.get("userMessage"))));
}
DateTime start = DateUtil.date();
ChatResponse call = chatClient.call(prompt);
Generation result = call.getResult();
long between = DateUtil.between(DateUtil.date(), start, DateUnit.MS);
String content = result.getOutput().getContent();
log.info("分析的结果是:{}", content);
if (StrUtil.isBlank(content)){
content = "{}";
}
return content;
return JSONUtil.toJsonStr(Map.of("result", content, "time", between+"毫秒"));
}
}

@ -1,21 +1,28 @@
package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.date.DateTime;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.hutool.poi.excel.ExcelUtil;
import cn.hutool.poi.excel.ExcelWriter;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.dto.RetrieveReqDTO;
import com.supervision.police.dto.RetrieveResDTO;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.OCRService;
import com.supervision.police.service.RecordSplitClassifyService;
import com.supervision.police.service.*;
import com.supervision.thread.RecordSplitClassifyTask;
import com.supervision.thread.TripleExtractTask;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
@ -30,6 +37,11 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
@ -172,10 +184,115 @@ public class FuHsiApplicationTests {
@Autowired
private OCRService ocrService;
@Test
public void test1() {
public void retrieveTest() {
String question = "银川市公安局金凤区分局\\r\\n拘留通知书\\r\\n副本\\r\\n银金公经侦拘通字[2024]10017号\\r\\n梁玉峰家属\\r\\n根据《中华人民共和国刑事诉讼法》第八十二条第一)项之规定\\r\\n罪\\r\\n将涉嫌合同诈骗\\r\\n我局已于2024年01月16日10时\\r\\n银川市看守所\\r\\n的\\r\\n梁玉峰\\r\\n刑事拘留现羁押在\\r\\n市\\r\\n二0二\\r\\n年名\\r\\n本通知书已收到。\\r\\n年月\\r\\n日时\\r\\n被拘留人家属\\r\\n如未在拘留后24小时内通知被拘留人家属注明原因办集民警\\r\\n当天拔打梁天峰亲属果无辉毛话对方电生无法接通\\r\\n办案人\\r\\n2024年1月16日1时\\r\\n此联附卷";
RetrieveResDTO retrieve = ocrService.retrieve(new RetrieveReqDTO(question));
System.out.println(retrieve);
}
@Autowired
private NotePromptService notePromptService;
@Test
public void promptTest() throws InterruptedException {
List<NotePrompt> listPromptBySplitId = notePromptService.listPromptBySplitId("1825358898516275202");
ArrayList<NotePrompt> notePrompts = new ArrayList<>();
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
notePrompts.addAll(listPromptBySplitId);
NoteRecordSplit noteRecordSplit = new NoteRecordSplit();
/*noteRecordSplit.setQuestion("办案警官问:你让马旭东帮你买车你们是如何协商的?");
noteRecordSplit.setAnswer("裴金禄回答:我和马旭东认识十年时间了,我们关系比较好。我找到马旭东告诉他我有车要出售,这样要求把帮忙出售,我不出面。出售完毕之后我再给马旭东支付一些好处费用");*/
noteRecordSplit.setQuestion("办案警官问:你和马旭东是如何协商好处费的?");
noteRecordSplit.setAnswer("裴金禄回答:我没有协商过,我都是随心意给马旭东给的,");
String mainActor = "裴金禄";
List<TripleRecord> tripleRecords = new ArrayList<>();
AtomicLong atomicLong = new AtomicLong();
CountDownLatch countDownLatch = new CountDownLatch(notePrompts.size());
ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 10000, 100L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
DateTime start = DateUtil.date();
for (NotePrompt notePrompt : notePrompts) {
executor.submit(()->{
TripleRecord tripleRecord = chat4Triple(notePrompt, noteRecordSplit, mainActor,null);
atomicLong.getAndAdd(tripleRecord.getTime());
countDownLatch.countDown();
tripleRecords.add(tripleRecord);
});
}
countDownLatch.await();
long between = DateUtil.between(start, DateUtil.date(), DateUnit.SECOND);
log.info("实际总耗时:" + between);
ExcelWriter writer = ExcelUtil.getWriter("F:\\tmp\\1\\ollama-模型优化\\50服务器\\prompt.xlsx","速度对比-并发1-qwen2.5-32b");
writer.write(tripleRecords);
writer.close();
log.info("所有线程总耗时:" + atomicLong.get());
}
private TripleRecord chat4Triple(NotePrompt prompt, NoteRecordSplit noteRecordSplit,String mainActor,String type) {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", prompt.getStartEntityType());
paramMap.put("relation", prompt.getRelType());
paramMap.put("tailEntityType", prompt.getEndEntityType());
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
//log.info("开始尝试提取三元组:{}-{}-{},mainActor:{}", prompt.getStartEntityType(), prompt.getRelType(), prompt.getEndEntityType(), mainActor == null ? "" : mainActor);
if (mainActor != null && "行为人".equals(prompt.getStartEntityType())) {
paramMap.put("requirement", "当前案件的行为人是" + mainActor + ",只尝试提取" + mainActor + "为头结点的三元组。");
} else {
paramMap.put("requirement", "");
}
String format = StrUtil.format(prompt.getPrompt(), paramMap);
//log.info("提示词内容:{}", format);
log.info("开始执行:chatClient.call 开始");
ChatResponse call = chatClient.call(new Prompt(new UserMessage(format)));
log.info("开始执行:chatClient.call 结束");
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
//log.info("问题:{}耗时:{},三元组提取结果是:{}", noteRecordSplit.getQuestion(), stopWatch.getTotalTimeSeconds(), content);
return new TripleRecord(noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type,
prompt.getId(),prompt.getPrompt(),format,
content, stopWatch.getTotalTimeMillis());
}
@Data
@AllArgsConstructor
class TripleRecord{
private String question;
private String answer;
private String type;
private String templateId;
private String template;
private String prompt;
private String result;
private long time;
}
}

Loading…
Cancel
Save