n2j.trainer#

Class managing the model training

Module Contents#

Classes#

Trainer

Functions#

is_decreasing(arr)

Returns True if array ever decreased

n2j.trainer.is_decreasing(arr)[source]#

Returns True if array ever decreased

class n2j.trainer.Trainer(device_type, checkpoint_dir='trained_models', seed=123)[source]#
seed_everything()[source]#

Seed the training and sampling for reproducibility

load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={})[source]#

Load dataset and dataloader for training or validation

Note#

Should be called for training data first, to set the normalizing stats used for both training and validation!

configure_model(model_name, model_kwargs={})[source]#
load_state(state_path)[source]#

Load the state dict of the past training

Parameters#

state_pathstr or osp.object

path of the state dict to load

save_state(train_loss, val_loss)[source]#

Save the state dict of the current training to disk

Parameters#

train_lossfloat

current training loss

val_lossfloat

current validation loss

configure_optim(early_stop_memory=20, weight_local_loss=0.1, optim_kwargs={}, lr_scheduler_kwargs={'factor': 0.5, 'min_lr': 1e-07})[source]#

Configure optimization-related objects

train_single_epoch(epoch_i)[source]#
train(n_epochs, sample_kwargs={})[source]#
infer(epoch_i)[source]#
_log_kappa_recovery_flow(epoch_i, x, u, y)[source]#
_log_kappa_recovery(epoch_i, x, u, y)[source]#
__repr__()[source]#

Return repr(self).