添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Don't lose your work!

Create a free account to track your progress, so you won't miss a thing. Join for free

Now that you've done the setup, you can train your model. Lightning offers an object called a trainer that handles a lot of the training boilerplate for us. When you use pl.Trainer , you no longer have to write training loops. What a relief!

What Does the Trainer Do

The Trainer object handles a lot of details, such as:

  • Moves parameters and data between the CPU and GPU
  • Executes callbacks
  • Logging
  • And much more (see documentation for the rest)
  • Pytorch Lightning Trainer Example

    Once you instantiate a trainer, all you need to do is pass it the model and datamodule to the .fit method.

    trainer = pl.Trainer(
        accelerator='gpu', devices=torch.cuda.device_count(), # Tell the trainer how many GPUs to use
        max_epochs=30, # Set the number of epochs
        callbacks=callbacks, # Pass the callbacks to the trainer
        logger=logger, # Pass the logger to the trainer
        log_every_n_steps=1, # Determine how often you want to log metrics. If computing metrics is slow, increasing this number could improve training time.
    
    INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
    INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
    INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
    INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
    

    Fit the Model

    In the cell below, you begin training your model. If you are running this code in the workset notebook, re-visit the cell where tensorboard is running to keep track of the model's training progress!

    trainer.fit(model=model, datamodule=datamodule)
    
    WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: lightning_logs/lightning_mlp
    INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    INFO:pytorch_lightning.callbacks.model_summary:
        | Name      | Type               | Params
    -------------------------------------------------
    0 | input     | Sequential         | 100 K 
    1 | hidden    | Sequential         | 50.3 K
    2 | out       | Linear             | 1.3 K 
    3 | loss      | CrossEntropyLoss   | 0     
    4 | train_acc | MulticlassAccuracy | 0     
    5 | valid_acc | MulticlassAccuracy | 0     
    -------------------------------------------------
    152 K     Trainable params
    0         Non-trainable params
    152 K     Total params
    0.609     Total estimated model params size (MB)
    Sanity Checking: 0it [00:00, ?it/s]
    Training: 0it [00:00, ?it/s]
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved. New best score: 0.173
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.023 >= min_delta = 0.0. New best score: 0.150
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.024 >= min_delta = 0.0. New best score: 0.126
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.117
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.103
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.009 >= min_delta = 0.0. New best score: 0.094
    Validation: 0it [00:00, ?it/s]
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Metric valid_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.092
    Validation: 0it [00:00, ?it/s]
    Validation: 0it [00:00, ?it/s]
    Validation: 0it [00:00, ?it/s]
    INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric valid_loss did not improve in the last 3 records. Best score: 0.092. Signaling Trainer to stop.
      

    Interested in becoming a professional Python developer?

    Python is the fastest-growing major programming language

  • Python is extremely versatile - regularly used for scripting, automation, middleware, web development, data science, machine learning, and AI
  • Python is consistently rated as one of the most in-demand programming languages, and demand is expected to increase over the next decade
  • The average annual salary for a Python engineer in the U.S. ranges from $134K - $210K USD
  • Join CodingNomads' membership program to take your skills to the next level and land your dream job
  • Exercise 9.2: Experiment with Tensorboard

    In this exercise, you will perform a few experiments to see if any changes in your model improve the best validation score you can achieve. Return to the line of code where you instantiate your model and change some of the hyperparameters one at a time. Do these improve your score? Use the running instance of tensorboard from a previous lesson👆 to monitor these experiments. In the cell below, take some notes on which hyperparameters work the best. If model performance is about the same even after some experimentation, what sort of choices would you make were you to choose a final model?

    Learn by Doing

    # Notes:
      

    Complete Exercise 9.2 in this notebook.

    Summary: Pytorch Lightning Trainer

    In this lesson, you used the Trainer object to train your model, achieving about 95% accuracy in about 10 epochs. You also performed a few experiments using a tensorboard to track your progress.

    Notes:

  • The Trainer object is Pytorch Lightning's training class
  • The Trainer moves parameters and data between the CPU and GPU
  • The Trainer executes callbacks
  • The Trainer creates logs
  •