rasa : 创建镜像text2vec_env
parent
097019a360
commit
597c865d84
@ -0,0 +1,14 @@
|
|||||||
|
# 设置基础镜像
|
||||||
|
FROM rasa_dev:1.0.0
|
||||||
|
|
||||||
|
COPY ./bert_chinese /usr/local/text2vec/bert_chinese
|
||||||
|
COPY ./app.py /usr/local/text2vec/
|
||||||
|
|
||||||
|
RUN /root/anaconda3/condabin/conda create --name text2vec_env python=3.9 -y && \
|
||||||
|
/root/anaconda3/condabin/conda run --no-capture-output --name text2vec_env pip install torch && \
|
||||||
|
/root/anaconda3/condabin/conda run --no-capture-output --name text2vec_env pip install flask && \
|
||||||
|
/root/anaconda3/condabin/conda run --no-capture-output --name text2vec_env pip install text2vec -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
expose 5000
|
||||||
|
#CMD [ "/root/anaconda3/condabin/conda","run","--no-capture-output","--name","text2vec_env", "python", "/usr/local/text2vec/app.py"]
|
||||||
|
|
@ -0,0 +1,101 @@
|
|||||||
|
from flask import Flask, request, jsonify
|
||||||
|
from text2vec import SentenceModel
|
||||||
|
import numpy as np
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# 获取当前脚本所在的目录
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
# BERT模型路径
|
||||||
|
model_path = os.path.join(current_dir, 'bert_chinese')
|
||||||
|
model = None
|
||||||
|
questions_data = [] # 用于存储问题数据
|
||||||
|
sentence_embeddings = []
|
||||||
|
default_threshold = 0.7
|
||||||
|
# 数据集文件路径
|
||||||
|
dataset_file_path = os.path.join(current_dir, 'question.json')
|
||||||
|
|
||||||
|
# 初始化函数,用于加载模型和数据集
|
||||||
|
def initialize_app():
|
||||||
|
global model, questions_data, sentence_embeddings
|
||||||
|
model = SentenceModel(model_path)
|
||||||
|
load_dataset()
|
||||||
|
|
||||||
|
# 加载数据集
|
||||||
|
def load_dataset():
|
||||||
|
global questions_data, sentence_embeddings, sentences
|
||||||
|
try:
|
||||||
|
with open(dataset_file_path, 'r', encoding='utf-8') as file:
|
||||||
|
questions_data = json.load(file)
|
||||||
|
sentences = [item["question"] for item in questions_data]
|
||||||
|
# 重新编码句子
|
||||||
|
sentence_embeddings = [model.encode(sent) / np.linalg.norm(model.encode(sent)) for sent in sentences]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error loading dataset: {str(e)}")
|
||||||
|
|
||||||
|
# 错误处理程序
|
||||||
|
@app.errorhandler(Exception)
|
||||||
|
def handle_error(e):
|
||||||
|
traceback.print_exc()
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
# 初始化应用
|
||||||
|
initialize_app()
|
||||||
|
|
||||||
|
# 替换数据集的接口
|
||||||
|
@app.route('/update_dataset', methods=['POST'])
|
||||||
|
def update_dataset():
|
||||||
|
global questions_data, sentence_embeddings
|
||||||
|
new_dataset = request.json or []
|
||||||
|
|
||||||
|
# 更新数据集
|
||||||
|
try:
|
||||||
|
with open(dataset_file_path, 'w', encoding='utf-8') as file:
|
||||||
|
json.dump(new_dataset, file, ensure_ascii=False, indent=2)
|
||||||
|
load_dataset()
|
||||||
|
return jsonify({'status': 'success', 'message': '数据集更新成功'})
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
return jsonify({'error': f'更新数据集错误: {str(e)}'}), 500
|
||||||
|
|
||||||
|
# 获取匹配的接口
|
||||||
|
@app.route('/matches', methods=['POST'])
|
||||||
|
def get_matches():
|
||||||
|
query_sentence = request.json.get('query_sentence', '')
|
||||||
|
query_embedding = model.encode([query_sentence])[0]
|
||||||
|
# 对向量进行单位化
|
||||||
|
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
||||||
|
# 获取阈值参数,如果请求中没有提供阈值,则使用默认阈值
|
||||||
|
threshold = request.json.get('threshold', default_threshold)
|
||||||
|
# 计算相似度
|
||||||
|
similarities = [embedding.dot(query_embedding) for embedding in sentence_embeddings]
|
||||||
|
|
||||||
|
# 获取所有相似度高于阈值的匹配项
|
||||||
|
matches = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(similarity)}
|
||||||
|
for i, similarity in enumerate(similarities) if similarity >= threshold]
|
||||||
|
|
||||||
|
return jsonify({'status': 'success', 'matches': matches} if matches else {'status': 'success', 'message': '未找到匹配项'})
|
||||||
|
|
||||||
|
# 获取所有相似度的接口
|
||||||
|
@app.route('/get_all_similarities', methods=['POST'])
|
||||||
|
def get_all_similarities():
|
||||||
|
query_sentence = request.json.get('query_sentence', '')
|
||||||
|
query_embedding = model.encode([query_sentence])[0]
|
||||||
|
# 对向量进行单位化
|
||||||
|
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
||||||
|
|
||||||
|
# 计算所有数据的相似度和对应的文本
|
||||||
|
results = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(embedding.dot(query_embedding))}
|
||||||
|
for i, embedding in enumerate(sentence_embeddings)]
|
||||||
|
|
||||||
|
# 返回所有相似度和对应文本
|
||||||
|
return jsonify({'status': 'success', 'results': results})
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True, host='0.0.0.0')
|
@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "hfl/chinese-macbert-base",
|
||||||
|
"architectures": [
|
||||||
|
"BertModel"
|
||||||
|
],
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"classifier_dropout": null,
|
||||||
|
"directionality": "bidi",
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"layer_norm_eps": 1e-12,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"model_type": "bert",
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"pooler_fc_size": 768,
|
||||||
|
"pooler_num_attention_heads": 12,
|
||||||
|
"pooler_num_fc_layers": 3,
|
||||||
|
"pooler_size_per_head": 128,
|
||||||
|
"pooler_type": "first_token_transform",
|
||||||
|
"position_embedding_type": "absolute",
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.12.3",
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 21128
|
||||||
|
}
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 启动text2vec-server
|
||||||
|
/root/anaconda3/condabin/conda run --no-capture-output --name text2vec_env python /usr/local/text2vec/app.py
|
||||||
|
pyton
|
||||||
|
# 启动jar包
|
||||||
|
java -jar -Duser.timezone=Asia/Shanghai /data/vp/virtual-patient-rasa-1.0-SNAPSHOT.jar "$@"
|
Loading…
Reference in New Issue