alpine.trainers package#
Submodules#
alpine.trainers.alpine_base module#
- class alpine.trainers.alpine_base.AlpineBaseModule#
Bases:
Module
- __init__()#
Base class for all Alpine INR models. Each INR model defined in alpine.models inherits AlpineBaseModule.
- compile(optimizer_name='adam', learning_rate=0.0001, scheduler=None)#
Setup optimizers.
- Parameters:
optimizer_name (str, optional) – Optimizer name. Defaults to “adam”.
learning_rate (float, optional) – Learning rate. Defaults to 1e-4.
scheduler (torch.optim.lr.schedular, optional) – PyTorch scheduler object. Defaults to None.
- fit_signal(*, input: Tensor | None = None, signal: Tensor | dict | None = None, dataloader: DataLoader | None = None, n_iters: int = 1000, closure: callable | None = None, enable_tqdm: bool = True, return_features: bool = False, track_loss_history: bool = False, metric_trackers: dict | None = None, save_best_weights: bool = False, kwargs: dict = {}) dict #
Final
- Parameters:
input (torch.Tensor, optional) – Input coordinates of shape ( B x * x D) where B is batch size, and D is the dimensionality of the input grid.
signal (Union[torch.Tensor, dict], optional) – PyTorch tensor or dictionary containing the signal and auxiliary data. Pleaee use key=signal for signal or ground truth measurement. Defaults to None.
dataloader (torch.utils.data.DataLoader, optional) – Input coordinate-signal pytorch dataloader object. Defaults to None.
n_iters (int, optional) – Number of iterations for fitting signal. Defaults to 1000.
closure (callable, optional) – Callable for custom forward propagation. Defaults to None and uses AlpineBaseModules’s forward propagation with mse losss
enable_tqdm (bool, optional) – Enables tqdm progress bar. Defaults to True.
return_features (bool, optional) – Return intermediate INR features. Defaults to False.
track_loss_history (bool, optional) – Track loss while fitting the signal. Defaults to False.
metric_trackers (dict, optional) – Dictionary of torchmetrics Metrictracker objects. Defaults to None.
save_best_weights (bool, optional) – Use best weights saved during training. Defaults to False.
kwargs (dict, optional) – Other keyword arguments that is a dict of dicts. Defaults to {}.
- Returns:
Returns a dictionary containing the output from the INR, with features if return_features=True, loss, and other metrics if provided.
- Return type:
dict
- forward(*args, **kwargs)#
Forward pass.
- Raises:
NotImplementedError – Please implement the forward method in your subclass.
- forward_w_features(*args, **kwargs)#
Forward pass with features.
- Raises:
NotImplementedError – Please implement the forward method in your subclass.
- register_loss_function(loss_function: callable)#
Registers a loss function to the model. Default loss function for fitting the signal is mean square error.
- Parameters:
loss_function (callable) – A PyTorch nn.Module class object or a callable function that takes in two arguments: model’s output dictionary and the ground truth data signal or dictionary.
- render(input, closure=None, return_features=False, use_best_weights=False)#
Renders the model output for the given input. This method is used for inference or evaluation.
- Parameters:
input (torch.Tensor) – Input coordinates of shape ( B x * x D) where B is batch size, and D is the dimensionality of the input grid.
closure (callable, optional) – Callable for custom forward propagation. Defaults to None and uses AlpineBaseModules’s forward propagation with mse losss
return_features (bool, optional) – Return intermediate INR features. Defaults to False.
use_best_weights (bool, optional) – Use best weights saved during training. Defaults to False.
- Returns:
Returns a dictionary containing the output from the INR, with features if return_features=True.
- Return type:
dict
alpine.trainers.base module#
alpine.trainers.metalearn module#
- class alpine.trainers.metalearn.MAMLMetaLearner#
Bases:
object
Metalearning is an early experimental feature. We are in the process of integrating torchmeta with Alpine
- __init__(model, inner_steps, config={}, custom_loss_fn=None, outer_optimizer='adam', inner_loop_loss_fn=None)#
- configure_optimizers()#
- forward(coords, data_packet)#
- get_inr_parameters(copy=True)#
- get_parameters(copy=True)#
- inner_loop(coords, data_packet)#
learns the inr for inner_steps. gradient is taken w.r.t inr parameters. so regular back prop will do.
- loss_fn_mse(data_packet)#
- mse_loss(x, y)#
- outer_loop(coords, data_packet)#
- render_inner_loop(coords, gt, inner_loop_steps=1)#
- set_parameters(params)#
- squeeze_output(output, gt)#
alpine.trainers.pl module#
- class alpine.trainers.pl.LightningTrainer#
Bases:
LightningModule
- __init__(model, dataloader=None, closure=None, return_features=False, log_results=False, is_distributed=False)#
Lightning Trainer class for Alpine.
- configure_optimizers()#
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config
).Dictionary, with an
"optimizer"
key, and (optionally) a"lr_scheduler"
key whose value is a single LR scheduler orlr_scheduler_config
.None - Fit will run without any optimizer.
The
lr_scheduler_config
is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()
method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateau
scheduler, Lightning requires that thelr_scheduler_config
contains the keyword"monitor"
set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)
in yourLightningModule
.Note
Some things to know:
Lightning calls
.backward()
and.step()
automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()
with key"interval"
(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()
method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16
), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS
, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()
hook.
- on_test_end()#
Called at the end of testing.
- on_test_epoch_end()#
Called in the test loop at the very end of the epoch.
- set_closure(closure: callable)#
Set closure function.
- Parameters:
closure (function) – Closure function.
- test_step(batch, batch_idx)#
_summary_
- Parameters:
batch (_type_) – _description_
batch_idx (_type_) – _description_
- train_dataloader()#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()
process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()
prepare_data()
setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- training_step(batch, batch_idx)#
Training step.
- Parameters:
batch (tuple) – Coordinate input data of shape ( B x * …. * x D)
- Returns:
Loss value.
- Return type:
torch.Tensor