This message was deleted.
# neural-forecast
s
This message was deleted.
c
Hi @Phil!
nf.models[0].early_stop_patience_steps = 1
Wont work because it is an argument of the
callbacks
object of the Trainer. The function in that post should still work, is it giving an error?
p
Hi Cristian. sorry for the delay in my response. It's been a chaotic morning at LinkedIn this morning. I managed to make it work. I adapted the function above to this
Copy code
def set_trainer_kwargs(
    nf: NeuralForecast, 
    max_steps: int, 
    early_stop_patience_steps: int, 
    val_check_steps: Optional[int] = None) -> None:
    """Set trainer arguments for fine-tuning a pre-trained NeuralForecast model.

    Args:
        nf: A pre-trained NeuralForecast model.
        max_steps: The maximum number of training steps.
        early_stop_patience_steps: Patience for early stopping (0 to disable).
        val_check_steps: The frequency of validation checks during training.

    Returns:
        None

    Example usage:
        trained_model_path = "./results/12315464155/"
        nf = load_neural_forecast_model(model_path=trained_model_path)
        set_trainer_kwargs(nf=nf, max_steps=1000, early_stop_patience_steps=3, val_check_steps=35)
        nf.fit(df=new_df, use_init_models=False, val_size=nf.models[0].h)
    """
    # Trainer arguments.
    trainer_kwargs = {
        # The maximum number of training steps.
        "max_steps": max_steps,
        # Display a progress bar during training.
        "callbacks": [TQDMProgressBar()],  
        # Use GPU if available, or "auto" to decide automatically.
        "accelerator": "gpu" if torch.cuda.is_available() else "auto",  
        # Use all GPUs if available, or 1 CPU if not.
        "devices": -1 if torch.cuda.is_available() else 1, 
        # Disable model checkpointing.
        "enable_checkpointing": False,
    }

    # Early stopping callback.
    # Stop training early if validation loss doesn't improve for 'patience' steps.
    if early_stop_patience_steps > 0:
        trainer_kwargs["callbacks"].append(
            EarlyStopping(monitor="ptl/val_loss", patience=early_stop_patience_steps)
        ) 
    # Set custom validation check frequency.
    if val_check_steps:
        nf.models[0].val_check_steps = val_check_steps
    
    # Update trainer arguments for the model and its initialization.
    nf.models[0].trainer_kwargs = trainer_kwargs
    nf.models_init[0].trainer_kwargs = trainer_kwargs
If I put the
val_check_steps
inside the
trainer_kwargs
it throws an error
I had to do it like the code above shows and set it here instead
Copy code
nf.models[0].val_check_steps = val_check_steps