from flask import Flask, request, jsonify
from text2vec import SentenceModel
import numpy as np
import traceback
import json
import os

app = Flask(__name__)

# 获取当前脚本所在的目录
current_dir = os.path.dirname(os.path.abspath(__file__))

# BERT模型路径
model_path = os.path.join(current_dir, 'bert_chinese')
model = None
questions_data = []  # 用于存储问题数据
sentence_embeddings = []
default_threshold = 0.7
# 数据集文件路径
dataset_file_path = os.path.join(current_dir, 'question.json')

# 初始化函数,用于加载模型和数据集
def initialize_app():
    global model, questions_data, sentence_embeddings
    model = SentenceModel(model_path)
    load_dataset()

# 加载数据集
def load_dataset():
    global questions_data, sentence_embeddings, sentences
    try:
        with open(dataset_file_path, 'r', encoding='utf-8') as file:
            questions_data = json.load(file)
            sentences = [item["question"] for item in questions_data]
            # 重新编码句子
            sentence_embeddings = [model.encode(sent) / np.linalg.norm(model.encode(sent)) for sent in sentences]

    except Exception as e:
        traceback.print_exc()
        print(f"Error loading dataset: {str(e)}")

# 错误处理程序
@app.errorhandler(Exception)
def handle_error(e):
    traceback.print_exc()
    return jsonify({'error': str(e)}), 500

# 初始化应用
initialize_app()

# 替换数据集的接口
@app.route('/update_dataset', methods=['POST'])
def update_dataset():
    global questions_data, sentence_embeddings
    new_dataset = request.json or []

    # 更新数据集
    try:
        with open(dataset_file_path, 'w', encoding='utf-8') as file:
            json.dump(new_dataset, file, ensure_ascii=False, indent=2)
        load_dataset()
        return jsonify({'status': 'success', 'message': '数据集更新成功'})
    except Exception as e:
        traceback.print_exc()
        return jsonify({'error': f'更新数据集错误: {str(e)}'}), 500

# 获取匹配的接口
@app.route('/matches', methods=['POST'])
def get_matches():
    query_sentence = request.json.get('querySentence', '')
    query_embedding = model.encode([query_sentence])[0]
    # 对向量进行单位化
    query_embedding = query_embedding / np.linalg.norm(query_embedding)
    # 获取阈值参数,如果请求中没有提供阈值,则使用默认阈值
    threshold = request.json.get('threshold', default_threshold)
    # 计算相似度
    similarities = [embedding.dot(query_embedding) for embedding in sentence_embeddings]

    # 获取所有相似度高于阈值的匹配项
    matches = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(similarity)}
               for i, similarity in enumerate(similarities) if similarity >= threshold]

    return jsonify({'status': 'success', 'results': matches} if matches else {'status': 'success', 'message': '未找到匹配项'})

# 获取所有相似度的接口
@app.route('/get_all_similarities', methods=['POST'])
def get_all_similarities():
    query_sentence = request.json.get('querySentence', '')
    query_embedding = model.encode([query_sentence])[0]
    # 对向量进行单位化
    query_embedding = query_embedding / np.linalg.norm(query_embedding)

    # 计算所有数据的相似度和对应的文本
    results = [{'id': questions_data[i]["id"], 'sentence': sentences[i], 'similarity': float(embedding.dot(query_embedding))}
               for i, embedding in enumerate(sentence_embeddings)]

    # 返回所有相似度和对应文本
    return jsonify({'status': 'success', 'results': results})

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0')