[10]:
import os
# save directory for ensemble state. Replace it with your own choice.
save_dir = 'models/moe2'
nfree_experts= 3000 # <- number of free experts
lookback_len=20
max_forecast_steps=3
target_seq_index = 0
use_gpu=False
## Pytorch network hyper-params. These are the also the hyper-params that are used in case moe_model=None is passed to MoE_ForecasterEnsemble.
hidden_dim=256
dim_head = 2
mlp_dim=256
dim_dropout=0. # if data is multi-dimensionsal, this can be set to a non-zero value to allow model to handle missing dimensions during test time
time_step_dropout=0
## Pytorch network hyper-params
config_ensemble = MoE_ForecasterEnsembleConfig(
batch_size=64, lr=0.0001, nfree_experts=nfree_experts, epoch_max=300,
lookback_len=lookback_len, max_forecast_steps=max_forecast_steps,
target_seq_index=target_seq_index, use_gpu=use_gpu,
transform=TemporalResample())
train_config_ensemble = EnsembleTrainConfig(valid_frac=0.5)
# Define expert models
models = [] # <- no external experts provided
nexperts = len(models)
# instantiate deep network for MoE
moe_model = TransformerModel(input_dim=len(train_data.names), lookback_len=lookback_len, nexperts=nexperts,\
output_dim=max_forecast_steps, nfree_experts=nfree_experts,\
hid_dim=hidden_dim, dim_head = dim_head, mlp_dim=mlp_dim,\
pool='cls', dim_dropout=dim_dropout,\
time_step_dropout=time_step_dropout)
moe_model = None # use me if you want to see the default model in use
# create MoE forecaster model
ensemble = MoE_ForecasterEnsemble(config=config_ensemble, models= models, moe_model=moe_model)
# train MoE
loss_list = ensemble.train(train_data=train_data, train_config = train_config_ensemble)
ensemble.save(save_dir)
Epoch 1 Loss: 8.253956: 100%|██████████| 11/11 [00:03<00:00, 3.00it/s]
Epoch 2 Loss: 8.215341: 100%|██████████| 11/11 [00:03<00:00, 3.09it/s]
Epoch 3 Loss: 8.140997: 100%|██████████| 11/11 [00:03<00:00, 3.37it/s]
Epoch 4 Loss: 8.030143: 100%|██████████| 11/11 [00:02<00:00, 3.70it/s]
Epoch 5 Loss: 7.880345: 100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Epoch 6 Loss: 7.679686: 100%|██████████| 11/11 [00:03<00:00, 3.23it/s]
Epoch 7 Loss: 7.431627: 100%|██████████| 11/11 [00:03<00:00, 3.65it/s]
Epoch 8 Loss: 7.121155: 100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Epoch 9 Loss: 6.788598: 100%|██████████| 11/11 [00:03<00:00, 3.32it/s]
Epoch 10 Loss: 6.469567: 100%|██████████| 11/11 [00:03<00:00, 2.98it/s]
Epoch 11 Loss: 6.230416: 100%|██████████| 11/11 [00:03<00:00, 3.48it/s]
Epoch 12 Loss: 6.048095: 100%|██████████| 11/11 [00:03<00:00, 3.60it/s]
Epoch 13 Loss: 5.915649: 100%|██████████| 11/11 [00:03<00:00, 2.79it/s]
Epoch 14 Loss: 5.811184: 100%|██████████| 11/11 [00:04<00:00, 2.69it/s]
Epoch 15 Loss: 5.744193: 100%|██████████| 11/11 [00:03<00:00, 3.31it/s]
Epoch 16 Loss: 5.686575: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 17 Loss: 5.657169: 100%|██████████| 11/11 [00:03<00:00, 3.31it/s]
Epoch 18 Loss: 5.635586: 100%|██████████| 11/11 [00:03<00:00, 3.36it/s]
Epoch 19 Loss: 5.627381: 100%|██████████| 11/11 [00:03<00:00, 3.62it/s]
Epoch 20 Loss: 5.592227: 100%|██████████| 11/11 [00:03<00:00, 3.46it/s]
Epoch 21 Loss: 5.565175: 100%|██████████| 11/11 [00:03<00:00, 3.18it/s]
Epoch 22 Loss: 5.561776: 100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Epoch 23 Loss: 5.541681: 100%|██████████| 11/11 [00:03<00:00, 3.61it/s]
Epoch 24 Loss: 5.545103: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 25 Loss: 5.522188: 100%|██████████| 11/11 [00:03<00:00, 3.21it/s]
Epoch 26 Loss: 5.508581: 100%|██████████| 11/11 [00:03<00:00, 3.52it/s]
Epoch 27 Loss: 5.489652: 100%|██████████| 11/11 [00:03<00:00, 3.46it/s]
Epoch 28 Loss: 5.476248: 100%|██████████| 11/11 [00:04<00:00, 2.51it/s]
Epoch 29 Loss: 5.466399: 100%|██████████| 11/11 [00:03<00:00, 3.08it/s]
Epoch 30 Loss: 5.475472: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 31 Loss: 5.461387: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 32 Loss: 5.433480: 100%|██████████| 11/11 [00:03<00:00, 3.43it/s]
Epoch 33 Loss: 5.417164: 100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Epoch 34 Loss: 5.391122: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 35 Loss: 5.355259: 100%|██████████| 11/11 [00:03<00:00, 3.24it/s]
Epoch 36 Loss: 5.324664: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 37 Loss: 5.289907: 100%|██████████| 11/11 [00:03<00:00, 3.53it/s]
Epoch 38 Loss: 5.252770: 100%|██████████| 11/11 [00:03<00:00, 3.27it/s]
Epoch 39 Loss: 5.224870: 100%|██████████| 11/11 [00:03<00:00, 3.06it/s]
Epoch 40 Loss: 5.189947: 100%|██████████| 11/11 [00:03<00:00, 3.58it/s]
Epoch 41 Loss: 5.141784: 100%|██████████| 11/11 [00:03<00:00, 3.42it/s]
Epoch 42 Loss: 5.089564: 100%|██████████| 11/11 [00:03<00:00, 2.83it/s]
Epoch 43 Loss: 5.053859: 100%|██████████| 11/11 [00:03<00:00, 2.92it/s]
Epoch 44 Loss: 5.017187: 100%|██████████| 11/11 [00:03<00:00, 3.16it/s]
Epoch 45 Loss: 4.980298: 100%|██████████| 11/11 [00:03<00:00, 2.88it/s]
Epoch 46 Loss: 4.946596: 100%|██████████| 11/11 [00:03<00:00, 2.99it/s]
Epoch 47 Loss: 4.913864: 100%|██████████| 11/11 [00:03<00:00, 3.65it/s]
Epoch 48 Loss: 4.886967: 100%|██████████| 11/11 [00:03<00:00, 3.64it/s]
Epoch 49 Loss: 4.860845: 100%|██████████| 11/11 [00:03<00:00, 3.38it/s]
Epoch 50 Loss: 4.831104: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 51 Loss: 4.816362: 100%|██████████| 11/11 [00:03<00:00, 3.58it/s]
Epoch 52 Loss: 4.779174: 100%|██████████| 11/11 [00:03<00:00, 3.39it/s]
Epoch 53 Loss: 4.736116: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 54 Loss: 4.717974: 100%|██████████| 11/11 [00:03<00:00, 3.40it/s]
Epoch 55 Loss: 4.690528: 100%|██████████| 11/11 [00:03<00:00, 3.47it/s]
Epoch 56 Loss: 4.657373: 100%|██████████| 11/11 [00:04<00:00, 2.68it/s]
Epoch 57 Loss: 4.630430: 100%|██████████| 11/11 [00:04<00:00, 2.64it/s]
Epoch 58 Loss: 4.626685: 100%|██████████| 11/11 [00:03<00:00, 3.33it/s]
Epoch 59 Loss: 4.576031: 100%|██████████| 11/11 [00:03<00:00, 3.12it/s]
Epoch 60 Loss: 4.536126: 100%|██████████| 11/11 [00:03<00:00, 3.28it/s]
Epoch 61 Loss: 4.522658: 100%|██████████| 11/11 [00:03<00:00, 3.41it/s]
Epoch 62 Loss: 4.482119: 100%|██████████| 11/11 [00:02<00:00, 3.70it/s]
Epoch 63 Loss: 4.440992: 100%|██████████| 11/11 [00:03<00:00, 3.66it/s]
Epoch 64 Loss: 4.420945: 100%|██████████| 11/11 [00:03<00:00, 3.26it/s]
Epoch 65 Loss: 4.380464: 100%|██████████| 11/11 [00:03<00:00, 3.59it/s]
Epoch 66 Loss: 4.354569: 100%|██████████| 11/11 [00:03<00:00, 3.41it/s]
Epoch 67 Loss: 4.330108: 100%|██████████| 11/11 [00:03<00:00, 3.27it/s]
Epoch 68 Loss: 4.302790: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 69 Loss: 4.276388: 100%|██████████| 11/11 [00:03<00:00, 3.63it/s]
Epoch 70 Loss: 4.252587: 100%|██████████| 11/11 [00:03<00:00, 3.08it/s]
Epoch 71 Loss: 4.220399: 100%|██████████| 11/11 [00:03<00:00, 2.80it/s]
Epoch 72 Loss: 4.197742: 100%|██████████| 11/11 [00:03<00:00, 2.84it/s]
Epoch 73 Loss: 4.184703: 100%|██████████| 11/11 [00:03<00:00, 3.26it/s]
Epoch 74 Loss: 4.177047: 100%|██████████| 11/11 [00:03<00:00, 3.37it/s]
Epoch 75 Loss: 4.132162: 100%|██████████| 11/11 [00:03<00:00, 3.12it/s]
Epoch 76 Loss: 4.098515: 100%|██████████| 11/11 [00:03<00:00, 3.65it/s]
Epoch 77 Loss: 4.085912: 100%|██████████| 11/11 [00:03<00:00, 3.63it/s]
Epoch 78 Loss: 4.072778: 100%|██████████| 11/11 [00:03<00:00, 3.48it/s]
Epoch 79 Loss: 4.026296: 100%|██████████| 11/11 [00:03<00:00, 3.13it/s]
Epoch 80 Loss: 4.000065: 100%|██████████| 11/11 [00:03<00:00, 3.65it/s]
Epoch 81 Loss: 3.971918: 100%|██████████| 11/11 [00:03<00:00, 3.48it/s]
Epoch 82 Loss: 3.954153: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 83 Loss: 3.924814: 100%|██████████| 11/11 [00:03<00:00, 3.21it/s]
Epoch 84 Loss: 3.895262: 100%|██████████| 11/11 [00:03<00:00, 3.16it/s]
Epoch 85 Loss: 3.874157: 100%|██████████| 11/11 [00:03<00:00, 3.06it/s]
Epoch 86 Loss: 3.862910: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
Epoch 87 Loss: 3.845997: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 88 Loss: 3.811604: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 89 Loss: 3.794181: 100%|██████████| 11/11 [00:03<00:00, 3.51it/s]
Epoch 90 Loss: 3.760030: 100%|██████████| 11/11 [00:03<00:00, 3.47it/s]
Epoch 91 Loss: 3.738671: 100%|██████████| 11/11 [00:03<00:00, 3.12it/s]
Epoch 92 Loss: 3.720268: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 93 Loss: 3.713584: 100%|██████████| 11/11 [00:02<00:00, 3.70it/s]
Epoch 94 Loss: 3.671633: 100%|██████████| 11/11 [00:03<00:00, 3.62it/s]
Epoch 95 Loss: 3.651652: 100%|██████████| 11/11 [00:03<00:00, 3.31it/s]
Epoch 96 Loss: 3.639432: 100%|██████████| 11/11 [00:03<00:00, 3.07it/s]
Epoch 97 Loss: 3.607364: 100%|██████████| 11/11 [00:03<00:00, 3.26it/s]
Epoch 98 Loss: 3.583107: 100%|██████████| 11/11 [00:03<00:00, 3.33it/s]
Epoch 99 Loss: 3.562718: 100%|██████████| 11/11 [00:03<00:00, 3.03it/s]
Epoch 100 Loss: 3.548420: 100%|██████████| 11/11 [00:03<00:00, 2.93it/s]
Epoch 101 Loss: 3.531171: 100%|██████████| 11/11 [00:04<00:00, 2.66it/s]
Epoch 102 Loss: 3.509029: 100%|██████████| 11/11 [00:03<00:00, 3.47it/s]
Epoch 103 Loss: 3.485391: 100%|██████████| 11/11 [00:03<00:00, 3.54it/s]
Epoch 104 Loss: 3.450888: 100%|██████████| 11/11 [00:03<00:00, 3.43it/s]
Epoch 105 Loss: 3.424491: 100%|██████████| 11/11 [00:03<00:00, 3.10it/s]
Epoch 106 Loss: 3.403693: 100%|██████████| 11/11 [00:02<00:00, 3.74it/s]
Epoch 107 Loss: 3.390665: 100%|██████████| 11/11 [00:03<00:00, 3.58it/s]
Epoch 108 Loss: 3.359253: 100%|██████████| 11/11 [00:03<00:00, 3.43it/s]
Epoch 109 Loss: 3.345198: 100%|██████████| 11/11 [00:03<00:00, 2.97it/s]
Epoch 110 Loss: 3.349600: 100%|██████████| 11/11 [00:03<00:00, 3.52it/s]
Epoch 111 Loss: 3.329763: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 112 Loss: 3.303129: 100%|██████████| 11/11 [00:03<00:00, 3.32it/s]
Epoch 113 Loss: 3.266704: 100%|██████████| 11/11 [00:03<00:00, 2.82it/s]
Epoch 114 Loss: 3.240781: 100%|██████████| 11/11 [00:03<00:00, 2.94it/s]
Epoch 115 Loss: 3.215163: 100%|██████████| 11/11 [00:03<00:00, 3.25it/s]
Epoch 116 Loss: 3.198615: 100%|██████████| 11/11 [00:03<00:00, 2.92it/s]
Epoch 117 Loss: 3.169775: 100%|██████████| 11/11 [00:03<00:00, 3.29it/s]
Epoch 118 Loss: 3.152784: 100%|██████████| 11/11 [00:03<00:00, 3.47it/s]
Epoch 119 Loss: 3.128009: 100%|██████████| 11/11 [00:03<00:00, 3.60it/s]
Epoch 120 Loss: 3.106502: 100%|██████████| 11/11 [00:03<00:00, 3.67it/s]
Epoch 121 Loss: 3.093420: 100%|██████████| 11/11 [00:03<00:00, 3.51it/s]
Epoch 122 Loss: 3.078810: 100%|██████████| 11/11 [00:03<00:00, 3.31it/s]
Epoch 123 Loss: 3.050844: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 124 Loss: 3.052847: 100%|██████████| 11/11 [00:03<00:00, 3.52it/s]
Epoch 125 Loss: 3.025279: 100%|██████████| 11/11 [00:03<00:00, 3.37it/s]
Epoch 126 Loss: 3.009476: 100%|██████████| 11/11 [00:03<00:00, 3.38it/s]
Epoch 127 Loss: 2.985609: 100%|██████████| 11/11 [00:04<00:00, 2.57it/s]
Epoch 128 Loss: 2.966175: 100%|██████████| 11/11 [00:03<00:00, 2.95it/s]
Epoch 129 Loss: 2.948951: 100%|██████████| 11/11 [00:03<00:00, 3.23it/s]
Epoch 130 Loss: 2.939453: 100%|██████████| 11/11 [00:03<00:00, 3.16it/s]
Epoch 131 Loss: 2.932430: 100%|██████████| 11/11 [00:03<00:00, 3.57it/s]
Epoch 132 Loss: 2.893917: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 133 Loss: 2.855756: 100%|██████████| 11/11 [00:03<00:00, 3.59it/s]
Epoch 134 Loss: 2.839368: 100%|██████████| 11/11 [00:02<00:00, 3.70it/s]
Epoch 135 Loss: 2.817158: 100%|██████████| 11/11 [00:03<00:00, 3.30it/s]
Epoch 136 Loss: 2.791486: 100%|██████████| 11/11 [00:03<00:00, 3.09it/s]
Epoch 137 Loss: 2.780610: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 138 Loss: 2.761558: 100%|██████████| 11/11 [00:03<00:00, 3.42it/s]
Epoch 139 Loss: 2.750291: 100%|██████████| 11/11 [00:03<00:00, 3.31it/s]
Epoch 140 Loss: 2.730247: 100%|██████████| 11/11 [00:03<00:00, 3.28it/s]
Epoch 141 Loss: 2.698439: 100%|██████████| 11/11 [00:04<00:00, 2.67it/s]
Epoch 142 Loss: 2.685364: 100%|██████████| 11/11 [00:03<00:00, 3.15it/s]
Epoch 143 Loss: 2.662411: 100%|██████████| 11/11 [00:03<00:00, 2.98it/s]
Epoch 144 Loss: 2.658126: 100%|██████████| 11/11 [00:03<00:00, 2.92it/s]
Epoch 145 Loss: 2.632239: 100%|██████████| 11/11 [00:03<00:00, 3.14it/s]
Epoch 146 Loss: 2.610601: 100%|██████████| 11/11 [00:03<00:00, 3.34it/s]
Epoch 147 Loss: 2.583837: 100%|██████████| 11/11 [00:03<00:00, 3.65it/s]
Epoch 148 Loss: 2.556195: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 149 Loss: 2.547484: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 150 Loss: 2.525524: 100%|██████████| 11/11 [00:02<00:00, 3.67it/s]
Epoch 151 Loss: 2.514556: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 152 Loss: 2.501829: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 153 Loss: 2.478203: 100%|██████████| 11/11 [00:03<00:00, 3.21it/s]
Epoch 154 Loss: 2.457710: 100%|██████████| 11/11 [00:03<00:00, 3.10it/s]
Epoch 155 Loss: 2.436375: 100%|██████████| 11/11 [00:03<00:00, 3.06it/s]
Epoch 156 Loss: 2.412483: 100%|██████████| 11/11 [00:03<00:00, 2.77it/s]
Epoch 157 Loss: 2.394380: 100%|██████████| 11/11 [00:03<00:00, 3.07it/s]
Epoch 158 Loss: 2.380202: 100%|██████████| 11/11 [00:03<00:00, 3.58it/s]
Epoch 159 Loss: 2.355152: 100%|██████████| 11/11 [00:03<00:00, 3.50it/s]
Epoch 160 Loss: 2.347697: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 161 Loss: 2.345931: 100%|██████████| 11/11 [00:03<00:00, 3.60it/s]
Epoch 162 Loss: 2.296478: 100%|██████████| 11/11 [00:03<00:00, 3.64it/s]
Epoch 163 Loss: 2.282725: 100%|██████████| 11/11 [00:03<00:00, 3.58it/s]
Epoch 164 Loss: 2.249418: 100%|██████████| 11/11 [00:03<00:00, 3.25it/s]
Epoch 165 Loss: 2.241911: 100%|██████████| 11/11 [00:03<00:00, 3.60it/s]
Epoch 166 Loss: 2.215718: 100%|██████████| 11/11 [00:03<00:00, 3.45it/s]
Epoch 167 Loss: 2.194424: 100%|██████████| 11/11 [00:03<00:00, 3.08it/s]
Epoch 168 Loss: 2.184618: 100%|██████████| 11/11 [00:03<00:00, 3.13it/s]
Epoch 169 Loss: 2.156022: 100%|██████████| 11/11 [00:03<00:00, 3.27it/s]
Epoch 170 Loss: 2.152164: 100%|██████████| 11/11 [00:03<00:00, 3.05it/s]
Epoch 171 Loss: 2.128624: 100%|██████████| 11/11 [00:03<00:00, 2.82it/s]
Epoch 172 Loss: 2.109892: 100%|██████████| 11/11 [00:03<00:00, 2.82it/s]
Epoch 173 Loss: 2.137191: 100%|██████████| 11/11 [00:02<00:00, 3.74it/s]
Epoch 174 Loss: 2.118212: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 175 Loss: 2.092953: 100%|██████████| 11/11 [00:03<00:00, 3.12it/s]
Epoch 176 Loss: 2.065081: 100%|██████████| 11/11 [00:03<00:00, 3.61it/s]
Epoch 177 Loss: 2.064701: 100%|██████████| 11/11 [00:02<00:00, 3.69it/s]
Epoch 178 Loss: 2.061531: 100%|██████████| 11/11 [00:03<00:00, 3.42it/s]
Epoch 179 Loss: 2.041025: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 180 Loss: 1.999497: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 181 Loss: 1.953928: 100%|██████████| 11/11 [00:03<00:00, 3.42it/s]
Epoch 182 Loss: 1.929946: 100%|██████████| 11/11 [00:03<00:00, 2.86it/s]
Epoch 183 Loss: 1.923069: 100%|██████████| 11/11 [00:03<00:00, 2.93it/s]
Epoch 184 Loss: 1.905399: 100%|██████████| 11/11 [00:03<00:00, 3.23it/s]
Epoch 185 Loss: 1.892242: 100%|██████████| 11/11 [00:03<00:00, 3.06it/s]
Epoch 186 Loss: 1.869357: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
Epoch 187 Loss: 1.840604: 100%|██████████| 11/11 [00:03<00:00, 3.59it/s]
Epoch 188 Loss: 1.837614: 100%|██████████| 11/11 [00:02<00:00, 3.68it/s]
Epoch 189 Loss: 1.816457: 100%|██████████| 11/11 [00:03<00:00, 3.57it/s]
Epoch 190 Loss: 1.786459: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 191 Loss: 1.776885: 100%|██████████| 11/11 [00:02<00:00, 3.79it/s]
Epoch 192 Loss: 1.766457: 100%|██████████| 11/11 [00:03<00:00, 3.61it/s]
Epoch 193 Loss: 1.742621: 100%|██████████| 11/11 [00:03<00:00, 3.39it/s]
Epoch 194 Loss: 1.723335: 100%|██████████| 11/11 [00:03<00:00, 2.79it/s]
Epoch 195 Loss: 1.708375: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 196 Loss: 1.692905: 100%|██████████| 11/11 [00:03<00:00, 3.27it/s]
Epoch 197 Loss: 1.694218: 100%|██████████| 11/11 [00:03<00:00, 2.89it/s]
Epoch 198 Loss: 1.677686: 100%|██████████| 11/11 [00:03<00:00, 2.96it/s]
Epoch 199 Loss: 1.644678: 100%|██████████| 11/11 [00:03<00:00, 3.09it/s]
Epoch 200 Loss: 1.632020: 100%|██████████| 11/11 [00:03<00:00, 3.16it/s]
Epoch 201 Loss: 1.604089: 100%|██████████| 11/11 [00:03<00:00, 3.53it/s]
Epoch 202 Loss: 1.597124: 100%|██████████| 11/11 [00:03<00:00, 3.08it/s]
Epoch 203 Loss: 1.580226: 100%|██████████| 11/11 [00:02<00:00, 3.74it/s]
Epoch 204 Loss: 1.577800: 100%|██████████| 11/11 [00:03<00:00, 3.62it/s]
Epoch 205 Loss: 1.556550: 100%|██████████| 11/11 [00:03<00:00, 3.33it/s]
Epoch 206 Loss: 1.531670: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 207 Loss: 1.524184: 100%|██████████| 11/11 [00:03<00:00, 3.55it/s]
Epoch 208 Loss: 1.504326: 100%|██████████| 11/11 [00:03<00:00, 3.43it/s]
Epoch 209 Loss: 1.495385: 100%|██████████| 11/11 [00:03<00:00, 3.06it/s]
Epoch 210 Loss: 1.477028: 100%|██████████| 11/11 [00:03<00:00, 2.92it/s]
Epoch 211 Loss: 1.466456: 100%|██████████| 11/11 [00:03<00:00, 3.16it/s]
Epoch 212 Loss: 1.437730: 100%|██████████| 11/11 [00:03<00:00, 2.97it/s]
Epoch 213 Loss: 1.439750: 100%|██████████| 11/11 [00:04<00:00, 2.70it/s]
Epoch 214 Loss: 1.427772: 100%|██████████| 11/11 [00:03<00:00, 3.19it/s]
Epoch 215 Loss: 1.417080: 100%|██████████| 11/11 [00:03<00:00, 3.54it/s]
Epoch 216 Loss: 1.403871: 100%|██████████| 11/11 [00:03<00:00, 3.40it/s]
Epoch 217 Loss: 1.374399: 100%|██████████| 11/11 [00:03<00:00, 3.29it/s]
Epoch 218 Loss: 1.362701: 100%|██████████| 11/11 [00:02<00:00, 3.83it/s]
Epoch 219 Loss: 1.348216: 100%|██████████| 11/11 [00:02<00:00, 3.76it/s]
Epoch 220 Loss: 1.337092: 100%|██████████| 11/11 [00:03<00:00, 3.51it/s]
Epoch 221 Loss: 1.331288: 100%|██████████| 11/11 [00:03<00:00, 3.04it/s]
Epoch 222 Loss: 1.322314: 100%|██████████| 11/11 [00:03<00:00, 3.54it/s]
Epoch 223 Loss: 1.303005: 100%|██████████| 11/11 [00:03<00:00, 3.46it/s]
Epoch 224 Loss: 1.308330: 100%|██████████| 11/11 [00:03<00:00, 2.86it/s]
Epoch 225 Loss: 1.307320: 100%|██████████| 11/11 [00:05<00:00, 2.17it/s]
Epoch 226 Loss: 1.259610: 100%|██████████| 11/11 [00:04<00:00, 2.59it/s]
Epoch 227 Loss: 1.247879: 100%|██████████| 11/11 [00:04<00:00, 2.44it/s]
Epoch 228 Loss: 1.232500: 100%|██████████| 11/11 [00:03<00:00, 2.84it/s]
Epoch 229 Loss: 1.223805: 100%|██████████| 11/11 [00:04<00:00, 2.71it/s]
Epoch 230 Loss: 1.210760: 100%|██████████| 11/11 [00:03<00:00, 3.23it/s]
Epoch 231 Loss: 1.193618: 100%|██████████| 11/11 [00:03<00:00, 3.09it/s]
Epoch 232 Loss: 1.186537: 100%|██████████| 11/11 [00:03<00:00, 3.09it/s]
Epoch 233 Loss: 1.159970: 100%|██████████| 11/11 [00:03<00:00, 2.82it/s]
Epoch 234 Loss: 1.156743: 100%|██████████| 11/11 [00:03<00:00, 3.41it/s]
Epoch 235 Loss: 1.146074: 100%|██████████| 11/11 [00:03<00:00, 2.96it/s]
Epoch 236 Loss: 1.151413: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 237 Loss: 1.131100: 100%|██████████| 11/11 [00:03<00:00, 3.30it/s]
Epoch 238 Loss: 1.122501: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 239 Loss: 1.099850: 100%|██████████| 11/11 [00:03<00:00, 2.81it/s]
Epoch 240 Loss: 1.086024: 100%|██████████| 11/11 [00:04<00:00, 2.57it/s]
Epoch 241 Loss: 1.072572: 100%|██████████| 11/11 [00:03<00:00, 2.93it/s]
Epoch 242 Loss: 1.068453: 100%|██████████| 11/11 [00:03<00:00, 3.37it/s]
Epoch 243 Loss: 1.044812: 100%|██████████| 11/11 [00:03<00:00, 3.52it/s]
Epoch 244 Loss: 1.030824: 100%|██████████| 11/11 [00:03<00:00, 3.23it/s]
Epoch 245 Loss: 1.019708: 100%|██████████| 11/11 [00:03<00:00, 3.34it/s]
Epoch 246 Loss: 1.011575: 100%|██████████| 11/11 [00:03<00:00, 3.59it/s]
Epoch 247 Loss: 0.990726: 100%|██████████| 11/11 [00:03<00:00, 3.46it/s]
Epoch 248 Loss: 0.986041: 100%|██████████| 11/11 [00:03<00:00, 2.91it/s]
Epoch 249 Loss: 0.981658: 100%|██████████| 11/11 [00:03<00:00, 3.21it/s]
Epoch 250 Loss: 0.980595: 100%|██████████| 11/11 [00:03<00:00, 3.37it/s]
Epoch 251 Loss: 0.968010: 100%|██████████| 11/11 [00:03<00:00, 3.13it/s]
Epoch 252 Loss: 0.962707: 100%|██████████| 11/11 [00:03<00:00, 2.80it/s]
Epoch 253 Loss: 0.942022: 100%|██████████| 11/11 [00:03<00:00, 3.11it/s]
Epoch 254 Loss: 0.944430: 100%|██████████| 11/11 [00:03<00:00, 2.89it/s]
Epoch 255 Loss: 0.929161: 100%|██████████| 11/11 [00:03<00:00, 2.88it/s]
Epoch 256 Loss: 0.937370: 100%|██████████| 11/11 [00:03<00:00, 2.90it/s]
Epoch 257 Loss: 0.903115: 100%|██████████| 11/11 [00:03<00:00, 3.54it/s]
Epoch 258 Loss: 0.883225: 100%|██████████| 11/11 [00:03<00:00, 3.25it/s]
Epoch 259 Loss: 0.879743: 100%|██████████| 11/11 [00:03<00:00, 2.87it/s]
Epoch 260 Loss: 0.868025: 100%|██████████| 11/11 [00:03<00:00, 3.35it/s]
Epoch 261 Loss: 0.871274: 100%|██████████| 11/11 [00:03<00:00, 3.52it/s]
Epoch 262 Loss: 0.845327: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
Epoch 263 Loss: 0.830346: 100%|██████████| 11/11 [00:03<00:00, 2.75it/s]
Epoch 264 Loss: 0.824749: 100%|██████████| 11/11 [00:03<00:00, 3.36it/s]
Epoch 265 Loss: 0.825509: 100%|██████████| 11/11 [00:03<00:00, 3.28it/s]
Epoch 266 Loss: 0.826433: 100%|██████████| 11/11 [00:04<00:00, 2.64it/s]
Epoch 267 Loss: 0.803352: 100%|██████████| 11/11 [00:04<00:00, 2.58it/s]
Epoch 268 Loss: 0.803057: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
Epoch 269 Loss: 0.779101: 100%|██████████| 11/11 [00:03<00:00, 3.00it/s]
Epoch 270 Loss: 0.776825: 100%|██████████| 11/11 [00:03<00:00, 3.17it/s]
Epoch 271 Loss: 0.781099: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 272 Loss: 0.785840: 100%|██████████| 11/11 [00:03<00:00, 3.28it/s]
Epoch 273 Loss: 0.802700: 100%|██████████| 11/11 [00:03<00:00, 3.29it/s]
Epoch 274 Loss: 0.772463: 100%|██████████| 11/11 [00:03<00:00, 3.00it/s]
Epoch 275 Loss: 0.764557: 100%|██████████| 11/11 [00:03<00:00, 3.41it/s]
Epoch 276 Loss: 0.756142: 100%|██████████| 11/11 [00:03<00:00, 3.40it/s]
Epoch 277 Loss: 0.749000: 100%|██████████| 11/11 [00:03<00:00, 3.18it/s]
Epoch 278 Loss: 0.727899: 100%|██████████| 11/11 [00:03<00:00, 2.91it/s]
Epoch 279 Loss: 0.736639: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 280 Loss: 0.732949: 100%|██████████| 11/11 [00:04<00:00, 2.73it/s]
Epoch 281 Loss: 0.714376: 100%|██████████| 11/11 [00:03<00:00, 2.79it/s]
Epoch 282 Loss: 0.706372: 100%|██████████| 11/11 [00:04<00:00, 2.62it/s]
Epoch 283 Loss: 0.718250: 100%|██████████| 11/11 [00:03<00:00, 3.10it/s]
Epoch 284 Loss: 0.702688: 100%|██████████| 11/11 [00:03<00:00, 3.32it/s]
Epoch 285 Loss: 0.683941: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
Epoch 286 Loss: 0.669917: 100%|██████████| 11/11 [00:03<00:00, 3.04it/s]
Epoch 287 Loss: 0.657017: 100%|██████████| 11/11 [00:03<00:00, 3.44it/s]
Epoch 288 Loss: 0.645382: 100%|██████████| 11/11 [00:03<00:00, 3.27it/s]
Epoch 289 Loss: 0.646296: 100%|██████████| 11/11 [00:03<00:00, 3.01it/s]
Epoch 290 Loss: 0.640179: 100%|██████████| 11/11 [00:03<00:00, 2.87it/s]
Epoch 291 Loss: 0.632041: 100%|██████████| 11/11 [00:03<00:00, 3.20it/s]
Epoch 292 Loss: 0.623746: 100%|██████████| 11/11 [00:03<00:00, 3.01it/s]
Epoch 293 Loss: 0.606923: 100%|██████████| 11/11 [00:04<00:00, 2.67it/s]
Epoch 294 Loss: 0.608943: 100%|██████████| 11/11 [00:04<00:00, 2.50it/s]
Epoch 295 Loss: 0.607394: 100%|██████████| 11/11 [00:03<00:00, 2.89it/s]
Epoch 296 Loss: 0.594561: 100%|██████████| 11/11 [00:04<00:00, 2.60it/s]
Epoch 297 Loss: 0.581359: 100%|██████████| 11/11 [00:04<00:00, 2.61it/s]
Epoch 298 Loss: 0.581903: 100%|██████████| 11/11 [00:03<00:00, 3.34it/s]
Epoch 299 Loss: 0.572175: 100%|██████████| 11/11 [00:03<00:00, 3.24it/s]
Epoch 300 Loss: 0.570626: 100%|██████████| 11/11 [00:03<00:00, 3.02it/s]
WARNING:merlion.models.ensemble.base:When initializing an ensemble, you must either provide the dict `model_configs` (mapping each model's name to its config) when creating the `DetectorEnsembleConfig`, or provide a list of `models` to the constructor of `EnsembleBase`. Received both. Overriding `model_configs` with the configs belonging to `models`.
expert_idx=None
# when no external experts are used, the value of expert_idx is not used in the
# forecast/batch_forecast/evaluate functions
mode='max' # either mean or max. Max picks the expert with the highest confidence; mean computes the weighted average.
use_gpu=False # set True if GPU available for faster speed
use_batch_forecast=True # set True for higher speed
y_pred_list, std_list, y_list, sMAPE_conf, sMAPE_not_conf, recall, overall_sMAPE =\
ensemble_loaded.evaluate(test_data, mode=mode, expert_idx=expert_idx,\
use_gpu=use_gpu, use_batch_forecast=use_batch_forecast, confidence_thres=1.)
out_idx=0 # plot this idx of all the steps forecasted by MoE
print(y_pred_list.shape)
plt.plot(y_pred_list[:100, out_idx], '--', color='k', label='prediction', linewidth=1)
plt.plot(y_list[:100, out_idx], color='b', label='data', linewidth=1)
# plt.fill_between(range(y_pred_list[:100, out_idx].shape[0]), y_pred_list[:100, out_idx]-std_list[:100, out_idx],\
# y_pred_list[:100, out_idx]+std_list[:100, out_idx]) # standard deviation error band
plt.legend()
plt.show()
print(f'sMAPE on confident samples: {sMAPE_conf:.2f}')
print(f'sMAPE on not confident samples: {sMAPE_not_conf:.2f}')
print(f'Percentage of samples on which MoE was confident: {recall:.2f}% (use a different confidence_thres to change this)')
print(f'sMAPE on all samples: {overall_sMAPE:.2f}')
sMAPE on confident samples: 1.87
sMAPE on not confident samples: 2.41
Percentage of samples on which MoE was confident: 21.33% (use a different confidence_thres to change this)
sMAPE on all samples: 2.30