Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 

README.md

In this work, we decompose a forecasting model into a backbone and a loss function and warp them with a BackboneLossModel class. You can use this framework to develop your own backbones or loss functions. In the following sections, we will show you how to do that.

1. Custom Backbone

Any backbone that takes past window as input and outputs future latent tokens can corperate with MMPD loss. Here we implement a conventional Attention is All You Need Transformer with patchify and learnable positional embedding as an example.


Conventional encoder-decoder Transformer with patchify and learnable positional embedding.

The full implementation is in models/backbones/encoder_decoder_transformer.py, the key part is shown below:

def forward(self, x_seq, *args, **kwargs):
    ...

    #encoder patchify
    patch_seq = rearrange(x_seq, 'b (n p) -> b n p', p = self.patch_size)
    patch_embed = self.patch_embedding(patch_seq) # [batch_size, input_patch_num, d_model]

    #position embedding
    pos_idxs = torch.arange(input_patch_num + output_patch_num)[None, :].expand(flatten_batch_size, -1).to(x_seq.device)
    pos_embed = self.position_embedding(pos_idxs)
    in_pos_embed = pos_embed[:, :input_patch_num, :]
    out_pos_embed = pos_embed[:, input_patch_num:, :]
    
    #add position embedding, pass to transformer
    enc_in = patch_embed + in_pos_embed
    dec_in = out_pos_embed
    dec_out = self.encoder_decoder(src=enc_in, tgt=dec_in)
    
    ...
  1. The input x_seq is devided into patches and embedded into patch_embed.
  2. Positional embeddings for input and output are obtained as in_pos_embed and out_pos_embed.
  3. The addition of patch_embed and in_pos_embed serves as the input of the encoder, while out_pos_embed serves as the input of the decoder. The encoder_decoder here is the standard Transformer from torch.nn.Transformer.

We register this backbone as EncoderDecoder in the 38-th line of exp/exp_forecast.py. Then we can evaluate it with MMPD loss by (in the root directory):

python main_mmpd.py --backbone EncoderDecoder --loss MMPD --data ETTh1 --in_len 336 --out_len 96 

We get the following results:

Top-3 MSE Top-3 MAE MSE CRPS
0.332 0.373 0.377 0.290

which is very close to the decoder-only Transformer used in the MMPD paper, showing classical Transformer is still competitive with proper loss function.

2. Custom Loss

The loss functions are inherited from BaseLossFunc in models/loss_funcs/base_loss.py. It should implement two functions:

  1. compute_loss which computes the loss between the predicted future tokens and the ground truth target series.
  2. predict which makes the prediction based on the predicted future tokens.

Here we implement an independnt Gaussian loss with predicted mean and variance as an example. The full implementation is in models/loss_funcs/distribution/gaussian_loss.py, the key parts of the two functions are shown below:

def compute_loss(self, target_seq, dec_condition, *args, **kwargs):
    ...

    mu, sigma = self.net(dec_condition) #[batch_size, data_dim, seq_len]
    
    gaussian_distribution = torch.distributions.Normal(loc=mu, scale=sigma)
    negative_log_likelihood = -gaussian_distribution.log_prob(target_seq)
    loss = negative_log_likelihood.sum(dim=-1) 
    
    return loss

where the dec_condition is the predicted future tokens from the backbone, mean and std are predicted by a two-head MLP. Then a Gaussian distribution is constructed and the negative log likelihood is calculated as the loss.

def predict(self, dec_condition, sample_num=1, *args, **kwargs):
    ...
    deterministic_pred = mu[:, :, :self.out_len]

    samples = gaussian_distribution.sample((sample_num,)) 
    prob_samples = samples.permute(1, 2, 0, 3)[:, :, :, :self.out_len]

    return deterministic_pred, None, prob_samples

In the predict function, predicted mean is used as the deterministic prediction, and we sample from the Gaussian distribution to get probabilistic samples. Note that this loss does not support multi-mode prediction, so the second return value is set to None.

We register this loss as Gauss in the 43-th line of exp/exp_forecast.py. Then we can evaluate it with the decoder-only Transformer backbone used in the MMPD paper by (in the root directory):

python main_mmpd.py --backbone Decoder --loss Gauss --data ETTh1 --in_len 336 --out_len 96 

Compared with conventional MSE loss, the CRPS is improved from 0.316 to 0.300.