|
|
|
@ -41,12 +41,18 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
|
|
|
|
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
|
|
|
|
if os.path.exists(txt_path) == False:
|
|
|
|
|
bert_dir = "%s/3-bert" % (opt_dir)
|
|
|
|
|
os.makedirs(opt_dir, exist_ok=True)
|
|
|
|
|
os.makedirs(bert_dir, exist_ok=True)
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "mps"
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
device = "cuda:0"
|
|
|
|
|
elif torch.backends.mps.is_available():
|
|
|
|
|
device = "mps"
|
|
|
|
|
else:
|
|
|
|
|
device = "cpu"
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
|
|
|
|
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
|
|
|
|
if is_half == True:
|
|
|
|
|