from typing import * import logging import os LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_PATH): os.mkdir(LOG_PATH) logger.info(f"logger path: {LOG_PATH} ") KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") if not os.path.exists(KB_ROOT_PATH): os.mkdir(KB_ROOT_PATH) logger.info(f"knowledge base path: {KB_ROOT_PATH} ") def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") def get_vs_path(knowledge_base_name: str, vector_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name) EMBEDDING_DEVICE = 'cuda:0' def embedding_device(device: str = None): device = device or EMBEDDING_DEVICE return device