添加链接
注册
登录
link管理
链接快照平台
输入网页链接,自动生成快照
标签化管理网页链接
相关文章推荐
近视的野马
·
在不知道列和行的情况下,替换pandas ...
·
8 月前
·
奔跑的草稿纸
·
佛山市顺德区容桂铭厨厨具厂
·
1 年前
·
强健的生姜
·
WPF之DragDrop拖放实例---图像资 ...
·
1 年前
·
愉快的单车
·
狐狸的夏天黎晏书成为首席设计师和盛虹签约是第 ...
·
1 年前
·
温柔的针织衫
·
乌龙院四格漫画3傻兄宝弟 - 书有新旧,但文字没有
·
1 年前
·
link管理
›
hfai.pl.callbacks.model_checkpoint_hf — hfai 7.9.7.14 documentation
checkpoint
http://doc.hfai.high-flyer.cn/_modules/hfai/pl/callbacks/model_checkpoint_hf.html
失恋的回锅肉
10 月前
import
torch
from
pytorch_lightning.callbacks
import
ModelCheckpoint
from
hfai.pl.utilities
import
_HFAI_AVAILABLE
if
_HFAI_AVAILABLE
:
import
hfai.client
[docs]
class
ModelCheckpointHF
(
ModelCheckpoint
):
这是一个可以自动处理 Hfai 打断信号,自动挂起任务的 checkpoint 回调函数管理类, 支持 ``1.6.0 <= pytorch_lightning.__version__ <= 1.7.6``
Args:
dirpath (str): 模型文件的保存目录(默认为 ``None``)
filename (str): 模型文件的保存名称,例如 ``{epoch}-{val_loss:.2f}-{other_metric:.2f}``(默认为 ``None``)
monitor (str): 监测的指标名称(默认为 ``None``)
verbose (bool): 输出状态(默认为 ``False``)
save_last (bool): 是否保存最后一个模型文件(默认为 ``None``)
save_top_k (int): 保存前 ``k`` 好的模型文件,``k`` 为 ``0`` 时不保存,``k`` 为 ``-1`` 时保存所有模型文件(默认为 ``1``)
save_weights_only (bool): 是否仅保存模型的权重(默认为 ``False``)
mode (str): 指标的排序方式,包括:从大到小(``max``)或者从小到大(``min``),(默认为 ``min``)
auto_insert_metric_name (bool): 是否在模型名称上自动插入指标的数值(默认为 ``True``)
every_n_train_steps (int): 保存模型文件的训练间隔 step 数量(默认为 ``None``),不能和 ``train_time_interval`` 和 ``every_n_epochs`` 一同使用
train_time_interval (timedelta): 保存模型文件的训练间隔时间(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``every_n_epochs`` 一同使用
every_n_epochs (int): 保存模型文件的训练间隔 epoch 数量(默认为 ``None``),不能和 ``every_n_train_steps`` 和 ``train_time_interval`` 一同使用
save_on_train_epoch_end (bool): 是否在训练 epoch 时保存模型文件(默认为 ``None``)
Raises:
MisconfigurationException:
如果 ``save_top_k`` 比 ``-1``小
如果 ``monitor`` 不是 ``None`` 同时 ``save_top_k`` 不是 ``None``、``-1``、``0``
如果 ``mode`` 不是 ``"min"`` 或者 ``"max"``
ValueError:
如果 ``trainer.save_checkpoint`` 是 ``None``
Examples:
.. code-block:: python
from hfai.pl import ModelCheckpointHF
output_dir = 'hfai_out'
checkpoint_callback = ModelCheckpointHF(dirpath=output_dir) # 第一步:定义 checkpoint_callback
trainer = pytorch_lightning.Trainer(
max_epochs=3,
gpus=8,
strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()],
callbacks=[checkpoint_callback] # 第二步:将 checkpoint_callback 输入到 trainer
model_module = ToyNetModule()
hfai_suspend_ckpt_path = f'{output_dir}/{checkpoint_callback.CHECKPOINT_NAME_SUSPEND}.ckpt'
hfai_suspend_ckpt_path = hfai_suspend_ckpt_path if os.path.exists(hfai_suspend_ckpt_path) else None
trainer.fit(
model_module,
ckpt_path=hfai_suspend_ckpt_path # 第三步:重启后载入打断前的最新模型
CHECKPOINT_NAME_SUSPEND
=
'hfai_latest'
def
hfai_suspend
(
self
,
trainer
:
"pytorch_lightning.Trainer"
):
return
_HFAI_AVAILABLE
and
trainer
.
global_rank
==
0
and
hfai
.
client
.
receive_suspend_command
()
def
_save_hfai_suspend_checkpoint
(
self
,
trainer
:
"pytorch_lightning.Trainer"
,
force_save
:
bool
=
False
)
->
None
:
if
not
force_save
:
if
not
self
.
hfai_suspend
(
trainer
):
return
if
pytorch_lightning
.
__version__
>=
'1.6.0'
:
monitor_candidates
=
self
.
_monitor_candidates
(
trainer
)
filepath
=
self
.
format_checkpoint_name
(
monitor_candidates
,
self
.
CHECKPOINT_NAME_SUSPEND
)
_checkpoint
=
trainer
.
_checkpoint_connector
.
dump_checkpoint
()
if
not
os
.
path
.
exists
(
os
.
path
.
dirname
(
filepath
)):
os
.
makedirs
(
os
.
path
.
dirname
(
filepath
))
torch
.
save
(
_checkpoint
,
filepath
)
else
:
monitor_candidates
=
self
.
_monitor_candidates
(
trainer
,
trainer
.
current_epoch
,
trainer
.
global_step
-
1
)
filepath
=
self
.
format_checkpoint_name
(
monitor_candidates
,
self
.
CHECKPOINT_NAME_SUSPEND
)
trainer
.
save_checkpoint
(
filepath
)
if
self
.
hfai_suspend
(
trainer
):
print
(
f
'Receive suspend command. Now save checkpoint and go suspend! '
f
'Global rank
{
trainer
.
global_rank
}
. Save checkpoint to
{
filepath
}
'
)
time
.
sleep
(
3
)
hfai
.
client
.
go_suspend
()
def
on_train_epoch_end
(
self
,
*
args
,
**
kwargs
)
->
None
:
trainer
=
kwargs
.
get
(
'trainer'
,
None
)
or
args
[
0
]
self
.
_save_hfai_suspend_checkpoint
(
trainer
,
force_save
=
trainer
.
global_rank
==
0
)
super
()
.
on_train_epoch_end
(
*
args
,
**
kwargs
)
def
on_validation_epoch_end
(
self
,
*
args
,
**
kwargs
)
->
None
:
trainer
=
kwargs
.
get
(
'trainer'
,
None
)
or
args
[
0
]
self
.
_save_hfai_suspend_checkpoint
(
trainer
)
super
()
.
on_validation_epoch_end
(
*
args
,
**
kwargs
)
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
)
->
None
:
trainer
=
kwargs
.
get
(
'trainer'
,
None
)
or
args
[
0
]
self
.
_save_hfai_suspend_checkpoint
(
trainer
)
super
()
.
on_train_batch_end
(
*
args
,
**
kwargs
)
def
on_validation_batch_end
(
self
,
*
args
,
**
kwargs
)
->
None
:
trainer
=
kwargs
.
get
(
'trainer'
,
None
)
or
args
[
0
]
self
.
_save_hfai_suspend_checkpoint
(
trainer
)
super
()
.
on_validation_batch_end
(
*
args
,
**
kwargs
)
推荐文章
近视的野马
·
在不知道列和行的情况下,替换pandas DataFrame中的某个特定值。
8 月前
奔跑的草稿纸
·
佛山市顺德区容桂铭厨厨具厂
1 年前
强健的生姜
·
WPF之DragDrop拖放实例---图像资源管理器 - Rang's Note
1 年前
愉快的单车
·
狐狸的夏天黎晏书成为首席设计师和盛虹签约是第几集_百度知道
1 年前
温柔的针织衫
·
乌龙院四格漫画3傻兄宝弟 - 书有新旧,但文字没有
1 年前