from fastapi import FastAPI, HTTPException, BackgroundTasks
from qa_Ask import QAService, match_query, store_data
from pydantic import BaseModel
from collections import deque
import requests
import os
import time
import uuid
import json
import shutil
import yaml
import logging

app = FastAPI()

import sys

# 配置日志记录到文件和终端
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('log/app.log'),
        logging.StreamHandler(sys.stdout)  # 添加控制台处理程序
    ]
)
logger = logging.getLogger(__name__)

class QuestionRequest(BaseModel):
    question: str
    scoreThreshold: float

class QuestionResponse(BaseModel):
    code: int
    msg: str
    data: list

class QuestionItem(BaseModel):
    questionId: str
    questionList: list[str]

class InputText(BaseModel):
    inputText: str

class ExtractedInfo(BaseModel):
    name: str
    cardNumber: str
    idNnumber: str

# 读取配置文件
with open('config/config.yaml', 'r') as config_file:
    config_data = yaml.safe_load(config_file)

knowledge_base_file = config_data['knowledge_base_file']
api_url = config_data['api']['url']
path = config_data['output_file_path']
max_knowledge_bases = config_data['max_knowledge_bases']

def load_knowledge_bases():
    """加载知识库名称列表"""
    if os.path.exists(knowledge_base_file):
        with open(knowledge_base_file, "r") as file:
            return file.read().splitlines()
    else:
        return []

def save_knowledge_bases(names):
    """保存知识库名称列表到文件"""
    with open(knowledge_base_file, "w") as file:
        file.write("\n".join(names))

def update_kb(kb_name, qa_service, path, max_knowledge_bases):
    """更新知识库"""
    store_data(qa_service, path)

    if len(recent_knowledge_bases) == max_knowledge_bases:
        folder_to_delete = recent_knowledge_bases.popleft()
        shutil.rmtree(f"knowledge_base/{folder_to_delete}")

    recent_knowledge_bases.append(kb_name)
    save_knowledge_bases(recent_knowledge_bases)

    os.remove(path)
    logger.info(f"Knowledge base updated: {kb_name}\n"
                f"Please wait while the database is being updated···")

recent_knowledge_bases = deque(load_knowledge_bases(), maxlen=max_knowledge_bases)

def text_to_number(text_id):
    chinese_nums = {'零': '0', '一': '1', '二': '2', '三': '3', '四': '4', '五': '5', '六': '6', '七': '7', '八': '8', '九': '9'}
    for chinese_num, arabic_num in chinese_nums.items():
        text_id = text_id.replace(chinese_num, arabic_num)
    return text_id

@app.post("/updateDatabase")
async def save_to_json(question_items: list[QuestionItem], background_tasks: BackgroundTasks):
    """接收问题数据并异步保存为JSON文件,触发后台更新任务"""
    try:
        json_data = json.dumps([item.dict() for item in question_items], ensure_ascii=False, indent=2)
        path = "output.json"

        with open(path, "w", encoding="utf-8") as file:
            file.write(json_data)

        kb_name = str(uuid.uuid4())

        device = None
        qa_service = QAService(kb_name, device)

        background_tasks.add_task(
            update_kb, kb_name, qa_service, path, max_knowledge_bases
        )

        return {"status": "success", "message": "Please wait while the database is being updated···"}

    except Exception as e:
        logger.error(f"Error saving data to file or scheduling knowledge base update task: {e}")
        # raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
        return {"status": "error", "message": "update task error···"}

@app.post("/matchQuestion")
def match_question(request: QuestionRequest):
    """匹配问题的端点"""
    try:
        logger.info(f"match_question:Request: {request}")
        start_time = time.time()
        query = request.question

        newest = recent_knowledge_bases[-1]

        top_k = 3

        device = None
        qa_service = QAService(newest, device)

        result = match_query(qa_service, query, top_k, request.scoreThreshold)

        response = QuestionResponse(code=200, msg="success", data=result)
        stop_time = time.time()
        duration = stop_time - start_time

        logger.info(f"match_question:Matched question in {duration} seconds. "
                    f"Response: {result}")
        return response

    except Exception as e:
        logger.error(f"Error matching question: {e}")
        return QuestionResponse(code=500, msg="success", data=[])

@app.post("/extractInformation")
async def extract_information(input_data: InputText):
    """提取信息的端点"""
    try:
        inputText = input_data.inputText
        from paddlenlp import Taskflow

        corrector = Taskflow("text_correction")
        data = corrector(inputText)

        target_value = data[0]['target']

        converted_id = text_to_number(target_value)

        schema = ["姓名", '嫌疑人', '涉案人员', "身份证号", "交易证件号", "卡号", "交易卡号", "银行卡号", ]
        ie = Taskflow('information_extraction', schema=schema, model='uie-base')
        extracted_info = ie(converted_id)

        result = {}
        for item in extracted_info:
            for key, value in item.items():
                result[key.lower()] = value[0]['text']

        extracted_result = ExtractedInfo(
            name=result.get('姓名', '') or result.get('嫌疑人', '') or result.get('涉案人员', ''),
            cardNumber=result.get('卡号', '') or result.get('交易卡号', '') or result.get('银行卡号', ''),
            idNnumber=result.get('身份证号', '') or result.get('交易证件号', '') or result.get('交易证件号', '')
        )

        return extracted_result

    except Exception as e:
        logger.error(f"Error extracting information: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error")

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="0.0.0.0", port=8000)