|
|
|
@ -79,15 +79,17 @@ class my_model_ckpt(ModelCheckpoint):
|
|
|
|
|
to_save_od["config"] = self.config
|
|
|
|
|
to_save_od["info"] = "GPT-e%s" % (trainer.current_epoch + 1)
|
|
|
|
|
# torch.save(
|
|
|
|
|
my_save(
|
|
|
|
|
to_save_od,
|
|
|
|
|
"%s/%s-e%s.ckpt"
|
|
|
|
|
% (
|
|
|
|
|
self.half_weights_save_dir,
|
|
|
|
|
self.exp_name,
|
|
|
|
|
trainer.current_epoch + 1,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
# print(os.environ)
|
|
|
|
|
if(os.environ.get("LOCAL_RANK","0")=="0"):
|
|
|
|
|
my_save(
|
|
|
|
|
to_save_od,
|
|
|
|
|
"%s/%s-e%s.ckpt"
|
|
|
|
|
% (
|
|
|
|
|
self.half_weights_save_dir,
|
|
|
|
|
self.exp_name,
|
|
|
|
|
trainer.current_epoch + 1,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
self._save_last_checkpoint(trainer, monitor_candidates)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|