# Define the model and the training and test steps
# The model uses convolutional neural network layers
class LitMNIST(pl.LightningModule):
def __init__(self):
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
preds = torch.argmax(y_hat, dim=1)
accuracy = (preds == y).float().mean()
self.log('val_loss', loss)
self.log('val_accuracy', accuracy)
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
dataset = MNIST('../data', train=True, download=True, transform=transform)
train_dataset, val_dataset = random_split(dataset, [55000, 5000])
train_loader = DataLoader(train_dataset, batch_size=64, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=num_workers)
model = LitMNIST()
trainer = pl.Trainer(max_epochs=num_epochs)
trainer.fit(model, train_loader, val_loader)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
| Name | Type | Params
0 | conv1 | Conv2d | 320
1 | conv2 | Conv2d | 18.5 K
2 | dropout1 | Dropout | 0
3 | dropout2 | Dropout | 0
4 | fc1 | Linear | 1.2 M
5 | fc2 | Linear | 1.3 K
6 | loss_fn | CrossEntropyLoss | 0
1.2 M Trainable params
0 Non-trainable params
1.2 M Total params
4.800 Total estimated model params size (MB)
┃ Validate metric ┃ DataLoader 0 ┃
│ val_accuracy │ 0.9901999831199646 │
│ val_loss │ 0.03850765526294708 │