tortreinador package#
Subpackages#
- tortreinador.models package
- tortreinador.utils package
Submodules#
tortreinador.train module#
- class tortreinador.train.TorchTrainer(is_gpu: bool = True, epoch: int = 150, log_dir: Optional[str] = None, model: Optional[Module] = None, optimizer: Optional[Optimizer] = None, extra_metric: Optional[Module] = None, criterion: Optional[Module] = None, data_save_mode: str = 'recorder')[source]#
Bases:
object
A class to implement the training and validation loop based on PyTorch.
- epoch#
The number of epochs to train the model.
- Type:
int
- model#
The model to be trained.
- Type:
nn.Module
- optimizer#
The optimizer used for training the model.
- Type:
Optimizer
- criterion#
The loss function used for training.
- Type:
nn.Module
- data_save_mode#
The mode to save data (‘recorder’ or ‘csv’).
- Type:
str
- device#
The device (CPU or GPU) on which the model will be trained.
- Type:
torch.device
- writer#
TensorBoard writer for visualizing training metrics.
- Type:
SummaryWriter
- epoch_train_loss#
Recorder for tracking loss across epochs (only if data_save_mode is ‘recorder’).
- Type:
- epoch_val_loss#
Recorder for tracking validation loss across epochs (only if data_save_mode is ‘recorder’).
- Type:
- epoch_train_metric#
Recorder for tracking training metrics across epochs (only if data_save_mode is ‘recorder’).
- Type:
- epoch_val_metric#
Recorder for tracking validation metrics across epochs (only if data_save_mode is ‘recorder’).
- Type:
- epoch_extra_metric#
Recorder for tracking additional metrics across epochs.
- Type:
RecorderForEpoch, optional
- csv_filename#
The filename for saving data in CSV format (only if data_save_mode is ‘csv’).
- Type:
str, optional
- cal_result(*args)[source]#
Updates the loss and metric recorders based on the results from the current epoch.
- Parameters:
*args – A variable-length argument list containing the results from the epoch.
- Returns:
A dictionary containing the updated loss and metrics.
- Return type:
dict
- calculate(x, y, mode='t')[source]#
Performs a forward pass of the model, calculates the loss and metrics.
- Parameters:
x (torch.Tensor) – Input data.
y (torch.Tensor) – Target data.
mode (str) – Mode of operation - ‘t’ for training, ‘v’ for validation.
- Returns:
A list containing the loss, metric, and other relevant information based on the mode.
- Return type:
list
- fit(t_l, v_l, **kwargs)[source]#
Trains and validates a machine learning model using the provided training and validation data loaders, applying specific learning rate schedules and warm-up periods. The function also handles the saving of the model based on validation performance metrics.
- Parameters:
t_l (DataLoader) – DataLoader containing the training data.
v_l (DataLoader) – DataLoader containing the validation data.
**kwargs – A dictionary of additional keyword arguments: m_p (str): Path where the model should be saved. w_e (int, optional): Number of initial epochs during which the learning rate is increased linearly. l_m (list of int, optional): Epoch indices at which the learning rate should be decreased. gamma (float, optional): Multiplicative factor by which the learning rate is decayed at each milestone. b_m (float, optional): Threshold for a performance metric (e.g., accuracy) to determine the best model. b_l (float, optional): Threshold for the loss value to determine the best model. condition(int): Decided by b_m and b_l. If b_m and b_l are not None, condition=2, if b_m is not None and b_l is None, condition=0, if b_m is None and b_l is not None, condition=1, note that this parameter can not be specified.
- Process:
The function initializes required devices and settings for training.
It enters a training loop for the specified number of epochs, handling both training and validation phases.
During the training phase, it optionally applies a warm-up schedule and adjusts the learning rate based on milestones.
Post each epoch, it checks if the current model outperforms previous metrics using the b_m value and potentially saves the model.
Outputs training progress and validation metrics after each epoch.
- Returns:
epoch_train_loss (MetricTracker): Tracker for training loss across epochs.
epoch_val_loss (MetricTracker): Tracker for validation loss across epochs.
epoch_val_metric (MetricTracker): Tracker for validation metric specified by b_m.
epoch_train_metric (MetricTracker): Tracker for training metric.
epoch_extra_metric (MetricTracker, optional): Additional metrics tracker if extra_metric is True.
- Return type:
A tuple of collected metrics over epochs
- Raises:
ValueError – If required parameters in kwargs like b_m are missing or if validation metrics exceed expected boundaries.
- tortreinador.train.config_generator(model_save_path: str, warmup_epochs: Optional[int] = None, lr_milestones: Optional[list] = None, lr_decay_rate: Optional[float] = None, best_metric: Optional[float] = None, best_loss: Optional[float] = None, validation_rate: float = 0.2)[source]#
Generates a configuration dictionary for model training based on specified parameters.
- Parameters:
model_save_path (str) – Path where the model should be saved.
warmup_epochs (int, optional) – Number of initial epochs during which the learning rate is increased linearly.
lr_milestones (list of int, optional) – Epoch indices at which the learning rate should be decreased.
lr_decay_rate (float, optional) – Multiplicative factor by which the learning rate is decayed at each milestone.
best_metric (float, optional) – Threshold for a performance metric (e.g., accuracy) to determine the best model.
best_loss (float, optional) – Threshold for the loss value to determine the best model.
validation_rate (float) – Fraction of the validation set which split by training data(developing)
- Raises:
ValueError – If model_save_path is None or if lr_milestones is set but lr_decay_rate is not provided.
- Returns:
- Configuration dictionary containing all settings necessary for model training, including paths,
learning rate schedules, and performance thresholds.
- Return type:
dict
This function constructs a dictionary that includes settings for saving the model, applying warmup periods and learning rate schedules, and monitoring specific performance metrics or loss values to save the best model during training.