In this work, we decompose a forecasting model into a backbone and a loss function and warp them with a BackboneLossModel class. You can use this framework to develop your own backbones or loss functions. In the following sections, we will show you how to do that.
Any backbone that takes past window as input and outputs future latent tokens can corperate with MMPD loss. Here we implement a conventional Attention is All You Need Transformer with patchify and learnable positional embedding as an example.

Conventional encoder-decoder Transformer with patchify and learnable positional embedding.
The full implementation is in models/backbones/encoder_decoder_transformer.py, the key part is shown below:
def forward(self, x_seq, *args, **kwargs):
...
#encoder patchify
patch_seq = rearrange(x_seq, 'b (n p) -> b n p', p = self.patch_size)
patch_embed = self.patch_embedding(patch_seq) # [batch_size, input_patch_num, d_model]
#position embedding
pos_idxs = torch.arange(input_patch_num + output_patch_num)[None, :].expand(flatten_batch_size, -1).to(x_seq.device)
pos_embed = self.position_embedding(pos_idxs)
in_pos_embed = pos_embed[:, :input_patch_num, :]
out_pos_embed = pos_embed[:, input_patch_num:, :]
#add position embedding, pass to transformer
enc_in = patch_embed + in_pos_embed
dec_in = out_pos_embed
dec_out = self.encoder_decoder(src=enc_in, tgt=dec_in)
...
- The input
x_seqis devided into patches and embedded intopatch_embed. - Positional embeddings for input and output are obtained as
in_pos_embedandout_pos_embed. - The addition of
patch_embedandin_pos_embedserves as the input of the encoder, whileout_pos_embedserves as the input of the decoder. The encoder_decoder here is the standard Transformer fromtorch.nn.Transformer.
We register this backbone as EncoderDecoder in the 38-th line of exp/exp_forecast.py. Then we can evaluate it with MMPD loss by (in the root directory):
python main_mmpd.py --backbone EncoderDecoder --loss MMPD --data ETTh1 --in_len 336 --out_len 96
We get the following results:
| Top-3 MSE | Top-3 MAE | MSE | CRPS |
|---|---|---|---|
| 0.332 | 0.373 | 0.377 | 0.290 |
which is very close to the decoder-only Transformer used in the MMPD paper, showing classical Transformer is still competitive with proper loss function.
The loss functions are inherited from BaseLossFunc in models/loss_funcs/base_loss.py. It should implement two functions:
compute_losswhich computes the loss between the predicted future tokens and the ground truth target series.predictwhich makes the prediction based on the predicted future tokens.
Here we implement an independnt Gaussian loss with predicted mean and variance as an example. The full implementation is in models/loss_funcs/distribution/gaussian_loss.py, the key parts of the two functions are shown below:
def compute_loss(self, target_seq, dec_condition, *args, **kwargs):
...
mu, sigma = self.net(dec_condition) #[batch_size, data_dim, seq_len]
gaussian_distribution = torch.distributions.Normal(loc=mu, scale=sigma)
negative_log_likelihood = -gaussian_distribution.log_prob(target_seq)
loss = negative_log_likelihood.sum(dim=-1)
return loss
where the dec_condition is the predicted future tokens from the backbone, mean and std are predicted by a two-head MLP. Then a Gaussian distribution is constructed and the negative log likelihood is calculated as the loss.
def predict(self, dec_condition, sample_num=1, *args, **kwargs):
...
deterministic_pred = mu[:, :, :self.out_len]
samples = gaussian_distribution.sample((sample_num,))
prob_samples = samples.permute(1, 2, 0, 3)[:, :, :, :self.out_len]
return deterministic_pred, None, prob_samples
In the predict function, predicted mean is used as the deterministic prediction, and we sample from the Gaussian distribution to get probabilistic samples. Note that this loss does not support multi-mode prediction, so the second return value is set to None.
We register this loss as Gauss in the 43-th line of exp/exp_forecast.py. Then we can evaluate it with the decoder-only Transformer backbone used in the MMPD paper by (in the root directory):
python main_mmpd.py --backbone Decoder --loss Gauss --data ETTh1 --in_len 336 --out_len 96
Compared with conventional MSE loss, the CRPS is improved from 0.316 to 0.300.