fu-hsi-service/src/test/java/com/supervision/demo/FuHsiApplicationTests.java

155 lines
6.8 KiB
Java

package com.supervision.demo;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.common.domain.R;
import com.supervision.neo4j.controller.Neo4jController;
import com.supervision.neo4j.domain.CaseNode;
import com.supervision.demo.controller.ExampleChatController;
import com.supervision.police.domain.ModelRecordType;
import com.supervision.police.domain.NoteRecordSplit;
import com.supervision.police.service.ModelRecordTypeService;
import com.supervision.police.service.NoteRecordSplitService;
import com.supervision.thread.RecordSplitTypeThread;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.util.StopWatch;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class FuHsiApplicationTests {
@Autowired
private OllamaChatClient ollamaChatClient;
@Autowired
private ExampleChatController exampleChatController;
@Autowired
private Neo4jController neo4jController;
@Test
public void contextLoads() {
}
@Test
public void test() {
// exampleChatController.test("1803663875373694977", "你现在是一个笔录分析人员,请用四个字描述一下下述内容属于哪种类型的对话?");
}
@Test
public void savePersion() {
CaseNode caseNode = new CaseNode();
caseNode.setName("自然人");
R<?> save = neo4jController.save(caseNode);
System.out.printf(save.toString());
}
private final OllamaChatClient chatClient;
@Autowired
public FuHsiApplicationTests(OllamaChatClient chatClient) {
this.chatClient = chatClient;
}
@Autowired
private NoteRecordSplitService noteRecordSplitService;
@Autowired
private ModelRecordTypeService modelRecordTypeService;
@Test
public void classificationTest() {
String recordId = "1824320931821637633";
final String NEW_TEMPLATE = """
:
:
{typeContext}
:
---
:
XXX
: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]}
---
:
1. ,
2. ,
2. ,
3. {"result":[]}
---
:
{question}
{answer}
---
json,:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]}
""";
final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}";
List<ModelRecordType> allTypeList = modelRecordTypeService.lambdaQuery().list();
String type = "";
// 根据recordId查询所有的分割后的笔录
List<NoteRecordSplit> noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list();
for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) {
try {
StopWatch stopWatch = new StopWatch();
// 首先拼接分类模板
List<String> typeContextList = new ArrayList<>();
for (ModelRecordType modelRecordType : allTypeList) {
String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt()));
typeContextList.add(format);
}
// 开始对笔录进行分类
Map<String, String> paramMap = new HashMap<>();
paramMap.put("typeContext", CollUtil.join(typeContextList, ";"));
paramMap.put("question", noteRecordSplit.getQuestion());
paramMap.put("answer", noteRecordSplit.getAnswer());
Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap)));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(prompt);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer());
log.info("分析的结果是:{}", content);
RecordSplitTypeThread.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitTypeThread.TypeResultDTO.class);
List<RecordSplitTypeThread.TypeNodeDTO> typeList = result.getResult();
if (CollUtil.isNotEmpty(typeList)) {
// 将type进行拼接,并以分号进行分割
type = CollUtil.join(typeList.stream().map(RecordSplitTypeThread.TypeNodeDTO::getType).collect(Collectors.toSet()), ";");
} else {
// 如果没有提取到,就是无
type = "无";
}
} catch (Exception e) {
log.error("分类任务执行失败:{}", e.getMessage(), e);
type = "无";
}
log.info("question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type);
}
}
}