Slackbot
10/13/2023, 8:55 PMCristian (Nixtla)
10/15/2023, 5:28 PMnf.models[0].early_stop_patience_steps = 1
Wont work because it is an argument of the callbacks
object of the Trainer. The function in that post should still work, is it giving an error?Phil
10/16/2023, 9:12 PMdef set_trainer_kwargs(
nf: NeuralForecast,
max_steps: int,
early_stop_patience_steps: int,
val_check_steps: Optional[int] = None) -> None:
"""Set trainer arguments for fine-tuning a pre-trained NeuralForecast model.
Args:
nf: A pre-trained NeuralForecast model.
max_steps: The maximum number of training steps.
early_stop_patience_steps: Patience for early stopping (0 to disable).
val_check_steps: The frequency of validation checks during training.
Returns:
None
Example usage:
trained_model_path = "./results/12315464155/"
nf = load_neural_forecast_model(model_path=trained_model_path)
set_trainer_kwargs(nf=nf, max_steps=1000, early_stop_patience_steps=3, val_check_steps=35)
nf.fit(df=new_df, use_init_models=False, val_size=nf.models[0].h)
"""
# Trainer arguments.
trainer_kwargs = {
# The maximum number of training steps.
"max_steps": max_steps,
# Display a progress bar during training.
"callbacks": [TQDMProgressBar()],
# Use GPU if available, or "auto" to decide automatically.
"accelerator": "gpu" if torch.cuda.is_available() else "auto",
# Use all GPUs if available, or 1 CPU if not.
"devices": -1 if torch.cuda.is_available() else 1,
# Disable model checkpointing.
"enable_checkpointing": False,
}
# Early stopping callback.
# Stop training early if validation loss doesn't improve for 'patience' steps.
if early_stop_patience_steps > 0:
trainer_kwargs["callbacks"].append(
EarlyStopping(monitor="ptl/val_loss", patience=early_stop_patience_steps)
)
# Set custom validation check frequency.
if val_check_steps:
nf.models[0].val_check_steps = val_check_steps
# Update trainer arguments for the model and its initialization.
nf.models[0].trainer_kwargs = trainer_kwargs
nf.models_init[0].trainer_kwargs = trainer_kwargs
Phil
10/16/2023, 9:13 PMval_check_steps
inside the trainer_kwargs
it throws an errorPhil
10/16/2023, 9:13 PMnf.models[0].val_check_steps = val_check_steps