1.调整ai话术生成逻辑

master
gitee 1 month ago
parent 1961adb6ae
commit 760c69e3a0

@ -4,27 +4,27 @@ package com.supervision.livedigitalavatarmanage.constant;
* *
*/ */
public class PromptTemplate { public class PromptTemplate {
/** /**
* *
*/ */
public static final String GENERATE_SALESPITCH_TEMPLATE = """ public static final String GENERATE_COPYWRITING_POINTS_TEMPLATE = """
1. 1.
2. 500 2. jsonArray["要点1","要点2"]/no_think
3.
4. ********
5. 10****
6. jsonArray["话术1","话术2"]/no_think
"""; """;
private String aa_back = """
广 /**
112X43cm100 216X54cm120 *
便?FAS* FAS) */
public static final String GENERATE_SALESPITCH_TEMPLATE = """
********
·VDESIGNCONCEPTI/no_think
1. ****150
2. ****
2.
3. {num}****
4. jsonArray["话术1","话术2"]/no_think
"""; """;
} }

@ -1,5 +1,6 @@
package com.supervision.livedigitalavatarmanage.service.impl; package com.supervision.livedigitalavatarmanage.service.impl;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert; import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil; import cn.hutool.json.JSONUtil;
@ -14,7 +15,7 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.OllamaChatModel; import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.List; import java.util.*;
@Slf4j @Slf4j
@Service @Service
@ -28,24 +29,159 @@ public class LiveDigitalServiceImpl implements LiveDigitalService {
Assert.notEmpty(salespitchReqVo.getSpecifications(), "产品规格不能为空"); Assert.notEmpty(salespitchReqVo.getSpecifications(), "产品规格不能为空");
Assert.notEmpty(salespitchReqVo.getDetail(), "产品详情不能为空"); Assert.notEmpty(salespitchReqVo.getDetail(), "产品详情不能为空");
for (int i = 0; i < 3; i++) { List<String> points = generateCopywritingPoints(salespitchReqVo);
try { Assert.notEmpty(points, "生成话术失败!请稍后重试");
String salesPitchTemplate = PromptTemplate.GENERATE_SALESPITCH_TEMPLATE; List<Prompt> prompts = generatePrompts(points, salespitchReqVo);
SystemMessage systemMessage = new SystemMessage(salesPitchTemplate); List<String> result = new ArrayList<>();
UserMessage userMessage = new UserMessage(salespitchReqVo.buildTemplate()); for (Prompt prompt : prompts) {
Prompt prompt = new Prompt(systemMessage, userMessage); for (int i = 0; i < 2; i++) {
try {
ChatResponse call = ollamaChatModel.call(prompt); log.info("======================>>>>");
String text = call.getResult().getOutput().getText(); log.info("开始生成销售话术,请求内容:{}", prompt.getContents());
// 去除think标签 ChatResponse call = ollamaChatModel.call(prompt);
if (StrUtil.isNotBlank(text) && text.contains("<think>") && text.contains("</think>")) { String text = call.getResult().getOutput().getText();
text = text.replaceAll("(?s)<think>.*?</think>", "").trim(); // 去除think标签
text = removeThinkTag(text);
log.info("生成销售话术成功,结果:{}", text);
log.info("<<<<======================");
result.addAll(JSONUtil.toList(text, String.class));
break;
}catch (Exception e) {
log.error("生成销售话术失败,尝试第 {} 次。请求内容:{}", i + 1, salespitchReqVo.buildTemplate(),e);
} }
return JSONUtil.toList(text, String.class);
}catch (Exception e) {
log.error("生成销售话术失败,尝试第 {} 次。请求内容:{}", i + 1, salespitchReqVo.buildTemplate(),e);
} }
} }
throw new RuntimeException("生成销售话术失败,请稍后重试"); if (result.size() > 10){
result = result.subList(0, 10);
}
return result;
}
private List<Prompt> generatePrompts(List<String> points, SalesPitchReqVo salespitchReqVo) {
int promptNum = 10; // 需要生成的销售话术数量
int bestNum = 5;
List<List<Integer>> pointIndexList = generateIndexCombinations(points.size());
List<List<Integer>> lists = pickBestSalesPitch(pointIndexList, bestNum);
List<Prompt> prompts = new ArrayList<>();
int quotient = promptNum / lists.size(); // 商
int remainder = promptNum % lists.size(); // 余数
for (List<Integer> list : lists) {
int num1 = quotient + (remainder > 0 ? 1 : 0);
SystemMessage systemMessage = new SystemMessage(StrUtil.format(PromptTemplate.GENERATE_SALESPITCH_TEMPLATE, Map.of("num", num1)));
UserMessage userMessage = new UserMessage(salespitchReqVo.buildTemplate(list.stream().map(points::get).toList()));
prompts.add(new Prompt(systemMessage, userMessage));
if (remainder > 0){
remainder --;
}
}
return prompts;
}
/**
*
* @param pointIndexList
* @param num
* @return
*/
private List<List<Integer>> pickBestSalesPitch(List<List<Integer>> pointIndexList,int num) {
if (CollUtil.isEmpty(pointIndexList)){
return new ArrayList<>();
}
if (pointIndexList.size() <= num){
return pointIndexList;
}
List<List<Integer>> bestCombinations = new ArrayList<>();
// 先挑选出组合为2的数据
for (List<Integer> indices : pointIndexList) {
if (indices.size() == 2 && bestCombinations.size() < num) {
// 检查当前组合是否与已有组合有交集
boolean noneMatch = bestCombinations.stream().filter(i->i.size() == 2).noneMatch(i -> i.stream().anyMatch(indices::contains));
if (noneMatch){
bestCombinations.add(indices);
}
}
}
if (pointIndexList.size() % 2 != 0){
Optional<Integer> max = pointIndexList.stream().filter(i -> i.size() == 1).flatMap(Collection::stream).min((i1, i2) -> i2 - i1);
bestCombinations.add(List.of(max.get())); // 如果组合数为奇数,先添加第一个组合
}
if (bestCombinations.size() >= num){
return bestCombinations;
}
// 如果组合为2的数据不够再挑选出组合为3的数据
for (List<Integer> indices : pointIndexList) {
if (indices.size() == 3 && bestCombinations.size() < num) {
// 检查当前组合是否与已有组合有交集
boolean noneMatch = bestCombinations.stream().filter(i->i.size() == 3).noneMatch(i -> i.stream().anyMatch(indices::contains));
if (noneMatch){
bestCombinations.add(indices);
}
}
}
if (bestCombinations.size() >= num){
return bestCombinations;
}
// 如果组合为3的数据也不够再挑选出组合为4的数据
for (List<Integer> indices : pointIndexList) {
if (indices.size() == 4 && bestCombinations.size() < num) {
// 检查当前组合是否与已有组合有交集
boolean noneMatch = bestCombinations.stream().filter(i->i.size() == 4).noneMatch(i -> i.stream().anyMatch(indices::contains));
if (noneMatch){
bestCombinations.add(indices);
}
}
}
// 如果组合为4的数据也不够再挑选出组合为1的数据
for (List<Integer> indices : pointIndexList) {
if (indices.size() == 1 && bestCombinations.size() < num) {
bestCombinations.add(indices);
}
}
return bestCombinations;
}
/**
*
* @param salespitchReqVo
* @return
*/
private List<String> generateCopywritingPoints(SalesPitchReqVo salespitchReqVo) {
SystemMessage systemMessage = new SystemMessage(PromptTemplate.GENERATE_COPYWRITING_POINTS_TEMPLATE);
UserMessage userMessage = new UserMessage(salespitchReqVo.buildProductInfo());
Prompt prompt = new Prompt(systemMessage, userMessage);
ChatResponse call = ollamaChatModel.call(prompt);
String text = call.getResult().getOutput().getText();
// 去除think标签
text = removeThinkTag(text);
return JSONUtil.toList(text, String.class);
}
private String removeThinkTag(String text) {
if (StrUtil.isNotBlank(text) && text.contains("<think>") && text.contains("</think>")) {
text = text.replaceAll("(?s)<think>.*?</think>", "").trim();
}
return text;
}
/**
* n
* @param n
* @return
*/
private List<List<Integer>> generateIndexCombinations(int n) {
List<List<Integer>> result = new ArrayList<>();
int total = 1 << n; // 2^n
for (int mask = 1; mask < total; mask++) { // 从1开始跳过空集
List<Integer> current = new ArrayList<>();
for (int i = 0; i < n; i++) {
if ((mask & (1 << i)) != 0) { // 检查第i位是否被选中
current.add(i);
}
}
result.add(current);
}
return result;
} }
} }

@ -29,16 +29,37 @@ public class SalesPitchReqVo {
public String buildTemplate() { public String buildTemplate() {
return "请根据产品信息生成销售话术:" +
"\n" +
buildProductInfo() + "/no_think";
}
public String buildProductInfo() {
StringBuilder info = new StringBuilder();
info.append("产品名称:").append(productName).append("\n");
info.append("产品规格:");
for (ProductSpecification specification : specifications) {
info.append("产品类型:").append(specification.getSize()).append("\n");
info.append("产品价格:").append(specification.getPrice()).append("\n");
}
info.append("产品详情:").append(detail);
return info.toString();
}
public String buildTemplate(List<String> copyingPoints) {
StringBuilder template = new StringBuilder(); StringBuilder template = new StringBuilder();
template.append("请根据产品信息生成销售话术:"); template.append("## 重点介绍点");
template.append("\n"); template.append("\n");
template.append("产品名称:").append(productName).append("\n"); template.append("```").append("\n");
template.append("产品规格:"); for (int i = 0; i < copyingPoints.size(); i++) {
for (ProductSpecification specification : specifications) { template.append(i + 1).append(". ").append(copyingPoints.get(i)).append("\n");
template.append("产品类型:").append(specification.getSize()).append("\n");
template.append("产品价格:").append(specification.getPrice()).append("\n");
} }
template.append("产品详情:").append(detail).append("/no_think"); template.append("```").append("\n");
template.append("## 完整产品信息").append("\n");
template.append("```").append("\n");
template.append(buildProductInfo()).append("\n");
template.append("```").append("\n");
template.append("/no_think");
return template.toString(); return template.toString();
} }
} }

Loading…
Cancel
Save