1. 添加测试用例

topo_dev
xueqingkun 10 months ago
parent d7e6d72153
commit 5011fa7e4d

@ -1,15 +1,21 @@
package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.thread.RecordSplitTypeThread;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
@ -17,7 +23,11 @@ import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@ -56,19 +66,89 @@ public class FuHsiApplicationTests {
this.chatClient = chatClient;
}
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Test
public void aaaaa() {
List<Message> messages = new ArrayList<>(List.of(new SystemMessage("你是谁")));
Prompt prompt = new Prompt(messages);
StopWatch stopWatch = new StopWatch();
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
Generation result = call.getResult();
String content = result.getOutput().getContent();
log.info("分析的结果是:{}", content);
public void classificationTest() {
String recordId = "1824320931821637633";
final String NEW_TEMPLATE = """
:
:
{typeContext}
:
---
:
XXX
: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]}
---
:
1. ,
2. ,
2. ,
3. {"result":[]}
---
:
{question}
{answer}
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
String type = "";
// 根据recordId查询所有的分割后的笔录
List<NoteRecordSplit> noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list();
for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) {
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer());
log.info("分析的结果是:{}", content);
RecordSplitTypeThread.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitTypeThread.TypeResultDTO.class);
List<RecordSplitTypeThread.TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(RecordSplitTypeThread.TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
log.info("question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type);
}
}
}

@ -0,0 +1,30 @@
package com.supervision.demo;
import com.supervision.demo.dto.QARecordNodeDTO;
import com.supervision.utils.RecordRegexUtil;
import com.supervision.utils.WordReadUtil;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.List;
public class RecordSplitTest {
public static void main(String[] args) throws FileNotFoundException {
FileInputStream inputStream = new FileInputStream("F:\\supervision\\doc\\宁夏公安\\行为人第四次(2).docx");
String context = WordReadUtil.readWord(inputStream);
System.out.println("context");
System.out.println(context);
System.out.println("end ");
List<QARecordNodeDTO> qaList = RecordRegexUtil.recordRegex(context, "裴金禄");
for (QARecordNodeDTO qaRecordNodeDTO : qaList) {
System.out.println(qaRecordNodeDTO.getQuestion());
System.out.println(qaRecordNodeDTO.getAnswer());
}
}
}
Loading…
Cancel
Save