# -*- coding: utf-8 -*-

from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import sys
import logging
from typing import List
from faiss_kb_service import FaissKBService
from faiss_kb_service import EmbeddingsFunAdapter

app = FastAPI()

# 配置日志记录到文件和终端
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('log/app.log'),
        logging.StreamHandler(sys.stdout)  # 这里添加控制台处理程序
    ]
)
logger = logging.getLogger(__name__)

class QuestionRequest(BaseModel):
    question: str
    scoreThreshold: float

class EmbeddingResponse(BaseModel):
    embeddings: List[float]


class QuestionResponse(BaseModel):
    code: int
    msg: str
    data: list


class QuestionItem(BaseModel):
    questionId: str
    questionList: list[str]

class EmbeddingRequest(BaseModel):
    text: str

with open('config/config.yaml', 'r') as config_file:
    config_data = yaml.safe_load(config_file)

knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']


def load_knowledge_bases():
    """加载知识库名称列表"""
    if os.path.exists(knowledge_base_file):
        with open(knowledge_base_file, "r") as file:
            return file.read().splitlines()
    else:
        return []


def save_knowledge_bases(names):
    """保存知识库名称列表到文件"""
    with open(knowledge_base_file, "w") as file:
        file.write("\n".join(names))


def update_kb(kb_name, qa_service, path, max_knowledge_bases):
    """更新知识库"""
    store_data(qa_service, path)

    if len(recent_knowledge_bases) == max_knowledge_bases:
        folder_to_delete = recent_knowledge_bases.popleft()
        shutil.rmtree(f"knowledge_base/{folder_to_delete}")

    recent_knowledge_bases.append(kb_name)
    save_knowledge_bases(recent_knowledge_bases)

    os.remove(path)
    logger.info(f"Knowledge base updated: {kb_name}\n"
                f"Please wait while the database is being updated···")


def fetch_and_write_data(api_url, path):
    """从API获取数据并写入文件"""
    try:
        response = requests.get(api_url)
        response_data = response.json()

        if response.status_code == 200 and response_data["code"] == 200:
            question_items = response_data["data"]

            with open(path, "w", encoding="utf-8") as file:
                json.dump(question_items, file, ensure_ascii=False, indent=2)

            return True
        else:
            logger.error(f"Failed to fetch data from API. Status code: {response.status_code}, Response data: {response_data}")
            return False
    except Exception as e:
        logger.error(f"Error fetching data from API: {e}")
        return False
bge_large_zh_v1_5_config = config_data.get('bge_large_zh_v1_5', {})
embed_model_path = bge_large_zh_v1_5_config.get('embed_model_path', 'default_path_if_not_provided')

recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)
faiss_service = FaissKBService(kb_name= recent_knowledge_bases[-1], embed_model_path=embed_model_path, device=None)

@app.post("/embeddings/", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):
    """使用FaissKBService实例来获取嵌入向量"""
    embed_func = EmbeddingsFunAdapter(faiss_service.embed_model_path, faiss_service.device)
    try:
        embeddings = embed_func.embed_query(request.text)
        embeddings_list = embeddings
        return EmbeddingResponse(embeddings=embeddings_list)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
    """接收问题数据并异步保存为JSON文件,触发后台更新任务"""
    try:
        json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
        path = "output.json"

        with open(path, "w", encoding="utf-8") as file:
            file.write(json_data)

        kb_name = str(uuid.uuid4())

        device = None
        qa_service = QAService(kb_name, device)

        background_tasks.add_task(
            update_kb, kb_name, qa_service, path, max_knowledge_bases
        )

        return {"status": "success", "message": "Please wait while the database is being updated···"}

    except Exception as e:
        logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
        # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
        return {"status": "error", "message": "update task error···"}


@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
    """匹配问题的端点"""
    try:
        logger.info(f"match_question:Request: {request}")
        start_time = time.time()
        query = request.question

        newest = recent_knowledge_bases[-1]

        top_k = 3

        device = None
        qa_service = QAService(newest, device)

        result = match_query(qa_service, query, top_k, request.scoreThreshold)

        response = QuestionResponse(code=200, msg="success", data=result)
        stop_time = time.time()
        duration = stop_time - start_time

        logger.info(f"match_question:Matched question in {duration} seconds. "
                    f"Response: {result}")
        return response

    except Exception as e:
        logger.error(f"Error matching question: {e}")
        return QuestionResponse(code=500, msg="success", data=[])

recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)

if fetch_and_write_data(api_url, path):
    kb_name = str(uuid.uuid4())
    device = None
    qa_service = QAService(kb_name, device)
    update_kb(kb_name, qa_service, path, max_knowledge_bases)

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)