n2j.trainer
#
Class managing the model training
Module Contents#
Classes#
Functions#
|
Returns True if array ever decreased |
- class n2j.trainer.Trainer(device_type, checkpoint_dir='trained_models', seed=123)[source]#
-
- load_dataset(data_kwargs, is_train, batch_size, sub_features=None, sub_target=None, sub_target_local=None, rebin=False, num_workers=2, noise_kwargs={'mag': {'override_kwargs': None, 'depth': 5}}, detection_kwargs={})[source]#
Load dataset and dataloader for training or validation
Note#
Should be called for training data first, to set the normalizing stats used for both training and validation!
- load_state(state_path)[source]#
Load the state dict of the past training
Parameters#
- state_pathstr or osp.object
path of the state dict to load
- save_state(train_loss, val_loss)[source]#
Save the state dict of the current training to disk
Parameters#
- train_lossfloat
current training loss
- val_lossfloat
current validation loss