添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
  • Advanced Features
    • Tutorial for AutoSARIMA Forecasting Model
    • Tutorial for Mixture of Expert (MoE) Forecasting Model
      • Load dataset
      • Create MoE model composed of external expert models and train
        • Specify hyper-parameters
        • Create expert models and MoE ensembler and train
        • Load the saved ensemble model
        • Forecast using the loaded model
        • Retrieve forecasts of individual experts along with their confidence from the loaded model
        • Evaluate MoE
        • Create MoE model containing free parameters (no external experts) and train
          • Specify hyper-parameters
          • Create MoE ensembler and train
          • Evaluate MoE
          • Load the saved ensemble model
          • Tutorial for Mixture of Expert (MoE) Forecasting Model

            This notebook provides a minimal example on how to use the MoE forecasting model.

            MoE runs in 2 settings: 1. Using external expert models 2. Using free parameters (no external experts)

            Example codes are provided for both cases below.

            # workaround to enable info-level logging in Jupyter notebook % config Application.log_level='WORKAROUND' % config Application.log_level='INFO' import logging logging . getLogger () . setLevel ( logging . INFO ) INFO:numexpr.utils:Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. INFO:numexpr.utils:NumExpr defaulting to 8 threads. 100%|██████████| 414/414 [00:00<00:00, 610.42it/s] from merlion.utils import TimeSeries train_data = TimeSeries . from_pd ( time_series [ metadata [ "trainval" ]]) test_data = TimeSeries . from_pd ( time_series [ ~ metadata [ "trainval" ]]) print ( 'train timeseries shape: ' , train_data . to_pd () . values . shape ) print ( 'test timeseries shape: ' , test_data . to_pd () . values . shape ) column_names = list ( train_data . to_pd () . columns ) idx = 0 tr = train_data . to_pd ()[ column_names [ idx ]] plt . plot ( tr ) plt . show () te = test_data . to_pd ()[ column_names [ idx ]] plt . plot ( te )

            Create MoE model composed of external expert models and train

            Specify hyper-parameters

            # save directory for ensemble state. Replace it with your own choice. save_dir = 'models/moe' nfree_experts = 0 # <- no free parameters provided lookback_len = 20 max_forecast_steps = 3 target_seq_index = 0 use_gpu = False ## Pytorch network hyper-params. These are the also the hyper-params that are used in case moe_model=None is passed to MoE_ForecasterEnsemble. hidden_dim = 256 dim_head = 2 mlp_dim = 256 dim_dropout = 0. # if data is multi-dimensionsal, this can be set to a non-zero value to allow model to handle missing dimensions during test time time_step_dropout = 0 ## Pytorch network hyper-params from merlion.models.factory import ModelFactory from merlion.models.ensemble.MoE_forecast import MoE_ForecasterEnsemble , MoE_ForecasterEnsembleConfig , TransformerModel from merlion.models.ensemble.base import EnsembleTrainConfig from merlion.transform.base import Identity from merlion.transform.resample import TemporalResample ## Define configs for all the experts as well as the MoE ensembler conf_sarima = { "order" : [ 15 , 1 , 5 ], "seasonal_order" : [ 2 , 0 , 1 , 24 ], "max_forecast_steps" : max_forecast_steps , "target_seq_index" : target_seq_index , "transform" : Identity () config_arima = { "order" : [ 15 , 1 , 5 ], "max_forecast_steps" : max_forecast_steps , "target_seq_index" : target_seq_index , "transform" : Identity () config_vector_ar = { "max_forecast_steps" : max_forecast_steps , "target_seq_index" : target_seq_index , "maxlags" : 14 } config_ensemble = MoE_ForecasterEnsembleConfig ( batch_size = 64 , lr = 0.0001 , \ nfree_experts = nfree_experts , epoch_max = 100 , \ lookback_len = lookback_len , \ max_forecast_steps = max_forecast_steps , \ target_seq_index = target_seq_index , use_gpu = use_gpu , \ transform = TemporalResample ()) train_config_ensemble = EnsembleTrainConfig ( valid_frac = 0.5 ) # Define expert models model = ModelFactory . create ( "Sarima" , ** conf_sarima ) model2 = ModelFactory . create ( "Arima" , ** config_arima ) model3 = ModelFactory . create ( "VectorAR" , ** config_vector_ar ) models = [ model , model2 , model3 ] nexperts = len ( models ) Instantiate deep network for MoE. It can also be instantiated as None. In that case, the default Pytorch network specified in the MoE_ForecasterEnsemble class will be used. FYI, the network below is used as the default network in MoE_ForecasterEnsemble. moe_model = TransformerModel ( input_dim = train_data . dim , lookback_len = lookback_len , nexperts = nexperts , \ output_dim = max_forecast_steps , nfree_experts = nfree_experts , \ hid_dim = hidden_dim , dim_head = dim_head , mlp_dim = mlp_dim , \ pool = 'cls' , dim_dropout = dim_dropout , \ time_step_dropout = time_step_dropout ) # moe_model = None # use me if you want to see the default model in use # create MoE forecaster model ensemble = MoE_ForecasterEnsemble ( config = config_ensemble , models = models , moe_model = moe_model ) # train & save MoE loss_list = ensemble . train ( train_data = train_data , train_config = train_config_ensemble ) ensemble . save ( save_dir ) INFO:merlion.models.ensemble.MoE_forecast:Training model 1/3... INFO:merlion.models.ensemble.MoE_forecast:Training model 2/3... INFO:merlion.models.ensemble.MoE_forecast:Training model 3/3... INFO:merlion.models.ensemble.MoE_forecast:Extracting and storing expert predictions 0%| | 0/6 [00:00<?, ?it/s]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 17%|█▋ | 1/6 [00:06<00:33, 6.78s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 33%|███▎ | 2/6 [00:10<00:19, 4.91s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 50%|█████ | 3/6 [00:14<00:13, 4.52s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 67%|██████▋ | 4/6 [00:17<00:08, 4.11s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 83%|████████▎ | 5/6 [00:22<00:04, 4.34s/it]INFO:merlion.models.ensemble.MoE_forecast:Getting model 1/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 2/3 predictions... INFO:merlion.models.ensemble.MoE_forecast:Getting model 3/3 predictions... 100%|██████████| 6/6 [00:23<00:00, 3.91s/it] Epoch 1 Loss: 1.304425: 100%|██████████| 6/6 [00:03<00:00, 1.92it/s] Epoch 2 Loss: 1.326118: 100%|██████████| 6/6 [00:03<00:00, 1.85it/s] Epoch 3 Loss: 1.285732: 100%|██████████| 6/6 [00:03<00:00, 1.61it/s] Epoch 4 Loss: 1.239944: 100%|██████████| 6/6 [00:03<00:00, 1.86it/s] Epoch 5 Loss: 1.217539: 100%|██████████| 6/6 [00:03<00:00, 1.92it/s] Epoch 6 Loss: 1.194959: 100%|██████████| 6/6 [00:04<00:00, 1.43it/s] Epoch 7 Loss: 1.182055: 100%|██████████| 6/6 [00:03<00:00, 1.86it/s] Epoch 8 Loss: 1.175249: 100%|██████████| 6/6 [00:02<00:00, 2.15it/s] Epoch 9 Loss: 1.173702: 100%|██████████| 6/6 [00:02<00:00, 2.43it/s] Epoch 10 Loss: 1.186111: 100%|██████████| 6/6 [00:02<00:00, 2.36it/s] Epoch 11 Loss: 1.184515: 100%|██████████| 6/6 [00:02<00:00, 2.08it/s] Epoch 12 Loss: 1.169571: 100%|██████████| 6/6 [00:02<00:00, 2.29it/s] Epoch 13 Loss: 1.174548: 100%|██████████| 6/6 [00:02<00:00, 2.42it/s] Epoch 14 Loss: 1.180143: 100%|██████████| 6/6 [00:02<00:00, 2.30it/s] Epoch 15 Loss: 1.186909: 100%|██████████| 6/6 [00:03<00:00, 1.95it/s] Epoch 16 Loss: 1.175635: 100%|██████████| 6/6 [00:02<00:00, 2.36it/s] Epoch 17 Loss: 1.189920: 100%|██████████| 6/6 [00:02<00:00, 2.23it/s] Epoch 18 Loss: 1.182831: 100%|██████████| 6/6 [00:03<00:00, 1.88it/s] Epoch 19 Loss: 1.176439: 100%|██████████| 6/6 [00:03<00:00, 1.90it/s] Epoch 20 Loss: 1.174525: 100%|██████████| 6/6 [00:02<00:00, 2.06it/s] Epoch 21 Loss: 1.172877: 100%|██████████| 6/6 [00:02<00:00, 2.07it/s] Epoch 22 Loss: 1.182378: 100%|██████████| 6/6 [00:03<00:00, 1.80it/s] Epoch 23 Loss: 1.165431: 100%|██████████| 6/6 [00:02<00:00, 2.23it/s] Epoch 24 Loss: 1.178699: 100%|██████████| 6/6 [00:02<00:00, 2.48it/s] Epoch 25 Loss: 1.179477: 100%|██████████| 6/6 [00:02<00:00, 2.32it/s] Epoch 26 Loss: 1.174176: 100%|██████████| 6/6 [00:02<00:00, 2.10it/s] Epoch 27 Loss: 1.187021: 100%|██████████| 6/6 [00:02<00:00, 2.32it/s] Epoch 28 Loss: 1.175006: 100%|██████████| 6/6 [00:02<00:00, 2.47it/s] Epoch 29 Loss: 1.179644: 100%|██████████| 6/6 [00:02<00:00, 2.39it/s] Epoch 30 Loss: 1.172667: 100%|██████████| 6/6 [00:02<00:00, 2.01it/s] Epoch 31 Loss: 1.166672: 100%|██████████| 6/6 [00:02<00:00, 2.45it/s] Epoch 32 Loss: 1.180196: 100%|██████████| 6/6 [00:02<00:00, 2.21it/s] Epoch 33 Loss: 1.180659: 100%|██████████| 6/6 [00:02<00:00, 2.03it/s] Epoch 34 Loss: 1.172211: 100%|██████████| 6/6 [00:02<00:00, 2.04it/s] Epoch 35 Loss: 1.179359: 100%|██████████| 6/6 [00:02<00:00, 2.39it/s] Epoch 36 Loss: 1.171128: 100%|██████████| 6/6 [00:03<00:00, 1.97it/s] Epoch 37 Loss: 1.175621: 100%|██████████| 6/6 [00:03<00:00, 1.65it/s] Epoch 38 Loss: 1.178169: 100%|██████████| 6/6 [00:02<00:00, 2.02it/s] Epoch 39 Loss: 1.171359: 100%|██████████| 6/6 [00:03<00:00, 1.94it/s] Epoch 40 Loss: 1.178329: 100%|██████████| 6/6 [00:02<00:00, 2.09it/s] Epoch 41 Loss: 1.183362: 100%|██████████| 6/6 [00:03<00:00, 1.99it/s] Epoch 42 Loss: 1.192560: 100%|██████████| 6/6 [00:02<00:00, 2.57it/s] Epoch 43 Loss: 1.180345: 100%|██████████| 6/6 [00:02<00:00, 2.41it/s] Epoch 44 Loss: 1.173638: 100%|██████████| 6/6 [00:02<00:00, 2.13it/s] Epoch 45 Loss: 1.152926: 100%|██████████| 6/6 [00:02<00:00, 2.26it/s] Epoch 46 Loss: 1.173875: 100%|██████████| 6/6 [00:02<00:00, 2.49it/s] Epoch 47 Loss: 1.164550: 100%|██████████| 6/6 [00:02<00:00, 2.30it/s] Epoch 48 Loss: 1.159929: 100%|██████████| 6/6 [00:03<00:00, 1.90it/s] Epoch 49 Loss: 1.153424: 100%|██████████| 6/6 [00:02<00:00, 2.29it/s] Epoch 50 Loss: 1.154480: 100%|██████████| 6/6 [00:02<00:00, 2.25it/s] Epoch 51 Loss: 1.140436: 100%|██████████| 6/6 [00:02<00:00, 2.04it/s] Epoch 52 Loss: 1.146496: 100%|██████████| 6/6 [00:03<00:00, 1.67it/s] Epoch 53 Loss: 1.103287: 100%|██████████| 6/6 [00:03<00:00, 1.95it/s] Epoch 54 Loss: 1.126842: 100%|██████████| 6/6 [00:03<00:00, 1.91it/s] Epoch 55 Loss: 1.090270: 100%|██████████| 6/6 [00:03<00:00, 1.89it/s] Epoch 56 Loss: 1.083246: 100%|██████████| 6/6 [00:02<00:00, 2.12it/s] Epoch 57 Loss: 1.103495: 100%|██████████| 6/6 [00:02<00:00, 2.47it/s] Epoch 58 Loss: 1.081064: 100%|██████████| 6/6 [00:02<00:00, 2.38it/s] Epoch 59 Loss: 1.031706: 100%|██████████| 6/6 [00:02<00:00, 2.14it/s] Epoch 60 Loss: 1.028626: 100%|██████████| 6/6 [00:02<00:00, 2.26it/s] Epoch 61 Loss: 1.070291: 100%|██████████| 6/6 [00:02<00:00, 2.40it/s] Epoch 62 Loss: 1.070380: 100%|██████████| 6/6 [00:02<00:00, 2.25it/s] Epoch 63 Loss: 1.057747: 100%|██████████| 6/6 [00:03<00:00, 1.96it/s] Epoch 64 Loss: 1.022432: 100%|██████████| 6/6 [00:02<00:00, 2.37it/s] Epoch 65 Loss: 1.092833: 100%|██████████| 6/6 [00:02<00:00, 2.26it/s] Epoch 66 Loss: 1.046667: 100%|██████████| 6/6 [00:02<00:00, 2.01it/s] Epoch 67 Loss: 1.024409: 100%|██████████| 6/6 [00:03<00:00, 1.95it/s] Epoch 68 Loss: 1.006745: 100%|██████████| 6/6 [00:02<00:00, 2.12it/s] Epoch 69 Loss: 1.026753: 100%|██████████| 6/6 [00:03<00:00, 1.99it/s] Epoch 70 Loss: 1.035914: 100%|██████████| 6/6 [00:03<00:00, 1.78it/s] Epoch 71 Loss: 1.011550: 100%|██████████| 6/6 [00:03<00:00, 1.76it/s] Epoch 72 Loss: 1.025698: 100%|██████████| 6/6 [00:02<00:00, 2.32it/s] Epoch 73 Loss: 1.021778: 100%|██████████| 6/6 [00:02<00:00, 2.44it/s] Epoch 74 Loss: 0.999083: 100%|██████████| 6/6 [00:02<00:00, 2.38it/s] Epoch 75 Loss: 1.003671: 100%|██████████| 6/6 [00:02<00:00, 2.02it/s] Epoch 76 Loss: 1.007998: 100%|██████████| 6/6 [00:02<00:00, 2.34it/s] Epoch 77 Loss: 0.992967: 100%|██████████| 6/6 [00:02<00:00, 2.45it/s] Epoch 78 Loss: 1.039574: 100%|██████████| 6/6 [00:02<00:00, 2.33it/s] Epoch 79 Loss: 0.979768: 100%|██████████| 6/6 [00:03<00:00, 1.85it/s] Epoch 80 Loss: 0.980897: 100%|██████████| 6/6 [00:02<00:00, 2.22it/s] Epoch 81 Loss: 1.021148: 100%|██████████| 6/6 [00:02<00:00, 2.36it/s] Epoch 82 Loss: 0.964973: 100%|██████████| 6/6 [00:02<00:00, 2.26it/s] Epoch 83 Loss: 0.970942: 100%|██████████| 6/6 [00:03<00:00, 1.83it/s] Epoch 84 Loss: 0.996262: 100%|██████████| 6/6 [00:03<00:00, 1.74it/s] Epoch 85 Loss: 0.961435: 100%|██████████| 6/6 [00:02<00:00, 2.04it/s] Epoch 86 Loss: 0.975354: 100%|██████████| 6/6 [00:02<00:00, 2.07it/s] Epoch 87 Loss: 0.980454: 100%|██████████| 6/6 [00:02<00:00, 2.06it/s] Epoch 88 Loss: 0.957371: 100%|██████████| 6/6 [00:02<00:00, 2.04it/s] Epoch 89 Loss: 0.940112: 100%|██████████| 6/6 [00:02<00:00, 2.39it/s] Epoch 90 Loss: 0.939432: 100%|██████████| 6/6 [00:02<00:00, 2.41it/s] Epoch 91 Loss: 0.952867: 100%|██████████| 6/6 [00:02<00:00, 2.38it/s] Epoch 92 Loss: 0.973232: 100%|██████████| 6/6 [00:02<00:00, 2.08it/s] Epoch 93 Loss: 0.988913: 100%|██████████| 6/6 [00:02<00:00, 2.53it/s] Epoch 94 Loss: 0.954114: 100%|██████████| 6/6 [00:02<00:00, 2.44it/s] Epoch 95 Loss: 0.906431: 100%|██████████| 6/6 [00:02<00:00, 2.29it/s] Epoch 96 Loss: 0.909996: 100%|██████████| 6/6 [00:03<00:00, 1.94it/s] Epoch 97 Loss: 0.945528: 100%|██████████| 6/6 [00:02<00:00, 2.18it/s] Epoch 98 Loss: 0.922009: 100%|██████████| 6/6 [00:02<00:00, 2.39it/s] Epoch 99 Loss: 0.920070: 100%|██████████| 6/6 [00:02<00:00, 2.21it/s] Epoch 100 Loss: 0.912859: 100%|██████████| 6/6 [00:03<00:00, 1.72it/s] WARNING:merlion.models.ensemble.base:When initializing an ensemble, you must either provide the dict `model_configs` (mapping each model's name to its config) when creating the `DetectorEnsembleConfig`, or provide a list of `models` to the constructor of `EnsembleBase`. Received both. Overriding `model_configs` with the configs belonging to `models`. end_idx = sample_length timestamps = test_data . univariates [ test_data . names [ 0 ]] . time_stamps [ start_idx : end_idx ] data = test_data . to_pd () . values data = data [ start_idx : end_idx ] timestamps = timestamps [ lookback_len :] x = data [: lookback_len ] x_ts = test_data [ start_idx : start_idx + lookback_len ] y = data [ lookback_len :, target_seq_index ] print ( 'True output: \n ' ) print ( y ) # perform single forecast print ( 'Performing single forecast: \n ' ) forecast , se = ensemble_loaded . forecast ( time_stamps = timestamps , time_series_prev = x_ts , expert_idx = None , mode = 'max' , use_gpu = False ) print ( 'Forecast \n ' , forecast ) print ( 'Standard Error \n ' , se ) # perform batch forecast (for simplicity, just feeding a list of single sample) print ( ' \n\n Performing batch forecast (notice the output is a list): \n ' ) forecast , se = ensemble_loaded . batch_forecast ( time_stamps_list = [ timestamps ], time_series_prev_list = [ x_ts ], expert_idx = None , mode = 'max' , use_gpu = False ) print ( 'Forecasts \n ' , forecast ) print ( 'Standard Errors \n ' , se ) 1677-10-28 00:00:00 810.339722 1677-10-28 01:00:00 796.296448 1677-10-28 02:00:00 760.550842 Standard Error H1_err 1677-10-28 00:00:00 42.765224 1677-10-28 01:00:00 198.829971 1677-10-28 02:00:00 350.278381 Performing batch forecast (notice the output is a list): Forecasts [ H1_0 1677-10-28 00:00:00 810.339722 1677-10-28 01:00:00 796.296448 1677-10-28 02:00:00 760.550842] Standard Errors [ H1_err_0 1677-10-28 00:00:00 42.765224 1677-10-28 01:00:00 198.829971 1677-10-28 02:00:00 350.278381]

            Retrieve forecasts of individual experts along with their confidence from the loaded model

            # perform forecast at the beginning of the test_data timestamp of length 3 and lookback=20 lookback_len = 20 forecast_len = 3 sample_length = lookback_len + forecast_len target_seq_index = 0 start_idx = 0 end_idx = sample_length timestamps = test_data . univariates [ test_data . names [ 0 ]] . time_stamps [ start_idx : end_idx ] data = test_data . to_pd () . values data = data [ start_idx : end_idx ] timestamps = timestamps [ lookback_len :] x = data [: lookback_len ] # shape (20,1) x_ts = test_data [ start_idx : start_idx + lookback_len ] y = data [ lookback_len :, target_seq_index ] print ( 'True output: \n ' ) print ( y ) # perform single forecast print ( 'Getting individual expert forecast and standard deviation for single data (notice the array shape): \n ' ) forecast , std = ensemble_loaded . _forecast ( time_stamps = timestamps , time_series_prev = x_ts , expert_idx = None , use_gpu = False ) print ( f 'Forecast (shape: { forecast . shape } ) \n ' , forecast ) print ( f 'Standard deviation (shape: { std . shape } ) \n ' , std ) # perform batch forecast (for simplicity, just feeding a list of single sample) print ( ' \n\n Getting individual expert forecast and standard deviation for a batch of data (notice the array shape): \n ' ) forecast , std = ensemble_loaded . _batch_forecast ( time_stamps_list = [ timestamps ], time_series_prev_array = np . expand_dims ( x , axis = 0 ), # shape (1,20,1) time_series_prev_list = [ x_ts ], expert_idx = None , use_gpu = False ) print ( f 'Forecast (shape: { forecast . shape } ) \n ' , forecast ) print ( f 'Standard deviation (shape: { std . shape } ) \n ' , std ) [803. 769. 751.] Getting individual expert forecast and standard deviation for single data (notice the array shape): Forecast (shape: (3, 3)) [[828.93043244 796.29646378 760.5508621 ] [812.10700752 760.22152836 709.16665671] [810.33973458 757.56872591 700.9503937 ]] Standard deviation (shape: (3, 3)) [[0.361902 0.47081903 0.7306497 ] [0.26711744 0.3186435 0.14048512] [0.3709806 0.21053748 0.12886517]] Getting individual expert forecast and standard deviation for a batch of data (notice the array shape): Forecast (shape: (1, 3, 3)) [[[828.93043244 796.29646378 760.5508621 ] [812.10700752 760.22152836 709.16665671] [810.33973458 757.56872591 700.9503937 ]]] Standard deviation (shape: (1, 3, 3)) [[[0.361902 0.47081903 0.7306497 ] [0.26711744 0.3186435 0.14048512] [0.3709806 0.21053748 0.12886517]]] expert_idx = None # if expert_idx=None, MoE uses all the experts provided and uses the 'mode' strategy specified below to forecast # if value is int (E.g. 0), MoE only uses the external expert at the corresponding index of `models` to make forecasts mode = 'max' # either mean or max. Max picks the expert with the highest confidence; mean computes the weighted average. use_gpu = False # set True if GPU available for faster speed use_batch_forecast = True # set True for higher speed y_pred_list , std_list , y_list , sMAPE_conf , sMAPE_not_conf , recall , overall_sMAPE = \ ensemble_loaded . evaluate ( test_data , mode = mode , expert_idx = expert_idx , \ use_gpu = use_gpu , use_batch_forecast = use_batch_forecast , confidence_thres = 100 ) out_idx = 0 # plot this idx of all the steps forecasted by MoE print ( y_pred_list . shape ) plt . plot ( y_pred_list [: 100 , out_idx ], '--' , color = 'k' , label = 'prediction' , linewidth = 1 ) # plotting 1st 100 for clarity plt . plot ( y_list [: 100 , out_idx ], color = 'b' , label = 'data' , linewidth = 1 ) # plt.fill_between(range(y_pred_list[:100, out_idx].shape[0]), y_pred_list[:100, out_idx]-std_list[:100, out_idx],\ # y_pred_list[:100, out_idx]+std_list[:100, out_idx]) # standard deviation error band plt . legend () plt . show () print ( f 'sMAPE on confident samples: { sMAPE_conf : .2f } ' ) print ( f 'sMAPE on not confident samples: { sMAPE_not_conf : .2f } ' ) print ( f 'Percentage of samples on which MoE was confident: { recall : .2f } % (use a different confidence_thres to change this)' ) print ( f 'sMAPE on all samples: { overall_sMAPE : .2f } ' ) 0%| | 0/1 [00:00<?, ?it/s] sMAPE_conf: 2.184 sMAPE_not_conf: 0.000 recall: 100.000% | Plain sMAPE 2.184: 0%| | 0/1 [00:01<?, ?it/s] sMAPE_conf: 2.184 sMAPE_not_conf: 0.000 recall: 100.000% | Plain sMAPE 2.184: 100%|██████████| 1/1 [00:01<00:00, 1.71s/it] sMAPE on confident samples: 2.18 sMAPE on not confident samples: 0.00 Percentage of samples on which MoE was confident: 100.00% (use a different confidence_thres to change this) sMAPE on all samples: 2.18

            Create MoE model containing free parameters (no external experts) and train

            Specify hyper-parameters

            [10]:
            import os
            # save directory for ensemble state. Replace it with your own choice.
            save_dir = 'models/moe2'
            nfree_experts= 3000 # <- number of free experts
            lookback_len=20
            max_forecast_steps=3
            target_seq_index = 0
            use_gpu=False
            ## Pytorch network hyper-params. These are the also the hyper-params that are used in case moe_model=None is passed to MoE_ForecasterEnsemble.
            hidden_dim=256
            dim_head = 2
            mlp_dim=256
            dim_dropout=0. # if data is multi-dimensionsal, this can be set to a non-zero value to allow model to handle missing dimensions during test time
            time_step_dropout=0
            ## Pytorch network hyper-params
            config_ensemble = MoE_ForecasterEnsembleConfig(
                batch_size=64, lr=0.0001, nfree_experts=nfree_experts, epoch_max=300,
                lookback_len=lookback_len, max_forecast_steps=max_forecast_steps,
                target_seq_index=target_seq_index, use_gpu=use_gpu,
                transform=TemporalResample())
            train_config_ensemble = EnsembleTrainConfig(valid_frac=0.5)
            # Define expert models
            models = [] # <- no external experts provided
            nexperts = len(models)
            # instantiate deep network for MoE
            moe_model = TransformerModel(input_dim=len(train_data.names), lookback_len=lookback_len, nexperts=nexperts,\
                                output_dim=max_forecast_steps, nfree_experts=nfree_experts,\
                                hid_dim=hidden_dim, dim_head = dim_head, mlp_dim=mlp_dim,\
                                 pool='cls', dim_dropout=dim_dropout,\
                                time_step_dropout=time_step_dropout)
            moe_model = None # use me if you want to see the default model in use
            # create MoE forecaster model
            ensemble = MoE_ForecasterEnsemble(config=config_ensemble, models= models, moe_model=moe_model)
            # train MoE
            loss_list = ensemble.train(train_data=train_data, train_config = train_config_ensemble)
            ensemble.save(save_dir)
            Epoch 1 Loss: 8.253956: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
            Epoch 2 Loss: 8.215341: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
            Epoch 3 Loss: 8.140997: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
            Epoch 4 Loss: 8.030143: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
            Epoch 5 Loss: 7.880345: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
            Epoch 6 Loss: 7.679686: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
            Epoch 7 Loss: 7.431627: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
            Epoch 8 Loss: 7.121155: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
            Epoch 9 Loss: 6.788598: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
            Epoch 10 Loss: 6.469567: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s]
            Epoch 11 Loss: 6.230416: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
            Epoch 12 Loss: 6.048095: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
            Epoch 13 Loss: 5.915649: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
            Epoch 14 Loss: 5.811184: 100%|██████████| 11/11 [00:04<00:00,  2.69it/s]
            Epoch 15 Loss: 5.744193: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
            Epoch 16 Loss: 5.686575: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 17 Loss: 5.657169: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
            Epoch 18 Loss: 5.635586: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s]
            Epoch 19 Loss: 5.627381: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
            Epoch 20 Loss: 5.592227: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
            Epoch 21 Loss: 5.565175: 100%|██████████| 11/11 [00:03<00:00,  3.18it/s]
            Epoch 22 Loss: 5.561776: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
            Epoch 23 Loss: 5.541681: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
            Epoch 24 Loss: 5.545103: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 25 Loss: 5.522188: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
            Epoch 26 Loss: 5.508581: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
            Epoch 27 Loss: 5.489652: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
            Epoch 28 Loss: 5.476248: 100%|██████████| 11/11 [00:04<00:00,  2.51it/s]
            Epoch 29 Loss: 5.466399: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
            Epoch 30 Loss: 5.475472: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 31 Loss: 5.461387: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 32 Loss: 5.433480: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
            Epoch 33 Loss: 5.417164: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
            Epoch 34 Loss: 5.391122: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 35 Loss: 5.355259: 100%|██████████| 11/11 [00:03<00:00,  3.24it/s]
            Epoch 36 Loss: 5.324664: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 37 Loss: 5.289907: 100%|██████████| 11/11 [00:03<00:00,  3.53it/s]
            Epoch 38 Loss: 5.252770: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
            Epoch 39 Loss: 5.224870: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
            Epoch 40 Loss: 5.189947: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
            Epoch 41 Loss: 5.141784: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
            Epoch 42 Loss: 5.089564: 100%|██████████| 11/11 [00:03<00:00,  2.83it/s]
            Epoch 43 Loss: 5.053859: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
            Epoch 44 Loss: 5.017187: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
            Epoch 45 Loss: 4.980298: 100%|██████████| 11/11 [00:03<00:00,  2.88it/s]
            Epoch 46 Loss: 4.946596: 100%|██████████| 11/11 [00:03<00:00,  2.99it/s]
            Epoch 47 Loss: 4.913864: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
            Epoch 48 Loss: 4.886967: 100%|██████████| 11/11 [00:03<00:00,  3.64it/s]
            Epoch 49 Loss: 4.860845: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s]
            Epoch 50 Loss: 4.831104: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 51 Loss: 4.816362: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
            Epoch 52 Loss: 4.779174: 100%|██████████| 11/11 [00:03<00:00,  3.39it/s]
            Epoch 53 Loss: 4.736116: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 54 Loss: 4.717974: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
            Epoch 55 Loss: 4.690528: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
            Epoch 56 Loss: 4.657373: 100%|██████████| 11/11 [00:04<00:00,  2.68it/s]
            Epoch 57 Loss: 4.630430: 100%|██████████| 11/11 [00:04<00:00,  2.64it/s]
            Epoch 58 Loss: 4.626685: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
            Epoch 59 Loss: 4.576031: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
            Epoch 60 Loss: 4.536126: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
            Epoch 61 Loss: 4.522658: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
            Epoch 62 Loss: 4.482119: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
            Epoch 63 Loss: 4.440992: 100%|██████████| 11/11 [00:03<00:00,  3.66it/s]
            Epoch 64 Loss: 4.420945: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
            Epoch 65 Loss: 4.380464: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
            Epoch 66 Loss: 4.354569: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
            Epoch 67 Loss: 4.330108: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
            Epoch 68 Loss: 4.302790: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 69 Loss: 4.276388: 100%|██████████| 11/11 [00:03<00:00,  3.63it/s]
            Epoch 70 Loss: 4.252587: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
            Epoch 71 Loss: 4.220399: 100%|██████████| 11/11 [00:03<00:00,  2.80it/s]
            Epoch 72 Loss: 4.197742: 100%|██████████| 11/11 [00:03<00:00,  2.84it/s]
            Epoch 73 Loss: 4.184703: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
            Epoch 74 Loss: 4.177047: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
            Epoch 75 Loss: 4.132162: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
            Epoch 76 Loss: 4.098515: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
            Epoch 77 Loss: 4.085912: 100%|██████████| 11/11 [00:03<00:00,  3.63it/s]
            Epoch 78 Loss: 4.072778: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
            Epoch 79 Loss: 4.026296: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
            Epoch 80 Loss: 4.000065: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
            Epoch 81 Loss: 3.971918: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s]
            Epoch 82 Loss: 3.954153: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 83 Loss: 3.924814: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
            Epoch 84 Loss: 3.895262: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
            Epoch 85 Loss: 3.874157: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
            Epoch 86 Loss: 3.862910: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            Epoch 87 Loss: 3.845997: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 88 Loss: 3.811604: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 89 Loss: 3.794181: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
            Epoch 90 Loss: 3.760030: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
            Epoch 91 Loss: 3.738671: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
            Epoch 92 Loss: 3.720268: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 93 Loss: 3.713584: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
            Epoch 94 Loss: 3.671633: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
            Epoch 95 Loss: 3.651652: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
            Epoch 96 Loss: 3.639432: 100%|██████████| 11/11 [00:03<00:00,  3.07it/s]
            Epoch 97 Loss: 3.607364: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s]
            Epoch 98 Loss: 3.583107: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
            Epoch 99 Loss: 3.562718: 100%|██████████| 11/11 [00:03<00:00,  3.03it/s]
            Epoch 100 Loss: 3.548420: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
            Epoch 101 Loss: 3.531171: 100%|██████████| 11/11 [00:04<00:00,  2.66it/s]
            Epoch 102 Loss: 3.509029: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
            Epoch 103 Loss: 3.485391: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
            Epoch 104 Loss: 3.450888: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
            Epoch 105 Loss: 3.424491: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
            Epoch 106 Loss: 3.403693: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
            Epoch 107 Loss: 3.390665: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
            Epoch 108 Loss: 3.359253: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
            Epoch 109 Loss: 3.345198: 100%|██████████| 11/11 [00:03<00:00,  2.97it/s]
            Epoch 110 Loss: 3.349600: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
            Epoch 111 Loss: 3.329763: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 112 Loss: 3.303129: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
            Epoch 113 Loss: 3.266704: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
            Epoch 114 Loss: 3.240781: 100%|██████████| 11/11 [00:03<00:00,  2.94it/s]
            Epoch 115 Loss: 3.215163: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
            Epoch 116 Loss: 3.198615: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
            Epoch 117 Loss: 3.169775: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
            Epoch 118 Loss: 3.152784: 100%|██████████| 11/11 [00:03<00:00,  3.47it/s]
            Epoch 119 Loss: 3.128009: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
            Epoch 120 Loss: 3.106502: 100%|██████████| 11/11 [00:03<00:00,  3.67it/s]
            Epoch 121 Loss: 3.093420: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
            Epoch 122 Loss: 3.078810: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
            Epoch 123 Loss: 3.050844: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 124 Loss: 3.052847: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
            Epoch 125 Loss: 3.025279: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
            Epoch 126 Loss: 3.009476: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s]
            Epoch 127 Loss: 2.985609: 100%|██████████| 11/11 [00:04<00:00,  2.57it/s]
            Epoch 128 Loss: 2.966175: 100%|██████████| 11/11 [00:03<00:00,  2.95it/s]
            Epoch 129 Loss: 2.948951: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
            Epoch 130 Loss: 2.939453: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
            Epoch 131 Loss: 2.932430: 100%|██████████| 11/11 [00:03<00:00,  3.57it/s]
            Epoch 132 Loss: 2.893917: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 133 Loss: 2.855756: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
            Epoch 134 Loss: 2.839368: 100%|██████████| 11/11 [00:02<00:00,  3.70it/s]
            Epoch 135 Loss: 2.817158: 100%|██████████| 11/11 [00:03<00:00,  3.30it/s]
            Epoch 136 Loss: 2.791486: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
            Epoch 137 Loss: 2.780610: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 138 Loss: 2.761558: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
            Epoch 139 Loss: 2.750291: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s]
            Epoch 140 Loss: 2.730247: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
            Epoch 141 Loss: 2.698439: 100%|██████████| 11/11 [00:04<00:00,  2.67it/s]
            Epoch 142 Loss: 2.685364: 100%|██████████| 11/11 [00:03<00:00,  3.15it/s]
            Epoch 143 Loss: 2.662411: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s]
            Epoch 144 Loss: 2.658126: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
            Epoch 145 Loss: 2.632239: 100%|██████████| 11/11 [00:03<00:00,  3.14it/s]
            Epoch 146 Loss: 2.610601: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
            Epoch 147 Loss: 2.583837: 100%|██████████| 11/11 [00:03<00:00,  3.65it/s]
            Epoch 148 Loss: 2.556195: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 149 Loss: 2.547484: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 150 Loss: 2.525524: 100%|██████████| 11/11 [00:02<00:00,  3.67it/s]
            Epoch 151 Loss: 2.514556: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 152 Loss: 2.501829: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 153 Loss: 2.478203: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
            Epoch 154 Loss: 2.457710: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
            Epoch 155 Loss: 2.436375: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
            Epoch 156 Loss: 2.412483: 100%|██████████| 11/11 [00:03<00:00,  2.77it/s]
            Epoch 157 Loss: 2.394380: 100%|██████████| 11/11 [00:03<00:00,  3.07it/s]
            Epoch 158 Loss: 2.380202: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
            Epoch 159 Loss: 2.355152: 100%|██████████| 11/11 [00:03<00:00,  3.50it/s]
            Epoch 160 Loss: 2.347697: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 161 Loss: 2.345931: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
            Epoch 162 Loss: 2.296478: 100%|██████████| 11/11 [00:03<00:00,  3.64it/s]
            Epoch 163 Loss: 2.282725: 100%|██████████| 11/11 [00:03<00:00,  3.58it/s]
            Epoch 164 Loss: 2.249418: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
            Epoch 165 Loss: 2.241911: 100%|██████████| 11/11 [00:03<00:00,  3.60it/s]
            Epoch 166 Loss: 2.215718: 100%|██████████| 11/11 [00:03<00:00,  3.45it/s]
            Epoch 167 Loss: 2.194424: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
            Epoch 168 Loss: 2.184618: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
            Epoch 169 Loss: 2.156022: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
            Epoch 170 Loss: 2.152164: 100%|██████████| 11/11 [00:03<00:00,  3.05it/s]
            Epoch 171 Loss: 2.128624: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
            Epoch 172 Loss: 2.109892: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
            Epoch 173 Loss: 2.137191: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
            Epoch 174 Loss: 2.118212: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 175 Loss: 2.092953: 100%|██████████| 11/11 [00:03<00:00,  3.12it/s]
            Epoch 176 Loss: 2.065081: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
            Epoch 177 Loss: 2.064701: 100%|██████████| 11/11 [00:02<00:00,  3.69it/s]
            Epoch 178 Loss: 2.061531: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
            Epoch 179 Loss: 2.041025: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 180 Loss: 1.999497: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 181 Loss: 1.953928: 100%|██████████| 11/11 [00:03<00:00,  3.42it/s]
            Epoch 182 Loss: 1.929946: 100%|██████████| 11/11 [00:03<00:00,  2.86it/s]
            Epoch 183 Loss: 1.923069: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
            Epoch 184 Loss: 1.905399: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
            Epoch 185 Loss: 1.892242: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
            Epoch 186 Loss: 1.869357: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            Epoch 187 Loss: 1.840604: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
            Epoch 188 Loss: 1.837614: 100%|██████████| 11/11 [00:02<00:00,  3.68it/s]
            Epoch 189 Loss: 1.816457: 100%|██████████| 11/11 [00:03<00:00,  3.57it/s]
            Epoch 190 Loss: 1.786459: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 191 Loss: 1.776885: 100%|██████████| 11/11 [00:02<00:00,  3.79it/s]
            Epoch 192 Loss: 1.766457: 100%|██████████| 11/11 [00:03<00:00,  3.61it/s]
            Epoch 193 Loss: 1.742621: 100%|██████████| 11/11 [00:03<00:00,  3.39it/s]
            Epoch 194 Loss: 1.723335: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
            Epoch 195 Loss: 1.708375: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 196 Loss: 1.692905: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
            Epoch 197 Loss: 1.694218: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
            Epoch 198 Loss: 1.677686: 100%|██████████| 11/11 [00:03<00:00,  2.96it/s]
            Epoch 199 Loss: 1.644678: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
            Epoch 200 Loss: 1.632020: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
            Epoch 201 Loss: 1.604089: 100%|██████████| 11/11 [00:03<00:00,  3.53it/s]
            Epoch 202 Loss: 1.597124: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s]
            Epoch 203 Loss: 1.580226: 100%|██████████| 11/11 [00:02<00:00,  3.74it/s]
            Epoch 204 Loss: 1.577800: 100%|██████████| 11/11 [00:03<00:00,  3.62it/s]
            Epoch 205 Loss: 1.556550: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s]
            Epoch 206 Loss: 1.531670: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 207 Loss: 1.524184: 100%|██████████| 11/11 [00:03<00:00,  3.55it/s]
            Epoch 208 Loss: 1.504326: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s]
            Epoch 209 Loss: 1.495385: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s]
            Epoch 210 Loss: 1.477028: 100%|██████████| 11/11 [00:03<00:00,  2.92it/s]
            Epoch 211 Loss: 1.466456: 100%|██████████| 11/11 [00:03<00:00,  3.16it/s]
            Epoch 212 Loss: 1.437730: 100%|██████████| 11/11 [00:03<00:00,  2.97it/s]
            Epoch 213 Loss: 1.439750: 100%|██████████| 11/11 [00:04<00:00,  2.70it/s]
            Epoch 214 Loss: 1.427772: 100%|██████████| 11/11 [00:03<00:00,  3.19it/s]
            Epoch 215 Loss: 1.417080: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
            Epoch 216 Loss: 1.403871: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
            Epoch 217 Loss: 1.374399: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
            Epoch 218 Loss: 1.362701: 100%|██████████| 11/11 [00:02<00:00,  3.83it/s]
            Epoch 219 Loss: 1.348216: 100%|██████████| 11/11 [00:02<00:00,  3.76it/s]
            Epoch 220 Loss: 1.337092: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s]
            Epoch 221 Loss: 1.331288: 100%|██████████| 11/11 [00:03<00:00,  3.04it/s]
            Epoch 222 Loss: 1.322314: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
            Epoch 223 Loss: 1.303005: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
            Epoch 224 Loss: 1.308330: 100%|██████████| 11/11 [00:03<00:00,  2.86it/s]
            Epoch 225 Loss: 1.307320: 100%|██████████| 11/11 [00:05<00:00,  2.17it/s]
            Epoch 226 Loss: 1.259610: 100%|██████████| 11/11 [00:04<00:00,  2.59it/s]
            Epoch 227 Loss: 1.247879: 100%|██████████| 11/11 [00:04<00:00,  2.44it/s]
            Epoch 228 Loss: 1.232500: 100%|██████████| 11/11 [00:03<00:00,  2.84it/s]
            Epoch 229 Loss: 1.223805: 100%|██████████| 11/11 [00:04<00:00,  2.71it/s]
            Epoch 230 Loss: 1.210760: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
            Epoch 231 Loss: 1.193618: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
            Epoch 232 Loss: 1.186537: 100%|██████████| 11/11 [00:03<00:00,  3.09it/s]
            Epoch 233 Loss: 1.159970: 100%|██████████| 11/11 [00:03<00:00,  2.82it/s]
            Epoch 234 Loss: 1.156743: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
            Epoch 235 Loss: 1.146074: 100%|██████████| 11/11 [00:03<00:00,  2.96it/s]
            Epoch 236 Loss: 1.151413: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 237 Loss: 1.131100: 100%|██████████| 11/11 [00:03<00:00,  3.30it/s]
            Epoch 238 Loss: 1.122501: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 239 Loss: 1.099850: 100%|██████████| 11/11 [00:03<00:00,  2.81it/s]
            Epoch 240 Loss: 1.086024: 100%|██████████| 11/11 [00:04<00:00,  2.57it/s]
            Epoch 241 Loss: 1.072572: 100%|██████████| 11/11 [00:03<00:00,  2.93it/s]
            Epoch 242 Loss: 1.068453: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
            Epoch 243 Loss: 1.044812: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
            Epoch 244 Loss: 1.030824: 100%|██████████| 11/11 [00:03<00:00,  3.23it/s]
            Epoch 245 Loss: 1.019708: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
            Epoch 246 Loss: 1.011575: 100%|██████████| 11/11 [00:03<00:00,  3.59it/s]
            Epoch 247 Loss: 0.990726: 100%|██████████| 11/11 [00:03<00:00,  3.46it/s]
            Epoch 248 Loss: 0.986041: 100%|██████████| 11/11 [00:03<00:00,  2.91it/s]
            Epoch 249 Loss: 0.981658: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s]
            Epoch 250 Loss: 0.980595: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s]
            Epoch 251 Loss: 0.968010: 100%|██████████| 11/11 [00:03<00:00,  3.13it/s]
            Epoch 252 Loss: 0.962707: 100%|██████████| 11/11 [00:03<00:00,  2.80it/s]
            Epoch 253 Loss: 0.942022: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s]
            Epoch 254 Loss: 0.944430: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
            Epoch 255 Loss: 0.929161: 100%|██████████| 11/11 [00:03<00:00,  2.88it/s]
            Epoch 256 Loss: 0.937370: 100%|██████████| 11/11 [00:03<00:00,  2.90it/s]
            Epoch 257 Loss: 0.903115: 100%|██████████| 11/11 [00:03<00:00,  3.54it/s]
            Epoch 258 Loss: 0.883225: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s]
            Epoch 259 Loss: 0.879743: 100%|██████████| 11/11 [00:03<00:00,  2.87it/s]
            Epoch 260 Loss: 0.868025: 100%|██████████| 11/11 [00:03<00:00,  3.35it/s]
            Epoch 261 Loss: 0.871274: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s]
            Epoch 262 Loss: 0.845327: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            Epoch 263 Loss: 0.830346: 100%|██████████| 11/11 [00:03<00:00,  2.75it/s]
            Epoch 264 Loss: 0.824749: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s]
            Epoch 265 Loss: 0.825509: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
            Epoch 266 Loss: 0.826433: 100%|██████████| 11/11 [00:04<00:00,  2.64it/s]
            Epoch 267 Loss: 0.803352: 100%|██████████| 11/11 [00:04<00:00,  2.58it/s]
            Epoch 268 Loss: 0.803057: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            Epoch 269 Loss: 0.779101: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
            Epoch 270 Loss: 0.776825: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s]
            Epoch 271 Loss: 0.781099: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 272 Loss: 0.785840: 100%|██████████| 11/11 [00:03<00:00,  3.28it/s]
            Epoch 273 Loss: 0.802700: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s]
            Epoch 274 Loss: 0.772463: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s]
            Epoch 275 Loss: 0.764557: 100%|██████████| 11/11 [00:03<00:00,  3.41it/s]
            Epoch 276 Loss: 0.756142: 100%|██████████| 11/11 [00:03<00:00,  3.40it/s]
            Epoch 277 Loss: 0.749000: 100%|██████████| 11/11 [00:03<00:00,  3.18it/s]
            Epoch 278 Loss: 0.727899: 100%|██████████| 11/11 [00:03<00:00,  2.91it/s]
            Epoch 279 Loss: 0.736639: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 280 Loss: 0.732949: 100%|██████████| 11/11 [00:04<00:00,  2.73it/s]
            Epoch 281 Loss: 0.714376: 100%|██████████| 11/11 [00:03<00:00,  2.79it/s]
            Epoch 282 Loss: 0.706372: 100%|██████████| 11/11 [00:04<00:00,  2.62it/s]
            Epoch 283 Loss: 0.718250: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s]
            Epoch 284 Loss: 0.702688: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s]
            Epoch 285 Loss: 0.683941: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            Epoch 286 Loss: 0.669917: 100%|██████████| 11/11 [00:03<00:00,  3.04it/s]
            Epoch 287 Loss: 0.657017: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s]
            Epoch 288 Loss: 0.645382: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s]
            Epoch 289 Loss: 0.646296: 100%|██████████| 11/11 [00:03<00:00,  3.01it/s]
            Epoch 290 Loss: 0.640179: 100%|██████████| 11/11 [00:03<00:00,  2.87it/s]
            Epoch 291 Loss: 0.632041: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s]
            Epoch 292 Loss: 0.623746: 100%|██████████| 11/11 [00:03<00:00,  3.01it/s]
            Epoch 293 Loss: 0.606923: 100%|██████████| 11/11 [00:04<00:00,  2.67it/s]
            Epoch 294 Loss: 0.608943: 100%|██████████| 11/11 [00:04<00:00,  2.50it/s]
            Epoch 295 Loss: 0.607394: 100%|██████████| 11/11 [00:03<00:00,  2.89it/s]
            Epoch 296 Loss: 0.594561: 100%|██████████| 11/11 [00:04<00:00,  2.60it/s]
            Epoch 297 Loss: 0.581359: 100%|██████████| 11/11 [00:04<00:00,  2.61it/s]
            Epoch 298 Loss: 0.581903: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s]
            Epoch 299 Loss: 0.572175: 100%|██████████| 11/11 [00:03<00:00,  3.24it/s]
            Epoch 300 Loss: 0.570626: 100%|██████████| 11/11 [00:03<00:00,  3.02it/s]
            WARNING:merlion.models.ensemble.base:When initializing an ensemble, you must either provide the dict `model_configs` (mapping each model's name to its config) when creating the `DetectorEnsembleConfig`, or provide a list of `models` to the constructor of `EnsembleBase`. Received both. Overriding `model_configs` with the configs belonging to `models`.
            expert_idx=None
            # when no external experts are used, the value of expert_idx is not used in the
            # forecast/batch_forecast/evaluate functions
            mode='max' # either mean or max. Max picks the expert with the highest confidence; mean computes the weighted average.
            use_gpu=False # set True if GPU available for faster speed
            use_batch_forecast=True # set True for higher speed
            y_pred_list, std_list, y_list, sMAPE_conf, sMAPE_not_conf, recall, overall_sMAPE =\
                                ensemble_loaded.evaluate(test_data, mode=mode, expert_idx=expert_idx,\
                                                     use_gpu=use_gpu, use_batch_forecast=use_batch_forecast, confidence_thres=1.)
            out_idx=0 # plot this idx of all the steps forecasted by MoE
            print(y_pred_list.shape)
            plt.plot(y_pred_list[:100, out_idx], '--', color='k', label='prediction', linewidth=1)
            plt.plot(y_list[:100, out_idx], color='b', label='data', linewidth=1)
            # plt.fill_between(range(y_pred_list[:100, out_idx].shape[0]), y_pred_list[:100, out_idx]-std_list[:100, out_idx],\
            #                  y_pred_list[:100, out_idx]+std_list[:100, out_idx]) # standard deviation error band
            plt.legend()
            plt.show()
            print(f'sMAPE on confident samples: {sMAPE_conf:.2f}')
            print(f'sMAPE on not confident samples: {sMAPE_not_conf:.2f}')
            print(f'Percentage of samples on which MoE was confident: {recall:.2f}% (use a different confidence_thres to change this)')
            print(f'sMAPE on all samples: {overall_sMAPE:.2f}')
            sMAPE on confident samples: 1.87
            sMAPE on not confident samples: 2.41
            Percentage of samples on which MoE was confident: 21.33% (use a different confidence_thres to change this)
            sMAPE on all samples: 2.30
            
  •