:py:mod:`n2j.trainer` ===================== .. py:module:: n2j.trainer .. autoapi-nested-parse:: Class managing the model training Module Contents --------------- Classes ~~~~~~~ .. autoapisummary:: n2j.trainer.Trainer Functions ~~~~~~~~~ .. autoapisummary:: n2j.trainer.is_decreasing .. py:function:: is_decreasing(arr) Returns True if array ever decreased .. py:class:: Trainer(device_type, checkpoint_dir='trained_models', seed=123) .. py:method:: seed_everything() Seed the training and sampling for reproducibility .. py:method:: load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={}) Load dataset and dataloader for training or validation Note ---- Should be called for training data first, to set the normalizing stats used for both training and validation! .. py:method:: configure_model(model_name, model_kwargs={}) .. py:method:: load_state(state_path) Load the state dict of the past training Parameters ---------- state_path : str or osp.object path of the state dict to load .. py:method:: save_state(train_loss, val_loss) Save the state dict of the current training to disk Parameters ---------- train_loss : float current training loss val_loss : float current validation loss .. py:method:: configure_optim(early_stop_memory=20, weight_local_loss=0.1, optim_kwargs={}, lr_scheduler_kwargs={'factor': 0.5, 'min_lr': 1e-07}) Configure optimization-related objects .. py:method:: train_single_epoch(epoch_i) .. py:method:: train(n_epochs, sample_kwargs={}) .. py:method:: infer(epoch_i) .. py:method:: _log_kappa_recovery_flow(epoch_i, x, u, y) .. py:method:: _log_kappa_recovery(epoch_i, x, u, y) .. py:method:: __repr__() Return repr(self).