|
|
package com.supervision.thread;
|
|
|
|
|
|
import cn.hutool.core.util.ObjectUtil;
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
|
import cn.hutool.json.JSONUtil;
|
|
|
import com.supervision.police.domain.NotePrompt;
|
|
|
import com.supervision.police.domain.TripleInfo;
|
|
|
import lombok.Data;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.ChatResponse;
|
|
|
import org.springframework.ai.chat.messages.UserMessage;
|
|
|
import org.springframework.ai.chat.prompt.Prompt;
|
|
|
import org.springframework.ai.ollama.OllamaChatClient;
|
|
|
import org.springframework.util.StopWatch;
|
|
|
|
|
|
import java.time.LocalDateTime;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.concurrent.Callable;
|
|
|
|
|
|
@Slf4j
|
|
|
public class TripleExtractThread implements Callable<TripleInfo> {
|
|
|
|
|
|
private final OllamaChatClient chatClient;
|
|
|
|
|
|
private final NotePrompt prompt;
|
|
|
|
|
|
private final String question;
|
|
|
|
|
|
private final String answer;
|
|
|
|
|
|
private final String recordSplitId;
|
|
|
|
|
|
private final String caseId;
|
|
|
|
|
|
private final String recordId;
|
|
|
|
|
|
|
|
|
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
|
|
|
NotePrompt prompt, String question, String answer) {
|
|
|
this.question = question;
|
|
|
this.chatClient = chatClient;
|
|
|
this.answer = answer;
|
|
|
this.prompt = prompt;
|
|
|
this.recordSplitId = recordSplitId;
|
|
|
this.caseId = caseId;
|
|
|
this.recordId = recordId;
|
|
|
}
|
|
|
|
|
|
|
|
|
/**
|
|
|
* 三元组提取任务:从给定对话中根据给定实体类型和关系提取对应关系的三元组。
|
|
|
* 给定的头实体类型为"{headEntityType}";给定的尾实体类型为"{tailEntityType}",给定的关系为"{relation}"。
|
|
|
* 请仔细分析以下的文本内容,精准找出符合给定关系且头尾实体类型相符的三元组,并进行提取。如果没有识别给定的三元组关系,请返回json:{"result":[]}。
|
|
|
* ---
|
|
|
* 为您提供一个示例供学习:
|
|
|
* 给定三元组类型为:头实体类型:"行为人"关系:"伪造",尾实体类型:"合同"
|
|
|
* 办案警官问:描述一下事情的经过。 行为人小明答:我做了一份假的购房合同。
|
|
|
* 本示例中应提取给定关系为"伪造"的三元组,则最终应提取的三元组为{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]}。
|
|
|
* ---
|
|
|
* 需要分析提取的QA对如下:
|
|
|
* {question}
|
|
|
* {answer}
|
|
|
* ---
|
|
|
* 在提取三元组时,请务必严格遵循以下要求:
|
|
|
* 1. 精准理解需要分析的QA文本的含义,确保提取的信息准确无误、合理恰当。
|
|
|
* 2. 只提取给定的实体类型和关系,不要提取给定关系和实体之外的三元组。
|
|
|
* 3. 尽量遵循常见的语义和逻辑规则,杜绝过度解读或不合理的关系推断。
|
|
|
* 4. 例子仅供参考,不要简单地返回示例中的结果。
|
|
|
* 5. 提取之后,再检查一遍,提取的关系和实体是否与给定关系和实体类型对应
|
|
|
* 返回格式为必须为以下的json格式:
|
|
|
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
|
|
|
*/
|
|
|
@Override
|
|
|
public TripleInfo call() {
|
|
|
try {
|
|
|
StopWatch stopWatch = new StopWatch();
|
|
|
// 分析三元组
|
|
|
stopWatch.start();
|
|
|
HashMap<String, String> paramMap = new HashMap<>();
|
|
|
paramMap.put("headEntityType", prompt.getStartEntityType());
|
|
|
paramMap.put("relation", prompt.getRelType());
|
|
|
paramMap.put("tailEntityType", prompt.getEndEntityType());
|
|
|
paramMap.put("question", question);
|
|
|
paramMap.put("answer", answer);
|
|
|
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
|
|
|
ChatResponse call = chatClient.call(ask);
|
|
|
stopWatch.stop();
|
|
|
String content = call.getResult().getOutput().getContent();
|
|
|
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
|
|
|
// 获取从提示词中提取到的三元组信息
|
|
|
TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class);
|
|
|
if (ObjectUtil.isEmpty(extractResult) || extractResult.result.isEmpty()) {
|
|
|
log.info("提取三元组信息为空,忽略");
|
|
|
return null;
|
|
|
}
|
|
|
for (TripleExtractNode tripleExtractNode : extractResult.getResult()) {
|
|
|
TripleEntity headEntity = tripleExtractNode.getHeadEntity();
|
|
|
TripleEntity tailEntity = tripleExtractNode.getTailEntity();
|
|
|
String relation = tripleExtractNode.getRelation();
|
|
|
if (StrUtil.hasEmpty(headEntity.getName(), relation, tailEntity.getName())) {
|
|
|
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", headEntity.getName(), relation, tailEntity.getName());
|
|
|
return null;
|
|
|
}
|
|
|
// 构建三元组信息
|
|
|
TripleInfo tripleInfo = new TripleInfo();
|
|
|
tripleInfo.setStartNode(headEntity.getName());
|
|
|
tripleInfo.setEndNode(tailEntity.getName());
|
|
|
tripleInfo.setRelation(relation);
|
|
|
tripleInfo.setCaseId(caseId);
|
|
|
tripleInfo.setRecordId(recordId);
|
|
|
tripleInfo.setRecordSplitId(recordSplitId);
|
|
|
tripleInfo.setStartNodeType(prompt.getStartEntityType());
|
|
|
tripleInfo.setEndNodeType(prompt.getEndEntityType());
|
|
|
return tripleInfo;
|
|
|
}
|
|
|
} catch (Exception e) {
|
|
|
log.error("提取三元组出现错误", e);
|
|
|
}
|
|
|
return null;
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
public static class TripleExtractResult {
|
|
|
private List<TripleExtractNode> result;
|
|
|
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
public static class TripleExtractNode {
|
|
|
private TripleEntity headEntity;
|
|
|
private String relation;
|
|
|
private TripleEntity tailEntity;
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
@Data
|
|
|
public static class TripleEntity {
|
|
|
private String name;
|
|
|
private String type;
|
|
|
}
|
|
|
|
|
|
|
|
|
}
|