添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
import torch from pytorch_lightning.callbacks import ModelCheckpoint from hfai.pl.utilities import _HFAI_AVAILABLE if _HFAI_AVAILABLE : import hfai.client
[docs] class ModelCheckpointHF ( ModelCheckpoint ): 这是一个可以自动处理 Hfai 打断信号,自动挂起任务的 checkpoint 回调函数管理类, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6`` Args: dirpath (str): 模型文件的保存目录(默认为 ``None``) filename (str): 模型文件的保存名称,例如 ``{epoch}-{val_loss:.2f}-{other_metric:.2f}``(默认为 ``None``) monitor (str): 监测的指标名称(默认为 ``None``) verbose (bool): 输出状态(默认为 ``False``) save_last (bool): 是否保存最后一个模型文件(默认为 ``None``) save_top_k (int): 保存前 ``k`` 好的模型文件,``k`` 为 ``0`` 时不保存,``k`` 为 ``-1`` 时保存所有模型文件(默认为 ``1``) save_weights_only (bool): 是否仅保存模型的权重(默认为 ``False``) mode (str): 指标的排序方式,包括:从大到小(``max``)或者从小到大(``min``),(默认为 ``min``) auto_insert_metric_name (bool): 是否在模型名称上自动插入指标的数值(默认为 ``True``) every_n_train_steps (int): 保存模型文件的训练间隔 step 数量(默认为 ``None``),不能和 ``train_time_interval`` 和 ``every_n_epochs`` 一同使用 train_time_interval (timedelta): 保存模型文件的训练间隔时间(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``every_n_epochs`` 一同使用 every_n_epochs (int): 保存模型文件的训练间隔 epoch 数量(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``train_time_interval`` 一同使用 save_on_train_epoch_end (bool): 是否在训练 epoch 时保存模型文件(默认为 ``None``) Raises: MisconfigurationException: 如果 ``save_top_k`` 比 ``-1``小 如果 ``monitor`` 不是 ``None`` 同时 ``save_top_k`` 不是 ``None``、``-1``、``0`` 如果 ``mode`` 不是 ``"min"`` 或者 ``"max"`` ValueError: 如果 ``trainer.save_checkpoint`` 是 ``None`` Examples: .. code-block:: python from hfai.pl import ModelCheckpointHF output_dir = 'hfai_out' checkpoint_callback = ModelCheckpointHF(dirpath=output_dir) # 第一步:定义 checkpoint_callback trainer = pytorch_lightning.Trainer( max_epochs=3, gpus=8, strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()], callbacks=[checkpoint_callback] # 第二步:将 checkpoint_callback 输入到 trainer model_module = ToyNetModule() hfai_suspend_ckpt_path = f'{output_dir}/{checkpoint_callback.CHECKPOINT_NAME_SUSPEND}.ckpt' hfai_suspend_ckpt_path = hfai_suspend_ckpt_path if os.path.exists(hfai_suspend_ckpt_path) else None trainer.fit( model_module, ckpt_path=hfai_suspend_ckpt_path # 第三步:重启后载入打断前的最新模型 CHECKPOINT_NAME_SUSPEND = 'hfai_latest' def hfai_suspend ( self , trainer : "pytorch_lightning.Trainer" ): return _HFAI_AVAILABLE and trainer . global_rank == 0 and hfai . client . receive_suspend_command () def _save_hfai_suspend_checkpoint ( self , trainer : "pytorch_lightning.Trainer" , force_save : bool = False ) -> None : if not force_save : if not self . hfai_suspend ( trainer ): return if pytorch_lightning . __version__ >= '1.6.0' : monitor_candidates = self . _monitor_candidates ( trainer ) filepath = self . format_checkpoint_name ( monitor_candidates , self . CHECKPOINT_NAME_SUSPEND ) _checkpoint = trainer . _checkpoint_connector . dump_checkpoint () if not os . path . exists ( os . path . dirname ( filepath )): os . makedirs ( os . path . dirname ( filepath )) torch . save ( _checkpoint , filepath ) else : monitor_candidates = self . _monitor_candidates ( trainer , trainer . current_epoch , trainer . global_step - 1 ) filepath = self . format_checkpoint_name ( monitor_candidates , self . CHECKPOINT_NAME_SUSPEND ) trainer . save_checkpoint ( filepath ) if self . hfai_suspend ( trainer ): print ( f 'Receive suspend command. Now save checkpoint and go suspend! ' f 'Global rank { trainer . global_rank } . Save checkpoint to { filepath } ' ) time . sleep ( 3 ) hfai . client . go_suspend () def on_train_epoch_end ( self , * args , ** kwargs ) -> None : trainer = kwargs . get ( 'trainer' , None ) or args [ 0 ] self . _save_hfai_suspend_checkpoint ( trainer , force_save = trainer . global_rank == 0 ) super () . on_train_epoch_end ( * args , ** kwargs ) def on_validation_epoch_end ( self , * args , ** kwargs ) -> None : trainer = kwargs . get ( 'trainer' , None ) or args [ 0 ] self . _save_hfai_suspend_checkpoint ( trainer ) super () . on_validation_epoch_end ( * args , ** kwargs ) def on_train_batch_end ( self , * args , ** kwargs ) -> None : trainer = kwargs . get ( 'trainer' , None ) or args [ 0 ] self . _save_hfai_suspend_checkpoint ( trainer ) super () . on_train_batch_end ( * args , ** kwargs ) def on_validation_batch_end ( self , * args , ** kwargs ) -> None : trainer = kwargs . get ( 'trainer' , None ) or args [ 0 ] self . _save_hfai_suspend_checkpoint ( trainer ) super () . on_validation_batch_end ( * args , ** kwargs )