You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/main/java/com/supervision/thread/TripleExtractThread.java

146 lines
6.2 KiB
Java

10 months ago
package com.supervision.thread;
import cn.hutool.core.util.ObjectUtil;
10 months ago
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
10 months ago
import com.supervision.police.domain.TripleInfo;
import lombok.Data;
10 months ago
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.List;
10 months ago
import java.util.concurrent.Callable;
@Slf4j
public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final NotePrompt prompt;
10 months ago
private final String question;
private final String answer;
private final String recordSplitId;
private final String caseId;
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
NotePrompt prompt, String question, String answer) {
10 months ago
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
this.prompt = prompt;
this.recordSplitId = recordSplitId;
this.caseId = caseId;
this.recordId = recordId;
}
/**
* :
* "{headEntityType}";"{tailEntityType}","{relation}"
* json:{"result":[]}
* ---
* :
* ::"行为人":"伪造",:"合同"
* : :
* "伪造",{"result":[{"headEntity": {"type": "行为人","name":"小明"},"relation": "伪造","tailEntity": {"type": "合同","name": "假的购房合同"}}]}
* ---
* QA
* {question}
* {answer}
* ---
*
* 1. QA
* 2.
* 3.
* 4. ,
* 5. ,,
* json:
* {"result":[{"headEntity": {"type": "{headEntityType}","name":"提取到的头实体内容1"},"relation": "{relation}","tailEntity": {"type": "{tailEntityType}","name": "提取到的尾实体内容1"}}]}
*/
10 months ago
@Override
public TripleInfo call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("headEntityType", prompt.getStartEntityType());
paramMap.put("relation", prompt.getRelType());
paramMap.put("tailEntityType", prompt.getEndEntityType());
paramMap.put("question", question);
paramMap.put("answer", answer);
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
10 months ago
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
String content = call.getResult().getOutput().getContent();
log.info("耗时:{},分析的结果是:{}", stopWatch.getTotalTimeSeconds(), content);
10 months ago
// 获取从提示词中提取到的三元组信息
TripleExtractResult extractResult = JSONUtil.toBean(content, TripleExtractResult.class);
if (ObjectUtil.isEmpty(extractResult) || extractResult.result.isEmpty()) {
log.info("提取三元组信息为空,忽略");
return null;
10 months ago
}
for (TripleExtractNode tripleExtractNode : extractResult.getResult()) {
TripleEntity headEntity = tripleExtractNode.getHeadEntity();
TripleEntity tailEntity = tripleExtractNode.getTailEntity();
String relation = tripleExtractNode.getRelation();
if (StrUtil.hasEmpty(headEntity.getName(), relation, tailEntity.getName())) {
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", headEntity.getName(), relation, tailEntity.getName());
return null;
}
// 构建三元组信息
TripleInfo tripleInfo = new TripleInfo();
tripleInfo.setStartNode(headEntity.getName());
tripleInfo.setEndNode(tailEntity.getName());
tripleInfo.setRelation(relation);
tripleInfo.setCaseId(caseId);
tripleInfo.setRecordId(recordId);
tripleInfo.setRecordSplitId(recordSplitId);
tripleInfo.setStartNodeType(prompt.getStartEntityType());
tripleInfo.setEndNodeType(prompt.getEndEntityType());
return tripleInfo;
}
10 months ago
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
@Data
public static class TripleExtractResult {
private List<TripleExtractNode> result;
}
@Data
public static class TripleExtractNode {
private TripleEntity headEntity;
private String relation;
private TripleEntity tailEntity;
10 months ago
}
@Data
public static class TripleEntity {
private String name;
private String type;
}
10 months ago
}