Hello! Question in regards to neuralforecast. I'm ...
# general
e
Hello! Question in regards to neuralforecast. I'm training a GMM_TFT model. I noticed one does not pass a validation dataset to the trainer via the fit method. How then does one use an early stopping callback based on validation loss?
k
Hi @Eric Braun If you are not using the NeuralForecast wrapper and only the PyTorch model these lines do the trick:
Copy code
from pytorch_lightning.callbacks import TQDMProgressBar
from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader

temporal_df = data['temporal'].copy()
temporal_df = temporal_df.groupby('unique_id').head(228).reset_index()
dataset, *_ = TimeSeriesDataset.from_df(df=temporal_df, sort_df=False)

model = GMM_TFT(input_size=12*4,
                hidden_size=64,
                h=12,
                K=2,
                k_cont_cols=['distance_month', 'y_[lag12]'],
                k_cont_inp_size=2,
                batch_size=32, #4,
                windows_batch_size=256,
                check_val_every_n_epoch=1,
                callbacks=TQDMProgressBar(refresh_rate=1, process_position=0),
                loss=GMM(quantiles=np.arange(0,100,5)[1:]/100),
                #loss=GMM(quantiles=np.arange(0,100,2)[1:]/100),
                #learning_rate=0.1,
                learning_rate=5e-3,
                max_epochs=2,
                enable_progress_bar=True)

# Fit and predict
model.fit(dataset, val_size=12, test_size=12)
Y_hat = model.predict(dataset=dataset)
Y_hat = Y_hat.reshape(555, 12, len(model.loss.quantiles))
We just added static variables to the model and homogenized the inputs, here is: • GMMTFT documentation. • GMMTFT code link. Let me know how it goes.
Here is a link to early stopping documentation. you would need to send the callbacks with early stopping like a list. Keep in mind that good performance in validation does not correlate as well in M5. If you are working with that dataset.
e
Thank you! I have used pytorch-lightning early stopping callback before - if the syntax is standard, I shouldn't have any issues. As far as my dataset - I'm using a simpler dataset with a continuous valued target (drug prices), since that seems more appropriate for the GMM loss. I think validation performance ought to be relevant but we'll see!
k
@Eric Braun, I would recommend you to first try NHITS/NBEATS with MAE, then move towards MQLoss and then to GMMTFT, that is still research code.
e
Thanks, I'll do that! Also - it seems the GMM_TFT class needs to be updated with the new parameters for TFT (tested with the documentation AirPassengers example):
k
Thanks @Eric Braun I created this github issue, will get to it in the week