Slackbot
10/31/2022, 4:51 PMKin Gtz. Olivares
10/31/2022, 4:59 PMfrom pytorch_lightning.callbacks import TQDMProgressBar
from neuralforecast.tsdataset import TimeSeriesDataset, TimeSeriesLoader
temporal_df = data['temporal'].copy()
temporal_df = temporal_df.groupby('unique_id').head(228).reset_index()
dataset, *_ = TimeSeriesDataset.from_df(df=temporal_df, sort_df=False)
model = GMM_TFT(input_size=12*4,
hidden_size=64,
h=12,
K=2,
k_cont_cols=['distance_month', 'y_[lag12]'],
k_cont_inp_size=2,
batch_size=32, #4,
windows_batch_size=256,
check_val_every_n_epoch=1,
callbacks=TQDMProgressBar(refresh_rate=1, process_position=0),
loss=GMM(quantiles=np.arange(0,100,5)[1:]/100),
#loss=GMM(quantiles=np.arange(0,100,2)[1:]/100),
#learning_rate=0.1,
learning_rate=5e-3,
max_epochs=2,
enable_progress_bar=True)
# Fit and predict
model.fit(dataset, val_size=12, test_size=12)
Y_hat = model.predict(dataset=dataset)
Y_hat = Y_hat.reshape(555, 12, len(model.loss.quantiles))
Kin Gtz. Olivares
10/31/2022, 5:00 PMKin Gtz. Olivares
10/31/2022, 5:12 PMEric Braun
10/31/2022, 5:15 PMKin Gtz. Olivares
10/31/2022, 5:46 PMEric Braun
10/31/2022, 5:59 PMKin Gtz. Olivares
10/31/2022, 7:09 PM