tortreinador package#

Subpackages#

Submodules#

tortreinador.train module#

class tortreinador.train.TorchTrainer(is_gpu: bool = True, epoch: int = 150, log_dir: Optional[str] = None, model: Optional[Module] = None, optimizer: Optional[Optimizer] = None, extra_metric: Optional[Module] = None, criterion: Optional[Module] = None, data_save_mode: str = 'recorder')[source]#

Bases: object

A class to implement the training and validation loop based on PyTorch.

epoch#

The number of epochs to train the model.

Type:

int

model#

The model to be trained.

Type:

nn.Module

optimizer#

The optimizer used for training the model.

Type:

Optimizer

criterion#

The loss function used for training.

Type:

nn.Module

data_save_mode#

The mode to save data (‘recorder’ or ‘csv’).

Type:

str

device#

The device (CPU or GPU) on which the model will be trained.

Type:

torch.device

writer#

TensorBoard writer for visualizing training metrics.

Type:

SummaryWriter

train_loss_recorder#

Recorder for tracking training loss.

Type:

Recorder

val_loss_recorder#

Recorder for tracking validation loss.

Type:

Recorder

train_metric_recorder#

Recorder for tracking training metrics.

Type:

Recorder

val_metric_recorder#

Recorder for tracking validation metrics.

Type:

Recorder

epoch_train_loss#

Recorder for tracking loss across epochs (only if data_save_mode is ‘recorder’).

Type:

RecorderForEpoch

epoch_val_loss#

Recorder for tracking validation loss across epochs (only if data_save_mode is ‘recorder’).

Type:

RecorderForEpoch

epoch_train_metric#

Recorder for tracking training metrics across epochs (only if data_save_mode is ‘recorder’).

Type:

RecorderForEpoch

epoch_val_metric#

Recorder for tracking validation metrics across epochs (only if data_save_mode is ‘recorder’).

Type:

RecorderForEpoch

epoch_extra_metric#

Recorder for tracking additional metrics across epochs.

Type:

RecorderForEpoch, optional

csv_filename#

The filename for saving data in CSV format (only if data_save_mode is ‘csv’).

Type:

str, optional

cal_result(*args)[source]#

Updates the loss and metric recorders based on the results from the current epoch.

Parameters:

*args – A variable-length argument list containing the results from the epoch.

Returns:

A dictionary containing the updated loss and metrics.

Return type:

dict

calculate(x, y, mode='t')[source]#

Performs a forward pass of the model, calculates the loss and metrics.

Parameters:
  • x (torch.Tensor) – Input data.

  • y (torch.Tensor) – Target data.

  • mode (str) – Mode of operation - ‘t’ for training, ‘v’ for validation.

Returns:

A list containing the loss, metric, and other relevant information based on the mode.

Return type:

list

fit(t_l, v_l, **kwargs)[source]#

Trains and validates a machine learning model using the provided training and validation data loaders, applying specific learning rate schedules and warm-up periods. The function also handles the saving of the model based on validation performance metrics.

Parameters:
  • t_l (DataLoader) – DataLoader containing the training data.

  • v_l (DataLoader) – DataLoader containing the validation data.

  • **kwargs – A dictionary of additional keyword arguments: m_p (str): Path where the model should be saved. w_e (int, optional): Number of initial epochs during which the learning rate is increased linearly. l_m (list of int, optional): Epoch indices at which the learning rate should be decreased. gamma (float, optional): Multiplicative factor by which the learning rate is decayed at each milestone. b_m (float, optional): Threshold for a performance metric (e.g., accuracy) to determine the best model. b_l (float, optional): Threshold for the loss value to determine the best model. condition(int): Decided by b_m and b_l. If b_m and b_l are not None, condition=2, if b_m is not None and b_l is None, condition=0, if b_m is None and b_l is not None, condition=1, note that this parameter can not be specified.

Process:
  1. The function initializes required devices and settings for training.

  2. It enters a training loop for the specified number of epochs, handling both training and validation phases.

  3. During the training phase, it optionally applies a warm-up schedule and adjusts the learning rate based on milestones.

  4. Post each epoch, it checks if the current model outperforms previous metrics using the b_m value and potentially saves the model.

  5. Outputs training progress and validation metrics after each epoch.

Returns:

  • epoch_train_loss (MetricTracker): Tracker for training loss across epochs.

  • epoch_val_loss (MetricTracker): Tracker for validation loss across epochs.

  • epoch_val_metric (MetricTracker): Tracker for validation metric specified by b_m.

  • epoch_train_metric (MetricTracker): Tracker for training metric.

  • epoch_extra_metric (MetricTracker, optional): Additional metrics tracker if extra_metric is True.

Return type:

A tuple of collected metrics over epochs

Raises:

ValueError – If required parameters in kwargs like b_m are missing or if validation metrics exceed expected boundaries.

tortreinador.train.config_generator(model_save_path: str, warmup_epochs: Optional[int] = None, lr_milestones: Optional[list] = None, lr_decay_rate: Optional[float] = None, best_metric: Optional[float] = None, best_loss: Optional[float] = None, validation_rate: float = 0.2)[source]#

Generates a configuration dictionary for model training based on specified parameters.

Parameters:
  • model_save_path (str) – Path where the model should be saved.

  • warmup_epochs (int, optional) – Number of initial epochs during which the learning rate is increased linearly.

  • lr_milestones (list of int, optional) – Epoch indices at which the learning rate should be decreased.

  • lr_decay_rate (float, optional) – Multiplicative factor by which the learning rate is decayed at each milestone.

  • best_metric (float, optional) – Threshold for a performance metric (e.g., accuracy) to determine the best model.

  • best_loss (float, optional) – Threshold for the loss value to determine the best model.

  • validation_rate (float) – Fraction of the validation set which split by training data(developing)

Raises:

ValueError – If model_save_path is None or if lr_milestones is set but lr_decay_rate is not provided.

Returns:

Configuration dictionary containing all settings necessary for model training, including paths,

learning rate schedules, and performance thresholds.

Return type:

dict

This function constructs a dictionary that includes settings for saving the model, applying warmup periods and learning rate schedules, and monitoring specific performance metrics or loss values to save the best model during training.

Module contents#