from flask import Flask, request, jsonify from text2vec import SentenceModel import numpy as np import traceback import json import os app = Flask(__name__) # 获取当前脚本所在的目录 current_dir = os.path.dirname(os.path.abspath(__file__)) # BERT模型路径 model_path = os.path.join(current_dir, 'bert_chinese') model = None questions_data = [] # 用于存储问题数据 sentence_embeddings = [] default_threshold = 0.7 # 数据集文件路径 dataset_file_path = os.path.join(current_dir, 'question.json') # 初始化函数,用于加载模型和数据集 def initialize_app(): global model, questions_data, sentence_embeddings model = SentenceModel(model_path) load_dataset() # 加载数据集 def load_dataset(): global questions_data, sentence_embeddings, sentences try: with open(dataset_file_path, 'r', encoding='utf-8') as file: questions_data = json.load(file) sentences = [item["question"] for item in questions_data] # 重新编码句子 sentence_embeddings = [model.encode(sent) / np.linalg.norm(model.encode(sent)) for sent in sentences] except Exception as e: traceback.print_exc() print(f"Error loading dataset: {str(e)}") # 错误处理程序 @app.errorhandler(Exception) def handle_error(e): traceback.print_exc() return jsonify({'error': str(e)}), 500 # 初始化应用 initialize_app() # 替换数据集的接口 @app.route('/update_dataset', methods=['POST']) def update_dataset(): global questions_data, sentence_embeddings new_dataset = request.json or [] # 更新数据集 try: with open(dataset_file_path, 'w', encoding='utf-8') as file: json.dump(new_dataset, file, ensure_ascii=False, indent=2) load_dataset() return jsonify({'status': 'success', 'message': '数据集更新成功'}) except Exception as e: traceback.print_exc() return jsonify({'error': f'更新数据集错误: {str(e)}'}), 500 # 获取匹配的接口 @app.route('/matches', methods=['POST']) def get_matches(): query_sentence = request.json.get('querySentence', '') query_embedding = model.encode([query_sentence])[0] # 对向量进行单位化 query_embedding = query_embedding / np.linalg.norm(query_embedding) # 获取阈值参数,如果请求中没有提供阈值,则使用默认阈值 threshold = request.json.get('threshold', default_threshold) # 计算相似度 similarities = [embedding.dot(query_embedding) for embedding in sentence_embeddings] # 获取所有相似度高于阈值的匹配项 matches = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(similarity)} for i, similarity in enumerate(similarities) if similarity >= threshold] return jsonify({'status': 'success', 'results': matches} if matches else {'status': 'success', 'message': '未找到匹配项'}) # 获取所有相似度的接口 @app.route('/get_all_similarities', methods=['POST']) def get_all_similarities(): query_sentence = request.json.get('querySentence', '') query_embedding = model.encode([query_sentence])[0] # 对向量进行单位化 query_embedding = query_embedding / np.linalg.norm(query_embedding) # 计算所有数据的相似度和对应的文本 results = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(embedding.dot(query_embedding))} for i, embedding in enumerate(sentence_embeddings)] # 返回所有相似度和对应文本 return jsonify({'status': 'success', 'results': results}) if __name__ == '__main__': app.run(debug=True, host='0.0.0.0')