You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
fu-hsi-service/src/main/java/com/supervision/thread/TripleExtractThread.java

102 lines
3.9 KiB
Java

10 months ago
package com.supervision.thread;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.police.domain.NotePrompt;
10 months ago
import com.supervision.police.domain.TripleInfo;
import com.supervision.springaidemo.domain.ModelMetric;
import com.supervision.springaidemo.domain.NoteCheckRecord;
import com.supervision.springaidemo.dto.MetricResultDTO;
import com.supervision.police.service.NoteCheckRecordService;
import lombok.extern.slf4j.Slf4j;
import org.json.JSONArray;
import org.json.JSONObject;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.util.StopWatch;
import java.time.LocalDateTime;
import java.util.HashMap;
10 months ago
import java.util.concurrent.Callable;
@Slf4j
public class TripleExtractThread implements Callable<TripleInfo> {
private final OllamaChatClient chatClient;
private final NotePrompt prompt;
10 months ago
private final String question;
private final String answer;
private final String recordSplitId;
private final String caseId;
private final String recordId;
public TripleExtractThread(OllamaChatClient chatClient, String caseId, String recordId, String recordSplitId,
NotePrompt prompt, String question, String answer) {
10 months ago
this.question = question;
this.chatClient = chatClient;
this.answer = answer;
this.prompt = prompt;
this.recordSplitId = recordSplitId;
this.caseId = caseId;
this.recordId = recordId;
}
@Override
public TripleInfo call() {
try {
StopWatch stopWatch = new StopWatch();
// 分析三元组
stopWatch.start();
HashMap<String, String> paramMap = new HashMap<>();
paramMap.put("qaRecord", question + answer);
Prompt ask = new Prompt(new UserMessage(StrUtil.format(prompt.getPrompt(), paramMap)));
10 months ago
log.info("开始分析:");
ChatResponse call = chatClient.call(ask);
stopWatch.stop();
log.info("耗时:{}", stopWatch.getTotalTimeSeconds());
String content = call.getResult().getOutput().getContent();
log.info("分析的结果是:{}", content);
// 获取从提示词中提取到的三元组信息
JSONObject jsonObject = new JSONObject(content);
// 修改,经测试,一次提取多个三元组效果较差,改成一次只提取一个三元组
//JSONArray threeInfo = jsonObject.getJSONArray("result");
//for (int i = 0; i < threeInfo.length(); i++) {
//JSONObject object = threeInfo.getJSONObject(i);
String entity = jsonObject.getString("主体");
String relation = jsonObject.getString("关系");
String value = jsonObject.getString("客体");
// 类型信息从notePrompt对象中获取
// String startNodeType = object.getString("startNodeType");
// String endNodeType = object.getString("endNodeType");
// 去空,如果存在任何的空值,则忽略
// if (StrUtil.hasEmpty(startNodeType, entity, endNodeType, property, value)) {
// continue;
// }
if (StrUtil.hasEmpty(entity, relation, value)) {
log.info("提取三元组信息出现空值,忽略,主体:{},关系:{},客体:{}", entity, relation, value);
return null;
10 months ago
}
// 构建三元组信息
return new TripleInfo(entity, relation, value, caseId, recordId, recordSplitId, LocalDateTime.now(), prompt.getStartEntityType(), prompt.getEndEntityType());
//}
10 months ago
} catch (Exception e) {
log.error("提取三元组出现错误", e);
}
return null;
}
10 months ago
}