添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

MNIST Pytorch Lightning

This is a "Hello world!" example with PyTorch Lightning.
It trains a convolutional neural network on the MNIST dataset.

Credits:
MNIST dataset, see http://yann.lecun.com/exdb/mnist/
Code adapted from the documentation of the Lightning-AI and PyTorch projects

from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader , random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl
# Define the model and the training and test steps
# The model uses convolutional neural network layers
class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss)
        return loss
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        accuracy = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('val_accuracy', accuracy)
torch.manual_seed(1)
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST('../data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)
model = LitMNIST()
trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader, val_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  | Name     | Type             | Params
----------------------------------------------
0 | conv1    | Conv2d           | 320   
1 | conv2    | Conv2d           | 18.5 K
2 | dropout1 | Dropout          | 0     
3 | dropout2 | Dropout          | 0     
4 | fc1      | Linear           | 1.2 M 
5 | fc2      | Linear           | 1.3 K 
6 | loss_fn  | CrossEntropyLoss | 0     
----------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.800     Total estimated model params size (MB)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       val_accuracy            0.9901999831199646     │
│         val_loss              0.03850765526294708    │
└───────────────────────────┴───────────────────────────┘