|
|
|
@ -45,14 +45,17 @@ i18n = I18nAuto()
|
|
|
|
|
from scipy.io import wavfile
|
|
|
|
|
from tools.my_utils import load_audio
|
|
|
|
|
from multiprocessing import cpu_count
|
|
|
|
|
|
|
|
|
|
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 当遇到mps不支持的步骤时使用cpu
|
|
|
|
|
|
|
|
|
|
n_cpu=cpu_count()
|
|
|
|
|
|
|
|
|
|
# 判断是否有能用来训练和加速推理的N卡
|
|
|
|
|
ngpu = torch.cuda.device_count()
|
|
|
|
|
gpu_infos = []
|
|
|
|
|
mem = []
|
|
|
|
|
if_gpu_ok = False
|
|
|
|
|
|
|
|
|
|
# 判断是否有能用来训练和加速推理的N卡
|
|
|
|
|
if torch.cuda.is_available() or ngpu != 0:
|
|
|
|
|
for i in range(ngpu):
|
|
|
|
|
gpu_name = torch.cuda.get_device_name(i)
|
|
|
|
@ -61,6 +64,12 @@ if torch.cuda.is_available() or ngpu != 0:
|
|
|
|
|
if_gpu_ok = True # 至少有一张能用的N卡
|
|
|
|
|
gpu_infos.append("%s\t%s" % (i, gpu_name))
|
|
|
|
|
mem.append(int(torch.cuda.get_device_properties(i).total_memory/ 1024/ 1024/ 1024+ 0.4))
|
|
|
|
|
# 判断是否支持mps加速
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
|
if_gpu_ok = True
|
|
|
|
|
gpu_infos.append("%s\t%s" % ("0", "Apple GPU"))
|
|
|
|
|
mem.append(psutil.virtual_memory().total/ 1024 / 1024 / 1024) # 实测使用系统内存作为显存不会爆显存
|
|
|
|
|
|
|
|
|
|
if if_gpu_ok and len(gpu_infos) > 0:
|
|
|
|
|
gpu_info = "\n".join(gpu_infos)
|
|
|
|
|
default_batch_size = min(mem) // 2
|
|
|
|
|