添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
谦逊的电脑桌  ·  Training & evaluation ...·  4 天前    · 
飞翔的投影仪  ·  System.IO.IOException无 ...·  5 天前    · 
逆袭的灭火器  ·  Deep learning with ...·  3 周前    · 
咆哮的爆米花  ·  DeepLearning_Notes_CV/ ...·  3 周前    · 
直爽的猕猴桃  ·  Error: Unhandled ...·  3 周前    · 
叛逆的苦瓜  ·  Redirects - October ...·  5 月前    · 
爽快的充电器  ·  Parse error for SVG ...·  11 月前    · 
坚韧的丝瓜  ·  SQL Server ...·  1 年前    · 

以指定频率保存 Keras 模型或权重。

ModelCheckpoint callback 通过 model.fit() 与训练结合使用,以指定时间间隔保存模型或权重为 checkpoint 文件,以便稍后可以加载模型或权重,从而从保存的状态继续训练。

该 callback 提供了如下选项:

  • 是只保留到目前为止 性能最佳 的模型,还是不管性能,每个 epoch 结束时都保存模型;
  • 最佳的定义 :要监控的指标,以及应该最大化还是最小化;
  • 保存的频率,目前支持在每个 epoch 结束时保存,或指定训练 batches 后保存;
  • 是只保存权重,还是保存整个模型。

[!NOTE]
如果出现 WARNING:tensorflow:Can save best model only with <name> available, skipping 信息,可以参考 monitor 参数说明。

filepath

保存模型文件的路径,string 或 PathLike ,例如 filepath = os.path.join(working_dir, 'ckpt', file_name) filepath 可以包含命名格式化选项。例如,如果 filepath weights.{epoch:02d}-{val_loss:.2f}.hdf5 ,则 model checkpoint 文件名包含 epoch 号和 validation loss。 filepath 目录不应该被其它 callback 使用,以避免冲突。

monitor

要监控指标的名称。指标一般通过 Model.compile 方法设置。注意:

  • 在名称前加 “val_” 前缀以监控 validation 指标;
  • 使用 “loss” 或 “val_loss” 以监控模型的总损失;
  • 如果用字符串指定指标,如 “accuracy”,传入相同的字符串(带或不带 “val_” 前缀);
  • 如果传入 metrics.Metric 对象, monitor 应该设置为 metric.name
  • 如果不确定指标名称,可以检查 history=model.fit() 返回的 history.history dict;
  • 多输出模型的指标名称包含额外的前缀。

verbose

详细模式,0 或 1:

  • 0 silent
  • 1 在 callback 执行时显示消息

save_best_only

save_best_only=True 只在模型被认为是目前最好时保存。如果 filepath 不包含格式化选项,例如 {epoch} ,则新保存的更好模型将覆盖之前保存的模型。

{‘auto’, ‘min’, ‘max’} 之一。

如果 save_best_only=True ,则根据监视指标的最大化或最小化来决定是否覆盖保存文件。对 val_acc 应为 max ,对 val_loss 应为 min 。在 auto 模式,如果监控的指标为 acc 或以 ‘fmeasure’ 开头,则模式为 max ,对余下的则为 min

save_weights_only

True 表示只保存模型的权重 model.save_weights(filepath) ,否则保存整个模型 model.save(filepath)

save_freq

‘epoch’ 或 integer。当使用 'epoch' 时,callback 在每个 epoch 后保存模型。当使用 integer,则在这些 batch 后保存模型。如果 Model 使用 steps_per_execution=N 选项进行编译,则每 Nth batch 检查保存条件。注意,如果保存和 epoch 没对齐,则监控指标可能不可靠(它可能只反应一个 batch,因为指标在每个 epoch 结束会重置)。默认 ‘epoch’。

options

save_weights_only 为 True 时可选的 tf.train.CheckpointOptions 对象 或 save_weights_only 为 False 时可选的 tf.saved_model.SaveOptions 对象。

initial_value_threshold

指标的最佳值(浮点数)。 save_best_value=True 时适用。当模型的性能优于该值时才保存模型权重。

model.compile(loss=..., optimizer=..., metrics=["accuracy"])
EPOCHS = 10
checkpoint_filepath = "/tmp/checkpoint"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
# 如果模型是目前为止最好的模型,则在每个 epoch 后保存模型权重
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
  • https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint
  • https://keras.io/api/callbacks/model_checkpoint/
我们训练完模型之后,一般会需要保存模型或者只保存权重文件。可以利用keras中的回调函数ModelCheckpoint进行保存。 keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=F...
from keras.callbacks import ModelCheckpoint Keras中使用ModelCheckpoint对训练完成的模型进行保存及载入 Keras函数——keras.callbacks.ModelCheckpoint()各参数解释及模型的训练
转载于https://cloud.tencent.com/developer/article/1049579,如有侵权,请联系 [email protected] 删除 深度学习模式可能需要几个小时,几天甚至几周的时间来训练。 如果运行意外停止,你可能就白干了。 在这篇文章中,你将会发现在使用Keras库的Python训练过程中,如何检查你的深度学习模型 Checkpoint神经网络模型 应用程序Checkpoint是为长时间运行进程准备的容错技术。 这是一种在系统故障的情况下拍摄系统状态快照的方法.
import json import tensorflow.keras.models from tensorflow.keras.callbacks import * #这是独立的包 import tensorflow.keras mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train
在神经网络的训练学习过程中,常常需要把训练好的模型保存下来,ModelCheckpoint技术就是一种很实用的模型保存与改进方法。 在keras中通过回调API实现Checkpoint功能,本质上是callbacks的一个类。使用前需要从keras库中调用: from kearas.callbacks import ModelCheckpoint ModelCheckpoint的一般格式是: checkpoint = ModelCheckpoint(filename, monitor='loss', ve
使用tf.trian.NewCheckpointReader(model_dir) 一个标准的模型文件有一下文件, model_dir就是MyModel(没有后缀) checkpoint Model.meta Model.data-00000-of-00001 Model.index import tensorflow as tf import pprint # 使用pprint 提高打印的可读性 NewCheck =tf.train.NewCheckpointReader("model") 打印模型中的所有变量 print("debug_string:\n") pprint.pprin