添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

How to use the ModelCheckpoint callback with Keras and TensorFlow

All that said, are you:

  • Short on time?
  • Learning on your employer’s administratively locked system?
  • Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
  • Ready to run the code right now on your Windows, macOS, or Linux system?

Then join PyImageSearch University today!

Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.

And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!

Checkpointing Best Neural Network Only

Perhaps the biggest downside with checkpointing incremental improvements is that we end up with a bunch of extra files that we are (unlikely) interested in, which is especially true if our validation loss moves up and down over training epochs — each of these incremental improvements will be captured and serialized to disk. In this case, it’s best to save only one model and simply overwrite it every time our metric improves during training.

Luckily, accomplishing this action is as simple as updating the ModelCheckpoint class to accept a simple string (i.e., a file path without any template variables). Then, whenever our metric improves, that file is simply overwritten. To understand the process, let’s create a second Python file named cifar10_checkpoint_best.py and review the differences.

First, we need to import our required Python packages:

# import the necessary packages
from sklearn.preprocessing import LabelBinarizer
from pyimagesearch.nn.conv import MiniVGGNet
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import cifar10
import argparse

Then parse our command line arguments:

# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-w", "--weights", required=True,
	help="path to best model weights file")
args = vars(ap.parse_args())

The name of the command line argument itself is the same ( --weights ), but the description of the switch is now different: “path to best model weights file .” Thus, this command line argument will be a simple string to an output path — there will be no template applied to this string.

From there we can load our CIFAR-10 dataset and prepare it for training:

# load the training and testing data, then scale it into the
# range [0, 1]
print("[INFO] loading CIFAR-10 data...")
((trainX, trainY), (testX, testY)) = cifar10.load_data()
trainX = trainX.astype("float") / 255.0
testX = testX.astype("float") / 255.0
# convert the labels from integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

As well as initialize our SGD optimizer and MiniVGGNet architecture:

# initialize the optimizer and model
print("[INFO] compiling model...")
opt = SGD(lr=0.01, decay=0.01 / 40, momentum=0.9, nesterov=True)
model = MiniVGGNet.build(width=32, height=32, depth=3, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt,
	metrics=["accuracy"])

We are now ready to update the ModelCheckpoint code:

# construct the callback to save only the *best* model to disk
# based on the validation loss
checkpoint = ModelCheckpoint(args["weights"], monitor="val_loss",
	save_best_only=True, verbose=1)
callbacks = [checkpoint]

Notice how the fname template string is gone — all we are doing is supplying the value of --weights to ModelCheckpoint . Since there are no template values to fill in, Keras will simply overwrite the existing serialized weights file whenever our monitoring metric improves (in this case, validation loss).

Finally, we train on network in the code block below:

# train the network
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY),
	batch_size=64, epochs=40, callbacks=callbacks, verbose=2)

To execute our script, issue the following command:

$ python cifar10_checkpoint_best.py \
	--weights weights/best/cifar10_best_weights.hdf5
[INFO] loading CIFAR-10 data...
[INFO] compiling model...
[INFO] training network...
Train on 50000 samples, validate on 10000 samples
Epoch 1/40
Epoch 00000: val_loss improved from inf to 1.26677, saving model to
	test_best/cifar10_best_weights.hdf5
305s - loss: 1.6657 - acc: 0.4441 - val_loss: 1.2668 - val_acc: 0.5584
Epoch 2/40
Epoch 00001: val_loss improved from 1.26677 to 1.21923, saving model to
	test_best/cifar10_best_weights.hdf5
309s - loss: 1.1996 - acc: 0.5828 - val_loss: 1.2192 - val_acc: 0.5798
Epoch 40/40
Epoch 00039: val_loss did not improve
173s - loss: 0.2615 - acc: 0.9079 - val_loss: 0.5511 - val_acc: 0.8250

Here, you can see that we overwrite our cifar10_best_weights.hdf5 file with the updated network only if our validation loss decreases. This has two primary benefits:

  1. There is only one serialized file at the end of the training process — the model epoch that obtained the lowest loss.
  2. We are not capturing “incremental improvements” where loss fluctuates up and down. Instead, we only save and overwrite the existing best model if our metric obtains a loss lower than all previous epochs.

To confirm this, take a look at my weights/best directory where you can see there is only one output file:

$ ls -l weights/best/
total 17024
-rw-rw-r-- 1 adrian adrian 17431968 Apr 28 09:47 cifar10_best_weights.hdf5

You can then take this serialized MiniVGGNet and further evaluate it on the testing data or apply it to your own images.

Course information:
79 total classes • 101+ hours of on-demand code walkthrough videos • Last updated: August 2023
★★★★★ 4.84 (128 Ratings) • 16,000+ Students Enrolled

I strongly believe that if you had the right teacher you could master computer vision and deep learning.

Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?

That’s not the case.

All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do . My mission is to change education and how complex Artificial Intelligence topics are taught.

If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.

Inside PyImageSearch University you'll find:

  • ✓ 79 courses on essential computer vision, deep learning, and OpenCV topics
  • ✓ 79 Certificates of Completion
  • ✓ 101+ hours of on-demand video
  • ✓ Brand new courses released regularly , ensuring you can keep up with state-of-the-art techniques
  • ✓ Pre-configured Jupyter Notebooks in Google Colab
  • ✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
  • ✓ Access to centralized code repos for all 512+ tutorials on PyImageSearch
  • ✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
  • ✓ Access on mobile, laptop, desktop, etc.
  • Click here to join PyImageSearch University

    About the Author

    Hi there, I’m Adrian Rosebrock, PhD. All too often I see developers, students, and researchers wasting their time, studying the wrong things, and generally struggling to get started with Computer Vision, Deep Learning, and OpenCV. I created this website to show you what I believe is the best possible way to get your start.

    Previous Article:

    Data augmentation with tf.data and TensorFlow

    Next Article:

    What is PyTorch?

    Comment section

    Hey, Adrian Rosebrock here, author and creator of PyImageSearch. While I love hearing from readers, a couple years ago I made the tough decision to no longer offer 1:1 help over blog post comments.

    At the time I was receiving 200+ emails per day and another 100+ blog post comments. I simply did not have the time to moderate and respond to them all, and the sheer volume of requests was taking a toll on me.

    Instead, my goal is to do the most good for the computer vision, deep learning, and OpenCV community at large by focusing my time on authoring high-quality blog posts, tutorials, and books/courses.

    If you need help learning computer vision and deep learning, I suggest you refer to my full catalog of books and courses — they have helped tens of thousands of developers, students, and researchers just like yourself learn Computer Vision, Deep Learning, and OpenCV.

    Click here to browse my full catalog.