Stefan Wiegand
08/31/2023, 7:34 PMCristian (Nixtla)
08/31/2023, 8:15 PMcross_validation
method. You will first need to change the number of max_steps
to 0, so when it calls the fit
method it doesn't run any iteration. You can do this with the following function:
def set_trainer_kwargs(nf, max_steps, early_stop_patience_steps):
## Trainer arguments ##
# Max steps, validation steps and check_val_every_n_epoch
trainer_kwargs = {**{'max_steps': max_steps}}
if 'max_epochs' in trainer_kwargs.keys():
raise Exception('max_epochs is deprecated, use max_steps instead.')
# Callbacks
if trainer_kwargs.get('callbacks', None) is None:
callbacks = [TQDMProgressBar()]
# Early stopping
if early_stop_patience_steps > 0:
callbacks += [EarlyStopping(monitor='ptl/val_loss',
patience=early_stop_patience_steps)]
trainer_kwargs['callbacks'] = callbacks
# Add GPU accelerator if available
if trainer_kwargs.get('accelerator', None) is None:
if torch.cuda.is_available():
trainer_kwargs['accelerator'] = "gpu"
if trainer_kwargs.get('devices', None) is None:
if torch.cuda.is_available():
trainer_kwargs['devices'] = -1
# Avoid saturating local memory, disabled fit model checkpoints
if trainer_kwargs.get('enable_checkpointing', None) is None:
trainer_kwargs['enable_checkpointing'] = False
nf.models[0].trainer_kwargs = trainer_kwargs
nf.models_init[0].trainer_kwargs = trainer_kwargs
nf
core object and set max_steps=0
Stefan Wiegand
09/04/2023, 6:41 AM