You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/main/java/com/supervision/thread/TripleExtractThread.java

89 lines
3.3 KiB
Java

10 months ago
package com.supervision.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.TripleInfo;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.police.service.NoteCheckRecordService;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.concurrent.Callable;
@Slf4j
public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final String prompt;
private final String question;
private final String answer;
private final String recordSplitId;
private final String caseId;
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId, String prompt, String question, String answer) {
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
this.prompt = prompt;
this.recordSplitId = recordSplitId;
this.caseId = caseId;
this.recordId = recordId;
}
@Override
public TripleInfo call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
Prompt ask = new Prompt(new UserMessage(prompt + question + answer));
stopWatch.start();
log.info("开始分析:");
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content);
JSONArray threeInfo = jsonObject.getJSONArray("result");
for (int i = 0; i < threeInfo.length(); i++) {
JSONObject object = threeInfo.getJSONObject(i);
String startNodeType = object.getString("startNodeType");
String entity = object.getString("entity");
String endNodeType = object.getString("endNodeType");
String property = object.getString("property");
String value = object.getString("value");
// 去空,如果存在任何的空值,则忽略
if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
continue;
}
// 构建三元组信息
return new TripleInfo(entity, property, value, caseId, recordId, recordSplitId, LocalDateTime.now(), startNodeType, endNodeType);
}
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
}