You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/test/java/com/supervision/demo/FuHsiApplicationTests.java

169 lines
7.4 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.police.service.RecordSplitClassifyService;
import com.supervision.thread.RecordSplitClassifyTask;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class FuHsiApplicationTests {
@Autowired
private OllamaChatClient ollamaChatClient;
@Autowired
private ExampleChatController exampleChatController;
@Autowired
private Neo4jController neo4jController;
@Test
public void contextLoads() {
}
@Test
public void test() {
// exampleChatController.test("1803663875373694977", "你现在是一个笔录分析人员,请用四个字描述一下下述内容属于哪种类型的对话?");
}
@Test
public void savePersion() {
CaseNode caseNode = new CaseNode();
caseNode.setName("自然人");
R<?> save = neo4jController.save(caseNode);
System.out.printf(save.toString());
}
private final OllamaChatClient chatClient;
@Autowired
public FuHsiApplicationTests(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Test
public void classificationTest() {
String recordId = "1824329325387304962";
final String NEW_TEMPLATE = """
分类任务: 将对话笔录文本进行分类。
目标: 将给定的对话分配到预定义的类别中
预定义类别为:{typeContext}。
说明: 提供一段对话笔记录文本,分类器应该识别出对话的主题,并将其归类到上述类别中的一个或多个。
---
示例输入:
办案警官问:你和上述的这些受害人签订协议之后是否实际履行合同?
行为人XXX回答我和他们签订合同就是为了骗他们相信我我就是伪造的一些假合同等我把他们的钱骗到手之后我就不会履行合同。
预期输出: {"result":[{"type":"","explain":"XXX"},{"type":"","explain":"XXX"}]}
---
任务要求:
1. 分类器应当准确地识别对话的主题,分类来自于预定义的类别。
2. 分类器应该实事求是按照对话进行分类,不要有过多的推测。
2. 如果一段对话笔记录包含多个主题,请选择最相关的类别,最多可选择三个分类。
3. 如果不涉及任何分类则回复{"result":[]}
---
以下为问答对内容:
{question}
{answer}
---
返回格式为json,字段名要严格一致:{"result":[{"type":"1","explain":""},{"type":"2","explain":""}]}
""";
final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
String type = "";
// 根据recordId查询所有的分割后的笔录
List<NoteRecordSplit> noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list();
for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) {
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer());
log.info("分析的结果是:{}", content);
RecordSplitClassifyTask.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitClassifyTask.TypeResultDTO.class);
List<RecordSplitClassifyTask.TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(RecordSplitClassifyTask.TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
log.info("最后结果 question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type);
}
}
@Autowired
private RecordSplitClassifyService recordSplitClassifyService;
@Test
public void classificationTest2() {
List<ModelRecordType> typeList = modelRecordTypeService.lambdaQuery().list();
List<NoteRecordSplit> noteRecordSplits = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getId, "1824729361214418946").list();
recordSplitClassifyService.classify(typeList,noteRecordSplits);
log.info("分类结果:{}",noteRecordSplits);
}
}