diff --git a/src/test/java/com/supervision/demo/FuHsiApplicationTests.java b/src/test/java/com/supervision/demo/FuHsiApplicationTests.java index dc99ac1..40d02e6 100644 --- a/src/test/java/com/supervision/demo/FuHsiApplicationTests.java +++ b/src/test/java/com/supervision/demo/FuHsiApplicationTests.java @@ -1,15 +1,21 @@ package com.supervision.demo; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.util.StrUtil; +import cn.hutool.json.JSONUtil; import com.supervision.common.domain.R; import com.supervision.neo4j.controller.Neo4jController; import com.supervision.neo4j.domain.CaseNode; import com.supervision.demo.controller.ExampleChatController; +import com.supervision.police.domain.ModelRecordType; +import com.supervision.police.domain.NoteRecordSplit; +import com.supervision.police.service.ModelRecordTypeService; +import com.supervision.police.service.NoteRecordSplitService; +import com.supervision.thread.RecordSplitTypeThread; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.springframework.ai.chat.ChatResponse; -import org.springframework.ai.chat.Generation; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.ollama.OllamaChatClient; import org.springframework.beans.factory.annotation.Autowired; @@ -17,7 +23,11 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.util.StopWatch; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + @Slf4j @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @@ -56,19 +66,89 @@ public class FuHsiApplicationTests { this.chatClient = chatClient; } + @Autowired + private NoteRecordSplitService noteRecordSplitService; + @Autowired + private ModelRecordTypeService modelRecordTypeService; + + @Test - public void aaaaa() { - List messages = new ArrayList<>(List.of(new SystemMessage("你是谁"))); - Prompt prompt = new Prompt(messages); - StopWatch stopWatch = new StopWatch(); - stopWatch.start(); - log.info("开始分析:"); - ChatResponse call = chatClient.call(prompt); - stopWatch.stop(); - log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); - Generation result = call.getResult(); - - String content = result.getOutput().getContent(); - log.info("分析的结果是:{}", content); + public void classificationTest() { + + String recordId = "1824320931821637633"; + + final String NEW_TEMPLATE = """ + 分类任务: 将对话笔录文本进行分类。 + 目标: 将给定的对话分配到预定义的类别中 + 预定义类别为:{typeContext}。 + 说明: 提供一段对话笔记录文本,分类器应该识别出对话的主题,并将其归类到上述类别中的一个或多个。 + --- + 示例输入: + 办案警官问:你和上述的这些受害人签订协议之后是否实际履行合同? + 行为人XXX回答:我和他们签订合同就是为了骗他们相信我,我就是伪造的一些假合同,等我把他们的钱骗到手之后我就不会履行合同。 + 预期输出: {"result":[{"type":"合同和协议","explain":"行为人XXX提到签订合同"},{"type":"虚假信息和伪造","explain":"行为人XXX提到合同是假合同"}]} + --- + 任务要求: + 1. 分类器应当准确地识别对话的主题,分类来自于预定义的类别。 + 2. 分类器应该实事求是按照对话进行分类,不要有过多的推测。 + 2. 如果一段对话笔记录包含多个主题,请选择最相关的类别,最多可选择三个分类。 + 3. 如果不涉及任何分类则回复{"result":[]} + --- + 以下为问答对内容: + {question} + {answer} + --- + 返回格式为json,字段名要严格一致:{"result":[{"type":"分类1","explain":"分类原因"},{"type":"分类2","explain":"分类原因"}]} + """; + + final String TYPE_CONTEXT_TEMPLATE = "{分类type:{type},区别点(分类释义):{typeExt}}"; + + List allTypeList = modelRecordTypeService.lambdaQuery().list(); + + String type = ""; + + // 根据recordId查询所有的分割后的笔录 + List noteRecordSplitList = noteRecordSplitService.lambdaQuery().eq(NoteRecordSplit::getNoteRecordId, recordId).list(); + for (NoteRecordSplit noteRecordSplit : noteRecordSplitList) { + try { + StopWatch stopWatch = new StopWatch(); + // 首先拼接分类模板 + List typeContextList = new ArrayList<>(); + for (ModelRecordType modelRecordType : allTypeList) { + String format = StrUtil.format(TYPE_CONTEXT_TEMPLATE, Map.of("type", modelRecordType.getRecordType(), "typeExt", modelRecordType.getRecordTypeExt())); + typeContextList.add(format); + } + // 开始对笔录进行分类 + Map paramMap = new HashMap<>(); + paramMap.put("typeContext", CollUtil.join(typeContextList, ";")); + paramMap.put("question", noteRecordSplit.getQuestion()); + paramMap.put("answer", noteRecordSplit.getAnswer()); + Prompt prompt = new Prompt(new UserMessage(StrUtil.format(NEW_TEMPLATE, paramMap))); + stopWatch.start(); + log.info("开始分析:"); + ChatResponse call = chatClient.call(prompt); + stopWatch.stop(); + log.info("耗时:{}", stopWatch.getTotalTimeSeconds()); + String content = call.getResult().getOutput().getContent(); + log.info("问:{}, 答:{}", noteRecordSplit.getQuestion(), noteRecordSplit.getAnswer()); + log.info("分析的结果是:{}", content); + RecordSplitTypeThread.TypeResultDTO result = JSONUtil.toBean(content, RecordSplitTypeThread.TypeResultDTO.class); + List typeList = result.getResult(); + if (CollUtil.isNotEmpty(typeList)) { + // 将type进行拼接,并以分号进行分割 + type = CollUtil.join(typeList.stream().map(RecordSplitTypeThread.TypeNodeDTO::getType).collect(Collectors.toSet()), ";"); + } else { + // 如果没有提取到,就是无 + type = "无"; + } + + } catch (Exception e) { + log.error("分类任务执行失败:{}", e.getMessage(), e); + type = "无"; + } + log.info("question:{},answer:{},分析的结果是:{}", noteRecordSplit.getQuestion(),noteRecordSplit.getAnswer(),type); + } + + } } diff --git a/src/test/java/com/supervision/demo/RecordSplitTest.java b/src/test/java/com/supervision/demo/RecordSplitTest.java new file mode 100644 index 0000000..747054b --- /dev/null +++ b/src/test/java/com/supervision/demo/RecordSplitTest.java @@ -0,0 +1,30 @@ +package com.supervision.demo; + +import com.supervision.demo.dto.QARecordNodeDTO; +import com.supervision.utils.RecordRegexUtil; +import com.supervision.utils.WordReadUtil; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.util.List; + +public class RecordSplitTest { + + public static void main(String[] args) throws FileNotFoundException { + + FileInputStream inputStream = new FileInputStream("F:\\supervision\\doc\\宁夏公安\\行为人第四次(2).docx"); + + String context = WordReadUtil.readWord(inputStream); + + System.out.println("context"); + System.out.println(context); + + System.out.println("end "); + List qaList = RecordRegexUtil.recordRegex(context, "裴金禄"); + + for (QARecordNodeDTO qaRecordNodeDTO : qaList) { + System.out.println(qaRecordNodeDTO.getQuestion()); + System.out.println(qaRecordNodeDTO.getAnswer()); + } + } +}