fu-hsi-service/src/test/java/com/supervision/demo/FuHsiApplicationTests.java

182 lines
8.5 KiB
Java

package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.dto.RetrieveReqDTO;
import com.supervision.police.dto.RetrieveResDTO;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.OCRService;
import com.supervision.police.service.RecordSplitClassifyService;
import com.supervision.thread.RecordSplitClassifyTask;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class FuHsiApplicationTests {
@Autowired
private OllamaChatClient ollamaChatClient;
@Autowired
private ExampleChatController exampleChatController;
@Autowired
private Neo4jController neo4jController;
@Test
public void contextLoads() {
}
@Test
public void test() {
// exampleChatController.test("1803663875373694977", "你现在是一个笔录分析人员,请用四个字描述一下下述内容属于哪种类型的对话?");
}
@Test
public void savePersion() {
CaseNode caseNode = new CaseNode();
caseNode.setName("自然人");
R<?> save = neo4jController.save(caseNode);
System.out.printf(save.toString());
}
private final OllamaChatClient chatClient;
@Autowired
public FuHsiApplicationTests(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Test
public void classificationTest() {
String recordId = "1824329325387304962";
final String NEW_TEMPLATE = """
:
:
{typeContext}
:
---
:
XXX
: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]}
---
:
1. ,
2. ,
2. ,
3. {"result":[]}
---
:
{question}
{answer}
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
String type = "";
// 根据recordId查询所有的分割后的笔录
List<NoteRecordSplit> noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list();
for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) {
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer());
log.info("分析的结果是:{}", content);
RecordSplitClassifyTask.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitClassifyTask.TypeResultDTO.class);
List<RecordSplitClassifyTask.TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(RecordSplitClassifyTask.TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
log.info("最后结果 question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type);
}
}
@Autowired
private RecordSplitClassifyService recordSplitClassifyService;
@Test
public void classificationTest2() {
List<ModelRecordType> typeList = modelRecordTypeService.lambdaQuery().list();
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getId, "1824729361214418946").list();
recordSplitClassifyService.classify(typeList,noteRecordSplits);
log.info("分类结果:{}",noteRecordSplits);
}
@Autowired
private OCRService ocrService;
@Test
public void test1() {
String question = "银川市公安局金凤区分局\\r\\n拘留通知书\\r\\n副本\\r\\n银金公经侦拘通字[2024]10017号\\r\\n梁玉峰家属\\r\\n根据《中华人民共和国刑事诉讼法》第八十二条第一)项之规定\\r\\n罪\\r\\n将涉嫌合同诈骗\\r\\n我局已于2024年01月16日10时\\r\\n银川市看守所\\r\\n的\\r\\n梁玉峰\\r\\n刑事拘留现羁押在\\r\\n市\\r\\n二0二\\r\\n年名\\r\\n本通知书已收到。\\r\\n年月\\r\\n日时\\r\\n被拘留人家属\\r\\n如未在拘留后24小时内通知被拘留人家属注明原因办集民警\\r\\n当天拔打梁天峰亲属果无辉毛话对方电生无法接通\\r\\n办案人\\r\\n2024年1月16日1时\\r\\n此联附卷";
RetrieveResDTO retrieve = ocrService.retrieve(new RetrieveReqDTO(question));
System.out.println(retrieve);
}
}