import psycopg2
from psycopg2 import OperationalError, extras
from loguru import logger
import os
import traceback


# 创建连接(你只需要创建一次,然后在服务中复用)
def create_connection():
    try:
        conn = psycopg2.connect(
            dbname=os.environ['POSTGRESQL_DATABASE'],
            user=os.environ['POSTGRESQL_USERNAME'],
            password=os.environ['POSTGRESQL_PASSWORD'],
            host=os.environ['POSTGRESQL_HOST'],
            port=os.environ['POSTGRESQL_PORT'],
            options=f'-c search_path={os.environ["POSTGRESQL_SCHEMA"]}'
        )
        conn.autocommit = False
        with conn.cursor() as cur:
            cur.execute("SET TIME ZONE 'Asia/Shanghai';")
        conn.commit()
        return conn
    except OperationalError as e:
        logger.error(f"连接数据库失败: {e}")
        raise e


# 插入数据的函数
def insert_data(conn, table, data_dict, pk_name='id'):
    """
    向指定表中插入数据
    :param conn: psycopg2 connection 对象
    :param table: 表名(字符串)
    :param data_dict: 字典格式的数据,比如 {"column1": value1, "column2": value2}
    """
    if conn is None:
        logger.error("数据库连接无效")
        return

    try:
        with conn.cursor() as cur:
            columns = data_dict.keys()
            values = [data_dict[col] for col in columns]

            # 构造 SQL
            insert_query = f"""
                INSERT INTO {table} ({', '.join(columns)})
                VALUES ({', '.join(['%s'] * len(values))})
                RETURNING {pk_name}
            """
            cur.execute(insert_query, values)
            inserted_id = cur.fetchone()[0]
            return inserted_id

    except Exception as e:
        logger.error(f"插入数据失败: {e}")
        raise e


def insert_multiple_data(conn, table, data_list, batch_size=100):
    """
    批量向指定表中插入数据(使用 execute_values)
    :param conn: psycopg2 connection 对象
    :param table: 表名(字符串)
    :param data_list: 包含多个字典的列表,每个字典代表一行数据,比如 [{"column1": value1, "column2": value2}, ...]
    """
    if conn is None:
        logger.error("数据库连接无效")
        return

    try:
        with conn.cursor() as cur:
            columns = data_list[0].keys()
            insert_query = f"""
                INSERT INTO {table} ({', '.join(columns)})
                VALUES %s
            """
            for i in range(0, len(data_list), batch_size):
                batch = data_list[i:i + batch_size]
                values = [tuple(d.values()) for d in batch]
                extras.execute_values(cur, insert_query, values)

    except Exception as e:
        logger.error(f"批量插入数据失败: {e}")
        raise e


conn = create_connection()


def insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, rec_results):
    data_dict = {
        'path': pdf_path,
        'filename': pdf_name,
        'process_status': process_status,
        'analysis_start_time': start_time,
        'analysis_end_time': end_time
    }
    try:
        inserted_id = insert_data(conn, 'pdf_info', data_dict)
        if process_status == 2:
            data_list = []
            for i in range(len(rec_results)):
                # 每一页
                page_no = i + 1
                for j in range(len(rec_results[i])):
                    # 每一个box
                    box = rec_results[i][j]
                    content = box.content
                    clsid = box.clsid
                    table_title = box.table_title
                    order = j
                    data_dict = {
                        'layout_type': clsid,
                        'content': content,
                        'page_no': page_no,
                        'pdf_id': inserted_id,
                        'table_title': table_title,
                        'display_order': order
                    }
                    data_list.append(data_dict)
            insert_multiple_data(conn, 'pdf_analysis_output', data_list)
        conn.commit()
        return inserted_id
    except Exception as e:
        conn.rollback()
        logger.error(f'operate database error!\n{traceback.format_exc()}')
        raise e