The pseudo label approach works by training a supervised model on the source domain and then using the model to predict labels for the target domain.
The target domain is then combined with the source domain and the model is retrained on the combined dataset.
This process is repeated until the model converges.
Here, we will train a model of FD003 of the CMAPSS dataset and pseudo label the FD001 dataset.
feature_extractor = rul_adapt.model.CnnExtractor(
14, [32, 16, 8], 30, fc_units=64
regressor = rul_adapt.model.FullyConnectedHead(
64, [1], act_func_on_last_layer=False
feature_extractor = rul_adapt.model.CnnExtractor(
14, [32, 16, 8], 30, fc_units=64
regressor = rul_adapt.model.FullyConnectedHead(
64, [1], act_func_on_last_layer=False
fd3 = rul_datasets.CmapssReader(fd=3)
dm_labeled = rul_datasets.RulDataModule(fd3, batch_size=128)
fd3 = rul_datasets.CmapssReader(fd=3)
dm_labeled = rul_datasets.RulDataModule(fd3, batch_size=128)
approach = rul_adapt.approach.SupervisedApproach(
lr=0.001, loss_type="rmse", optim_type="adam"
approach.set_model(feature_extractor, regressor)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(approach, dm_labeled)
trainer.validate(approach, dm_labeled)
approach = rul_adapt.approach.SupervisedApproach(
lr=0.001, loss_type="rmse", optim_type="adam"
approach.set_model(feature_extractor, regressor)
trainer = pl.Trainer(max_epochs=10)
trainer.fit(approach, dm_labeled)
trainer.validate(approach, dm_labeled)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
| Name | Type | Params
----------------------------------------------------------
0 | train_loss | MeanSquaredError | 0
1 | val_loss | MeanSquaredError | 0
2 | test_loss | MeanSquaredError | 0
3 | evaluator | AdaptionEvaluator | 0
4 | _feature_extractor | CnnExtractor | 15.7 K
5 | _regressor | FullyConnectedHead | 65
----------------------------------------------------------
15.7 K Trainable params
0 Non-trainable params
15.7 K Total params
0.063 Total estimated model params size (MB)
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Validate metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
val/loss 14.083422660827637
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
fd1 = rul_datasets.CmapssReader(fd=1, percent_broken=0.8)
dm_unlabeled = rul_datasets.RulDataModule(fd1, batch_size=128)
fd1 = rul_datasets.CmapssReader(fd=1, percent_broken=0.8)
dm_unlabeled = rul_datasets.RulDataModule(fd1, batch_size=128)
The pseudo label is generated for the last time step of each sequence.
They may be implausible, e.g. less than zero, in the early iterations and need to be clipped.
When patching the data module with the pseudo labels, a suitable RUL values for each sequence are created.
pseudo_labels = rul_adapt.approach.generate_pseudo_labels(dm_unlabeled, approach)
pseudo_labels = [max(0, pl) for pl in pseudo_labels]
rul_adapt.approach.patch_pseudo_labels(dm_unlabeled, pseudo_labels)
pseudo_labels = rul_adapt.approach.generate_pseudo_labels(dm_unlabeled, approach)
pseudo_labels = [max(0, pl) for pl in pseudo_labels]
rul_adapt.approach.patch_pseudo_labels(dm_unlabeled, pseudo_labels)
/home/tilman/Programming/rul-adapt/rul_adapt/approach/pseudo_labels.py:88: UserWarning: At least one of the generated pseudo labels is negative. Please consider clipping them to zero.
warnings.warn(
trainer = pl.Trainer(max_epochs=10)
trainer.validate(approach, dm_unlabeled)
trainer = pl.Trainer(max_epochs=10)
trainer.validate(approach, dm_unlabeled)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Validate metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
val/loss 36.179779052734375
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Afterward, we combine FD003 and the pseudo labeled FD001 and train the approach for another 10 epochs.
We can observe that the validation loss decreased significantly.
The pseudo labeling and training can now be repeated with the new model until the validation loss converges.
combined_train_data = torch.utils.data.ConcatDataset(
[dm_labeled.to_dataset("dev"), dm_unlabeled.to_dataset("dev")]
combined_train_dl = torch.utils.data.DataLoader(
combined_train_data, batch_size=128, shuffle=True
trainer.fit(approach, train_dataloaders=combined_train_dl)
trainer.validate(approach, dm_unlabeled)
combined_train_data = torch.utils.data.ConcatDataset(
[dm_labeled.to_dataset("dev"), dm_unlabeled.to_dataset("dev")]
combined_train_dl = torch.utils.data.DataLoader(
combined_train_data, batch_size=128, shuffle=True
trainer.fit(approach, train_dataloaders=combined_train_dl)
trainer.validate(approach, dm_unlabeled)
/home/tilman/Programming/rul-adapt/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/configuration_validator.py:108: PossibleUserWarning: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
rank_zero_warn(
| Name | Type | Params
----------------------------------------------------------
0 | train_loss | MeanSquaredError | 0
1 | val_loss | MeanSquaredError | 0
2 | test_loss | MeanSquaredError | 0
3 | evaluator | AdaptionEvaluator | 0
4 | _feature_extractor | CnnExtractor | 15.7 K
5 | _regressor | FullyConnectedHead | 65
----------------------------------------------------------
15.7 K Trainable params
0 Non-trainable params
15.7 K Total params
0.063 Total estimated model params size (MB)
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Validate metric DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
val/loss 29.42894172668457
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────