diff --git a/webui.py b/webui.py index e8262ea..12196ec 100644 --- a/webui.py +++ b/webui.py @@ -93,10 +93,30 @@ if torch.cuda.is_available() or ngpu != 0: def set_default(): - global default_batch_size,default_max_batch_size,gpu_info,default_sovits_epoch,default_sovits_save_every_epoch,max_sovits_epoch,max_sovits_save_every_epoch,default_batch_size_s1 + global default_batch_size,default_max_batch_size,gpu_info,default_sovits_epoch,default_sovits_save_every_epoch,max_sovits_epoch,max_sovits_save_every_epoch,default_batch_size_s1,if_force_ckpt + if_force_ckpt = False if if_gpu_ok and len(gpu_infos) > 0: gpu_info = "\n".join(gpu_infos) minmem = min(mem) + if version == "v3" and minmem < 14: + # API读取不到共享显存,直接填充确认 + try: + torch.zeros((1024,1024,1024,14),dtype=torch.int8,device="cuda") + torch.cuda.empty_cache() + minmem = 14 + except RuntimeError as _: + # 强制梯度检查只需要12G显存 + if minmem >= 12 : + if_force_ckpt = True + minmem = 14 + else: + try: + torch.zeros((1024,1024,1024,12),dtype=torch.int8,device="cuda") + torch.cuda.empty_cache() + if_force_ckpt = True + minmem = 14 + except RuntimeError as _: + print("显存不足以开启V3训练") default_batch_size = minmem // 2 if version!="v3"else minmem//14 default_batch_size_s1=minmem // 2 else: @@ -783,7 +803,7 @@ def switch_version(version_): else: gr.Warning(i18n(f'未下载{version.upper()}模型')) set_default() - return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False},{'__type__':'update',"interactive":True if version == "v3" else False},{'__type__':'update',"interactive":False if version == "v3" else True,"value":False} + return {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1].replace("s2G","s2D")}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_gpt_name[int(version[-1])-1]}, {'__type__':'update', 'value':pretrained_sovits_name[int(version[-1])-1]},{'__type__':'update',"value":default_batch_size,"maximum":default_max_batch_size},{'__type__':'update',"value":default_sovits_epoch,"maximum":max_sovits_epoch},{'__type__':'update',"value":default_sovits_save_every_epoch,"maximum":max_sovits_save_every_epoch},{'__type__':'update',"interactive":True if version!="v3"else False},{'__type__':'update',"value":False if not if_force_ckpt else True, "interactive":True if not if_force_ckpt else False},{'__type__':'update',"interactive":False if version == "v3" else True,"value":False} if os.path.exists('GPT_SoVITS/text/G2PWModel'):... else: