To fine-tune an already fitted model, is it enough...
# neural-forecast
m
To fine-tune an already fitted model, is it enough to pass
use_init_models=False
when calling the fit method? Thanks
c
yes!
🚀 2
m
@Cristian (Nixtla) If I want to fine-tune the model for a different number of steps than the number used to train the original model, how can I do it? I tried something like
nf.models[0].max_steps = 100
before calling
fit()
but it doesn't work
@Cristian (Nixtla) The idea is that I may have trained the original model for 5000 steps, but for fine-tuning I only want to do 100 steps. The complication seems to stem from the fact that the number of steps is not a parameter of the
fit()
method
c
it doesn't work because the attribute used after the model is fitted is the
trainer_kwargs
dictionary. Here is the function I used in my experiments on transfer learning:
Copy code
def set_trainer_kwargs(nf, max_steps, early_stop_patience_steps):
	 ## Trainer arguments ##
        # Max steps, validation steps and check_val_every_n_epoch
        trainer_kwargs = {**{'max_steps': max_steps}}

        if 'max_epochs' in trainer_kwargs.keys():
            raise Exception('max_epochs is deprecated, use max_steps instead.')

        # Callbacks
        if trainer_kwargs.get('callbacks', None) is None:
            callbacks = [TQDMProgressBar()]
            # Early stopping
            if early_stop_patience_steps > 0:
                callbacks += [EarlyStopping(monitor='ptl/val_loss',
                                            patience=early_stop_patience_steps)]

            trainer_kwargs['callbacks'] = callbacks

        # Add GPU accelerator if available
        if trainer_kwargs.get('accelerator', None) is None:
            if torch.cuda.is_available():
                trainer_kwargs['accelerator'] = "gpu"
        if trainer_kwargs.get('devices', None) is None:
            if torch.cuda.is_available():
                trainer_kwargs['devices'] = -1

        # Avoid saturating local memory, disabled fit model checkpoints
        if trainer_kwargs.get('enable_checkpointing', None) is None:
            trainer_kwargs['enable_checkpointing'] = False

        nf.models[0].trainer_kwargs = trainer_kwargs
        nf.models_fitted[0].trainer_kwargs = trainer_kwargs
so you simply call this function and it will override the
trainer_kwargs
of both
models
and
models_fitted
. It also consideres other values of the dictionary including the early stopping patience and gpu.
hope this helps!
m
Thanks!