添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
沉着的沙滩裤  ·  StatusCode.RESOURCE_EX ...·  4 天前    · 
坐怀不乱的大葱  ·  RPC error - Logic 2 ...·  4 天前    · 
要出家的煎饼果子  ·  CREATE INDEX ...·  3 天前    · 
完美的馒头  ·  Mail sending not ...·  昨天    · 
会搭讪的大葱  ·  Changelog — Haystack ...·  昨天    · 
健壮的核桃  ·  兼容性·  1 月前    · 

A Multilayer Perceptron for Ordinal Regression using CORAL -- Cement Dataset

In this tutorial, we implement a multilayer perceptron for ordinal regression based on the CORAL method. To learn more about CORAL, please have a look at our paper:

  • Wenzhi Cao, Vahid Mirjalili, and Sebastian Raschka (2020): Rank Consistent Ordinal Regression for Neural Networks with Application to Age Estimation . Pattern Recognition Letters. 140, 325-331
  • General settings and hyperparameters

  • Here, we specify some general hyperparameter values and general settings
  • Note that for small datatsets, it is not necessary and better not to use multiple workers as it can sometimes cause issues with too many open files in PyTorch
  • BATCH_SIZE = 32
    NUM_EPOCHS = 200
    LEARNING_RATE = 0.01
    NUM_WORKERS = 0
    DATA_BASEPATH = "./data"
    

    Converting a regular classifier into a CORAL ordinal regression model

    Changing a classifier to a CORAL model for ordinal regression is actually really simple and only requires a few changes:

    We replace the output layer

    output_layer = torch.nn.Linear(hidden_units[-1], num_classes)
    

    by a CORAL layer (available through coral_pytorch):

    output_layer = CoralLayer(size_in=hidden_units[-1], num_classes=num_classes)`
    

    Convert the integer class labels into the extended binary label format using the levels_from_labelbatch provided via coral_pytorch:

    levels = levels_from_labelbatch(class_labels, 
                                    num_classes=num_classes)
    

    Swap the cross entropy loss from PyTorch,

    torch.nn.functional.cross_entropy(logits, true_labels)
    

    with the CORAL loss (also provided via coral_pytorch):

    loss = coral_loss(logits, levels)
    

    In a regular classifier, we usually obtain the predicted class labels as follows:

    predicted_labels = torch.argmax(logits, dim=1)
    

    Replace this with the following code to convert the predicted probabilities into the predicted labels:

    predicted_labels = proba_to_label(probas)
    

    Implementing a MultiLayerPerceptron using PyTorch Lightning's LightningModule

  • In this section, we set up the main model architecture using the LightningModule from PyTorch Lightning.
  • We start with defining our MultiLayerPerceptron model in pure PyTorch, and then we use it in the LightningModule to get all the extra benefits that PyTorch Lightning provides.
  • Given a multilayer perceptron classifier with cross-entropy loss, it is very easy to change this classifier into a ordinal regression model using CORAL as explained in the previous section. In the code example below, we use "1) the CoralLayer".
  • import torch
    from coral_pytorch.layers import CoralLayer
    # Regular PyTorch Module
    class MultiLayerPerceptron(torch.nn.Module):
        def __init__(self, input_size, hidden_units, num_classes):
            super().__init__()
            # num_classes is used by the CORAL loss function
            self.num_classes = num_classes
            # Initialize MLP layers
            all_layers = []
            for hidden_unit in hidden_units:
                layer = torch.nn.Linear(input_size, hidden_unit)
                all_layers.append(layer)
                all_layers.append(torch.nn.ReLU())
                input_size = hidden_unit
            # CORAL: output layer -------------------------------------------
            # Regular classifier would use the following output layer:
            # output_layer = torch.nn.Linear(hidden_units[-1], num_classes)
            # We replace it by the CORAL layer:
            output_layer = CoralLayer(size_in=hidden_units[-1],
                                      num_classes=num_classes)
            # ----------------------------------------------------------------
            all_layers.append(output_layer)
            self.model = torch.nn.Sequential(*all_layers)
        def forward(self, x):
            x = self.model(x)
            return x
    
  • In our LightningModule we use loggers to track mean absolute errors for both the training and validation set during training; this allows us to select the best model based on validation set performance later.
  • Note that we make changes 2) (levels_from_labelbatch), 3) (coral_loss), and 4) (proba_to_label) to implement a CORAL model instead of a regular classifier:
  • from coral_pytorch.losses import coral_loss
    from coral_pytorch.dataset import levels_from_labelbatch
    from coral_pytorch.dataset import proba_to_label
    import pytorch_lightning as pl
    import torchmetrics
    # LightningModule that receives a PyTorch model as input
    class LightningMLP(pl.LightningModule):
        def __init__(self, model, learning_rate):
            super().__init__()
            self.learning_rate = learning_rate
            # The inherited PyTorch module
            self.model = model
            # Save settings and hyperparameters to the log directory
            # but skip the model parameters
            self.save_hyperparameters(ignore=['model'])
            # Set up attributes for computing the MAE
            self.train_mae = torchmetrics.MeanAbsoluteError()
            self.valid_mae = torchmetrics.MeanAbsoluteError()
            self.test_mae = torchmetrics.MeanAbsoluteError()
        # Defining the forward method is only necessary 
        # if you want to use a Trainer's .predict() method (optional)
        def forward(self, x):
            return self.model(x)
        # A common forward step to compute the loss and labels
        # this is used for training, validation, and testing below
        def _shared_step(self, batch):
            features, true_labels = batch
            # Convert class labels for CORAL ------------------------
            levels = levels_from_labelbatch(
                true_labels, num_classes=self.model.num_classes)
            # -------------------------------------------------------
            logits = self(features)
            # CORAL Loss --------------------------------------------
            # A regular classifier uses:
            # loss = torch.nn.functional.cross_entropy(logits, true_labels)
            loss = coral_loss(logits, levels.type_as(logits))
            # -------------------------------------------------------
            # CORAL Prediction to label -----------------------------
            # A regular classifier uses:
            # predicted_labels = torch.argmax(logits, dim=1)
            probas = torch.sigmoid(logits)
            predicted_labels = proba_to_label(probas)
            # -------------------------------------------------------
            return loss, true_labels, predicted_labels
        def training_step(self, batch, batch_idx):
            loss, true_labels, predicted_labels = self._shared_step(batch)
            self.log("train_loss", loss)
            self.train_mae(predicted_labels, true_labels)
            self.log("train_mae", self.train_mae, on_epoch=True, on_step=False)
            return loss  # this is passed to the optimzer for training
        def validation_step(self, batch, batch_idx):
            loss, true_labels, predicted_labels = self._shared_step(batch)
            self.log("valid_loss", loss)
            self.valid_mae(predicted_labels, true_labels)
            self.log("valid_mae", self.valid_mae,
                     on_epoch=True, on_step=False, prog_bar=True)
        def test_step(self, batch, batch_idx):
            _, true_labels, predicted_labels = self._shared_step(batch)
            self.test_mae(predicted_labels, true_labels)
            self.log("test_mae", self.test_mae, on_epoch=True, on_step=False)
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
            return optimizer
    

    Setting up the dataset

  • In this section, we are going to set up our dataset.
  • We start by downloading and taking a look at the Cement dataset:
  • Inspecting the dataset

    import pandas as pd
    import numpy as np
    data_df = pd.read_csv("https://raw.githubusercontent.com/gagolews/"
                          "ordinal_regression_data/master/cement_strength.csv")
    data_df["response"] = data_df["response"]-1  # labels should start at 0
    data_labels = data_df["response"]
    data_features = data_df.loc[:, [
        "V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8"]]
    data_df.head()
    
    print('Number of features:', data_features.shape[1])
    print('Number of examples:', data_features.shape[0])
    print('Labels:', np.unique(data_labels.values))
    print('Label distribution:', np.bincount(data_labels))
    
    Number of features: 8
    Number of examples: 998
    Labels: [0 1 2 3 4]
    Label distribution: [196 310 244 152  96]
    
  • Above, we can see that the dataset consists of 8 features, and there are 998 examples in total.
  • The labels are in range from 1 (weakest) to 5 (strongest), and we normalize them to start at zero (hence, the normalized labels are in the range 0 to 4).
  • Notice also that the dataset is quite imbalanced.
  • Performance baseline

  • Especially for imbalanced datasets, it's quite useful to compute a performance baseline.
  • In classification contexts, a useful baseline is to compute the accuracy for a scenario where the model always predicts the majority class -- you want your model to be better than that!
  • Note that if you are intersted in a single number that minimized the dataset mean squared error (MSE), that's the mean; similary, the median is a number that minimzes the mean absolute error (MAE).
  • So, if we use the mean absolute error, , to evaluate the model, it is useful to compute the MAE pretending the predicted label is always the median:
  • avg_prediction = np.median(data_labels.values)  # median minimizes MAE
    baseline_mae = np.mean(np.abs(data_labels.values - avg_prediction))
    print(f'Baseline MAE: {baseline_mae:.2f}')
    
    Baseline MAE: 1.03
    
  • In other words, a model that would always predict the dataset median would achieve a MAE of 1.03. A model that has an MAE of > 1 is certainly a bad model.
  • Creating a Dataset class

  • Next, let us set up a data loading mechanism for our model.
  • Note that the Cement dataset is a relatively small dataset that fits into memory quite comfortably so this may seem like overkill. However, the following steps are useful as a template since you can use those for arbitrarily-sized datatsets.
  • First, we define a PyTorch Dataset class that returns the features (inputs) and labels:
  • from torch.utils.data import Dataset
    class MyDataset(Dataset):
        def __init__(self, feature_array, label_array, dtype=np.float32):
            self.features = feature_array.astype(dtype)
            self.labels = label_array
        def __getitem__(self, index):
            inputs = self.features[index]
            label = self.labels[index]
            return inputs, label
        def __len__(self):
            return self.features.shape[0]
    

    Setting up a DataModule

  • There are three main ways we can prepare the dataset for Lightning. We can
  • make the dataset part of the model;
  • set up the data loaders as usual and feed them to the fit method of a Lightning Trainer -- the Trainer is introduced in the next subsection;
  • create a LightningDataModule.
  • Here, we are going to use approach 3, which is the most organized approach. The LightningDataModule consists of several self-explanatory methods as we can see below:
  • import os
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    from torch.utils.data import DataLoader
    class DataModule(pl.LightningDataModule):
        def __init__(self, data_path='./'):
            super().__init__()
            self.data_path = data_path
        def prepare_data(self):
            data_df = pd.read_csv(
                'https://raw.githubusercontent.com/gagolews/'
                'ordinal_regression_data/master/cement_strength.csv')
            data_df.to_csv(
                os.path.join(self.data_path, 'cement_strength.csv'), index=None)
            return
        def setup(self, stage=None):
            data_df = pd.read_csv(
                os.path.join(self.data_path, 'cement_strength.csv'))
            data_df["response"] = data_df["response"]-1  # labels should start at 0
            self.data_labels = data_df["response"]
            self.data_features = data_df.loc[:, [
                "V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8"]]
            # Split into
            # 70% train, 10% validation, 20% testing
            X_temp, X_test, y_temp, y_test = train_test_split(
                self.data_features.values,
                self.data_labels.values,
                test_size=0.2,
                random_state=1,
                stratify=self.data_labels.values)
            X_train, X_valid, y_train, y_valid = train_test_split(
                X_temp,
                y_temp,
                test_size=0.1,
                random_state=1,
                stratify=y_temp)
            # Standardize features
            sc = StandardScaler()
            X_train_std = sc.fit_transform(X_train)
            X_valid_std = sc.transform(X_valid)
            X_test_std = sc.transform(X_test)
            self.train = MyDataset(X_train_std, y_train)
            self.valid = MyDataset(X_valid_std, y_valid)
            self.test = MyDataset(X_test_std, y_test)
        def train_dataloader(self):
            return DataLoader(self.train, batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS,
                              drop_last=True)
        def val_dataloader(self):
            return DataLoader(self.valid, batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS)
        def test_dataloader(self):
            return DataLoader(self.test, batch_size=BATCH_SIZE,
                              num_workers=NUM_WORKERS)
    
  • Note that the prepare_data method is usually used for steps that only need to be executed once, for example, downloading the dataset; the setup method defines the the dataset loading -- if you run your code in a distributed setting, this will be called on each node / GPU.
  • Next, lets initialize the DataModule; we use a random seed for reproducibility (so that the data set is shuffled the same way when we re-execute this code):
  • torch.manual_seed(1) 
    data_module = DataModule(data_path=DATA_BASEPATH)
    

    Training the model using the PyTorch Lightning Trainer class

  • Next, we initialize our multilayer perceptron model (here, a 2-layer MLP with 24 units in the first hidden layer, and 16 units in the second hidden layer).
  • We wrap the model in our LightningMLP so that we can use PyTorch Lightning's powerful Trainer API.
  • Also, we define a callback so that we can obtain the model with the best validation set performance after training.
  • Note PyTorch Lightning offers many advanced logging services like Weights & Biases. However, here, we will keep things simple and use the CSVLogger:
  • from pytorch_lightning.callbacks import ModelCheckpoint
    from pytorch_lightning.loggers import CSVLogger
    pytorch_model = MultiLayerPerceptron(
        input_size=data_features.shape[1],
        hidden_units=(24, 16),
        num_classes=np.bincount(data_labels).shape[0])
    lightning_model = LightningMLP(
        model=pytorch_model,
        learning_rate=LEARNING_RATE)
    callbacks = [ModelCheckpoint(
        save_top_k=1, mode="min", monitor="valid_mae")]  # save top 1 model 
    logger = CSVLogger(save_dir="logs/", name="mlp-coral-cement")
    
  • Now it's time to train our model:
  • import time
    trainer = pl.Trainer(
        max_epochs=NUM_EPOCHS,
        callbacks=callbacks,
        progress_bar_refresh_rate=50,  # recommended for notebooks
        accelerator="auto",  # Uses GPUs or TPUs if available
        devices="auto",  # Uses all available GPUs/TPUs if applicable
        logger=logger,
        deterministic=True,
        log_every_n_steps=10)
    start_time = time.time()
    trainer.fit(model=lightning_model, datamodule=data_module)
    runtime = (time.time() - start_time)/60
    print(f"Training took {runtime:.2f} min in total.")
    
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
      | Name      | Type                 | Params
    ---------------------------------------------------
    0 | model     | MultiLayerPerceptron | 636   
    1 | train_mae | MeanAbsoluteError    | 0     
    2 | valid_mae | MeanAbsoluteError    | 0     
    3 | test_mae  | MeanAbsoluteError    | 0     
    ---------------------------------------------------
    636       Trainable params
    0         Non-trainable params
    636       Total params
    0.003     Total estimated model params size (MB)
    Validation sanity check: 0it [00:00, ?it/s]
    /home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      rank_zero_warn(
    /home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      rank_zero_warn(
    Training: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Validating: 0it [00:00, ?it/s]
    Training took 0.94 min in total.
    

    Evaluating the model

  • After training, let's plot our training MAE and validation MAE using pandas, which, in turn, uses matplotlib for plotting (you may want to consider a more advanced logger that does that for you):
  • metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
    aggreg_metrics = []
    agg_col = "epoch"
    for i, dfg in metrics.groupby(agg_col):
        agg = dict(dfg.mean())
        agg[agg_col] = i
        aggreg_metrics.append(agg)
    df_metrics = pd.DataFrame(aggreg_metrics)
    df_metrics[["train_loss", "valid_loss"]].plot(
        grid=True, legend=True, xlabel='Epoch', ylabel='Loss')
    df_metrics[["train_mae", "valid_mae"]].plot(
        grid=True, legend=True, xlabel='Epoch', ylabel='MAE')
    
    <AxesSubplot:xlabel='Epoch', ylabel='MAE'>
    
  • As we can see from the loss plot above, the model starts overfitting pretty quickly; however the validation set MAE keeps improving. Based on the MAE plot, we can see that the best model, based on the validation set MAE, may be around epoch 110.
  • The trainer saved this model automatically for us, we which we can load from the checkpoint via the ckpt_path='best' argument; below we use the trainer instance to evaluate the best model on the test set:
  • trainer.test(model=lightning_model, datamodule=data_module, ckpt_path='best')
    
    Restoring states from the checkpoint path at logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    Loaded model weights from checkpoint at logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
    /home/jovyan/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py:132: UserWarning: The dataloader, test_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      rank_zero_warn(
    Testing: 0it [00:00, ?it/s]
    --------------------------------------------------------------------------------
    DATALOADER:0 TEST RESULTS
    {'test_mae': 0.25}
    --------------------------------------------------------------------------------
    
  • The MAE of our model is quite good, especially compared to the 1.03 MAE baseline earlier.
  • Predicting labels of new data

  • You can use the trainer.predict method on a new DataLoader or DataModule to apply the model to new data.
  • Alternatively, you can also manually load the best model from a checkpoint as shown below:
  • path = trainer.checkpoint_callback.best_model_path
    print(path)
    
    logs/mlp-coral-cement/version_3/checkpoints/epoch=114-step=2529.ckpt
    
    lightning_model = LightningMLP.load_from_checkpoint(
        path, model=pytorch_model)
    lightning_model.eval();
    
  • Note that our MultilayerPerceptron, which is passed to LightningMLP requires input arguments. However, this is automatically being taken care of since we used self.save_hyperparameters() in LightningMLP's __init__ method.
  • Now, below is an example applying the model manually. Here, pretend that the test_dataloader is a new data loader.
  • test_dataloader = data_module.test_dataloader()
    all_predicted_labels = []
    for batch in test_dataloader:
        features, _ = batch
        logits = lightning_model(features)
        probas = torch.sigmoid(logits)
        predicted_labels = proba_to_label(probas)
        all_predicted_labels.append(predicted_labels)
    all_predicted_labels = torch.cat(all_predicted_labels)
    all_predicted_labels[:5]
    
    tensor([0, 4, 0, 3, 1])