移除rasa
parent
42b9deb072
commit
15e8f072d0
@ -1,19 +0,0 @@
|
|||||||
# 设置基础镜像
|
|
||||||
FROM rasa_dev:1.1.0
|
|
||||||
|
|
||||||
|
|
||||||
COPY ./docs/docker-entrypoint.sh /usr/local/bin/docker-entrypoint.sh
|
|
||||||
|
|
||||||
# 设置工作目录
|
|
||||||
WORKDIR /data/vp
|
|
||||||
|
|
||||||
# 复制java jar 到容器中
|
|
||||||
#COPY target/virtual-patient-rasa-1.0-SNAPSHOT.jar /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar
|
|
||||||
# 复制rasa配置文件到 rasa目录下
|
|
||||||
COPY docs/rasa /rasa
|
|
||||||
RUN rm -f /rasa/config-local.yml
|
|
||||||
|
|
||||||
# 暴漏服务端口
|
|
||||||
EXPOSE 8890
|
|
||||||
# 设置启动命令
|
|
||||||
ENTRYPOINT ["/usr/local/bin/docker-entrypoint.sh"]
|
|
@ -1,16 +0,0 @@
|
|||||||
# 设置基础镜像
|
|
||||||
FROM rasa_dev:1.0.0
|
|
||||||
|
|
||||||
COPY ./bert_chinese /usr/local/text2vec/bert_chinese
|
|
||||||
COPY ./app.py /usr/local/text2vec/
|
|
||||||
#COPY ./question.json /usr/local/text2vec/
|
|
||||||
|
|
||||||
RUN source /root/anaconda3/etc/profile.d/conda.sh && \
|
|
||||||
conda create --name text2vec_env python=3.9 -y && \
|
|
||||||
conda activate text2vec_env && \
|
|
||||||
pip install torch && \
|
|
||||||
pip install flask && \
|
|
||||||
pip install text2vec -i https://pypi.tuna.tsinghua.edu.cn/simple
|
|
||||||
|
|
||||||
expose 5000
|
|
||||||
|
|
@ -1,101 +0,0 @@
|
|||||||
from flask import Flask, request, jsonify
|
|
||||||
from text2vec import SentenceModel
|
|
||||||
import numpy as np
|
|
||||||
import traceback
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
# 获取当前脚本所在的目录
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
# BERT模型路径
|
|
||||||
model_path = os.path.join(current_dir, 'bert_chinese')
|
|
||||||
model = None
|
|
||||||
questions_data = [] # 用于存储问题数据
|
|
||||||
sentence_embeddings = []
|
|
||||||
default_threshold = 0.7
|
|
||||||
# 数据集文件路径
|
|
||||||
dataset_file_path = os.path.join(current_dir, 'question.json')
|
|
||||||
|
|
||||||
# 初始化函数,用于加载模型和数据集
|
|
||||||
def initialize_app():
|
|
||||||
global model, questions_data, sentence_embeddings
|
|
||||||
model = SentenceModel(model_path)
|
|
||||||
load_dataset()
|
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
def load_dataset():
|
|
||||||
global questions_data, sentence_embeddings, sentences
|
|
||||||
try:
|
|
||||||
with open(dataset_file_path, 'r', encoding='utf-8') as file:
|
|
||||||
questions_data = json.load(file)
|
|
||||||
sentences = [item["question"] for item in questions_data]
|
|
||||||
# 重新编码句子
|
|
||||||
sentence_embeddings = [model.encode(sent) / np.linalg.norm(model.encode(sent)) for sent in sentences]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
print(f"Error loading dataset: {str(e)}")
|
|
||||||
|
|
||||||
# 错误处理程序
|
|
||||||
@app.errorhandler(Exception)
|
|
||||||
def handle_error(e):
|
|
||||||
traceback.print_exc()
|
|
||||||
return jsonify({'error': str(e)}), 500
|
|
||||||
|
|
||||||
# 初始化应用
|
|
||||||
initialize_app()
|
|
||||||
|
|
||||||
# 替换数据集的接口
|
|
||||||
@app.route('/update_dataset', methods=['POST'])
|
|
||||||
def update_dataset():
|
|
||||||
global questions_data, sentence_embeddings
|
|
||||||
new_dataset = request.json or []
|
|
||||||
|
|
||||||
# 更新数据集
|
|
||||||
try:
|
|
||||||
with open(dataset_file_path, 'w', encoding='utf-8') as file:
|
|
||||||
json.dump(new_dataset, file, ensure_ascii=False, indent=2)
|
|
||||||
load_dataset()
|
|
||||||
return jsonify({'status': 'success', 'message': '数据集更新成功'})
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
return jsonify({'error': f'更新数据集错误: {str(e)}'}), 500
|
|
||||||
|
|
||||||
# 获取匹配的接口
|
|
||||||
@app.route('/matches', methods=['POST'])
|
|
||||||
def get_matches():
|
|
||||||
query_sentence = request.json.get('querySentence', '')
|
|
||||||
query_embedding = model.encode([query_sentence])[0]
|
|
||||||
# 对向量进行单位化
|
|
||||||
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
|
||||||
# 获取阈值参数,如果请求中没有提供阈值,则使用默认阈值
|
|
||||||
threshold = request.json.get('threshold', default_threshold)
|
|
||||||
# 计算相似度
|
|
||||||
similarities = [embedding.dot(query_embedding) for embedding in sentence_embeddings]
|
|
||||||
|
|
||||||
# 获取所有相似度高于阈值的匹配项
|
|
||||||
matches = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(similarity)}
|
|
||||||
for i, similarity in enumerate(similarities) if similarity >= threshold]
|
|
||||||
|
|
||||||
return jsonify({'status': 'success', 'results': matches} if matches else {'status': 'success', 'message': '未找到匹配项'})
|
|
||||||
|
|
||||||
# 获取所有相似度的接口
|
|
||||||
@app.route('/get_all_similarities', methods=['POST'])
|
|
||||||
def get_all_similarities():
|
|
||||||
query_sentence = request.json.get('querySentence', '')
|
|
||||||
query_embedding = model.encode([query_sentence])[0]
|
|
||||||
# 对向量进行单位化
|
|
||||||
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
|
||||||
|
|
||||||
# 计算所有数据的相似度和对应的文本
|
|
||||||
results = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(embedding.dot(query_embedding))}
|
|
||||||
for i, embedding in enumerate(sentence_embeddings)]
|
|
||||||
|
|
||||||
# 返回所有相似度和对应文本
|
|
||||||
return jsonify({'status': 'success', 'results': results})
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
app.run(debug=True, host='0.0.0.0')
|
|
@ -1,32 +0,0 @@
|
|||||||
{
|
|
||||||
"_name_or_path": "hfl/chinese-macbert-base",
|
|
||||||
"architectures": [
|
|
||||||
"BertModel"
|
|
||||||
],
|
|
||||||
"attention_probs_dropout_prob": 0.1,
|
|
||||||
"classifier_dropout": null,
|
|
||||||
"directionality": "bidi",
|
|
||||||
"gradient_checkpointing": false,
|
|
||||||
"hidden_act": "gelu",
|
|
||||||
"hidden_dropout_prob": 0.1,
|
|
||||||
"hidden_size": 768,
|
|
||||||
"initializer_range": 0.02,
|
|
||||||
"intermediate_size": 3072,
|
|
||||||
"layer_norm_eps": 1e-12,
|
|
||||||
"max_position_embeddings": 512,
|
|
||||||
"model_type": "bert",
|
|
||||||
"num_attention_heads": 12,
|
|
||||||
"num_hidden_layers": 12,
|
|
||||||
"pad_token_id": 0,
|
|
||||||
"pooler_fc_size": 768,
|
|
||||||
"pooler_num_attention_heads": 12,
|
|
||||||
"pooler_num_fc_layers": 3,
|
|
||||||
"pooler_size_per_head": 128,
|
|
||||||
"pooler_type": "first_token_transform",
|
|
||||||
"position_embedding_type": "absolute",
|
|
||||||
"torch_dtype": "float32",
|
|
||||||
"transformers_version": "4.12.3",
|
|
||||||
"type_vocab_size": 2,
|
|
||||||
"use_cache": true,
|
|
||||||
"vocab_size": 21128
|
|
||||||
}
|
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,218 +0,0 @@
|
|||||||
[
|
|
||||||
{
|
|
||||||
"id": "101",
|
|
||||||
"question": "你好"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "152",
|
|
||||||
"question": "你好吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "3",
|
|
||||||
"question": "你好不好啊?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "4",
|
|
||||||
"question": "你怎么样?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "5",
|
|
||||||
"question": "你好吗,今天?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "6",
|
|
||||||
"question": "你近期好吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "7",
|
|
||||||
"question": "你好,对吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "8",
|
|
||||||
"question": "你还好吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "9",
|
|
||||||
"question": "你一切安好吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "10",
|
|
||||||
"question": "你好,还是不好?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "11",
|
|
||||||
"question": "你好吗,朋友?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "12",
|
|
||||||
"question": "你近来可好?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "12",
|
|
||||||
"question": "你还好吗,最近?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "14",
|
|
||||||
"question": "你还好吗,朋友?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "15",
|
|
||||||
"question": "你好吗,最近怎么样?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "16",
|
|
||||||
"question": "你好吗,身体怎么样?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "17",
|
|
||||||
"question": "你好吗,心情如何?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "18",
|
|
||||||
"question": "你好吗,一切顺利吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "19",
|
|
||||||
"question": "你好吗,有什么新的计划吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "20",
|
|
||||||
"question": "你好,有什么想说的吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "21",
|
|
||||||
"question": "你好,有什么问题需要我帮忙的吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "22",
|
|
||||||
"question": "您以前有过类似情况吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "23",
|
|
||||||
"question": "您之前是否遇到过相似的情况?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "24",
|
|
||||||
"question": "您以前有没有经历过这样的事情?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "25",
|
|
||||||
"question": "这种情况以前对您来说是否熟悉?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "26",
|
|
||||||
"question": "您是否有过与这类似的经历?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "27",
|
|
||||||
"question": "在您的经验中,是否有过这样的先例?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "28",
|
|
||||||
"question": "您是否曾面临过相似的挑战?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "29",
|
|
||||||
"question": "这种情境是否让您回想起过去的某些经历?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "30",
|
|
||||||
"question": "您是否曾经历过与此类似的困境?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "31",
|
|
||||||
"question": "这种情况是否以前也发生过?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "32",
|
|
||||||
"question": "您是否曾有过类似的体验?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "33",
|
|
||||||
"question": "以前是否遇到过相似的问题?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "34",
|
|
||||||
"question": "这种情境是否与您过去的经历相符?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "35",
|
|
||||||
"question": "您是否曾经身处于类似的状况?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "36",
|
|
||||||
"question": "您是否有过和这类似的历史?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "37",
|
|
||||||
"question": "这种情况是否让您想起以前的经历?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "38",
|
|
||||||
"question": "您是否曾遭遇过相同的状况?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "39",
|
|
||||||
"question": "您是否有过类似的先例?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "40",
|
|
||||||
"question": "这种情况是否与您的过往经历相似?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "41",
|
|
||||||
"question": "您是否曾经面对过这样的局面?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "42",
|
|
||||||
"question": "您是否有过和这类似的经验?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "43",
|
|
||||||
"question": "您以前有没有因为这些症状看过其他的医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "44",
|
|
||||||
"question": "您之前是否因为这些症状咨询过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "45",
|
|
||||||
"question": "您是否曾因这些症状求助于其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "46",
|
|
||||||
"question": "因为这些症状,您以前看过别的医生吗?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "47",
|
|
||||||
"question": "您是否曾因为这些症状去看过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "48",
|
|
||||||
"question": "您是否因这些症状找过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "49",
|
|
||||||
"question": "关于这些症状,您以前是否咨询过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "50",
|
|
||||||
"question": "您是否曾经因为这些症状找过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "51",
|
|
||||||
"question": "您是否曾因为这些症状向其他医生寻求帮助?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "52",
|
|
||||||
"question": "关于这些症状,您是否看过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "53",
|
|
||||||
"question": "您以前有没有因为这些症状而咨询过其他医生?"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "54",
|
|
||||||
"question": "这些症状以前是否让您去看过其他医生?"
|
|
||||||
}
|
|
||||||
]
|
|
@ -1,9 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# 启动text2vec服务
|
|
||||||
source /root/anaconda3/etc/profile.d/conda.sh && \
|
|
||||||
conda activate text2vec_env && \
|
|
||||||
nohup python /usr/local/text2vec/app.py &
|
|
||||||
|
|
||||||
# 启动jar包
|
|
||||||
java -jar -Duser.timezone=Asia/Shanghai -Dfile.encoding=UTF-8 -Xms256m -Xmx1g /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar "$@"
|
|
@ -1,29 +0,0 @@
|
|||||||
|
|
||||||
recipe: default.v1
|
|
||||||
language: zh
|
|
||||||
|
|
||||||
pipeline:
|
|
||||||
- name: JiebaTokenizer
|
|
||||||
- name: LanguageModelFeaturizer
|
|
||||||
model_name: "bert"
|
|
||||||
model_weights: "/rasa/bert-base-chinese"
|
|
||||||
- name: RegexFeaturizer
|
|
||||||
- name: DIETClassifier
|
|
||||||
epochs: 100
|
|
||||||
learning_rate: 0.001
|
|
||||||
tensorboard_log_directory: ./log
|
|
||||||
- name: ResponseSelector
|
|
||||||
epochs: 100
|
|
||||||
learning_rate: 0.001
|
|
||||||
- name: FallbackClassifier
|
|
||||||
threshold: 0.4
|
|
||||||
ambiguity_threshold: 0.1
|
|
||||||
- name: EntitySynonymMapper
|
|
||||||
|
|
||||||
policies:
|
|
||||||
- name: MemoizationPolicy
|
|
||||||
- name: TEDPolicy
|
|
||||||
- name: RulePolicy
|
|
||||||
core_fallback_threshold: 0.4
|
|
||||||
core_fallback_action_name: "action_default_fallback"
|
|
||||||
enable_fallback_prediction: True
|
|
@ -1,33 +0,0 @@
|
|||||||
# This file contains the credentials for the voice & chat platforms
|
|
||||||
# which your bot is using.
|
|
||||||
# https://rasa.com/docs/rasa/messaging-and-voice-channels
|
|
||||||
|
|
||||||
rest:
|
|
||||||
# # you don't need to provide anything here - this channel doesn't
|
|
||||||
# # require any credentials
|
|
||||||
|
|
||||||
|
|
||||||
#facebook:
|
|
||||||
# verify: "<verify>"
|
|
||||||
# secret: "<your secret>"
|
|
||||||
# page-access-token: "<your page access token>"
|
|
||||||
|
|
||||||
#slack:
|
|
||||||
# slack_token: "<your slack token>"
|
|
||||||
# slack_channel: "<the slack channel>"
|
|
||||||
# slack_signing_secret: "<your slack signing secret>"
|
|
||||||
|
|
||||||
#socketio:
|
|
||||||
# user_message_evt: <event name for user message>
|
|
||||||
# bot_message_evt: <event name for bot messages>
|
|
||||||
# session_persistence: <true/false>
|
|
||||||
|
|
||||||
#mattermost:
|
|
||||||
# url: "https://<mattermost instance>/api/v4"
|
|
||||||
# token: "<bot token>"
|
|
||||||
# webhook_url: "<callback URL>"
|
|
||||||
|
|
||||||
# This entry is needed if you are using Rasa Enterprise. The entry represents credentials
|
|
||||||
# for the Rasa Enterprise "channel", i.e. Talk to your bot and Share with guest testers.
|
|
||||||
rasa:
|
|
||||||
url: "http://localhost:5002/api"
|
|
@ -1,51 +0,0 @@
|
|||||||
# This file contains the different endpoints your bot can use.
|
|
||||||
|
|
||||||
# Server where the models are pulled from.
|
|
||||||
# https://rasa.com/docs/rasa/model-storage#fetching-models-from-a-server
|
|
||||||
|
|
||||||
#models:
|
|
||||||
# url: http://my-server.com/models/default_core@latest
|
|
||||||
# wait_time_between_pulls: 10 # [optional](default: 100)
|
|
||||||
|
|
||||||
# Server which runs your custom actions.
|
|
||||||
# https://rasa.com/docs/rasa/custom-actions
|
|
||||||
|
|
||||||
action_endpoint:
|
|
||||||
url: "http://127.0.0.1:5055/webhook"
|
|
||||||
|
|
||||||
# Tracker store which is used to store the conversations.
|
|
||||||
# By default the conversations are stored in memory.
|
|
||||||
# https://rasa.com/docs/rasa/tracker-stores
|
|
||||||
|
|
||||||
#tracker_store:
|
|
||||||
# type: redis
|
|
||||||
# url: <host of the redis instance, e.g. localhost>
|
|
||||||
# port: <port of your redis instance, usually 6379>
|
|
||||||
# db: <number of your database within redis, e.g. 0>
|
|
||||||
# password: <password used for authentication>
|
|
||||||
# use_ssl: <whether or not the communication is encrypted, default false>
|
|
||||||
|
|
||||||
#tracker_store:
|
|
||||||
# type: mongod
|
|
||||||
# url: <url to your mongo instance, e.g. mongodb://localhost:27017>
|
|
||||||
# db: <name of the db within your mongo instance, e.g. rasa>
|
|
||||||
# username: <username used for authentication>
|
|
||||||
# password: <password used for authentication>
|
|
||||||
|
|
||||||
|
|
||||||
#tracker_store:
|
|
||||||
# type: mongod
|
|
||||||
# url: mongodb://192.168.10.137:27017
|
|
||||||
# db: dialog_test
|
|
||||||
# username:
|
|
||||||
# password:
|
|
||||||
|
|
||||||
|
|
||||||
# Event broker which all conversation events should be streamed to.
|
|
||||||
# https://rasa.com/docs/rasa/event-brokers
|
|
||||||
|
|
||||||
#event_broker:
|
|
||||||
# url: localhost
|
|
||||||
# username: username
|
|
||||||
# password: password
|
|
||||||
# queue: queue
|
|
@ -1,112 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>com.supervision</groupId>
|
|
||||||
<artifactId>virtual-patient</artifactId>
|
|
||||||
<version>1.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>virtual-patient-rasa</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>virtual-patient-rasa</name>
|
|
||||||
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<java.version>17</java.version>
|
|
||||||
<maven.compiler.source>17</maven.compiler.source>
|
|
||||||
<maven.compiler.target>17</maven.compiler.target>
|
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.alibaba.cloud</groupId>
|
|
||||||
<artifactId>spring-cloud-starter-alibaba-nacos-discovery</artifactId>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.alibaba</groupId>
|
|
||||||
<artifactId>fastjson</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.cloud</groupId>
|
|
||||||
<artifactId>spring-cloud-starter-openfeign</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.cloud</groupId>
|
|
||||||
<artifactId>spring-cloud-starter-loadbalancer</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.alibaba.cloud</groupId>
|
|
||||||
<artifactId>spring-cloud-starter-alibaba-nacos-config</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.supervision</groupId>
|
|
||||||
<artifactId>virtual-patient-common</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.supervision</groupId>
|
|
||||||
<artifactId>virtual-patient-model</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.projectlombok</groupId>
|
|
||||||
<artifactId>lombok</artifactId>
|
|
||||||
<optional>true</optional>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-starter-test</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>cn.hutool</groupId>
|
|
||||||
<artifactId>hutool-all</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<!--用来生成yml文件,jakson的模板库满足不了多行文本|管道符的需求-->
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.freemarker</groupId>
|
|
||||||
<artifactId>freemarker</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-maven-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<excludes>
|
|
||||||
<exclude>
|
|
||||||
<groupId>org.projectlombok</groupId>
|
|
||||||
<artifactId>lombok</artifactId>
|
|
||||||
</exclude>
|
|
||||||
</excludes>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
</project>
|
|
@ -1,39 +0,0 @@
|
|||||||
package com.supervision.rasa;
|
|
||||||
|
|
||||||
import com.supervision.config.WebConfig;
|
|
||||||
import com.supervision.rasa.service.RasaModelManager;
|
|
||||||
import com.supervision.rasa.service.Text2vecService;
|
|
||||||
import org.mybatis.spring.annotation.MapperScan;
|
|
||||||
import org.springframework.boot.SpringApplication;
|
|
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
|
||||||
import org.springframework.cloud.client.discovery.EnableDiscoveryClient;
|
|
||||||
import org.springframework.context.ConfigurableApplicationContext;
|
|
||||||
import org.springframework.context.annotation.ComponentScan;
|
|
||||||
import org.springframework.context.annotation.FilterType;
|
|
||||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
|
||||||
|
|
||||||
@SpringBootApplication
|
|
||||||
@EnableScheduling
|
|
||||||
@MapperScan(basePackages = {"com.supervision.**.mapper"})
|
|
||||||
// 排除JWT权限校验
|
|
||||||
@ComponentScan(basePackages = {"com.supervision"},excludeFilters = @ComponentScan.Filter(type = FilterType.ASSIGNABLE_TYPE, classes = {WebConfig.class}))
|
|
||||||
@EnableDiscoveryClient
|
|
||||||
public class VirtualPatientRasaApplication {
|
|
||||||
|
|
||||||
public static void main(String[] args) {
|
|
||||||
ConfigurableApplicationContext context = SpringApplication.run(VirtualPatientRasaApplication.class, args);
|
|
||||||
|
|
||||||
// 启动rasa服务
|
|
||||||
RasaModelManager rasaModelManager = context.getBean(RasaModelManager.class);
|
|
||||||
try {
|
|
||||||
rasaModelManager.wakeUpInterruptServerScheduled();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化文本匹配数据
|
|
||||||
Text2vecService text2vecService = context.getBean(Text2vecService.class);
|
|
||||||
text2vecService.initText2vecDataset();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,30 +0,0 @@
|
|||||||
package com.supervision.rasa.config;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.concurrent.*;
|
|
||||||
|
|
||||||
public class ThreadPoolExecutorConfig {
|
|
||||||
|
|
||||||
private volatile static ThreadPoolExecutor instance = null;
|
|
||||||
|
|
||||||
private ThreadPoolExecutorConfig(){}
|
|
||||||
|
|
||||||
|
|
||||||
public static ThreadPoolExecutor getInstance() {
|
|
||||||
new ArrayList<>();
|
|
||||||
if (instance == null) {
|
|
||||||
synchronized (ThreadPoolExecutorConfig.class) { // 加锁
|
|
||||||
if (instance == null) {
|
|
||||||
int corePoolSize = 5;
|
|
||||||
int maximumPoolSize = 10;
|
|
||||||
long keepAliveTime = 100;
|
|
||||||
BlockingQueue<Runnable> workQueue = new ArrayBlockingQueue<>(20);
|
|
||||||
RejectedExecutionHandler rejectedExecutionHandler = new ThreadPoolExecutor.AbortPolicy();
|
|
||||||
instance = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, TimeUnit.SECONDS, workQueue, rejectedExecutionHandler);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
package com.supervision.rasa.constant;
|
|
||||||
|
|
||||||
public class RasaConstant {
|
|
||||||
|
|
||||||
public static final String TRAN_SUCCESS_MESSAGE = "Your Rasa model is trained and saved at";
|
|
||||||
public static final String RUN_SUCCESS_MESSAGE = "Rasa server is up and running";
|
|
||||||
public static final String RUN_SHELL = "run.sh";
|
|
||||||
public static final String TRAIN_SHELL = "train.sh";
|
|
||||||
public static final String KILL_SHELL = "kill.sh";
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
@ -1,56 +0,0 @@
|
|||||||
package com.supervision.rasa.controller;
|
|
||||||
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
|
||||||
import com.supervision.exception.BusinessException;
|
|
||||||
import com.supervision.rasa.constant.RasaConstant;
|
|
||||||
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
|
|
||||||
import com.supervision.rasa.service.RasaCmdService;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.concurrent.ExecutionException;
|
|
||||||
import java.util.concurrent.TimeoutException;
|
|
||||||
|
|
||||||
@Tag(name = "rasa管理")
|
|
||||||
@RestController
|
|
||||||
@RequestMapping("rasaCmd")
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaCmdController {
|
|
||||||
|
|
||||||
private final RasaCmdService rasaCmdService;
|
|
||||||
|
|
||||||
@Operation(summary = "执行训练shell命令")
|
|
||||||
@PostMapping("/trainExec")
|
|
||||||
public String trainExec(@RequestBody RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException {
|
|
||||||
|
|
||||||
return rasaCmdService.trainExec(argument);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Operation(summary = "执行启动shell命令")
|
|
||||||
@PostMapping("/runExec")
|
|
||||||
public String runExec(@RequestBody RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException {
|
|
||||||
|
|
||||||
String outString = rasaCmdService.runExec(argument);
|
|
||||||
if (StrUtil.isEmptyIfStr(outString) || !outString.contains(RasaConstant.RUN_SUCCESS_MESSAGE)){
|
|
||||||
throw new BusinessException("任务执行异常。详细日志:"+outString);
|
|
||||||
}
|
|
||||||
return outString;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Operation(summary = "部署rasa")
|
|
||||||
@PostMapping("/deploy")
|
|
||||||
public boolean deployRasa() throws Exception {
|
|
||||||
return rasaCmdService.deployRasa();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
@ -1,36 +0,0 @@
|
|||||||
package com.supervision.rasa.controller;
|
|
||||||
|
|
||||||
|
|
||||||
import com.supervision.exception.BusinessException;
|
|
||||||
import com.supervision.rasa.service.RasaFileService;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestParam;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@Tag(name = "rasa文件保存")
|
|
||||||
@RestController
|
|
||||||
@RequestMapping("rasaFile")
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaFileController {
|
|
||||||
|
|
||||||
private final RasaFileService rasaFileService;
|
|
||||||
|
|
||||||
@Operation(summary = "接受并保存rasa文件")
|
|
||||||
@PostMapping("/saveRasaFile")
|
|
||||||
public String saveRasaFile(@RequestParam("file") MultipartFile file) throws IOException {
|
|
||||||
|
|
||||||
if (file == null || file.isEmpty()) {
|
|
||||||
throw new BusinessException("file is empty");
|
|
||||||
}
|
|
||||||
rasaFileService.saveRasaFile(file,"1");
|
|
||||||
return "success";
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,30 +0,0 @@
|
|||||||
package com.supervision.rasa.controller;
|
|
||||||
|
|
||||||
import com.supervision.rasa.service.RasaTalkService;
|
|
||||||
import com.supervision.vo.rasa.RasaTalkVo;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Tag(name = "ras对话服务")
|
|
||||||
@RestController
|
|
||||||
@RequestMapping("rasa")
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaTalkController {
|
|
||||||
|
|
||||||
private final RasaTalkService rasaTalkService;
|
|
||||||
|
|
||||||
@Operation(summary = "rasa对话")
|
|
||||||
@PostMapping("talkRasa")
|
|
||||||
public List<String> talkRasa(@RequestBody RasaTalkVo rasaTalkVo){
|
|
||||||
|
|
||||||
return rasaTalkService.talkRasa(rasaTalkVo);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,38 +0,0 @@
|
|||||||
package com.supervision.rasa.controller;
|
|
||||||
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
|
|
||||||
import com.supervision.rasa.service.Text2vecService;
|
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
|
||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Tag(name = "text2vec服务")
|
|
||||||
@RestController
|
|
||||||
@RequestMapping("/text2vec")
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class Text2vecController {
|
|
||||||
|
|
||||||
private final Text2vecService text2vecService;
|
|
||||||
|
|
||||||
@Operation(summary = "更新数据库")
|
|
||||||
@PostMapping("updateDataset")
|
|
||||||
public boolean talkRasa(@RequestBody List<Text2vecDataVo> text2vecDataVoList){
|
|
||||||
|
|
||||||
return text2vecService.updateDataset(text2vecDataVoList);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Operation(summary = "获取匹配项")
|
|
||||||
@PostMapping("matches")
|
|
||||||
public List<Text2vecMatchesRes> matches(@RequestBody Text2vecMatchesReq text2vecMatchesReq){
|
|
||||||
|
|
||||||
return text2vecService.matches(text2vecMatchesReq);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,28 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DomainYmlTemplate {
|
|
||||||
|
|
||||||
private List<String> intents;
|
|
||||||
|
|
||||||
private LinkedHashMap<String,List<String>> responses;
|
|
||||||
|
|
||||||
private List<String> actions;
|
|
||||||
|
|
||||||
private SessionConfig session_config = new SessionConfig();
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class SessionConfig{
|
|
||||||
|
|
||||||
private final int session_expiration_time = 60;
|
|
||||||
|
|
||||||
private final Boolean carry_over_slots_to_new_session = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,28 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class NluYmlTemplate {
|
|
||||||
|
|
||||||
private List<Nlu> nlu;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class Nlu{
|
|
||||||
private String intent;
|
|
||||||
|
|
||||||
private List<String> examples;
|
|
||||||
|
|
||||||
public Nlu(String intent, List<String> examples) {
|
|
||||||
this.intent = intent;
|
|
||||||
this.examples = examples;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Nlu() {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class QuestionAnswerDTO {
|
|
||||||
|
|
||||||
private List<String> questionList;
|
|
||||||
|
|
||||||
private List<String> answerList;
|
|
||||||
|
|
||||||
private String desc;
|
|
||||||
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RasaReqDTO {
|
|
||||||
|
|
||||||
private String sender;
|
|
||||||
|
|
||||||
private String message;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RasaResDTO {
|
|
||||||
|
|
||||||
private String recipient_id;
|
|
||||||
|
|
||||||
private String text;
|
|
||||||
}
|
|
@ -1,72 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollUtil;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 启动参数
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class RasaRunParam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* bash 路径
|
|
||||||
*/
|
|
||||||
private String bashPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 启动脚本路径
|
|
||||||
*/
|
|
||||||
private String shellPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 模型路径
|
|
||||||
*/
|
|
||||||
private String rasaModelPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 配置文件位置
|
|
||||||
*/
|
|
||||||
private String endpointsPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 服务端口
|
|
||||||
*/
|
|
||||||
private String port;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 通过list构建RasaRunParam对象
|
|
||||||
* @param args bashPath = args[0], shellPath = args[1], rasaModelPath = args[2], endpointsPath = args[3], port = args[4]
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static RasaRunParam build(List<String> args) {
|
|
||||||
RasaRunParam rasaRunParam = new RasaRunParam();
|
|
||||||
if (CollUtil.isEmpty(args)){
|
|
||||||
return rasaRunParam;
|
|
||||||
}
|
|
||||||
|
|
||||||
rasaRunParam.setBashPath(args.get(0));
|
|
||||||
if (args.size()>1){
|
|
||||||
rasaRunParam.setShellPath(args.get(1));
|
|
||||||
}
|
|
||||||
if (args.size()>2){
|
|
||||||
rasaRunParam.setRasaModelPath(args.get(2));
|
|
||||||
}
|
|
||||||
if (args.size()>3){
|
|
||||||
rasaRunParam.setEndpointsPath(args.get(3));
|
|
||||||
}
|
|
||||||
if (args.size()>4){
|
|
||||||
rasaRunParam.setPort(args.get(4));
|
|
||||||
}
|
|
||||||
return rasaRunParam;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public List<String> toList(){
|
|
||||||
return CollUtil.newArrayList(bashPath,shellPath,rasaModelPath,endpointsPath,port);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,86 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollUtil;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 训练参数
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class RasaTrainParam {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* bash路径
|
|
||||||
*/
|
|
||||||
private String bashPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 训练脚本路径
|
|
||||||
*/
|
|
||||||
private String shellPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 训练配置文件路径
|
|
||||||
*/
|
|
||||||
private String configPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 训练数据路径 (rules.yml nlu.yml)
|
|
||||||
*/
|
|
||||||
private String localDataPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa domain.yml 存放路径
|
|
||||||
*/
|
|
||||||
private String domainPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* rasa 训练出的模型存放路径
|
|
||||||
*/
|
|
||||||
private String localModelsPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 训练出的模型名称
|
|
||||||
*/
|
|
||||||
private String fixedModelName;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 通过list构建RasaTrainParam对象
|
|
||||||
* @param args 参数列表 bashPath = args[0] shellPath = args[1] configPath = args[2]
|
|
||||||
* localDataPath = args[3] domainPath = args[4] localModelsPath = args[5] fixedModelName = args[6]
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static RasaTrainParam build(List<String> args) {
|
|
||||||
RasaTrainParam rasaTrainParam = new RasaTrainParam();
|
|
||||||
if (CollUtil.isEmpty(args)){
|
|
||||||
return rasaTrainParam;
|
|
||||||
}
|
|
||||||
rasaTrainParam.bashPath = args.get(0);
|
|
||||||
if (args.size() > 1){
|
|
||||||
rasaTrainParam.shellPath = args.get(1);
|
|
||||||
}
|
|
||||||
if (args.size() > 2){
|
|
||||||
rasaTrainParam.configPath = args.get(2);
|
|
||||||
}
|
|
||||||
if (args.size() > 3){
|
|
||||||
rasaTrainParam.localDataPath = args.get(3);
|
|
||||||
}
|
|
||||||
if (args.size() > 4){
|
|
||||||
rasaTrainParam.domainPath = args.get(4);
|
|
||||||
}
|
|
||||||
if (args.size() > 5){
|
|
||||||
rasaTrainParam.localModelsPath = args.get(5);
|
|
||||||
}
|
|
||||||
if (args.size() > 6){
|
|
||||||
rasaTrainParam.fixedModelName = args.get(6);
|
|
||||||
}
|
|
||||||
return rasaTrainParam;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> toList() {
|
|
||||||
return CollUtil.newArrayList(bashPath, shellPath, configPath, localDataPath, domainPath, localModelsPath, fixedModelName);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,44 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RuleYmlTemplate {
|
|
||||||
|
|
||||||
private List<Rule> rules;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class Rule {
|
|
||||||
private String rule;
|
|
||||||
|
|
||||||
private List<Step> steps;
|
|
||||||
|
|
||||||
public Rule() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public Rule(String rule, String intent, String action) {
|
|
||||||
this.rule = rule;
|
|
||||||
steps = new ArrayList<>();
|
|
||||||
Step step = new Step(intent, action);
|
|
||||||
steps.add(step);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public static class Step {
|
|
||||||
private String intent;
|
|
||||||
|
|
||||||
private String action;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class Text2vecDataVo {
|
|
||||||
|
|
||||||
@Schema(description = "数据id")
|
|
||||||
private String id;
|
|
||||||
|
|
||||||
@Schema(description = "问题")
|
|
||||||
private String question;
|
|
||||||
|
|
||||||
public Text2vecDataVo(String id, String question) {
|
|
||||||
this.id = id;
|
|
||||||
this.question = question;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Text2vecDataVo() {
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,27 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class Text2vecMatchesReq {
|
|
||||||
|
|
||||||
@Schema(description = "需要被匹配的语句")
|
|
||||||
private String querySentence;
|
|
||||||
|
|
||||||
@Schema(description = "相似度阈值")
|
|
||||||
private Double threshold;
|
|
||||||
|
|
||||||
public Text2vecMatchesReq() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public Text2vecMatchesReq(String querySentence) {
|
|
||||||
this.querySentence = querySentence;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Text2vecMatchesReq(String querySentence, Double threshold) {
|
|
||||||
this.querySentence = querySentence;
|
|
||||||
this.threshold = threshold;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,18 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.dto;
|
|
||||||
|
|
||||||
|
|
||||||
import io.swagger.v3.oas.annotations.media.Schema;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class Text2vecMatchesRes {
|
|
||||||
|
|
||||||
@Schema(description = "id")
|
|
||||||
private String id;
|
|
||||||
|
|
||||||
@Schema(description = "句子")
|
|
||||||
private String sentence;
|
|
||||||
|
|
||||||
@Schema(description = "相似度")
|
|
||||||
private String similarity;
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package com.supervision.rasa.pojo.vo;
|
|
||||||
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.Date;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RasaCmdArgumentVo {
|
|
||||||
|
|
||||||
private String fixedModelName;//fixed-model-name
|
|
||||||
|
|
||||||
private String modelId;
|
|
||||||
|
|
||||||
public void setFixedModelNameIfAbsent(){
|
|
||||||
if (StrUtil.isEmpty(fixedModelName)){
|
|
||||||
fixedModelName = String.valueOf(new Date().getTime()/1000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,30 +0,0 @@
|
|||||||
package com.supervision.rasa.service;
|
|
||||||
|
|
||||||
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
|
|
||||||
import com.supervision.rasa.pojo.vo.RasaCmdArgumentVo;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.ExecutionException;
|
|
||||||
import java.util.concurrent.TimeoutException;
|
|
||||||
import java.util.function.Predicate;
|
|
||||||
|
|
||||||
public interface RasaCmdService {
|
|
||||||
|
|
||||||
String trainExec(RasaCmdArgumentVo argument) throws IOException, ExecutionException, InterruptedException, TimeoutException;
|
|
||||||
|
|
||||||
String runExec( RasaCmdArgumentVo argument) throws ExecutionException, InterruptedException, TimeoutException;
|
|
||||||
|
|
||||||
|
|
||||||
List<String> execCmd(List<String> cmds, Predicate<String> endPredicate, long timeOut) throws InterruptedException, ExecutionException, TimeoutException;
|
|
||||||
|
|
||||||
String getShellPath(String shell);
|
|
||||||
|
|
||||||
boolean deployRasa() throws Exception;
|
|
||||||
|
|
||||||
Map<String, QuestionAnswerDTO> generateRasaYml(String path);
|
|
||||||
|
|
||||||
Map<String, QuestionAnswerDTO> getIntentCodeAndIdMap();
|
|
||||||
|
|
||||||
}
|
|
@ -1,10 +0,0 @@
|
|||||||
package com.supervision.rasa.service;
|
|
||||||
|
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public interface RasaFileService {
|
|
||||||
|
|
||||||
void saveRasaFile(MultipartFile file,String modelId) throws IOException;
|
|
||||||
}
|
|
@ -1,149 +0,0 @@
|
|||||||
package com.supervision.rasa.service;
|
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
|
||||||
import cn.hutool.core.io.FileUtil;
|
|
||||||
import cn.hutool.core.lang.Assert;
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
|
||||||
import com.supervision.model.RasaModelInfo;
|
|
||||||
import com.supervision.rasa.constant.RasaConstant;
|
|
||||||
import com.supervision.rasa.pojo.dto.RasaRunParam;
|
|
||||||
import com.supervision.rasa.util.PortUtil;
|
|
||||||
import com.supervision.service.RasaModeService;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.scheduling.annotation.Scheduled;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileFilter;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.ExecutionException;
|
|
||||||
import java.util.concurrent.TimeoutException;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Component
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaModelManager {
|
|
||||||
|
|
||||||
@Value("${rasa.models-path}")
|
|
||||||
private String modelsPath;
|
|
||||||
|
|
||||||
private final RasaModeService rasaModeService;
|
|
||||||
|
|
||||||
private final RasaCmdService rasaCmdService;
|
|
||||||
|
|
||||||
private boolean wakeUpInterruptServerRunning = false;
|
|
||||||
|
|
||||||
public void wakeUpInterruptServer(){
|
|
||||||
|
|
||||||
// 1. 查找出记录表中存活的服务
|
|
||||||
List<RasaModelInfo> rasaModelInfos = rasaModeService.listActive();
|
|
||||||
List<RasaModelInfo> activeRasaList = rasaModelInfos.stream().filter(info -> CollectionUtil.isNotEmpty(info.getRunCmd()) && null != info.getPort()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtil.isEmpty(activeRasaList)){
|
|
||||||
log.info("wakeUpInterruptService: no rasa service need wake up ...");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 重新启动中断的服务
|
|
||||||
for (RasaModelInfo rasaModelInfo : activeRasaList) {
|
|
||||||
if (PortUtil.portIsActive(rasaModelInfo.getPort())) {
|
|
||||||
log.info("wakeUpInterruptServer: port:{} is run..", rasaModelInfo.getPort());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
RasaRunParam rasaRunParam = RasaRunParam.build(rasaModelInfo.getRunCmd());
|
|
||||||
rasaRunParam.setPort(String.valueOf(rasaModelInfo.getPort()));
|
|
||||||
rasaRunParam.setShellPath(rasaCmdService.getShellPath(RasaConstant.RUN_SHELL));
|
|
||||||
String rasaModelPath = rasaRunParam.getRasaModelPath();
|
|
||||||
if (StrUtil.isEmpty(rasaModelPath) || !FileUtil.exist(rasaModelPath)) {
|
|
||||||
log.info("wakeUpInterruptServer: rasa model path {} not exist,attempt find last ...", rasaModelPath);
|
|
||||||
String modeParentPath = replaceDuplicateSeparator(String.join(File.separator, modelsPath));
|
|
||||||
String fixedModePath = listLastFilePath(modeParentPath, f -> f.getName().matches("-?\\d+(\\.\\d+)?.tar.gz"));
|
|
||||||
Assert.notEmpty(fixedModePath, "wakeUpInterruptService: no rasa model in path {} ", modeParentPath);
|
|
||||||
rasaRunParam.setRasaModelPath(fixedModePath);
|
|
||||||
}
|
|
||||||
log.info("wakeUpInterruptServer : use fixedModePath :{}", rasaRunParam.getRasaModelPath());
|
|
||||||
List<String> outMessageList = rasaCmdService.execCmd(rasaRunParam.toList(),
|
|
||||||
s -> StrUtil.isNotBlank(s) && s.contains(RasaConstant.RUN_SUCCESS_MESSAGE), 300);
|
|
||||||
|
|
||||||
rasaModelInfo.setRunLog(String.join("\r\n", outMessageList));
|
|
||||||
rasaModelInfo.setRunCmd(rasaRunParam.toList());
|
|
||||||
rasaModeService.updateById(rasaModelInfo);
|
|
||||||
|
|
||||||
if (!runIsSuccess(outMessageList)) {
|
|
||||||
log.info("wakeUpInterruptServer: restart server port for {} failed,details info : {}", rasaModelInfo.getPort(), String.join("\r\n", outMessageList));
|
|
||||||
}
|
|
||||||
} catch (InterruptedException | ExecutionException | TimeoutException e) {
|
|
||||||
log.info("wakeUpInterruptServer: restart server port for {} failed", rasaModelInfo.getPort());
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
log.info("wakeUpInterruptServer: restart server port for {} success ", rasaModelInfo.getPort());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//默认每十分钟执行一次
|
|
||||||
@Scheduled(cron = "${rasa.wakeup.cron:0 */10 * * * ?}")
|
|
||||||
public void wakeUpInterruptServerScheduled() {
|
|
||||||
log.info("wakeUpInterruptServerScheduled: Scheduled is run .... wakeUpInterruptServerRunning is :{}", wakeUpInterruptServerRunning);
|
|
||||||
if (wakeUpInterruptServerRunning) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
wakeUpInterruptServerRunning = true;
|
|
||||||
wakeUpInterruptServer();
|
|
||||||
} finally {
|
|
||||||
wakeUpInterruptServerRunning = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean runIsSuccess(List<String> messageList){
|
|
||||||
|
|
||||||
return containKey(messageList,RasaConstant.RUN_SUCCESS_MESSAGE);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean containKey(List<String> messageList,String keyWord){
|
|
||||||
|
|
||||||
if (CollectionUtil.isEmpty(messageList)){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (StrUtil.isEmpty(keyWord)){
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return messageList.stream().anyMatch(s->StrUtil.isNotEmpty(s) && s.contains(keyWord));
|
|
||||||
}
|
|
||||||
|
|
||||||
private String replaceDuplicateSeparator(String path){
|
|
||||||
|
|
||||||
if (StrUtil.isEmpty(path)){
|
|
||||||
return path;
|
|
||||||
}
|
|
||||||
|
|
||||||
return path.replace(File.separator + File.separator, File.separator);
|
|
||||||
}
|
|
||||||
|
|
||||||
private String listLastFilePath(String path, FileFilter filter){
|
|
||||||
File file = listLastFile(path, filter);
|
|
||||||
if (null == file){
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return file.getPath();
|
|
||||||
}
|
|
||||||
|
|
||||||
private File listLastFile(String path,FileFilter filter){
|
|
||||||
File file = new File(path);
|
|
||||||
File[] files = file.listFiles(filter);
|
|
||||||
if (null == files){
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Arrays.stream(files).max(Comparator.comparing(File::getName)).orElse(null);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
package com.supervision.rasa.service;
|
|
||||||
|
|
||||||
import com.supervision.vo.rasa.RasaTalkVo;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface RasaTalkService {
|
|
||||||
|
|
||||||
|
|
||||||
List<String> talkRasa(RasaTalkVo rasaTalkVo) ;
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
package com.supervision.rasa.service;
|
|
||||||
|
|
||||||
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public interface Text2vecService {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 更新数据
|
|
||||||
* @param text2vecDataVoList 数据集合
|
|
||||||
* @return 是否更新成功
|
|
||||||
*/
|
|
||||||
boolean updateDataset(List<Text2vecDataVo> text2vecDataVoList);
|
|
||||||
|
|
||||||
boolean updateDataset(Map<String, QuestionAnswerDTO> questionAnswerDTOMap);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 语句匹配
|
|
||||||
* @param text2vecMatchesReq
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
List<Text2vecMatchesRes> matches(Text2vecMatchesReq text2vecMatchesReq);
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 初始化语料库
|
|
||||||
*/
|
|
||||||
void initText2vecDataset();
|
|
||||||
}
|
|
@ -1,69 +0,0 @@
|
|||||||
package com.supervision.rasa.service.impl;
|
|
||||||
|
|
||||||
import cn.hutool.core.io.FileUtil;
|
|
||||||
import cn.hutool.core.util.ZipUtil;
|
|
||||||
import com.supervision.rasa.service.RasaFileService;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.web.multipart.MultipartFile;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
@Slf4j
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaFileServiceImpl implements RasaFileService {
|
|
||||||
|
|
||||||
|
|
||||||
@Value("${rasa.data-path:/home/rasa/model_resource/}")
|
|
||||||
private String rasaFilePath;
|
|
||||||
|
|
||||||
@Value("${rasa.file-name:rasa.zip}")
|
|
||||||
private String rasaFileName;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void saveRasaFile(MultipartFile file,String modelId) throws IOException {
|
|
||||||
|
|
||||||
|
|
||||||
String suffix = "_back";
|
|
||||||
String rasaFullPath = String.join(File.separator, rasaFilePath,modelId, rasaFileName);
|
|
||||||
String rasaBackFullPath = rasaFullPath + suffix;
|
|
||||||
|
|
||||||
//初始化目录
|
|
||||||
File dir = new File(String.join(File.separator, rasaFilePath,modelId));
|
|
||||||
if (!dir.exists()){
|
|
||||||
FileUtil.mkdir(dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
//1.检查路径下是否存在文件
|
|
||||||
File oldFile = new File(rasaFullPath);
|
|
||||||
if (oldFile.exists()){
|
|
||||||
//1.1 如果存在文件,先备份文件
|
|
||||||
FileUtil.rename(oldFile,rasaBackFullPath,true);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
try {
|
|
||||||
//2.把流文件保存到本地
|
|
||||||
file.transferTo(new File(rasaFullPath));
|
|
||||||
|
|
||||||
//3.解压文件
|
|
||||||
ZipUtil.unzip(rasaFullPath,String.join(File.separator, rasaFilePath,modelId));
|
|
||||||
|
|
||||||
//4.删除备份文件
|
|
||||||
FileUtil.del(rasaBackFullPath);
|
|
||||||
} catch (IOException e) {
|
|
||||||
// 恢复文件
|
|
||||||
File backFile = new File(rasaBackFullPath);
|
|
||||||
if (backFile.exists()){ //恢复文件
|
|
||||||
FileUtil.rename(backFile,rasaFileName,true);
|
|
||||||
}
|
|
||||||
throw new IOException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,65 +0,0 @@
|
|||||||
package com.supervision.rasa.service.impl;
|
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollUtil;
|
|
||||||
import cn.hutool.core.util.StrUtil;
|
|
||||||
import cn.hutool.http.HttpUtil;
|
|
||||||
import cn.hutool.json.JSONUtil;
|
|
||||||
import com.supervision.model.RasaModelInfo;
|
|
||||||
import com.supervision.rasa.pojo.dto.RasaReqDTO;
|
|
||||||
import com.supervision.rasa.pojo.dto.RasaResDTO;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesReq;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecMatchesRes;
|
|
||||||
import com.supervision.rasa.service.Text2vecService;
|
|
||||||
import com.supervision.vo.rasa.RasaTalkVo;
|
|
||||||
import com.supervision.rasa.service.RasaTalkService;
|
|
||||||
import com.supervision.service.RasaModeService;
|
|
||||||
import lombok.RequiredArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
@Service
|
|
||||||
@Slf4j
|
|
||||||
@RequiredArgsConstructor
|
|
||||||
public class RasaTalkServiceImpl implements RasaTalkService {
|
|
||||||
|
|
||||||
@Value("${rasa.url}")
|
|
||||||
private String rasaUrl;
|
|
||||||
|
|
||||||
private final RasaModeService rasaModeService;
|
|
||||||
|
|
||||||
private final Text2vecService text2vecService;
|
|
||||||
@Override
|
|
||||||
public List<String> talkRasa(RasaTalkVo rasaTalkVo) {
|
|
||||||
|
|
||||||
RasaModelInfo rasaModelInfo = rasaModeService.queryByModelId("1");
|
|
||||||
|
|
||||||
RasaReqDTO rasaReqDTO = new RasaReqDTO();
|
|
||||||
rasaReqDTO.setSender(rasaTalkVo.getSessionId());
|
|
||||||
rasaReqDTO.setMessage(rasaTalkVo.getQuestion());
|
|
||||||
|
|
||||||
String rasaUrl = getRasaUrl(rasaModelInfo.getPort());
|
|
||||||
log.info("talkRasa: url is: {}",rasaUrl);
|
|
||||||
|
|
||||||
rasaReqDTO.setSender(UUID.randomUUID().toString());
|
|
||||||
String post = HttpUtil.post(rasaUrl, JSONUtil.toJsonStr(rasaReqDTO));
|
|
||||||
|
|
||||||
List<RasaResDTO> list = JSONUtil.toList(post, RasaResDTO.class);
|
|
||||||
|
|
||||||
log.info("talkRasa: rasa talk result is: {}",JSONUtil.toJsonStr(list));
|
|
||||||
if (CollUtil.isNotEmpty(list)){
|
|
||||||
return list.stream().map(RasaResDTO::getText).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
log.info("talkRasa: rasa talk result is empty , redirect for text2vecService ...");
|
|
||||||
return text2vecService.matches(new Text2vecMatchesReq(rasaTalkVo.getQuestion()))
|
|
||||||
.stream().map(Text2vecMatchesRes::getId).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getRasaUrl(int port){
|
|
||||||
|
|
||||||
return StrUtil.format(rasaUrl, port);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,42 +0,0 @@
|
|||||||
package com.supervision.rasa.util;
|
|
||||||
|
|
||||||
import cn.hutool.core.collection.CollectionUtil;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.Socket;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class PortUtil {
|
|
||||||
|
|
||||||
|
|
||||||
public static boolean portIsActive(int port){
|
|
||||||
try {
|
|
||||||
Socket socket = new Socket("localhost", port);
|
|
||||||
socket.close();
|
|
||||||
return true;
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.info("portIsActive: port:{} connect error",port);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static int findUnusedPort(int minPort, int maxPort, List<Integer> excludePorts){
|
|
||||||
|
|
||||||
if (maxPort < minPort){
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int port = minPort; port < maxPort; port++) {
|
|
||||||
if (CollectionUtil.isNotEmpty(excludePorts)&& excludePorts.contains(port)){
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!portIsActive(port)){
|
|
||||||
return port;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,10 +0,0 @@
|
|||||||
spring:
|
|
||||||
cloud:
|
|
||||||
nacos:
|
|
||||||
config:
|
|
||||||
server-addr: 192.168.10.137:8848
|
|
||||||
file-extension: yml
|
|
||||||
namespace: b9eea377-79ec-4ba5-9cc2-354f7bd5181e
|
|
||||||
discovery:
|
|
||||||
server-addr: 192.168.10.137:8848
|
|
||||||
namespace: b9eea377-79ec-4ba5-9cc2-354f7bd5181e
|
|
@ -1,10 +0,0 @@
|
|||||||
spring:
|
|
||||||
cloud:
|
|
||||||
nacos:
|
|
||||||
config:
|
|
||||||
server-addr: 192.168.10.137:8848
|
|
||||||
file-extension: yml
|
|
||||||
namespace: 88e1f674-1fbc-4021-9ff1-60b94ee13ef0
|
|
||||||
discovery:
|
|
||||||
server-addr: 192.168.10.137:8848
|
|
||||||
namespace: 88e1f674-1fbc-4021-9ff1-60b94ee13ef0
|
|
@ -1,5 +0,0 @@
|
|||||||
spring:
|
|
||||||
profiles:
|
|
||||||
active: test
|
|
||||||
application:
|
|
||||||
name: virtual-patient-rasa
|
|
@ -1,40 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<configuration>
|
|
||||||
<include resource="org/springframework/boot/logging/logback/base.xml"/>
|
|
||||||
|
|
||||||
<logger name="org.springframework.web" level="INFO"/>
|
|
||||||
|
|
||||||
<!-- 开发环境 -->
|
|
||||||
<springProfile name="local">
|
|
||||||
<logger name="org.springframework.web" level="INFO"/>
|
|
||||||
<logger name="org.springboot.sample" level="INFO"/>
|
|
||||||
<logger name="com.supervision" level="DEBUG"/>
|
|
||||||
<logger name="org.springframework.scheduling" level="INFO"/>
|
|
||||||
</springProfile>
|
|
||||||
|
|
||||||
<!-- 测试环境,生产环境 -->
|
|
||||||
<springProfile name="dev,test,prod">
|
|
||||||
<logger name="org.springframework.web" level="INFO"/>
|
|
||||||
<logger name="org.springboot.sample" level="INFO"/>
|
|
||||||
<logger name="com.supervision" level="INFO"/>
|
|
||||||
<logger name="org.springframework.scheduling" level="INFO"/>
|
|
||||||
<root level="INFO">
|
|
||||||
<appender name="DAILY_LOG" class="ch.qos.logback.core.rolling.RollingFileAppender">
|
|
||||||
<!-- 服务器中当天的日志 -->
|
|
||||||
<file>/data/vp/log/virtual-patient-rasa.log</file>
|
|
||||||
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
|
|
||||||
<!-- 服务器归档日志 -->
|
|
||||||
<fileNamePattern>/data/vp/log/history/virtual-patient-rasa-%d{yyyy-MM-dd}.log</fileNamePattern>
|
|
||||||
</rollingPolicy>
|
|
||||||
<encoder>
|
|
||||||
<pattern>%date [%thread] %-5level %logger{35} - %msg%n</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
</root>
|
|
||||||
</springProfile>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</configuration>
|
|
@ -1,14 +0,0 @@
|
|||||||
!/usr/bin/env bash
|
|
||||||
PORT=$1
|
|
||||||
|
|
||||||
KPID=$(lsof -t -i:$PORT)
|
|
||||||
|
|
||||||
if [ -n "$KPID" ]; then
|
|
||||||
echo "PORT ===>> $PORT map PID is not empty. pid is $KPID "
|
|
||||||
echo "cmd is kill -9 $KPID "
|
|
||||||
kill -9 $KPID
|
|
||||||
else
|
|
||||||
echo "PORT ===>> $PORT map PID is empty."
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
@ -1,16 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
source /root/anaconda3/etc/profile.d/conda.sh
|
|
||||||
|
|
||||||
conda activate rasa3
|
|
||||||
|
|
||||||
MODELS_PATH=$1
|
|
||||||
ENDPOINTS=$2
|
|
||||||
PORT=$3
|
|
||||||
|
|
||||||
echo "shell cmd is rasa run -m $MODELS_PATH --enable-api --cors "*" --debug --endpoints $ENDPOINTS --port $PORT "
|
|
||||||
|
|
||||||
rasa run -m $MODELS_PATH --enable-api --cors "*" --debug --endpoints $ENDPOINTS --port $PORT
|
|
||||||
|
|
||||||
echo 'start success ...'
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
source /root/anaconda3/etc/profile.d/conda.sh
|
|
||||||
|
|
||||||
conda activate rasa3
|
|
||||||
|
|
||||||
CONFIG=$1
|
|
||||||
DATA=$2
|
|
||||||
DOMAIN=$3
|
|
||||||
OUT=$4
|
|
||||||
FIXED_MODEL_NAME=$5
|
|
||||||
|
|
||||||
echo "tran shell is run..."
|
|
||||||
|
|
||||||
echo "cmd is rasa train --config $CONFIG --data $DATA --domain $DOMAIN --out $OUT --fixed-model-name $FIXED_MODEL_NAME"
|
|
||||||
|
|
||||||
rasa train --config $CONFIG --data $DATA --domain $DOMAIN --out $OUT --fixed-model-name $FIXED_MODEL_NAME
|
|
||||||
|
|
||||||
echo 'train done'
|
|
||||||
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
|||||||
recipe: default.v1
|
|
||||||
language: zh
|
|
||||||
|
|
||||||
pipeline:
|
|
||||||
- name: JiebaTokenizer
|
|
||||||
- name: LanguageModelFeaturizer
|
|
||||||
model_name: bert
|
|
||||||
model_weights: bert-base-chinese
|
|
||||||
- name: RegexFeaturizer
|
|
||||||
- name: DIETClassifier
|
|
||||||
epochs: 100
|
|
||||||
learning_rate: 0.001
|
|
||||||
tensorboard_log_directory: ./log
|
|
||||||
- name: ResponseSelector
|
|
||||||
epochs: 100
|
|
||||||
learning_rate: 0.001
|
|
||||||
- name: FallbackClassifier
|
|
||||||
threshold: 0.4
|
|
||||||
ambiguity_threshold: 0.1
|
|
||||||
- name: EntitySynonymMapper
|
|
||||||
|
|
||||||
policies:
|
|
||||||
- name: MemoizationPolicy
|
|
||||||
- name: TEDPolicy
|
|
||||||
- name: RulePolicy
|
|
||||||
core_fallback_threshold: 0.4
|
|
||||||
core_fallback_action_name: "action_default_fallback"
|
|
||||||
enable_fallback_prediction: True
|
|
@ -1,23 +0,0 @@
|
|||||||
version: "3.1"
|
|
||||||
|
|
||||||
intents:
|
|
||||||
<#list intents as intent>
|
|
||||||
- ${intent}
|
|
||||||
</#list>
|
|
||||||
|
|
||||||
responses:
|
|
||||||
<#list responses?keys as response>
|
|
||||||
${response}:
|
|
||||||
<#list responses[response] as item>
|
|
||||||
- text: "${item}"
|
|
||||||
</#list>
|
|
||||||
</#list>
|
|
||||||
|
|
||||||
actions:
|
|
||||||
<#list actions as action>
|
|
||||||
- ${action}
|
|
||||||
</#list>
|
|
||||||
|
|
||||||
session_config:
|
|
||||||
session_expiration_time: 60
|
|
||||||
carry_over_slots_to_new_session: true
|
|
@ -1,10 +0,0 @@
|
|||||||
version: "3.1"
|
|
||||||
|
|
||||||
nlu:
|
|
||||||
<#list nlu as item>
|
|
||||||
- intent: ${item.intent}
|
|
||||||
examples: |
|
|
||||||
<#list item.examples as example>
|
|
||||||
- ${example}
|
|
||||||
</#list>
|
|
||||||
</#list>
|
|
@ -1,12 +0,0 @@
|
|||||||
version: "3.1"
|
|
||||||
|
|
||||||
rules:
|
|
||||||
|
|
||||||
<#list rules as item>
|
|
||||||
- rule: ${item.rule}
|
|
||||||
steps:
|
|
||||||
<#list item.steps as ss>
|
|
||||||
- intent: ${ss.intent}
|
|
||||||
- action: ${ss.action}
|
|
||||||
</#list>
|
|
||||||
</#list>
|
|
@ -1,47 +0,0 @@
|
|||||||
package com.supervision.rasa;
|
|
||||||
|
|
||||||
import com.supervision.rasa.pojo.dto.QuestionAnswerDTO;
|
|
||||||
import com.supervision.rasa.pojo.dto.Text2vecDataVo;
|
|
||||||
import com.supervision.rasa.service.RasaCmdService;
|
|
||||||
import com.supervision.rasa.service.Text2vecService;
|
|
||||||
import com.supervision.util.RedisSequenceUtil;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.boot.test.context.SpringBootTest;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@SpringBootTest
|
|
||||||
class VirtualPatientRasaApplicationTests {
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private RasaCmdService rasaCmdService;
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private Text2vecService text2vecService;
|
|
||||||
@Test
|
|
||||||
void contextLoads() {
|
|
||||||
/*Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml("F:\\tmp\\rasa");
|
|
||||||
System.out.println(questionAnswerDTOMap);*/
|
|
||||||
|
|
||||||
/* Map<String, QuestionAnswerDTO> questionAnswerDTOMap = rasaCmdService.generateRasaYml(String.join(File.separator, "F:\\tmp\\rasa"));
|
|
||||||
List<Text2vecDataVo> text2vecDataVoList = questionAnswerDTOMap.entrySet().stream()
|
|
||||||
.flatMap(entry -> entry.getValue().getQuestionList().stream()
|
|
||||||
.map(question -> new Text2vecDataVo(entry.getKey(), question))).collect(Collectors.toList());
|
|
||||||
text2vecService.updateDataset(text2vecDataVoList);*/
|
|
||||||
|
|
||||||
String complexDiseaseNo = RedisSequenceUtil.getComplexDiseaseNo();
|
|
||||||
|
|
||||||
String processNo = RedisSequenceUtil.getProcessNo();
|
|
||||||
|
|
||||||
String questionLibraryCode = RedisSequenceUtil.getQuestionLibraryCode(()->0L);
|
|
||||||
|
|
||||||
String questionLibraryCode1 = RedisSequenceUtil.getQuestionLibraryDefaultAnswerCode(() -> 0L);
|
|
||||||
|
|
||||||
System.out.println("complexDiseaseNo"+complexDiseaseNo+" processNo"+processNo+" questionLibraryCode"+questionLibraryCode+" questionLibraryCode1"+questionLibraryCode1);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
Loading…
Reference in New Issue