添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
爱旅游的硬币  ·  Memory-efficient ...·  3 周前    · 
坚韧的包子  ·  Tensorflow model ...·  3 周前    · 
从容的木耳  ·  [email protected] | ...·  3 月前    · 
沉稳的筷子  ·  ajax 导出excel - CSDN文库·  5 月前    · 
贪玩的紫菜  ·  Node-RED : 安全·  6 月前    · 
会开车的仙人掌  ·  How to fix corrupted ...·  1 年前    · 

tf.keras.callbacks.ModelCheckpoint

以某种频率保存 Keras 模型或模型权重的回调。

继承自: Callback

View aliases

用于迁移的兼容别名

有关详细信息,请参阅 Migration guide

tf.compat.v1.keras.callbacks.ModelCheckpoint

tf.keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    options=None,
    initial_value_threshold=None,
    **kwargs

ModelCheckpoint 回调与使用 model.fit() 的训练结合使用,以在某个时间间隔保存模型或权重(在检查点文件中),因此可以稍后加载模型或权重以从保存的状态继续训练。

此回调提供的一些选项包括:

  • 是否只保留迄今为止达到 "best performance" 的模型,或者是否在每个 epoch 结束时保存模型,而不管性能如何。
  • “最好”的定义;要监控的数量以及是否应最大化或最小化。
  • 保存的频率。目前,回调支持在每个 epoch 结束时或固定数量的训练批次后保存。
  • 是仅保存权重,还是保存整个模型。
注意:如果您使用的是 WARNING:tensorflow:Can save best model only with <name> available, skipping ,请参阅 monitor 参数的说明,了解如何正确执行此操作的详细信息。

Example:

model.compile(loss=..., optimizer=...,
              metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)
# 模型权重在每个 epoch 结束时保存(如果是最好的)
# 迄今为止。
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# 模型权重(被认为是最好的)被加载到模型中。
model.load_weights(checkpoint_filepath)
  filepath  string 或  PathLike  ,保存模型文件的路径。例如 filepath = os.path.join(working_dir, 'ckpt', file_name)。 filepath  可以包含命名格式选项,这些选项将填充  epoch  的值和  logs  中的键(在  on_epoch_end  中传递)。例如:如果  filepathweights.{epoch:02d}-{val_loss:.2f}.hdf5  ,则模型检查点将与纪元号和验证损失一起保存在文件名中。文件路径的目录不应被任何其他回调重复使用,以避免冲突。 
  monitor  要监视的指标名称。通常,指标由  Model.compile  方法设置。笔记:
  • 名称前面加上 "val_ " 前缀以监视验证指标。
  • 使用 "loss" 或 " val_loss " 来监控模型的总损耗。
  • 如果将指标指定为字符串(例如 "accuracy" ),请传递相同的字符串(带或不带 "val_" 前缀)。

  • 如果传递 metrics.Metric 对象, monitor 应设置为 metric.name

  • 如果您不确定指标名称,可以检查 history = model.fit() 返回的 history.history 字典的内容

  • 多输出模型在指标名称上设置附加前缀。

  • verbose 详细模式,0 或 1。模式 0 是静默的,模式 1 在回调执行操作时显示消息。 save_best_only 如果是 save_best_only=True ,则仅在型号被认为是 "best" 时保存,并且根据监控数量最新的最佳型号不会被覆盖。如果 filepath 不包含像 {epoch} 这样的格式化选项,那么 filepath 将被每个新的更好型号覆盖。 mode {'自动'、'最小'、'最大'}之一。如果是 save_best_only=True ,则根据监视数量的最大值或最小值来决定覆盖当前保存文件。对于 val_acc ,这应该是 max ,对于 val_loss ,这应该是 min ,等等。在 auto 模式下,如果监测的量是“acc”或以“fmeasure”开头,则模式设置为 max ,其余的设置为 min 。数量。 save_weights_only 如果为 True,则仅保存模型的权重 ( model.save_weights(filepath) ),否则保存完整模型 ( model.save(filepath) )。 save_freq 'epoch' 或整数。使用 'epoch' 时,回调会在每个纪元后保存模型。使用整数时,回调会在这么多批次结束时保存模型。如果 Model 是用 steps_per_execution=N 编译的,则每 N 个批次将检查保存标准。请注意,如果保存未与纪元对齐,则受监控的指标可能不太可靠(它可能只反映 1 个批次,因为指标在每个纪元都会重置)。默认为 'epoch' options 如果 save_weights_only 是 true ,则可选 tf.train.CheckpointOptions 对象;如果 save_weights_only 为 false,则可选 tf.saved_model.SaveOptions 对象。 initial_value_threshold 要监视的指标的浮点初始 "best" 值。仅适用于 save_best_value=True 。仅当当前模型的性能优于该值时才覆盖已保存的模型权重。 **kwargs 向后兼容性的附加参数。可能的密钥是 period

    Methods

    set_model

    View source

    set_model(
        model
    

    set_params

    View source

    set_params(
        params
        © 2022 The TensorFlow Authors. All rights reserved.
    Licensed under the Creative Commons Attribution License 4.0.
    Code samples licensed under the Apache 2.0 License.
    https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/callbacks/ModelCheckpoint