fu-hsi-service/src/main/java/com/supervision/thread/TripleExtractThread.java

146 lines
6.2 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.thread;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
import com.supervision.police.domain.TripleInfo;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Callable;
@Slf4j
public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final NotePrompt prompt;
private final String question;
private final String answer;
private final String recordSplitId;
private final String caseId;
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
NotePrompt prompt, String question, String answer) {
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
this.prompt = prompt;
this.recordSplitId = recordSplitId;
this.caseId = caseId;
this.recordId = recordId;
}
/**
* 三元组提取任务:从给定对话中根据给定实体类型和关系提取对应关系的三元组。
* 给定的头实体类型为"{headEntityType}";给定的尾实体类型为"{tailEntityType}",给定的关系为"{relation}"。
* 请仔细分析以下的文本内容精准找出符合给定关系且头尾实体类型相符的三元组并进行提取。如果没有识别给定的三元组关系请返回json:{"result":[]}。
* ---
* 为您提供一个示例供学习:
* 给定三元组类型为:头实体类型:"行为人"关系:"伪造",尾实体类型:"合同"
* 办案警官问:描述一下事情的经过。 行为人小明答:我做了一份假的购房合同。
* 本示例中应提取给定关系为"伪造"的三元组,则最终应提取的三元组为{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]}。
* ---
* 需要分析提取的QA对如下
* {question}
* {answer}
* ---
* 在提取三元组时,请务必严格遵循以下要求:
* 1. 精准理解需要分析的QA文本的含义确保提取的信息准确无误、合理恰当。
* 2. 只提取给定的实体类型和关系,不要提取给定关系和实体之外的三元组。
* 3. 尽量遵循常见的语义和逻辑规则,杜绝过度解读或不合理的关系推断。
* 4. 例子仅供参考,不要简单地返回示例中的结果。
* 5. 提取之后,再检查一遍,提取的关系和实体是否与给定关系和实体类型对应
* 返回格式为必须为以下的json格式:
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
*/
@Override
public TripleInfo call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", prompt.getStartEntityType());
paramMap.put("relation", prompt.getRelType());
paramMap.put("tailEntityType", prompt.getEndEntityType());
paramMap.put("question", question);
paramMap.put("answer", answer);
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
// 获取从提示词中提取到的三元组信息
TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class);
if (ObjectUtil.isEmpty(extractResult) || extractResult.result.isEmpty()) {
log.info("提取三元组信息为空,忽略");
return null;
}
for (TripleExtractNode tripleExtractNode : extractResult.getResult()) {
TripleEntity headEntity = tripleExtractNode.getHeadEntity();
TripleEntity tailEntity = tripleExtractNode.getTailEntity();
String relation = tripleExtractNode.getRelation();
if (StrUtil.hasEmpty(headEntity.getName(), relation, tailEntity.getName())) {
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", headEntity.getName(), relation, tailEntity.getName());
return null;
}
// 构建三元组信息
TripleInfo tripleInfo = new TripleInfo();
tripleInfo.setStartNode(headEntity.getName());
tripleInfo.setEndNode(tailEntity.getName());
tripleInfo.setRelation(relation);
tripleInfo.setCaseId(caseId);
tripleInfo.setRecordId(recordId);
tripleInfo.setRecordSplitId(recordSplitId);
tripleInfo.setStartNodeType(prompt.getStartEntityType());
tripleInfo.setEndNodeType(prompt.getEndEntityType());
return tripleInfo;
}
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
@Data
public static class TripleExtractResult {
private List<TripleExtractNode> result;
}
@Data
public static class TripleExtractNode {
private TripleEntity headEntity;
private String relation;
private TripleEntity tailEntity;
}
@Data
public static class TripleEntity {
private String name;
private String type;
}
}