Manuel
07/12/2023, 2:20 PMuse_init_models=False
when calling the fit method? ThanksCristian (Nixtla)
07/12/2023, 2:26 PMManuel
07/12/2023, 6:38 PMnf.models[0].max_steps = 100
before calling fit()
but it doesn't workfit()
methodCristian (Nixtla)
07/12/2023, 7:52 PMtrainer_kwargs
dictionary. Here is the function I used in my experiments on transfer learning:def set_trainer_kwargs(nf, max_steps, early_stop_patience_steps):
## Trainer arguments ##
# Max steps, validation steps and check_val_every_n_epoch
trainer_kwargs = {**{'max_steps': max_steps}}
if 'max_epochs' in trainer_kwargs.keys():
raise Exception('max_epochs is deprecated, use max_steps instead.')
# Callbacks
if trainer_kwargs.get('callbacks', None) is None:
callbacks = [TQDMProgressBar()]
# Early stopping
if early_stop_patience_steps > 0:
callbacks += [EarlyStopping(monitor='ptl/val_loss',
patience=early_stop_patience_steps)]
trainer_kwargs['callbacks'] = callbacks
# Add GPU accelerator if available
if trainer_kwargs.get('accelerator', None) is None:
if torch.cuda.is_available():
trainer_kwargs['accelerator'] = "gpu"
if trainer_kwargs.get('devices', None) is None:
if torch.cuda.is_available():
trainer_kwargs['devices'] = -1
# Avoid saturating local memory, disabled fit model checkpoints
if trainer_kwargs.get('enable_checkpointing', None) is None:
trainer_kwargs['enable_checkpointing'] = False
nf.models[0].trainer_kwargs = trainer_kwargs
nf.models_fitted[0].trainer_kwargs = trainer_kwargs
trainer_kwargs
of both models
and models_fitted
. It also consideres other values of the dictionary including the early stopping patience and gpu.Manuel
07/12/2023, 7:55 PM