from typing import * import os # LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" # logger = logging.getLogger() # logger.setLevel(logging.INFO) # logging.basicConfig(format=LOG_FORMAT) # 日志存储路径 # LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") # if not os.path.exists(LOG_PATH): # os.mkdir(LOG_PATH) # logger.info(f"logger path: {LOG_PATH} ") KB_ROOT_PATH = os.path.join(os.path.dirname(__file__), "knowledge_base") # if not os.path.exists(KB_ROOT_PATH): # os.mkdir(KB_ROOT_PATH) # logger.info(f"knowledge base path: {KB_ROOT_PATH} ") def get_kb_path(knowledge_base_name: str): return os.path.join(KB_ROOT_PATH, knowledge_base_name) def get_doc_path(knowledge_base_name: str): return os.path.join(get_kb_path(knowledge_base_name), "content") def get_vs_path(knowledge_base_name: str, vector_name: str): return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name) EMBEDDING_DEVICE = 'cuda:0' def embedding_device(device: str = None): device = device or EMBEDDING_DEVICE return device