以某种频率保存 Keras 模型或模型权重的回调。
继承自:
Callback
tf.keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
initial_value_threshold=None,
**kwargs
ModelCheckpoint
回调与使用
model.fit()
的训练结合使用,以在某个时间间隔保存模型或权重(在检查点文件中),因此可以稍后加载模型或权重以从保存的状态继续训练。
此回调提供的一些选项包括:
-
是否只保留迄今为止达到 "best performance" 的模型,或者是否在每个 epoch 结束时保存模型,而不管性能如何。
-
“最好”的定义;要监控的数量以及是否应最大化或最小化。
-
保存的频率。目前,回调支持在每个 epoch 结束时或固定数量的训练批次后保存。
-
是仅保存权重,还是保存整个模型。
注意:如果您使用的是
WARNING:tensorflow:Can save best model only with <name> available, skipping
,请参阅
monitor
参数的说明,了解如何正确执行此操作的详细信息。
Example:
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
model.load_weights(checkpoint_filepath)
filepath
string 或 PathLike
,保存模型文件的路径。例如 filepath = os.path.join(working_dir, 'ckpt', file_name)。 filepath
可以包含命名格式选项,这些选项将填充 epoch
的值和 logs
中的键(在 on_epoch_end
中传递)。例如:如果 filepath
是 weights.{epoch:02d}-{val_loss:.2f}.hdf5
,则模型检查点将与纪元号和验证损失一起保存在文件名中。文件路径的目录不应被任何其他回调重复使用,以避免冲突。
monitor
要监视的指标名称。通常,指标由 Model.compile
方法设置。笔记:
- 名称前面加上
"val_
" 前缀以监视验证指标。 - 使用
"loss"
或 " val_loss
" 来监控模型的总损耗。 如果将指标指定为字符串(例如 "accuracy"
),请传递相同的字符串(带或不带 "val_"
前缀)。
如果传递 metrics.Metric
对象, monitor
应设置为 metric.name
如果您不确定指标名称,可以检查 history = model.fit()
返回的 history.history
字典的内容
多输出模型在指标名称上设置附加前缀。
verbose
详细模式,0 或 1。模式 0 是静默的,模式 1 在回调执行操作时显示消息。
save_best_only
如果是 save_best_only=True
,则仅在型号被认为是 "best" 时保存,并且根据监控数量最新的最佳型号不会被覆盖。如果 filepath
不包含像 {epoch}
这样的格式化选项,那么 filepath
将被每个新的更好型号覆盖。
mode
{'自动'、'最小'、'最大'}之一。如果是 save_best_only=True
,则根据监视数量的最大值或最小值来决定覆盖当前保存文件。对于 val_acc
,这应该是 max
,对于 val_loss
,这应该是 min
,等等。在 auto
模式下,如果监测的量是“acc”或以“fmeasure”开头,则模式设置为 max
,其余的设置为 min
。数量。
save_weights_only
如果为 True,则仅保存模型的权重 ( model.save_weights(filepath)
),否则保存完整模型 ( model.save(filepath)
)。
save_freq
'epoch'
或整数。使用 'epoch'
时,回调会在每个纪元后保存模型。使用整数时,回调会在这么多批次结束时保存模型。如果 Model
是用 steps_per_execution=N
编译的,则每 N 个批次将检查保存标准。请注意,如果保存未与纪元对齐,则受监控的指标可能不太可靠(它可能只反映 1 个批次,因为指标在每个纪元都会重置)。默认为 'epoch'
。
options
如果 save_weights_only
是 true ,则可选 tf.train.CheckpointOptions
对象;如果 save_weights_only
为 false,则可选 tf.saved_model.SaveOptions
对象。
initial_value_threshold
要监视的指标的浮点初始 "best" 值。仅适用于 save_best_value=True
。仅当当前模型的性能优于该值时才覆盖已保存的模型权重。
**kwargs
向后兼容性的附加参数。可能的密钥是 period
。 Methods
set_model
View source
set_model(
model
set_params
View source
set_params(
params
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/callbacks/ModelCheckpoint