添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Returns activation class by its name from torch.nn namespace. This function support all modules available from torch.nn and also their lower-case aliases. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).

act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01) act = act_cls()

44
def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type[nn.Module]:
    Returns activation class by its name from torch.nn namespace. This function support all modules available from
    torch.nn and also their lower-case aliases.
    On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).
    >>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
    >>> act = act_cls()
    :param activation: Activation function name (E.g. ReLU). If None - return nn.Identity
    :param **kwargs  : Extra arguments to pass to constructor during instantiation (E.g. inplace=True)
    :returns         : Type of the activation function that is ready to be instantiated
    if activation is None:
        activation_cls = nn.Identity
    else:
        lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())
        # Register additional aliases
        lowercase_aliases["leaky_relu"] = "LeakyReLU"  # LeakyRelu in snake_case
        lowercase_aliases["swish"] = "SiLU"  # Swish shich is equivalent to SiLU
        lowercase_aliases["none"] = "Identity"
        if activation in lowercase_aliases:
            activation = lowercase_aliases[activation]
        if activation not in torch.nn.__dict__:
            raise KeyError(f"Requested activation function {activation} is not known")
        activation_cls = torch.nn.__dict__[activation]
        if len(kwargs):
            activation_cls = partial(activation_cls, **kwargs)
    return activation_cls
29
def batch_distance2bbox(points: Tensor, distance: Tensor, max_shapes: Optional[Tensor] = None) -> Tensor:
    """Decode distance prediction to bounding box for batch.
    :param points: [B, ..., 2], "xy" format
    :param distance: [B, ..., 4], "ltrb" format
    :param max_shapes: [B, 2], "h,w" format, Shape of the image.
    :return: Tensor: Decoded bboxes, "x1y1x2y2" format.
    lt, rb = torch.split(distance, 2, dim=-1)
    # while tensor add parameters, parameters should be better placed on the second place
    x1y1 = -lt + points
    x2y2 = rb + points
    out_bbox = torch.cat([x1y1, x2y2], dim=-1)
    if max_shapes is not None:
        max_shapes = max_shapes.flip(-1).tile([1, 2])
        delta_dim = out_bbox.ndim - max_shapes.ndim
        for _ in range(delta_dim):
            max_shapes.unsqueeze_(1)
        out_bbox = torch.where(out_bbox < max_shapes, out_bbox, max_shapes)
        out_bbox = torch.where(out_bbox > 0, out_bbox, torch.zeros_like(out_bbox))
    return out_bbox
      

Base callback class with all the callback methods. Derived classes may override one or many of the available events to receive callbacks when such events are triggered by the training loop.

The order of the events is as follows:

on_training_start(context) # called once before training starts, good for setting up the warmup LR

for epoch in range(epochs):
    on_train_loader_start(context)
        for batch in train_loader:
            on_train_batch_start(context)
            on_train_batch_loss_end(context)               # called after loss has been computed
            on_train_batch_backward_end(context)           # called after .backward() was called
            on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
            on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
            on_train_batch_end(context)
    on_train_loader_end(context)
    on_validation_loader_start(context)
        for batch in validation_loader:
            on_validation_batch_start(context)
            on_validation_batch_end(context)
    on_validation_loader_end(context)
    on_validation_end_best_epoch(context)
on_test_start(context)
    for batch in test_loader:
        on_test_batch_start(context)
        on_test_batch_end(context)
on_test_end(context)

on_training_end(context) # called once after training ends.

Correspondence mapping from the old callback API:

on_training_start(context) <-> Phase.PRE_TRAINING for epoch in range(epochs): on_train_loader_start(context) <-> Phase.TRAIN_EPOCH_START for batch in train_loader: on_train_batch_start(context) on_train_batch_loss_end(context) on_train_batch_backward_end(context) <-> Phase.TRAIN_BATCH_END on_train_batch_gradient_step_start(context) on_train_batch_gradient_step_end(context) <-> Phase.TRAIN_BATCH_STEP on_train_batch_end(context) on_train_loader_end(context) <-> Phase.TRAIN_EPOCH_END

on_validation_loader_start(context)
    for batch in validation_loader:
        on_validation_batch_start(context)
        on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH

on_test_start(context) for batch in test_loader: on_test_batch_start(context) on_test_batch_end(context) <-> Phase.TEST_BATCH_END on_test_end(context) <-> Phase.TEST_END

on_training_end(context) <-> Phase.POST_TRAINING

Source code in V3_2/src/super_gradients/training/utils/callbacks/base_callbacks.py 355
class Callback:
    Base callback class with all the callback methods. Derived classes may override one or many of the available events
    to receive callbacks when such events are triggered by the training loop.
    The order of the events is as follows:
    on_training_start(context)                              # called once before training starts, good for setting up the warmup LR
        for epoch in range(epochs):
            on_train_loader_start(context)
                for batch in train_loader:
                    on_train_batch_start(context)
                    on_train_batch_loss_end(context)               # called after loss has been computed
                    on_train_batch_backward_end(context)           # called after .backward() was called
                    on_train_batch_gradient_step_start(context)    # called before the optimizer step about to happen (gradient clipping, logging of gradients)
                    on_train_batch_gradient_step_end(context)      # called after gradient step was done, good place to update LR (for step-based schedulers)
                    on_train_batch_end(context)
            on_train_loader_end(context)
            on_validation_loader_start(context)
                for batch in validation_loader:
                    on_validation_batch_start(context)
                    on_validation_batch_end(context)
            on_validation_loader_end(context)
            on_validation_end_best_epoch(context)
        on_test_start(context)
            for batch in test_loader:
                on_test_batch_start(context)
                on_test_batch_end(context)
        on_test_end(context)
    on_training_end(context)                    # called once after training ends.
    Correspondence mapping from the old callback API:
    on_training_start(context)                                 <-> Phase.PRE_TRAINING
    for epoch in range(epochs):
        on_train_loader_start(context)                         <-> Phase.TRAIN_EPOCH_START
            for batch in train_loader:
                on_train_batch_start(context)
                on_train_batch_loss_end(context)
                on_train_batch_backward_end(context)           <-> Phase.TRAIN_BATCH_END
                on_train_batch_gradient_step_start(context)
                on_train_batch_gradient_step_end(context)      <-> Phase.TRAIN_BATCH_STEP
                on_train_batch_end(context)
        on_train_loader_end(context)                           <-> Phase.TRAIN_EPOCH_END
        on_validation_loader_start(context)
            for batch in validation_loader:
                on_validation_batch_start(context)
                on_validation_batch_end(context)               <-> Phase.VALIDATION_BATCH_END
        on_validation_loader_end(context)                      <-> Phase.VALIDATION_EPOCH_END
        on_validation_end_best_epoch(context)                  <-> Phase.VALIDATION_END_BEST_EPOCH
    on_test_start(context)
        for batch in test_loader:
            on_test_batch_start(context)
            on_test_batch_end(context)                         <-> Phase.TEST_BATCH_END
    on_test_end(context)                                       <-> Phase.TEST_END
    on_training_end(context)                                   <-> Phase.POST_TRAINING
    def on_training_start(self, context: PhaseContext) -> None:
        Called once before start of the first epoch
        At this point, the context argument is guaranteed to have the following attributes:
        - optimizer
        - net
        - checkpoints_dir_path
        - criterion
        - sg_logger
        - train_loader
        - valid_loader
        - training_params
        - checkpoint_params
        - architecture
        - arch_params
        - metric_to_watch
        - device
        - ema_model
        The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
        :param context:
        :return:
    def on_train_loader_start(self, context: PhaseContext) -> None:
        Called each epoch at the start of train data loader (before getting the first batch).
        At this point, the context argument is guaranteed to have the following attributes:
        - epoch
        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
        :param context:
        :return:
    def on_train_batch_start(self, context: PhaseContext) -> None:
        Called at each batch after getting batch of data from data loader and moving it to target device.
        This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).
        At this point the context argument is guaranteed to have the following attributes:
        - batch_idx
        - inputs
        - targets
        - **additional_batch_items
        :param context:
        :return:
    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        Called after model forward and loss computation has been done.
        At this point the context argument is guaranteed to have the following attributes:
        - preds
        - loss_log_items
        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.
        :param context:
        :return:
    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        Called after loss.backward() method was called for a given batch
        :param context:
        :return:
    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        Called before the graadient step is about to happen.
        Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
        :param context:
        :return:
    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
        The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
        :param context:
        :return:
    def on_train_batch_end(self, context: PhaseContext) -> None:
        Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
        :param context:
        :return:
    def on_train_loader_end(self, context: PhaseContext) -> None:
        Called each epoch at the end of train data loader (after processing the last batch).
        The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
        :param context:
        :return:
    def on_validation_loader_start(self, context: PhaseContext) -> None:
        Called each epoch at the start of validation data loader (before getting the first batch).
        :param context:
        :return:
    def on_validation_batch_start(self, context: PhaseContext) -> None:
        Called at each batch after getting batch of data from validation loader and moving it to target device.
        :param context:
        :return:
    def on_validation_batch_end(self, context: PhaseContext) -> None:
        Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
        The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
        :param context:
        :return:
    def on_validation_loader_end(self, context: PhaseContext) -> None:
        Called each epoch at the end of validation data loader (after processing the last batch).
        The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
        :param context:
        :return:
    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        Called each epoch after validation has been performed and the best metric has been achieved.
        The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
        :param context:
        :return:
    def on_test_loader_start(self, context: PhaseContext) -> None:
        Called once at the start of test data loader (before getting the first batch).
        :param context:
        :return:
    def on_test_batch_start(self, context: PhaseContext) -> None:
        Called at each batch after getting batch of data from test loader and moving it to target device.
        :param context:
        :return:
    def on_test_batch_end(self, context: PhaseContext) -> None:
        Called after all forward step have been performed for a given batch and there is nothing left to do.
        The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
        :param context:
        :return:
    def on_test_loader_end(self, context: PhaseContext) -> None:
        Called once at the end of test data loader (after processing the last batch).
        The corresponding Phase enum value for this event is Phase.TEST_END.
        :param context:
        :return:
    def on_training_end(self, context: PhaseContext) -> None:
        Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
        The corresponding Phase enum value for this event is Phase.POST_TRAINING.
        :param context:
        :return:
      

Called after all forward step have been performed for a given batch and there is nothing left to do. The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.

Parameters:

Description Default 337
def on_test_batch_end(self, context: PhaseContext) -> None:
    Called after all forward step have been performed for a given batch and there is nothing left to do.
    The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
    :param context:
    :return:
      

Called at each batch after getting batch of data from test loader and moving it to target device.

Parameters:

Description Default 328
def on_test_batch_start(self, context: PhaseContext) -> None:
    Called at each batch after getting batch of data from test loader and moving it to target device.
    :param context:
    :return:
      

Called once at the end of test data loader (after processing the last batch). The corresponding Phase enum value for this event is Phase.TEST_END.

Parameters:

Description Default 346
def on_test_loader_end(self, context: PhaseContext) -> None:
    Called once at the end of test data loader (after processing the last batch).
    The corresponding Phase enum value for this event is Phase.TEST_END.
    :param context:
    :return:
320
def on_test_loader_start(self, context: PhaseContext) -> None:
    Called once at the start of test data loader (before getting the first batch).
    :param context:
    :return:
229
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
    Called after loss.backward() method was called for a given batch
    :param context:
    :return:
      

Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.

Parameters:

Description Default 257
def on_train_batch_end(self, context: PhaseContext) -> None:
    Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
    :param context:
    :return:
      

Called after gradient step has been performed. Good place to update LR (for step-based schedulers) The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.

Parameters:

Description Default 247
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
    Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
    :param context:
    :return:
      

Called before the graadient step is about to happen. Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.

Parameters:

Description Default 238
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
    Called before the graadient step is about to happen.
    Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
    :param context:
    :return:
      

Called after model forward and loss computation has been done. At this point the context argument is guaranteed to have the following attributes: - preds - loss_log_items The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.

Parameters:

Description Default 220
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
    Called after model forward and loss computation has been done.
    At this point the context argument is guaranteed to have the following attributes:
    - preds
    - loss_log_items
    The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.
    :param context:
    :return:
      

Called at each batch after getting batch of data from data loader and moving it to target device. This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).

At this point the context argument is guaranteed to have the following attributes: - batch_idx - inputs - targets - **additional_batch_items

Parameters:

Description Default 206
def on_train_batch_start(self, context: PhaseContext) -> None:
    Called at each batch after getting batch of data from data loader and moving it to target device.
    This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).
    At this point the context argument is guaranteed to have the following attributes:
    - batch_idx
    - inputs
    - targets
    - **additional_batch_items
    :param context:
    :return:
      

Called each epoch at the end of train data loader (after processing the last batch). The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.

Parameters:

Description Default 267
def on_train_loader_end(self, context: PhaseContext) -> None:
    Called each epoch at the end of train data loader (after processing the last batch).
    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
    :param context:
    :return:
      

Called each epoch at the start of train data loader (before getting the first batch). At this point, the context argument is guaranteed to have the following attributes: - epoch The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.

Parameters:

Description Default 190
def on_train_loader_start(self, context: PhaseContext) -> None:
    Called each epoch at the start of train data loader (before getting the first batch).
    At this point, the context argument is guaranteed to have the following attributes:
    - epoch
    The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
    :param context:
    :return:
      

Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.) The corresponding Phase enum value for this event is Phase.POST_TRAINING.

Parameters:

Description Default 355
def on_training_end(self, context: PhaseContext) -> None:
    Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
    The corresponding Phase enum value for this event is Phase.POST_TRAINING.
    :param context:
    :return:
      

Called once before start of the first epoch At this point, the context argument is guaranteed to have the following attributes: - optimizer - net - checkpoints_dir_path - criterion - sg_logger - train_loader - valid_loader - training_params - checkpoint_params - architecture - arch_params - metric_to_watch - device - ema_model

The corresponding Phase enum value for this event is Phase.PRE_TRAINING.

Parameters:

Description Default 179
def on_training_start(self, context: PhaseContext) -> None:
    Called once before start of the first epoch
    At this point, the context argument is guaranteed to have the following attributes:
    - optimizer
    - net
    - checkpoints_dir_path
    - criterion
    - sg_logger
    - train_loader
    - valid_loader
    - training_params
    - checkpoint_params
    - architecture
    - arch_params
    - metric_to_watch
    - device
    - ema_model
    The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
    :param context:
    :return:
      

Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do. The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.

Parameters:

Description Default 293
def on_validation_batch_end(self, context: PhaseContext) -> None:
    Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
    The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
    :param context:
    :return:
      

Called at each batch after getting batch of data from validation loader and moving it to target device.

Parameters:

Description Default 284
def on_validation_batch_start(self, context: PhaseContext) -> None:
    Called at each batch after getting batch of data from validation loader and moving it to target device.
    :param context:
    :return:
      

Called each epoch after validation has been performed and the best metric has been achieved. The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.

Parameters:

Description Default 311
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
    Called each epoch after validation has been performed and the best metric has been achieved.
    The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
    :param context:
    :return:
      

Called each epoch at the end of validation data loader (after processing the last batch). The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.

Parameters:

Description Default 302
def on_validation_loader_end(self, context: PhaseContext) -> None:
    Called each epoch at the end of validation data loader (after processing the last batch).
    The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
    :param context:
    :return:
      

Called each epoch at the start of validation data loader (before getting the first batch).

Parameters:

Description Default 276
def on_validation_loader_start(self, context: PhaseContext) -> None:
    Called each epoch at the start of validation data loader (before getting the first batch).
    :param context:
    :return:
    def __init__(self, callbacks: List[Callback]):
        # TODO: Add reordering of callbacks to make sure that they are called in the right order
        # For instance, two callbacks may be dependent on each other, so the first one should be called first
        # Example: Gradient Clipping & Gradient Logging callback. We first need to clip the gradients, and then log them
        # So if user added them in wrong order we can guarantee their order would be correct.
        # We can achieve this by adding a property to the callback to the callback indicating it's priority:
        # Forward   = 0
        # Loss      = 100
        # Backward  = 200
        # Metrics   = 300
        # Scheduler = 400
        # Logging   = 500
        # So ordering callbacks by their order would ensure than we first run all Forward-related callbacks (for a given event),
        # Than backward, and only then - logging.
        self.callbacks = callbacks
    def on_training_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_start(context)
    def on_train_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_start(context)
    def on_train_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_start(context)
    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_loss_end(context)
    def on_train_batch_backward_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_backward_end(context)
    def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_start(context)
    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_gradient_step_end(context)
    def on_train_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_batch_end(context)
    def on_validation_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_start(context)
    def on_validation_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_start(context)
    def on_validation_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_batch_end(context)
    def on_validation_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_loader_end(context)
    def on_train_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_train_loader_end(context)
    def on_training_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_training_end(context)
    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_validation_end_best_epoch(context)
    def on_test_loader_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_start(context)
    def on_test_batch_start(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_start(context)
    def on_test_batch_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_batch_end(context)
    def on_test_loader_end(self, context: PhaseContext) -> None:
        for callback in self.callbacks:
            callback.on_test_loader_end(context)
      

Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead. This callback supports receiving only a subset of events defined in Phase enum:

PRE_TRAINING = "PRE_TRAINING" TRAIN_EPOCH_START = "TRAIN_EPOCH_START" TRAIN_BATCH_END = "TRAIN_BATCH_END" TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP" TRAIN_EPOCH_END = "TRAIN_EPOCH_END"

VALIDATION_BATCH_END = "VALIDATION_BATCH_END" VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END" VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"

TEST_BATCH_END = "TEST_BATCH_END" TEST_END = "TEST_END" POST_TRAINING = "POST_TRAINING"

Source code in V3_2/src/super_gradients/training/utils/callbacks/base_callbacks.py 429
class PhaseCallback(Callback):
    Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead.
    This callback supports receiving only a subset of events defined in Phase enum:
    PRE_TRAINING = "PRE_TRAINING"
    TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
    TRAIN_BATCH_END = "TRAIN_BATCH_END"
    TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
    TRAIN_EPOCH_END = "TRAIN_EPOCH_END"
    VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
    VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
    VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"
    TEST_BATCH_END = "TEST_BATCH_END"
    TEST_END = "TEST_END"
    POST_TRAINING = "POST_TRAINING"
    def __init__(self, phase: Phase):
        self.phase = phase
    def __call__(self, *args, **kwargs):
        raise NotImplementedError
    def __repr__(self) -> str:
        return self.__class__.__name__
    def on_training_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.PRE_TRAINING:
            self(context)
    def on_train_loader_start(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_START:
            self(context)
    def on_train_batch_loss_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_END:
            self(context)
    def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_BATCH_STEP:
            self(context)
    def on_train_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TRAIN_EPOCH_END:
            self(context)
    def on_validation_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_BATCH_END:
            self(context)
    def on_validation_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_EPOCH_END:
            self(context)
    def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
        if self.phase == Phase.VALIDATION_END_BEST_EPOCH:
            self(context)
    def on_test_batch_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_BATCH_END:
            self(context)
    def on_test_loader_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.TEST_END:
            self(context)
    def on_training_end(self, context: PhaseContext) -> None:
        if self.phase == Phase.POST_TRAINING:
            self(context)
88
class PhaseContext:
    Represents the input for phase callbacks, and is constantly updated after callback calls.
    def __init__(
        self,
        epoch=None,
        batch_idx=None,
        optimizer=None,
        metrics_dict=None,
        inputs=None,
        preds=None,
        target=None,
        metrics_compute_fn=None,
        loss_avg_meter=None,
        loss_log_items=None,
        criterion=None,
        device=None,
        experiment_name=None,
        ckpt_dir=None,
        net=None,
        lr_warmup_epochs=None,
        sg_logger=None,
        train_loader=None,
        valid_loader=None,
        training_params=None,
        ddp_silent_mode=None,
        checkpoint_params=None,
        architecture=None,
        arch_params=None,
        metric_to_watch=None,
        valid_metrics=None,
        ema_model=None,
        self.epoch = epoch
        self.batch_idx = batch_idx
        self.optimizer = optimizer
        self.inputs = inputs
        self.preds = preds
        self.target = target
        self.metrics_dict = metrics_dict
        self.metrics_compute_fn = metrics_compute_fn
        self.loss_avg_meter = loss_avg_meter
        self.loss_log_items = loss_log_items
        self.criterion = criterion
        self.device = device
        self.stop_training = False
        self.experiment_name = experiment_name
        self.ckpt_dir = ckpt_dir
        self.net = net
        self.lr_warmup_epochs = lr_warmup_epochs
        self.sg_logger = sg_logger
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.training_params = training_params
        self.ddp_silent_mode = ddp_silent_mode
        self.checkpoint_params = checkpoint_params
        self.architecture = architecture
        self.arch_params = arch_params
        self.metric_to_watch = metric_to_watch
        self.valid_metrics = valid_metrics
        self.ema_model = ema_model
    def update_context(self, **kwargs):
        for attr, attr_val in kwargs.items():
            setattr(self, attr, attr_val)
377
@register_lr_warmup(LRWarmups.LINEAR_BATCH_STEP)
class BatchStepLinearWarmupLRCallback(Callback):
    LR scheduling callback for linear step warmup on each batch step.
    LR climbs from warmup_initial_lr with to initial lr.
    def __init__(
        self,
        warmup_initial_lr: float,
        initial_lr: float,
        train_loader_len: int,
        update_param_groups: bool,
        lr_warmup_steps: int,
        training_params,
        net,
        **kwargs,
        :param warmup_initial_lr: Starting learning rate
        :param initial_lr: Target learning rate after warmup
        :param train_loader_len: Length of train data loader
        :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
        :param kwargs:
        super(BatchStepLinearWarmupLRCallback, self).__init__()
        if lr_warmup_steps > train_loader_len:
            logger.warning(
                f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
                f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
        lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
        learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)
        self.lr = initial_lr
        self.initial_lr = initial_lr
        self.update_param_groups = update_param_groups
        self.training_params = training_params
        self.net = net
        self.learning_rates = learning_rates
        self.train_loader_len = train_loader_len
        self.lr_warmup_steps = lr_warmup_steps
    def on_train_batch_start(self, context: PhaseContext) -> None:
        global_training_step = context.batch_idx + context.epoch * self.train_loader_len
        if global_training_step < self.lr_warmup_steps:
            self.lr = float(self.learning_rates[global_training_step])
            self.update_lr(context.optimizer, context.epoch, context.batch_idx)
    def update_lr(self, optimizer, epoch, batch_idx=None):
        Same as in LRCallbackBase
        :param optimizer:
        :param epoch:
        :param batch_idx:
        :return:
        if self.update_param_groups:
            param_groups = unwrap_model(self.net).update_param_groups(
                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
            optimizer.param_groups = param_groups
        else:
            # UPDATE THE OPTIMIZERS PARAMETER
            for param_group in optimizer.param_groups:
                param_group["lr"] = self.lr
          

Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.

required kwargs :param warmup_initial_lr: Starting learning rate :param initial_lr: Target learning rate after warmup :param train_loader_len: Length of train data loader :param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None. :param kwargs: super(BatchStepLinearWarmupLRCallback, self).__init__() if lr_warmup_steps > train_loader_len: logger.warning( f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). " f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers." lr_warmup_steps = min(lr_warmup_steps, train_loader_len) learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True) self.lr = initial_lr self.initial_lr = initial_lr self.update_param_groups = update_param_groups self.training_params = training_params self.net = net self.learning_rates = learning_rates self.train_loader_len = train_loader_len self.lr_warmup_steps = lr_warmup_steps 377
def update_lr(self, optimizer, epoch, batch_idx=None):
    Same as in LRCallbackBase
    :param optimizer:
    :param epoch:
    :param batch_idx:
    :return:
    if self.update_param_groups:
        param_groups = unwrap_model(self.net).update_param_groups(
            optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
        optimizer.param_groups = param_groups
    else:
        # UPDATE THE OPTIMIZERS PARAMETER
        for param_group in optimizer.param_groups:
            param_group["lr"] = self.lr
      

A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

Parameters:

Description Default 667
class BinarySegmentationVisualizationCallback(PhaseCallback):
    A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    def __init__(self, phase: Phase, freq: int, batch_idx: int = 0, last_img_idx_in_batch: int = -1):
        super(BinarySegmentationVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.batch_idx = batch_idx
        self.last_img_idx_in_batch = last_img_idx_in_batch
    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
            if isinstance(context.preds, tuple):
                preds = context.preds[0].clone()
            else:
                preds = context.preds.clone()
            batch_imgs = BinarySegmentationVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")
490
@register_lr_scheduler(LRSchedulers.COSINE)
class CosineLRCallback(LRCallbackBase):
    Hard coded step Cosine anealing learning rate scheduling.
    def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
        super(CosineLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs
        self.cosine_final_lr_ratio = cosine_final_lr_ratio
    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = max(0, self.train_loader_len * effective_epoch + context.batch_idx - self.training_params.lr_warmup_steps)
        max_iter = self.train_loader_len * effective_max_epochs - self.training_params.lr_warmup_steps
        lr = self.compute_learning_rate(current_iter, max_iter, self.initial_lr, self.cosine_final_lr_ratio)
        self.lr = float(lr)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)
    def is_lr_scheduling_enabled(self, context):
        # Account of per-step warmup
        if self.training_params.lr_warmup_steps > 0:
            current_step = self.train_loader_len * context.epoch + context.batch_idx
            return current_step >= self.training_params.lr_warmup_steps
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
    @classmethod
    def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: float, initial_lr: float, final_lr_ratio: float):
        # the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch
        lr = 0.5 * initial_lr * (1.0 + np.cos(step / (total_steps + 1) * math.pi))
        return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)
221
@register_callback(Callbacks.DECI_LAB_UPLOAD)
class DeciLabUploadCallback(PhaseCallback):
    Post-training callback for uploading and optimizing a model.
    :param model_meta_data:             Model's meta-data object. Type: ModelMetadata
    :param optimization_request_form:   Optimization request form object. Type: OptimizationRequestForm
    :param ckpt_name:                   Checkpoint filename, inside the checkpoint directory.
    def __init__(
        self,
        model_name: str,
        input_dimensions: Sequence[int],
        target_hardware_types: "Optional[List[str]]" = None,
        target_batch_size: "Optional[int]" = None,
        target_quantization_level: "Optional[str]" = None,
        ckpt_name: str = "ckpt_best.pth",
        **kwargs,
        super().__init__(phase=Phase.POST_TRAINING)
        self.input_dimensions = input_dimensions
        self.model_name = model_name
        self.target_hardware_types = target_hardware_types
        self.target_batch_size = target_batch_size
        self.target_quantization_level = target_quantization_level
        self.ckpt_name = ckpt_name
        self.platform_client = DeciClient()
    @staticmethod
    def log_optimization_failed():
        logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")
    def upload_model(self, model):
        This function will upload the trained model to the Deci Lab
        :param model: The resulting model from the training process
        self.platform_client.upload_model(
            model=model,
            name=self.model_name,
            input_dimensions=self.input_dimensions,
            target_hardware_types=self.target_hardware_types,
            target_batch_size=self.target_batch_size,
            target_quantization_level=self.target_quantization_level,
    def get_optimization_status(self, optimized_model_name: str):
        This function will do fetch the optimized version of the trained model and check on its benchmark status.
        The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
        or log about the successful optimization - whichever happens first.
        :param optimized_model_name: Optimized model name
        :return: Whether or not the optimized model has been benchmarked
        def handler(_signum, _frame):
            logger.error("Process timed out. Visit https://console.deci.ai for details")
            return False
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(1800)
        finished = False
        while not finished:
            if self.platform_client.is_model_benchmarking(name=optimized_model_name):
                time.sleep(30)
            else:
                finished = True
        signal.alarm(0)
        return True
    def __call__(self, context: PhaseContext) -> None:
        This function will attempt to upload the trained model and schedule an optimization for it.
        :param context: Training phase context
        try:
            model = copy.deepcopy(unwrap_model(context.net))
            model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
            model_state_dict = torch.load(model_state_dict_path)["net"]
            model.load_state_dict(state_dict=model_state_dict)
            model = model.cpu()
            if hasattr(model, "prep_model_for_conversion"):
                model.prep_model_for_conversion(input_size=self.input_dimensions)
            self.upload_model(model=model)
            model_name = self.model_name
            logger.info(f"Successfully added {model_name} to the model repository")
            optimized_model_name = f"{model_name}_1_1"
            logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
            success = self.get_optimization_status(optimized_model_name=optimized_model_name)
            if success:
                logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
            else:
                DeciLabUploadCallback.log_optimization_failed()
        except Exception as ex:
            DeciLabUploadCallback.log_optimization_failed()
            logger.error(ex)
      

This function will attempt to upload the trained model and schedule an optimization for it.

Parameters:

Description Default 221
def __call__(self, context: PhaseContext) -> None:
    This function will attempt to upload the trained model and schedule an optimization for it.
    :param context: Training phase context
    try:
        model = copy.deepcopy(unwrap_model(context.net))
        model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
        model_state_dict = torch.load(model_state_dict_path)["net"]
        model.load_state_dict(state_dict=model_state_dict)
        model = model.cpu()
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)
        self.upload_model(model=model)
        model_name = self.model_name
        logger.info(f"Successfully added {model_name} to the model repository")
        optimized_model_name = f"{model_name}_1_1"
        logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
        success = self.get_optimization_status(optimized_model_name=optimized_model_name)
        if success:
            logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
        else:
            DeciLabUploadCallback.log_optimization_failed()
    except Exception as ex:
        DeciLabUploadCallback.log_optimization_failed()
        logger.error(ex)
      

This function will do fetch the optimized version of the trained model and check on its benchmark status. The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes or log about the successful optimization - whichever happens first.

Parameters:

Description Default 190
def get_optimization_status(self, optimized_model_name: str):
    This function will do fetch the optimized version of the trained model and check on its benchmark status.
    The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
    or log about the successful optimization - whichever happens first.
    :param optimized_model_name: Optimized model name
    :return: Whether or not the optimized model has been benchmarked
    def handler(_signum, _frame):
        logger.error("Process timed out. Visit https://console.deci.ai for details")
        return False
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(1800)
    finished = False
    while not finished:
        if self.platform_client.is_model_benchmarking(name=optimized_model_name):
            time.sleep(30)
        else:
            finished = True
    signal.alarm(0)
    return True
162
def upload_model(self, model):
    This function will upload the trained model to the Deci Lab
    :param model: The resulting model from the training process
    self.platform_client.upload_model(
        model=model,
        name=self.model_name,
        input_dimensions=self.input_dimensions,
        target_hardware_types=self.target_hardware_types,
        target_batch_size=self.target_batch_size,
        target_quantization_level=self.target_quantization_level,
      

A callback that adds a visualization of a batch of detection predictions to context.sg_logger

Parameters:

Description Default 638
@register_callback(Callbacks.DETECTION_VISUALIZATION_CALLBACK)
class DetectionVisualizationCallback(PhaseCallback):
    A callback that adds a visualization of a batch of detection predictions to context.sg_logger
    :param phase:                   When to trigger the callback.
    :param freq:                    Frequency (in epochs) to perform this callback.
    :param batch_idx:               Batch index to perform visualization for.
    :param classes:                 Class list of the dataset.
    :param last_img_idx_in_batch:   Last image index to add to log. (default=-1, will take entire batch).
    def __init__(
        self,
        phase: Phase,
        freq: int,
        post_prediction_callback: DetectionPostPredictionCallback,
        classes: list,
        batch_idx: int = 0,
        last_img_idx_in_batch: int = -1,
        super(DetectionVisualizationCallback, self).__init__(phase)
        self.freq = freq
        self.post_prediction_callback = post_prediction_callback
        self.batch_idx = batch_idx
        self.classes = classes
        self.last_img_idx_in_batch = last_img_idx_in_batch
    def __call__(self, context: PhaseContext):
        if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
            # SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS
            preds = (context.preds[0].clone(), None)
            preds = self.post_prediction_callback(preds)
            batch_imgs = DetectionVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx, self.classes)
            batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
            batch_imgs = np.stack(batch_imgs)
            tag = "batch_" + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")
      

LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step. LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from initial_lr/(1+warmup_epochs).

Source code in V3_2/src/super_gradients/training/utils/callbacks/callbacks.py 293
@register_lr_warmup(LRWarmups.LINEAR_EPOCH_STEP)
class EpochStepWarmupLRCallback(LRCallbackBase):
    LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
    LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
     initial_lr/(1+warmup_epochs).
    def __init__(self, **kwargs):
        super(EpochStepWarmupLRCallback, self).__init__(Phase.TRAIN_EPOCH_START, **kwargs)
        self.warmup_initial_lr = self.training_params.warmup_initial_lr or self.initial_lr / (self.training_params.lr_warmup_epochs + 1)
        self.warmup_step_size = (
            (self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs if self.training_params.lr_warmup_epochs > 0 else 0
    def perform_scheduling(self, context):
        self.lr = self.warmup_initial_lr + context.epoch * self.warmup_step_size
        self.update_lr(context.optimizer, context.epoch, None)
    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs > 0 and self.training_params.lr_warmup_epochs >= context.epoch
429
@register_lr_scheduler(LRSchedulers.EXP)
class ExponentialLRCallback(LRCallbackBase):
    Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.
    def __init__(self, lr_decay_factor: float, **kwargs):
        super().__init__(phase=Phase.TRAIN_BATCH_STEP, **kwargs)
        self.lr_decay_factor = lr_decay_factor
    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        current_iter = self.train_loader_len * effective_epoch + context.batch_idx
        self.lr = self.initial_lr * self.lr_decay_factor ** (current_iter / self.train_loader_len)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)
    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
520
@register_lr_scheduler(LRSchedulers.FUNCTION)
class FunctionLRCallback(LRCallbackBase):
    Hard coded rate scheduling for user defined lr scheduling function.
    @deprecated(version="3.2.0", reason="This callback is deprecated and will be removed in future versions.")
    def __init__(self, max_epochs, lr_schedule_function, **kwargs):
        super(FunctionLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        assert callable(lr_schedule_function), "self.lr_function must be callable"
        self.lr_schedule_function = lr_schedule_function
        self.max_epochs = max_epochs
    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        self.lr = self.lr_schedule_function(
            initial_lr=self.initial_lr,
            epoch=effective_epoch,
            iter=context.batch_idx,
            max_epoch=effective_max_epochs,
            iters_per_epoch=self.train_loader_len,
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)
532
class IllegalLRSchedulerMetric(Exception):
    """Exception raised illegal combination of training parameters.
    :param metric_name: Name of the metric that is not supported.
    :param metrics_dict: Dictionary of metrics that are supported.
    def __init__(self, metric_name: str, metrics_dict: dict):
        self.message = "Illegal metric name: " + metric_name + ". Expected one of metics_dics keys: " + str(metrics_dict.keys())
        super().__init__(self.message)
269
@register_callback(Callbacks.LR_CALLBACK_BASE)
class LRCallbackBase(PhaseCallback):
    Base class for hard coded learning rate scheduling regimes, implemented as callbacks.
    def __init__(self, phase, initial_lr, update_param_groups, train_loader_len, net, training_params, **kwargs):
        super(LRCallbackBase, self).__init__(phase)
        self.initial_lr = initial_lr
        self.lr = initial_lr
        self.update_param_groups = update_param_groups
        self.train_loader_len = train_loader_len
        self.net = net
        self.training_params = training_params
    def __call__(self, context: PhaseContext, **kwargs):
        if self.is_lr_scheduling_enabled(context):
            self.perform_scheduling(context)
    def is_lr_scheduling_enabled(self, context: PhaseContext):
        Predicate that controls whether to perform lr scheduling based on values in context.
        :param context: PhaseContext: current phase's context.
        :return: bool, whether to apply lr scheduling or not.
        raise NotImplementedError
    def perform_scheduling(self, context: PhaseContext):
        Performs lr scheduling based on values in context.
        :param context: PhaseContext: current phase's context.
        raise NotImplementedError
    def update_lr(self, optimizer, epoch, batch_idx=None):
        if self.update_param_groups:
            param_groups = unwrap_model(self.net).update_param_groups(
                optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
            optimizer.param_groups = param_groups
        else:
            # UPDATE THE OPTIMIZERS PARAMETER
            for param_group in optimizer.param_groups:
                param_group["lr"] = self.lr
250
def is_lr_scheduling_enabled(self, context: PhaseContext):
    Predicate that controls whether to perform lr scheduling based on values in context.
    :param context: PhaseContext: current phase's context.
    :return: bool, whether to apply lr scheduling or not.
    raise NotImplementedError
258
def perform_scheduling(self, context: PhaseContext):
    Performs lr scheduling based on values in context.
    :param context: PhaseContext: current phase's context.
    raise NotImplementedError
      

Learning rate scheduler callback.

When passing call a metrics_dict, with a key=self.metric_name, the value of that metric will monitored for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).

Parameters:

Description Default torch.optim.lr_scheduler._LRScheduler

Learning rate scheduler to be called step() with.

required metric_name

Metric name for ReduceLROnPlateau learning rate scheduler.

phase Phase

Phase of when to trigger it.

required 563
@register_callback(Callbacks.LR_SCHEDULER)
class LRSchedulerCallback(PhaseCallback):
    Learning rate scheduler callback.
    When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
         for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).
    :param scheduler:       Learning rate scheduler to be called step() with.
    :param metric_name:     Metric name for ReduceLROnPlateau learning rate scheduler.
    :param phase:           Phase of when to trigger it.
    def __init__(self, scheduler: torch.optim.lr_scheduler._LRScheduler, phase: Phase, metric_name: str = None):
        super(LRSchedulerCallback, self).__init__(phase)
        self.scheduler = scheduler
        self.metric_name = metric_name
    def __call__(self, context: PhaseContext):
        if context.lr_warmup_epochs <= context.epoch:
            if self.metric_name and self.metric_name in context.metrics_dict.keys():
                self.scheduler.step(context.metrics_dict[self.metric_name])
            elif self.metric_name is None:
                self.scheduler.step()
            else:
                raise IllegalLRSchedulerMetric(self.metric_name, context.metrics_dict)
    def __repr__(self):
        return "LRSchedulerCallback: " + repr(self.scheduler)
305
@register_lr_warmup(LRWarmups.LINEAR_STEP)
class LinearStepWarmupLRCallback(EpochStepWarmupLRCallback):
    """Deprecated, use EpochStepWarmupLRCallback instead"""
    def __init__(self, **kwargs):
        logger.warning(
            f"Parameter {LRWarmups.LINEAR_STEP} has been made deprecated and will be removed in the next SG release. "
            f"Please use `{LRWarmups.LINEAR_EPOCH_STEP}` instead."
        super(LinearStepWarmupLRCallback, self).__init__(**kwargs)
      

Pre-training callback that verifies model conversion to onnx given specified conversion parameters.

The model is converted, then inference is applied with onnx runtime.

Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.

Parameters:

Description Default 113
@register_callback(Callbacks.MODEL_CONVERSION_CHECK)
class ModelConversionCheckCallback(PhaseCallback):
    Pre-training callback that verifies model conversion to onnx given specified conversion parameters.
    The model is converted, then inference is applied with onnx runtime.
    Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.
    :param model_name:              Model's name
    :param input_dimensions:        Model's input dimensions
    :param primary_batch_size:      Model's primary batch size
    :param opset_version:           (default=11)
    :param do_constant_folding:     (default=True)
    :param dynamic_axes:            (default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    :param input_names:             (default=["input"])
    :param output_names:            (default=["output"])
    :param rtol:                    (default=1e-03)
    :param atol:                    (default=1e-05)
    def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_batch_size: int, **kwargs):
        super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
        self.model_name = model_name
        self.input_dimensions = input_dimensions
        self.primary_batch_size = primary_batch_size
        self.opset_version = kwargs.get("opset_version", 10)
        self.do_constant_folding = kwargs.get("do_constant_folding", None) if kwargs.get("do_constant_folding", None) else True
        self.input_names = kwargs.get("input_names") or ["input"]
        self.output_names = kwargs.get("output_names") or ["output"]
        self.dynamic_axes = kwargs.get("dynamic_axes") or {"input": {0: "batch_size"}, "output": {0: "batch_size"}}
        self.rtol = kwargs.get("rtol", 1e-03)
        self.atol = kwargs.get("atol", 1e-05)
    def __call__(self, context: PhaseContext):
        model = copy.deepcopy(unwrap_model(context.net))
        model = model.cpu()
        model.eval()  # Put model into eval mode
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(input_size=self.input_dimensions)
        x = torch.randn(self.primary_batch_size, *self.input_dimensions, requires_grad=False)
        tmp_model_path = os.path.join(context.ckpt_dir, self.model_name + "_tmp.onnx")
        with torch.no_grad():
            torch_out = model(x)
        torch.onnx.export(
            model,  # Model being run
            x,  # Model input (or a tuple for multiple inputs)
            tmp_model_path,  # Where to save the model (can be a file or file-like object)
            export_params=True,  # Store the trained parameter weights inside the model file
            opset_version=self.opset_version,
            do_constant_folding=self.do_constant_folding,
            input_names=self.input_names,
            output_names=self.output_names,
            dynamic_axes=self.dynamic_axes,
        onnx_model = onnx.load(tmp_model_path)
        onnx.checker.check_model(onnx_model)
        ort_session = onnxruntime.InferenceSession(tmp_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
        # compute ONNX Runtime output prediction
        ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
        ort_outs = ort_session.run(None, ort_inputs)
        # TODO: Ideally we don't want to check this but have the certainty of just calling torch_out.cpu()
        if isinstance(torch_out, List) or isinstance(torch_out, tuple):
            torch_out = torch_out[0]
        # compare ONNX Runtime and PyTorch results
        np.testing.assert_allclose(torch_out.cpu().numpy(), ort_outs[0], rtol=self.rtol, atol=self.atol)
        os.remove(tmp_model_path)
        logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
598
class PhaseContextTestCallback(PhaseCallback):
    A callback that saves the phase context the for testing.
    def __init__(self, phase: Phase):
        super(PhaseContextTestCallback, self).__init__(phase)
        self.context = None
    def __call__(self, context: PhaseContext):
        self.context = context
452
@register_lr_scheduler(LRSchedulers.POLY)
class PolyLRCallback(LRCallbackBase):
    Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).
    def __init__(self, max_epochs, **kwargs):
        super(PolyLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
        self.max_epochs = max_epochs
    def perform_scheduling(self, context):
        effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
        effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
        current_iter = (self.train_loader_len * effective_epoch + context.batch_idx) / self.training_params.batch_accumulate
        max_iter = self.train_loader_len * effective_max_epochs / self.training_params.batch_accumulate
        self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)
        self.update_lr(context.optimizer, context.epoch, context.batch_idx)
    def is_lr_scheduling_enabled(self, context):
        post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
        return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
742
@register_callback(Callbacks.ROBOFLOW_RESULT_CALLBACK)
class RoboflowResultCallback(Callback):
    """Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends."""
    def __init__(self, dataset_name: str, output_path: Optional[str] = None):
        :param dataset_name:    Name of the dataset that was used to train the model.
        :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
        self.dataset_name = dataset_name
        self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")
        if self.output_path is None:
            raise ValueError("Output path must be specified")
        super(RoboflowResultCallback, self).__init__()
    @multi_process_safe
    def on_training_end(self, context: PhaseContext):
        with open(self.output_path, mode="a", newline="") as csv_file:
            writer = csv.writer(csv_file)
            mAP = context.metrics_dict["[email protected]:0.95"].item()
            writer.writerow([self.dataset_name, mAP])
                Optional[str]
          

Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'

734
def __init__(self, dataset_name: str, output_path: Optional[str] = None):
    :param dataset_name:    Name of the dataset that was used to train the model.
    :param output_path:     Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
    self.dataset_name = dataset_name
    self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")
    if self.output_path is None:
        raise ValueError("Output path must be specified")
    super(RoboflowResultCallback, self).__init__()
408
@register_lr_scheduler(LRSchedulers.STEP)
class StepLRCallback(LRCallbackBase):
    Hard coded step learning rate scheduling (i.e at specific milestones).
    def __init__(self, lr_updates, lr_decay_factor, step_lr_update_freq=None, **kwargs):
        super(StepLRCallback, self).__init__(Phase.TRAIN_EPOCH_END, **kwargs)
        if step_lr_update_freq and len(lr_updates):
            raise ValueError("Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRCallback constructor")
        if step_lr_update_freq:
            max_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
            warmup_epochs = self.training_params.lr_warmup_epochs
            lr_updates = [
                int(np.ceil(step_lr_update_freq * x)) for x in range(1, max_epochs) if warmup_epochs <= int(np.ceil(step_lr_update_freq * x)) < max_epochs
        elif self.training_params.lr_cooldown_epochs > 0:
            logger.warning("Specific lr_updates were passed along with cooldown_epochs > 0," " cooldown will have no effect.")
        self.lr_updates = lr_updates
        self.lr_decay_factor = lr_decay_factor
    def perform_scheduling(self, context):
        num_updates_passed = [x for x in self.lr_updates if x <= context.epoch]
        self.lr = self.initial_lr * self.lr_decay_factor ** len(num_updates_passed)
        self.update_lr(context.optimizer, context.epoch, None)
    def is_lr_scheduling_enabled(self, context):
        return self.training_params.lr_warmup_epochs <= context.epoch
      

Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.

Source code in V3_2/src/super_gradients/training/utils/callbacks/callbacks.py 757
class TestLRCallback(PhaseCallback):
    Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
     the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
     one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
    def __init__(self, lr_placeholder):
        super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
        self.lr_placeholder = lr_placeholder
    def __call__(self, context: PhaseContext):
        self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])
      

TrainingStageSwitchCallback

A phase callback that is called at a specific epoch (epoch start) to support multi-stage training. It does so by manipulating the objects inside the context.

Parameters:

Description Default 695
class TrainingStageSwitchCallbackBase(PhaseCallback):
    TrainingStageSwitchCallback
    A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
    It does so by manipulating the objects inside the context.
    :param next_stage_start_epoch: Epoch idx to apply the stage change.
    def __init__(self, next_stage_start_epoch: int):
        super(TrainingStageSwitchCallbackBase, self).__init__(phase=Phase.TRAIN_EPOCH_START)
        self.next_stage_start_epoch = next_stage_start_epoch
    def __call__(self, context: PhaseContext):
        if context.epoch == self.next_stage_start_epoch:
            self.apply_stage_change(context)
    def apply_stage_change(self, context: PhaseContext):
        This method is called when the callback is fired on the next_stage_start_epoch,
         and holds the stage change logic that should be applied to the context's objects.
        :param context: PhaseContext, context of current phase
        raise NotImplementedError
      

This method is called when the callback is fired on the next_stage_start_epoch, and holds the stage change logic that should be applied to the context's objects.

Parameters:

Description Default 695
def apply_stage_change(self, context: PhaseContext):
    This method is called when the callback is fired on the next_stage_start_epoch,
     and holds the stage change logic that should be applied to the context's objects.
    :param context: PhaseContext, context of current phase
    raise NotImplementedError
716
@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    YoloXTrainingStageSwitchCallback
    Training stage switch for YoloX training.
    Disables mosaic, and manipulates YoloX loss to use L1.
    def __init__(self, next_stage_start_epoch: int = 285):
        super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)
    def apply_stage_change(self, context: PhaseContext):
        for transform in context.train_loader.dataset.transforms:
            if hasattr(transform, "close"):
                transform.close()
        iter(context.train_loader)
        context.criterion.use_l1 = True
                Union[str, Mapping]
          

Union[str, Mapping], When str: Learning rate scheduling policy, one of ['step','poly','cosine','function']. 'step' refers to constant updates at epoch numbers passed through lr_updates. Each update decays the learning rate by lr_decay_factor. 'cosine' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. The final learning rate ratio is controlled by cosine_final_lr_ratio training parameter. 'poly' refers to the polynomial decrease: in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9) 'function' refers to a user-defined learning rate scheduling function, that is passed through lr_schedule_function. When Mapping, refers to a torch.optim.lr_scheduler.LRScheduler, following the below API: lr_mode = {LR_SCHEDULER_CLASS_NAME: {*LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX) Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step(). For instance, in order to: - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...) https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using ReduceLROnPlateau. In any other case this kwarg is ignored. *LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init_.

required train_loader DataLoader

DataLoader, the Trainer.train_loader used for training.

required torch.nn.Module

torch.nn.Module, the Trainer.net used for training.

required training_params Mapping

Mapping, Trainer.training_params.

required update_param_groups

bool, Whether the Trainer.net has a specific way of updaitng its parameter group.

required optimizer torch.optim.Optimizer

The optimizer used for training. Will be passed to the LR callback's init (or the torch scheduler's init, depending on the lr_mode value as described above).

required 950
def create_lr_scheduler_callback(
    lr_mode: Union[str, Mapping],
    train_loader: DataLoader,
    net: torch.nn.Module,
    training_params: Mapping,
    update_param_groups: bool,
    optimizer: torch.optim.Optimizer,
) -> PhaseCallback:
    Creates the phase callback in charge of LR scheduling, to be used by Trainer.
    :param lr_mode: Union[str, Mapping],
                    When str:
                    Learning rate scheduling policy, one of ['step','poly','cosine','function'].
                    'step' refers to constant updates at epoch numbers passed through `lr_updates`. Each update decays the learning rate by `lr_decay_factor`.
                    'cosine' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983.
                      The final learning rate ratio is controlled by `cosine_final_lr_ratio` training parameter.
                    'poly' refers to the polynomial decrease: in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)`
                    'function' refers to a user-defined learning rate scheduling function, that is passed through `lr_schedule_function`.
                    When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:
                        lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)
                        Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().
                        For instance, in order to:
                        - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END
                        - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END
                        The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
                         https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
                          ReduceLROnPlateau. In any other case this kwarg is ignored.
                        **LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.
    :param train_loader: DataLoader, the Trainer.train_loader used for training.
    :param net: torch.nn.Module, the Trainer.net used for training.
    :param training_params: Mapping, Trainer.training_params.
    :param update_param_groups:bool,  Whether the Trainer.net has a specific way of updaitng its parameter group.
    :param optimizer: The optimizer used for training. Will be passed to the LR callback's __init__
     (or the torch scheduler's init, depending on the lr_mode value as described above).
    :return: a PhaseCallback instance to be used by Trainer for LR scheduling.
    if isinstance(lr_mode, str) and lr_mode in LR_SCHEDULERS_CLS_DICT:
        sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[lr_mode]
        sg_lr_callback = sg_lr_callback_cls(
            train_loader_len=len(train_loader),
            net=net,
            training_params=training_params,
            update_param_groups=update_param_groups,
            **training_params.to_dict(),
    elif isinstance(lr_mode, Mapping) and list(lr_mode.keys())[0] in TORCH_LR_SCHEDULERS:
        if update_param_groups:
            logger.warning(
                "The network's way of updataing (i.e update_param_groups) is not supported with native " "torch lr schedulers and will have no effect."
        lr_scheduler_name = list(lr_mode.keys())[0]
        torch_scheduler_params = {k: v for k, v in lr_mode[lr_scheduler_name].items() if k != "phase" and k != "metric_name"}
        torch_scheduler_params["optimizer"] = optimizer
        torch_scheduler = TORCH_LR_SCHEDULERS[lr_scheduler_name](**torch_scheduler_params)
        if get_param(lr_mode[lr_scheduler_name], "phase") is None:
            raise ValueError("Phase is required argument when working with torch schedulers.")
        if lr_scheduler_name == "ReduceLROnPlateau" and get_param(lr_mode[lr_scheduler_name], "metric_name") is None:
            raise ValueError("metric_name is required argument when working with ReduceLROnPlateau schedulers.")
        sg_lr_callback = LRSchedulerCallback(
            scheduler=torch_scheduler, phase=lr_mode[lr_scheduler_name]["phase"], metric_name=get_param(lr_mode[lr_scheduler_name], "metric_name")
    else:
        raise ValueError(f"Unknown lr_mode: {lr_mode}")
    return sg_lr_callback
          Source code in V3_2/src/super_gradients/training/utils/callbacks/ppyoloe_switch_callback.py
29
@register_callback(Callbacks.PPYOLOE_TRAINING_STAGE_SWITCH)
class PPYoloETrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
    PPYoloETrainingStageSwitchCallback
    Training stage switch for PPYolo training.
    It changes static bbox assigner to a task aligned assigned after certain number of epochs passed
    def __init__(
        self,
        static_assigner_end_epoch: int = 30,
        super().__init__(next_stage_start_epoch=static_assigner_end_epoch)
    def apply_stage_change(self, context: PhaseContext):
        from super_gradients.training.losses import PPYoloELoss
        if not isinstance(context.criterion, PPYoloELoss):
            raise RuntimeError(
                f"A criterion must be an instance of PPYoloELoss when using PPYoloETrainingStageSwitchCallback. " f"Got criterion {repr(context.criterion)}"
        context.criterion.use_static_assigner = False
278
class MissingPretrainedWeightsException(Exception):
    """Exception raised by unsupported pretrianed model.
    :param desc: explanation of the error
    def __init__(self, desc):
        self.message = "Missing pretrained wights: " + desc
        super().__init__(self.message)
      

Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit the ckpt in order to properly load the weights into the model. If unsuccessful - returns None :param model_state_dict: the model state_dict :param source_ckpt: checkpoint dict :param exclude optional list for excluded layers :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val) that returns a desired weight for ckpt_val. :return: renamed checkpoint dict (if possible)

Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py 180
def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable = None):
    Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
    the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
        :param model_state_dict:               the model state_dict
        :param source_ckpt:                         checkpoint dict
        :param exclude                  optional list for excluded layers
        :param solver:                  callable with signature (ckpt_key, ckpt_val, model_key, model_val)
                                        that returns a desired weight for ckpt_val.
        :return: renamed checkpoint dict (if possible)
    if "net" in source_ckpt.keys():
        source_ckpt = source_ckpt["net"]
    model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
    new_ckpt_dict = {}
    for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
        if solver is not None:
            ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
        if ckpt_val.shape != model_val.shape:
            raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
        new_ckpt_dict[model_key] = ckpt_val
    return {"net": new_ckpt_dict}
      

Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.

Parameters:

Description Default

callable with signature (ckpt_key, ckpt_val, model_key, model_val) that returns a desired weight for ckpt_val.

75
def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Union[bool, StrictLoad], solver=None):
    Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
    :param net: (nn.Module) to load state_dict to
    :param state_dict: (dict) Checkpoint state_dict
    :param strict: (StrictLoad) key matching strictness
    :param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
                     that returns a desired weight for ckpt_val.
    :return:
    state_dict = state_dict["net"] if "net" in state_dict else state_dict
    # This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
    # and contains "module." prefix in all keys
    # If all keys start with "module.", then we remove it.
    if all([key.startswith("module.") for key in state_dict.keys()]):
        state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])
    try:
        strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
        net.load_state_dict(state_dict, strict=strict_bool)
    except (RuntimeError, ValueError, KeyError) as ex:
        if strict == StrictLoad.NO_KEY_MATCHING:
            adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict, solver=solver)
            net.load_state_dict(adapted_state_dict["net"], strict=True)
        elif strict == StrictLoad.KEY_MATCHING:
            transfer_weights(net, state_dict)
        else:
            raise_informative_runtime_error(net.state_dict(), state_dict, ex)
      

Copy the checkpoint from any supported source to a local destination path :param local_ckpt_destination_dir: destination where the checkpoint will be saved to :param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth :param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Modelull URL) :param path_src: S3 / url :param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder

:return: Path to checkpoint
        Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py
135
@explicit_params_validation(validation_type="None")
def copy_ckpt_to_local_folder(
    local_ckpt_destination_dir: str,
    ckpt_filename: str,
    remote_ckpt_source_dir: str = None,
    path_src: str = "local",
    overwrite_local_ckpt: bool = False,
    load_weights_only: bool = False,
    Copy the checkpoint from any supported source to a local destination path
        :param local_ckpt_destination_dir:  destination where the checkpoint will be saved to
        :param ckpt_filename:         ckpt_best.pth Or ckpt_latest.pth
        :param remote_ckpt_source_dir:       Name of the source checkpoint to be loaded (S3 Model\full URL)
        :param path_src:              S3 / url
        :param overwrite_local_ckpt:  determines if checkpoint will be saved in destination dir or in a temp folder
        :return: Path to checkpoint
    ckpt_file_full_local_path = None
    # IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
    remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
    if not overwrite_local_ckpt:
        # CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
        download_ckpt_destination_dir = tempfile.gettempdir()
        print(
            "PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False "
            "-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART"
    else:
        # SAVE THE CHECKPOINT TO MODEL's FOLDER
        download_ckpt_destination_dir = pkg_resources.resource_filename("checkpoints", local_ckpt_destination_dir)
    if path_src.startswith("s3"):
        model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
        # DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
        ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
            ckpt_source_remote_dir=remote_ckpt_source_dir,
            ckpt_destination_local_dir=download_ckpt_destination_dir,
            ckpt_file_name=ckpt_filename,
            overwrite_local_checkpoints_file=overwrite_local_ckpt,
        if not load_weights_only:
            # COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
            model_checkpoints_data_interface.load_all_remote_log_files(
                model_name=remote_ckpt_source_dir, model_checkpoint_local_dir=download_ckpt_destination_dir
    if path_src == "url":
        ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
        # DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
        with wait_for_the_master(get_local_rank()):
            download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
    return ckpt_file_full_local_path
          

Whether to call set_dataset_processing_params on "processing_params" entry inside the checkpoint file (default=False).

False ckpt_local_path: str, load_backbone: bool = False, strict: Union[str, StrictLoad] = StrictLoad.NO_KEY_MATCHING, load_weights_only: bool = False, load_ema_as_net: bool = False, load_processing_params: bool = False, Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict. :param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set :param ckpt_local_path: local path to the checkpoint file :param load_backbone: whether to load the checkpoint as a backbone :param net: network to load the checkpoint to :param strict: :param load_weights_only: Whether to ignore all other entries other then "net". :param load_processing_params: Whether to call set_dataset_processing_params on "processing_params" entry inside the checkpoint file (default=False). :return: if isinstance(strict, str): strict = StrictLoad(strict) net = unwrap_model(net) if load_backbone and not hasattr(net, "backbone"): raise ValueError("No backbone attribute in net - Can't load backbone weights") # LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path) if load_ema_as_net: if "ema_net" not in checkpoint.keys(): raise ValueError("Can't load ema network- no EMA network stored in checkpoint file") else: checkpoint["net"] = checkpoint["ema_net"] # LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL if load_backbone: adaptive_load_state_dict(net.backbone, checkpoint, strict) else: adaptive_load_state_dict(net, checkpoint, strict) message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint." message_model = "model" if not load_backbone else "model's backbone" logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix) if (isinstance(net, HasPredict)) and load_processing_params: if "processing_params" not in checkpoint.keys(): raise ValueError("Can't load processing params - could not find any stored in checkpoint file.") try: net.set_dataset_processing_params(**checkpoint["processing_params"]) except Exception as e: logger.warning( f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling" "predict make sure to call set_dataset_processing_params." if load_weights_only or load_backbone: # DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS [checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"] return checkpoint 330
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
    Loads pretrained weights from the MODEL_URLS dictionary to model
    :param architecture: name of the model's architecture
    :param model: model to load pretrinaed weights for
    :param pretrained_weights: name for the pretrianed weights (i.e imagenet)
    :return: None
    from super_gradients.common.object_names import Models
    model_url_key = architecture + "_" + str(pretrained_weights)
    if model_url_key not in MODEL_URLS.keys():
        raise MissingPretrainedWeightsException(model_url_key)
    url = MODEL_URLS[model_url_key]
    if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
        logger.info(
            "License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
            "https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS.md\n"
            "By downloading the pre-trained weight files you agree to comply with these terms."
    unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
    map_location = torch.device("cpu")
    with wait_for_the_master(get_local_rank()):
        pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
    _load_weights(architecture, model, pretrained_state_dict)
353
def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
    Loads pretrained weights from the MODEL_URLS dictionary to model
    :param architecture: name of the model's architecture
    :param model: model to load pretrinaed weights for
    :param pretrained_weights: path tp pretrained weights
    :return: None
    map_location = torch.device("cpu")
    pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
    _load_weights(architecture, model, pretrained_state_dict)
      

Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names" and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible

Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py 199
def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
    Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
    and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
    try:
        new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
        temp_file = tempfile.NamedTemporaryFile().name + ".pt"
        torch.save(new_ckpt_dict, temp_file)
        exception_msg = (
            f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_"
            f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
    except ValueError as ex:  # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
        exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
    finally:
        raise RuntimeError(exception_msg)
156
def read_ckpt_state_dict(ckpt_path: str, device="cpu") -> Mapping[str, torch.Tensor]:
    Reads a checkpoint state dict from a given path or url
    :param ckpt_path: Checkpoint path or url
    :param device: Target device where tensors should be loaded
    :return: Checkpoint state dict object
    if ckpt_path.startswith("https://"):
        with wait_for_the_master(get_local_rank()):
            state_dict = load_state_dict_from_url(ckpt_path, progress=False, map_location=device)
        return state_dict
    else:
        if not os.path.exists(ckpt_path):
            raise FileNotFoundError(f"Incorrect Checkpoint path: {ckpt_path} (This should be an absolute path)")
        state_dict = torch.load(ckpt_path, map_location=device)
        return state_dict
      

Copy weights from model_state_dict to model, skipping layers that are incompatible (Having different shape). This method is helpful if you are doing some model surgery and want to load part of the model weights into different model. This function will go over all the layers in model_state_dict and will try to find a matching layer in model and copy the weights into it. If shape will not match, the layer will be skipped.

Parameters:

Description Default 44
def transfer_weights(model: nn.Module, model_state_dict: Mapping[str, Tensor]) -> None:
    Copy weights from `model_state_dict` to `model`, skipping layers that are incompatible (Having different shape).
    This method is helpful if you are doing some model surgery and want to load
    part of the model weights into different model.
    This function will go over all the layers in `model_state_dict` and will try to find a matching layer in `model` and
    copy the weights into it. If shape will not match, the layer will be skipped.
    :param model: Model to load weights into
    :param model_state_dict: Model state dict to load weights from
    :return: None
    for name, value in model_state_dict.items():
        try:
            model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False)
        except RuntimeError:
      

Implements access counting mechanism for configuration settings (dicts/lists). It is achieved by wrapping underlying config and override getitem, getattr methods to catch read operations and increments access counter for each property.

Source code in V3_2/src/super_gradients/training/utils/config_utils.py 77
class AccessCounterMixin:
    Implements access counting mechanism for configuration settings (dicts/lists).
    It is achieved by wrapping underlying config and override __getitem__, __getattr__ methods to catch read operations
    and increments access counter for each property.
    _access_counter: Mapping[str, int]
    _prefix: str  # Prefix string
    def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
        Return an attribute value optionally wrapped as access counter adapter to trace read counts.
        :param value: Attribute value
        :param key: Attribute name
        :param count_usage: Whether increment usage count for given attribute. Default is True.
        :return: wrapped value
        key_with_prefix = self._prefix + str(key)
        if count_usage:
            self._access_counter[key_with_prefix] += 1
        if isinstance(value, Mapping):
            return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
        if isinstance(value, Iterable) and not isinstance(value, str):
            return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
        return value
    @property
    def access_counter(self):
        return self._access_counter
    @abc.abstractmethod
    def get_all_params(self) -> Set[str]:
        raise NotImplementedError()
    def get_used_params(self) -> Set[str]:
        used_params = {k for (k, v) in self._access_counter.items() if v > 0}
        return used_params
    def get_unused_params(self) -> Set[str]:
        unused_params = self.get_all_params() - self.get_used_params()
        return unused_params
    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.__dict__.update(self.__dict__)
        return result
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            setattr(result, k, deepcopy(v, memo))
        return result
      

Return an attribute value optionally wrapped as access counter adapter to trace read counts.

Parameters:

Description Default 47
def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
    Return an attribute value optionally wrapped as access counter adapter to trace read counts.
    :param value: Attribute value
    :param key: Attribute name
    :param count_usage: Whether increment usage count for given attribute. Default is True.
    :return: wrapped value
    key_with_prefix = self._prefix + str(key)
    if count_usage:
        self._access_counter[key_with_prefix] += 1
    if isinstance(value, Mapping):
        return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
    if isinstance(value, Iterable) and not isinstance(value, str):
        return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
    return value
      

A helper function to check whether all confuration parameters were used on given block of code. Motivation to have this check is to ensure there were no typo or outdated configuration parameters. It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception. Example usage:

from super_gradients.training.utils import raise_if_unused_params

with raise_if_unused_params(some_config) as some_config: do_something_with_config(some_config) 243

def raise_if_unused_params(config: Union[HpmStruct, DictConfig, ListConfig, Mapping, list, tuple]) -> ConfigInspector:
    A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
    this check is to ensure there were no typo or outdated configuration parameters.
    It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception.
    Example usage:
    >>> from super_gradients.training.utils import raise_if_unused_params
    >>> with raise_if_unused_params(some_config) as some_config:
    >>>    do_something_with_config(some_config)
    :param config: A config to check
    :return: An instance of ConfigInspector
    if isinstance(config, HpmStruct):
        wrapper_cls = AccessCounterHpmStruct
    elif isinstance(config, (Mapping, DictConfig)):
        wrapper_cls = AccessCounterDict
    elif isinstance(config, (list, tuple, ListConfig)):
        wrapper_cls = AccessCounterList
    else:
        raise RuntimeError(f"Unsupported type. Root configuration object must be a mapping or list. Got type {type(config)}")
    return ConfigInspector(wrapper_cls(config), unused_params_action="raise")
      

A helper function to check whether all confuration parameters were used on given block of code. Motivation to have this check is to ensure there were no typo or outdated configuration parameters. It at least one of config parameters was not used, this function will emit warning. Example usage:

from super_gradients.training.utils import warn_if_unused_params

with warn_if_unused_params(some_config) as some_config: do_something_with_config(some_config) 271

def warn_if_unused_params(config):
    A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
    this check is to ensure there were no typo or outdated configuration parameters.
    It at least one of config parameters was not used, this function will emit warning.
    Example usage:
    >>> from super_gradients.training.utils import warn_if_unused_params
    >>> with warn_if_unused_params(some_config) as some_config:
    >>>    do_something_with_config(some_config)
    :param config: A config to check
    :return: An instance of ConfigInspector
    if isinstance(config, HpmStruct):
        wrapper_cls = AccessCounterHpmStruct
    elif isinstance(config, (Mapping, DictConfig)):
        wrapper_cls = AccessCounterDict
    elif isinstance(config, (list, tuple, ListConfig)):
        wrapper_cls = AccessCounterList
    else:
        raise RuntimeError("Unsupported type. Root configuration object must be a mapping or list.")
    return ConfigInspector(wrapper_cls(config), unused_params_action="warn")

from super_gradients.training.utils.deprecated_utils import wrap_with_warning from super_gradients.training.utils.callbacks import EpochStepWarmupLRCallback, BatchStepLinearWarmupLRCallback

LR_WARMUP_CLS_DICT = { "linear": wrap_with_warning( EpochStepWarmupLRCallback, message=f"Parameter linear has been made deprecated and will be removed in the next SG release. Please use linear_epoch instead", 'linear_epoch`': EpochStepWarmupLRCallback, 32

def wrap_with_warning(cls: Callable, message: str) -> Any:
    Emits a warning when target class of function is called.
    >>> from super_gradients.training.utils.deprecated_utils import wrap_with_warning
    >>> from super_gradients.training.utils.callbacks import EpochStepWarmupLRCallback, BatchStepLinearWarmupLRCallback
    >>> LR_WARMUP_CLS_DICT = {
    >>>     "linear": wrap_with_warning(
    >>>         EpochStepWarmupLRCallback,
    >>>         message=f"Parameter `linear` has been made deprecated and will be removed in the next SG release. Please use `linear_epoch` instead",
    >>>     'linear_epoch`': EpochStepWarmupLRCallback,
    :param cls: A class or function to wrap
    :param message: A message to emit when this class is called
    :return: A factory method that returns wrapped class
    def _inner_fn(*args, **kwargs):
        logger.warning(message)
        return cls(*args, **kwargs)
    return _inner_fn
598
class Anchors(nn.Module):
    A wrapper function to hold the anchors used by detection models such as Yolo
    def __init__(self, anchors_list: List[List], strides: List[int]):
        :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
            the width and height of the anchors of a specific detection layer.
            i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
            The width and height are in pixels (not relative to image size)
        :param strides: a list containing the stride of the layers from which the detection heads are fed.
            i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
        super().__init__()
        self.__anchors_list = anchors_list
        self.__strides = strides
        self._check_all_lists(anchors_list)
        self._check_all_len_equal_and_even(anchors_list)
        self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
        anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
        self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
        self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
    @staticmethod
    def _check_all_lists(anchors: list) -> bool:
        for a in anchors:
            if not isinstance(a, (list, ListConfig)):
                raise RuntimeError("All objects of anchors_list must be lists")
    @staticmethod
    def _check_all_len_equal_and_even(anchors: list) -> bool:
        len_of_first = len(anchors[0])
        for a in anchors:
            if len(a) % 2 == 1 or len(a) != len_of_first:
                raise RuntimeError("All objects of anchors_list must be of the same even length")
    @property
    def stride(self) -> nn.Parameter:
        return self._stride
    @property
    def anchors(self) -> nn.Parameter:
        return self._anchors
    @property
    def anchor_grid(self) -> nn.Parameter:
        return self._anchor_grid
    @property
    def detection_layers_num(self) -> int:
        return self._anchors.shape[0]
    @property
    def num_anchors(self) -> int:
        return self._anchors.shape[1]
    def __repr__(self):
        return f"anchors_list: {self.__anchors_list} strides: {self.__strides}"
                List[List]
          

of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds the width and height of the anchors of a specific detection layer. i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each The width and height are in pixels (not relative to image size)

required strides List[int]

a list containing the stride of the layers from which the detection heads are fed. i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8

required 562
def __init__(self, anchors_list: List[List], strides: List[int]):
    :param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
        the width and height of the anchors of a specific detection layer.
        i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
        The width and height are in pixels (not relative to image size)
    :param strides: a list containing the stride of the layers from which the detection heads are fed.
        i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
    super().__init__()
    self.__anchors_list = anchors_list
    self.__strides = strides
    self._check_all_lists(anchors_list)
    self._check_all_len_equal_and_even(anchors_list)
    self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
    anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
    self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
    self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
826
@register_collate_function()
class CrowdDetectionCollateFN(DetectionCollateFN):
    Collate function for Yolox training with additional_batch_items that includes crowd targets
    def __init__(self):
        super().__init__()
        self.expected_item_names = ("image", "targets", "crowd_targets")
    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        try:
            images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
        return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
807
class CrowdDetectionPPYoloECollateFN(PPYoloECollateFN):
    Collate function for Yolox training with additional_batch_items that includes crowd targets
    def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
        super().__init__(random_resize_sizes, random_resize_modes)
        self.expected_item_names = ("image", "targets", "crowd_targets")
    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        if self.random_resize_sizes is not None:
            data = self.random_resize(data)
        try:
            images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
        return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
687
class DatasetItemsException(Exception):
    def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
        :param data_sample: item(s) returned by a dataset
        :param collate_type: type of the collate that caused the exception
        :param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
        collate_type_name = collate_type.__name__
        num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
        error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
        super().__init__(error_msg)
                Tuple
          

tuple of names of items that are expected by the collate to be returned from the dataset

required 687
def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
    :param data_sample: item(s) returned by a dataset
    :param collate_type: type of the collate that caused the exception
    :param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
    collate_type_name = collate_type.__name__
    num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
    error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
    super().__init__(error_msg)
        self.expected_item_names = ("image", "targets")
    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        try:
            images_batch, labels_batch = list(zip(*data))
        except (ValueError, TypeError):
            raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
        return self._format_images(images_batch), self._format_targets(labels_batch)
    def _format_images(self, images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
        images_batch = [torch.tensor(img) for img in images_batch]
        images_batch_stack = torch.stack(images_batch, 0)
        if images_batch_stack.shape[3] == 3:
            images_batch_stack = torch.moveaxis(images_batch_stack, -1, 1).float()
        return images_batch_stack
    def _format_targets(self, labels_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
        Stack a batch id column to targets and concatenate
        :param labels_batch: a list of targets per image (each of arbitrary length)
        :return: one tensor of targets of all imahes of shape [N, 6], where N is the total number of targets in a batch
                 and the 1st column is batch item index
        labels_batch = [torch.tensor(labels) for labels in labels_batch]
        labels_batch_indexed = []
        for i, labels in enumerate(labels_batch):
            batch_column = labels.new_ones((labels.shape[0], 1)) * i
            labels = torch.cat((batch_column, labels), dim=-1)
            labels_batch_indexed.append(labels)
        return torch.cat(labels_batch_indexed, 0)
194
class DetectionPostPredictionCallback(ABC, nn.Module):
    def __init__(self) -> None:
        super().__init__()
    @abstractmethod
    def forward(self, x, device: str):
        :param x:       the output of your model
        :param device:  the device to move all output tensors into
        :return:        a list with length batch_size, each item in the list is a detections
                        with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
        raise NotImplementedError
    :param x:       the output of your model
    :param device:  the device to move all output tensors into
    :return:        a list with length batch_size, each item in the list is a detections
                    with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
    raise NotImplementedError
      

Enum class for the different detection output formats

When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).

For example: LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]

Source code in V3_2/src/super_gradients/training/utils/detection_utils.py 39
class DetectionTargetsFormat(Enum):
    Enum class for the different detection output formats
    When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).
    For example:
    LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]
    LABEL_XYXY = "LABEL_XYXY"
    XYXY_LABEL = "XYXY_LABEL"
    LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
    NORMALIZED_XYXY_LABEL = "NORMALIZED_XYXY_LABEL"
    LABEL_CXCYWH = "LABEL_CXCYWH"
    CXCYWH_LABEL = "CXCYWH_LABEL"
    LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
    NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
534
class DetectionVisualization:
    @staticmethod
    def _generate_color_mapping(num_classes: int) -> List[Tuple[int]]:
        Generate a unique BGR color for each class
        return generate_color_mapping(num_classes=num_classes)
    @staticmethod
    def _draw_box_title(
        color_mapping: List[Tuple[int]],
        class_names: List[str],
        box_thickness: int,
        image_np: np.ndarray,
        x1: int,
        y1: int,
        x2: int,
        y2: int,
        class_id: int,
        pred_conf: float = None,
        is_target: bool = False,
        color = color_mapping[class_id]
        class_name = class_names[class_id]
        if is_target:
            title = f"[GT] {class_name}"
        else:
            title = f'[Pred] {class_name}  {str(round(pred_conf, 2)) if pred_conf is not None else ""}'
        image_np = draw_bbox(image=image_np, title=title, x1=x1, y1=y1, x2=x2, y2=y2, box_thickness=box_thickness, color=color)
        return image_np
    @staticmethod
    def _visualize_image(
        image_np: np.ndarray,
        pred_boxes: np.ndarray,
        target_boxes: np.ndarray,
        class_names: List[str],
        box_thickness: int,
        gt_alpha: float,
        image_scale: float,
        checkpoint_dir: str,
        image_name: str,
        image_np = cv2.resize(image_np, (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)
        color_mapping = DetectionVisualization._generate_color_mapping(len(class_names))
        # Draw predictions
        pred_boxes[:, :4] *= image_scale
        for box in pred_boxes:
            image_np = DetectionVisualization._draw_box_title(
                color_mapping, class_names, box_thickness, image_np, *box[:4].astype(int), class_id=int(box[5]), pred_conf=box[4]
        # Draw ground truths
        target_boxes_image = np.zeros_like(image_np, np.uint8)
        for box in target_boxes:
            target_boxes_image = DetectionVisualization._draw_box_title(
                color_mapping, class_names, box_thickness, target_boxes_image, *box[2:], class_id=box[1], is_target=True
        # Transparent overlay of ground truth boxes
        mask = target_boxes_image.astype(bool)
        image_np[mask] = cv2.addWeighted(image_np, 1 - gt_alpha, target_boxes_image, gt_alpha, 0)[mask]
        if checkpoint_dir is None:
            return image_np
        else:
            pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
            cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), image_np)
    @staticmethod
    def _scaled_ccwh_to_xyxy(target_boxes: np.ndarray, h: int, w: int, image_scale: float) -> np.ndarray:
        Modifies target_boxes inplace
        :param target_boxes:    (c1, c2, w, h) boxes in [0, 1] range
        :param h:               image height
        :param w:               image width
        :param image_scale:     desired scale for the boxes w.r.t. w and h
        :return:                targets in (x1, y1, x2, y2) format
                                in range [0, w * self.image_scale] [0, h * self.image_scale]
        # unscale
        target_boxes[:, 2:] *= np.array([[w, h, w, h]])
        # x1 = c1 - w // 2; y1 = c2 - h // 2
        target_boxes[:, 2] -= target_boxes[:, 4] // 2
        target_boxes[:, 3] -= target_boxes[:, 5] // 2
        # x2 = w + x1; y2 = h + y1
        target_boxes[:, 4] += target_boxes[:, 2]
        target_boxes[:, 5] += target_boxes[:, 3]
        target_boxes[:, 2:] *= image_scale
        target_boxes = target_boxes.astype(int)
        return target_boxes
    @staticmethod
    def visualize_batch(
        image_tensor: torch.Tensor,
        pred_boxes: List[torch.Tensor],
        target_boxes: torch.Tensor,
        batch_name: Union[int, str],
        class_names: List[str],
        checkpoint_dir: str = None,
        undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing,
        box_thickness: int = 2,
        image_scale: float = 1.0,
        gt_alpha: float = 0.4,
        A helper function to visualize detections predicted by a network:
        saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
        Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
        Adjustable:
            * Ground truth box transparency;
            * Box width;
            * Image size (larger or smaller than what's provided)
        :param image_tensor:            rgb images, (B, H, W, 3)
        :param pred_boxes:              boxes after NMS for each image in a batch, each (Num_boxes, 6),
                                        values on dim 1 are: x1, y1, x2, y2, confidence, class
        :param target_boxes:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
                                        (coordinates scaled to [0, 1])
        :param batch_name:              id of the current batch to use for image naming
        :param class_names:             names of all classes, each on its own index
        :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                        be returns as a list of numpy image arrays
        :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
        :param box_thickness:           box line thickness in px
        :param image_scale:             scale of an image w.r.t. given image size,
                                        e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
        :param gt_alpha:                a value in [0., 1.] transparency on ground truth boxes,
                                        0 for invisible, 1 for fully opaque
        image_np = undo_preprocessing_func(image_tensor.detach())
        targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale)
        out_images = []
        for i in range(image_np.shape[0]):
            preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6))
            targets_cur = targets[targets[:, 0] == i]
            image_name = "_".join([str(batch_name), str(i)])
            res_image = DetectionVisualization._visualize_image(
                image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale, checkpoint_dir, image_name
            if res_image is not None:
                out_images.append(res_image)
        return out_images
visualize_batch(image_tensor, pred_boxes, target_boxes, batch_name, class_names, checkpoint_dir=None, undo_preprocessing_func=undo_image_preprocessing, box_thickness=2, image_scale=1.0, gt_alpha=0.4)
      staticmethod
      

A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

Adjustable: * Ground truth box transparency; * Box width; * Image size (larger or smaller than what's provided)

Parameters:

Description Default List[torch.Tensor]

boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class

required target_boxes torch.Tensor

(Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1])

required batch_name Union[int, str]

id of the current batch to use for image naming

required class_names List[str]

names of all classes, each on its own index

required checkpoint_dir

a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays

undo_preprocessing_func Callable[[torch.Tensor], np.ndarray]

a function to convert preprocessed images tensor into a batch of cv2-like images

undo_image_preprocessing box_thickness

box line thickness in px

image_scale float

scale of an image w.r.t. given image size, e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)

gt_alpha float

a value in [0., 1.] transparency on ground truth boxes, 0 for invisible, 1 for fully opaque

def visualize_batch( image_tensor: torch.Tensor, pred_boxes: List[torch.Tensor], target_boxes: torch.Tensor, batch_name: Union[int, str], class_names: List[str], checkpoint_dir: str = None, undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing, box_thickness: int = 2, image_scale: float = 1.0, gt_alpha: float = 0.4, A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes. Adjustable: * Ground truth box transparency; * Box width; * Image size (larger or smaller than what's provided) :param image_tensor: rgb images, (B, H, W, 3) :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1]) :param batch_name: id of the current batch to use for image naming :param class_names: names of all classes, each on its own index :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images :param box_thickness: box line thickness in px :param image_scale: scale of an image w.r.t. given image size, e.g. incoming images are (320x320), use scale = 2. to preview in (640x640) :param gt_alpha: a value in [0., 1.] transparency on ground truth boxes, 0 for invisible, 1 for fully opaque image_np = undo_preprocessing_func(image_tensor.detach()) targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale) out_images = [] for i in range(image_np.shape[0]): preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6)) targets_cur = targets[targets[:, 0] == i] image_name = "_".join([str(batch_name), str(i)]) res_image = DetectionVisualization._visualize_image( image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale, checkpoint_dir, image_name if res_image is not None: out_images.append(res_image) return out_images 220
class IouThreshold(tuple, Enum):
    MAP_05 = (0.5, 0.5)
    MAP_05_TO_095 = (0.5, 0.95)
    def is_range(self):
        return self[0] != self[1]
    def to_tensor(self):
        if self.is_range():
            return self.from_bounds(self[0], self[1], step=0.05)
        else:
            return torch.tensor([self[0]])
    @classmethod
    def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
        Create a tensor with values from low (including) to high (including) with a given step size.
        :param low: Lower bound
        :param high: Upper bound
        :param step: Step size
        :return: Tensor of [low, low + step, low + 2 * step, ..., high]
        n_iou_thresh = int(round((high - low) / step)) + 1
        return torch.linspace(low, high, n_iou_thresh)
      

Create a tensor with values from low (including) to high (including) with a given step size.

Parameters:

Description Default 220
@classmethod
def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
    Create a tensor with values from low (including) to high (including) with a given step size.
    :param low: Lower bound
    :param high: Upper bound
    :param step: Step size
    :return: Tensor of [low, low + step, low + 2 * step, ..., high]
    n_iou_thresh = int(round((high - low) / step)) + 1
    return torch.linspace(low, high, n_iou_thresh)
366
class NMS_Type(str, Enum):
    Type of non max suppression algorithm that can be used for post processing detection
    ITERATIVE = "iterative"
    MATRIX = "matrix"
785
class PPYoloECollateFN(DetectionCollateFN):
    Collate function for PPYoloE training
    def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
        :param random_resize_sizes: (rows, cols)
        super().__init__()
        self.random_resize_sizes = random_resize_sizes
        self.random_resize_modes = random_resize_modes
    def __repr__(self):
        return f"PPYoloECollateFN(random_resize_sizes={self.random_resize_sizes}, random_resize_modes={self.random_resize_modes})"
    def __str__(self):
        return self.__repr__()
    def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.random_resize_sizes is not None:
            data = self.random_resize(data)
        return super().__call__(data)
    def random_resize(self, batch):
        target_size = random.choice(self.random_resize_sizes)
        interpolation = random.choice(self.random_resize_modes)
        batch = [self.random_resize_sample(sample, target_size, interpolation) for sample in batch]
        return batch
    def random_resize_sample(self, sample, target_size, interpolation):
        if len(sample) == 2:
            image, targets = sample  # TARGETS ARE IN LABEL_CXCYWH
            with_crowd = False
        elif len(sample) == 3:
            image, targets, crowd_targets = sample
            with_crowd = True
        else:
            raise DatasetItemsException(data_sample=sample, collate_type=type(self), expected_item_names=self.expected_item_names)
        dsize = int(target_size), int(target_size)
        scale_factors = target_size / image.shape[0], target_size / image.shape[1]
        image = cv2.resize(
            image,
            dsize=dsize,
            interpolation=interpolation,
        sy, sx = scale_factors
        targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
        if with_crowd:
            crowd_targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
            return image, targets, crowd_targets
        return image, targets
741
def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
    :param random_resize_sizes: (rows, cols)
    super().__init__()
    self.random_resize_sizes = random_resize_sizes
    self.random_resize_modes = random_resize_modes
674
def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
    Adjusts the bbox annotations of rescaled, padded image.
    :param bbox: (np.array) bbox to modify.
    :param scale_ratio: (float) scale ratio between rescale output image and original one.
    :param padw: (int) width padding size.
    :param padh: (int) height padding size.
    :param w_max: (int) width border.
    :param h_max: (int) height border
    :return: modified bbox (np.array)
    bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
    bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
    return bbox
      

Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

Parameters:

Description Default 242
def box_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
    # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
    Return intersection-over-union (Jaccard index) of boxes.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    :param box1: Tensor of shape [N, 4]
    :param box2: Tensor of shape [M, 4]
    :return:     iou, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
    def box_area(box):
        # box = 4xn
        return (box[2] - box[0]) * (box[3] - box[1])
    area1 = box_area(box1.T)
    area2 = box_area(box2.T)
    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
    return inter / (area1[:, None] + area2 - inter)  # iou = inter / (area1 + area2 - inter)
                torch.Tensor
          

a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where each box format is [x1,y1,x2,y2]

required 171
def calc_bbox_iou_matrix(pred: torch.Tensor):
    calculate iou for every pair of boxes in the boxes vector
    :param pred: a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where
                 each box format is [x1,y1,x2,y2]
    :return: a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i'th image in the batch
    box = pred[:, :, :4]  #
    b1_x1, b1_y1 = box[:, :, 0].unsqueeze(1), box[:, :, 1].unsqueeze(1)
    b1_x2, b1_y2 = box[:, :, 2].unsqueeze(1), box[:, :, 3].unsqueeze(1)
    b2_x1 = b1_x1.transpose(2, 1)
    b2_x2 = b1_x2.transpose(2, 1)
    b2_y1 = b1_y1.transpose(2, 1)
    b2_y2 = b1_y2.transpose(2, 1)
    intersection_area = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
    # Union Area
    w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
    w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
    union_area = (w1 * h1 + 1e-16) + w2 * h2 - intersection_area
    ious = intersection_area / union_area
    return ious
      

calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2

Parameters:

Description Default 147
def calculate_bbox_iou_matrix(box1, box2, x1y1x2y2=True, GIoU: bool = False, DIoU=False, CIoU=False, eps=1e-9):
    calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2
    :param box1: a 2D tensor of boxes (shape N x 4)
    :param box2: a 2D tensor of boxes (shape M x 4)
    :param x1y1x2y2: boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)
    :return: a 2D iou matrix (shape NxM)
    if box1.dim() > 1:
        box1 = box1.T
    # Get the coordinates of bounding boxes
    if x1y1x2y2:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
    else:  # x, y, w, h = box1
        b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
        b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    b1_x1, b1_y1, b1_x2, b1_y2 = b1_x1.unsqueeze(1), b1_y1.unsqueeze(1), b1_x2.unsqueeze(1), b1_y2.unsqueeze(1)
    return _iou(CIoU, DIoU, GIoU, b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, eps)
836
def compute_box_area(box: torch.Tensor) -> torch.Tensor:
    Compute the area of one or many boxes.
    :param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)
    :return: Area of every box, shape = (1, ?)
    # box = 4xn
    return (box[2] - box[0]) * (box[3] - box[1])
      

Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.

Parameters:

Description Default List[torch.Tensor]

list (of length batch_size) of Tensors of shape (num_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size

required targets torch.Tensor

targets for all images of shape (total_num_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]

required height

dimensions of the image

required width

dimensions of the image

required iou_thresholds torch.Tensor

Threshold to compute the mAP

required device

Device

required crowd_targets Optional[torch.Tensor]

crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]

top_k

Number of predictions to keep per class, ordered by confidence score

denormalize_targets

If True, denormalize the targets and crowd_targets

required return_on_cpu

If True, the output will be returned on "CPU", otherwise it will be returned on "device"

List[Tuple]

list of the following tensors, for every image: :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target

919
def compute_detection_matching(
    output: List[torch.Tensor],
    targets: torch.Tensor,
    height: int,
    width: int,
    iou_thresholds: torch.Tensor,
    denormalize_targets: bool,
    device: str,
    crowd_targets: Optional[torch.Tensor] = None,
    top_k: int = 100,
    return_on_cpu: bool = True,
) -> List[Tuple]:
    Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.
    :param output:          list (of length batch_size) of Tensors of shape (num_predictions, 6)
                            format:     (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
    :param targets:         targets for all images of shape (total_num_targets, 6)
                            format:     (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
    :param height:          dimensions of the image
    :param width:           dimensions of the image
    :param iou_thresholds:  Threshold to compute the mAP
    :param device:          Device
    :param crowd_targets:   crowd targets for all images of shape (total_num_crowd_targets, 6)
                            format:     (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
    :param top_k:           Number of predictions to keep per class, ordered by confidence score
    :param denormalize_targets: If True, denormalize the targets and crowd_targets
    :param return_on_cpu:   If True, the output will be returned on "CPU", otherwise it will be returned on "device"
    :return:                list of the following tensors, for every image:
        :preds_matched:     Tensor of shape (num_img_predictions, n_iou_thresholds)
                            True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
        :preds_to_ignore:   Tensor of shape (num_img_predictions, n_iou_thresholds)
                            True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
        :preds_scores:      Tensor of shape (num_img_predictions), confidence score for every prediction
        :preds_cls:         Tensor of shape (num_img_predictions), predicted class for every prediction
        :targets_cls:       Tensor of shape (num_img_targets), ground truth class for every target
    output = map(lambda tensor: None if tensor is None else tensor.to(device), output)
    targets, iou_thresholds = targets.to(device), iou_thresholds.to(device)
    # If crowd_targets is not provided, we patch it with an empty tensor
    crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.to(device)
    batch_metrics = []
    for img_i, img_preds in enumerate(output):
        # If img_preds is None (not prediction for this image), we patch it with an empty tensor
        img_preds = img_preds if img_preds is not None else torch.zeros(size=(0, 6), device=device)
        img_targets = targets[targets[:, 0] == img_i, 1:]
        img_crowd_targets = crowd_targets[crowd_targets[:, 0] == img_i, 1:]
        img_matching_tensors = compute_img_detection_matching(
            preds=img_preds,
            targets=img_targets,
            crowd_targets=img_crowd_targets,
            denormalize_targets=denormalize_targets,
            height=height,
            width=width,
            device=device,
            iou_thresholds=iou_thresholds,
            top_k=top_k,
            return_on_cpu=return_on_cpu,
        batch_metrics.append(img_matching_tensors)
    return batch_metrics
      

Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.

Parameters:

Description Default torch.Tensor

Tensor of shape (num_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold

required preds_scores torch.Tensor

Tensor of shape (num_predictions), confidence score for every prediction

required preds_cls torch.Tensor

Tensor of shape (num_predictions), predicted class for every prediction

required targets_cls torch.Tensor

Tensor of shape (num_targets), ground truth class for every target box to be detected

required recall_thresholds Optional[torch.Tensor]

Recall thresholds used to compute MaP.

score_threshold Optional[float]

Minimum confidence score to consider a prediction for the computation of precision, recall and f1 (not MaP)

device

Device

required calc_best_score_thresholds

If True, the best confidence score threshold is computed for each class

False Tuple

:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs) :unique_classes: Vector with all unique target classes :best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds is True else None :best_score_threshold_per_cls: dict that stores the best score threshold for each class , if calc_best_score_thresholds is True else None

1170
def compute_detection_metrics(
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_scores: torch.Tensor,
    preds_cls: torch.Tensor,
    targets_cls: torch.Tensor,
    device: str,
    recall_thresholds: Optional[torch.Tensor] = None,
    score_threshold: Optional[float] = 0.1,
    calc_best_score_thresholds: bool = False,
) -> Tuple:
    Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.
    :param preds_matched:      Tensor of shape (num_predictions, n_iou_thresholds)
                                    True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
    :param preds_to_ignore     Tensor of shape (num_predictions, n_iou_thresholds)
                                    True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
    :param preds_scores:       Tensor of shape (num_predictions), confidence score for every prediction
    :param preds_cls:          Tensor of shape (num_predictions), predicted class for every prediction
    :param targets_cls:        Tensor of shape (num_targets), ground truth class for every target box to be detected
    :param recall_thresholds:   Recall thresholds used to compute MaP.
    :param score_threshold:    Minimum confidence score to consider a prediction for the computation of
                                    precision, recall and f1 (not MaP)
    :param device:             Device
    :param calc_best_score_thresholds: If True, the best confidence score threshold is computed for each class
    :return:
        :ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)
        :unique_classes:            Vector with all unique target classes
        :best_score_threshold:      torch.float with the best overall score threshold if calc_best_score_thresholds
                                    is True else None
        :best_score_threshold_per_cls:     dict that stores the best score threshold for each class , if
                                            calc_best_score_thresholds is True else None
    preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device)
    preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device)
    recall_thresholds = torch.linspace(0, 1, 101, device=device) if recall_thresholds is None else recall_thresholds.to(device)
    unique_classes = torch.unique(targets_cls).long()
    n_class, nb_iou_thrs = len(unique_classes), preds_matched.shape[-1]
    ap = torch.zeros((n_class, nb_iou_thrs), device=device)
    precision = torch.zeros((n_class, nb_iou_thrs), device=device)
    recall = torch.zeros((n_class, nb_iou_thrs), device=device)
    nb_score_thrs = 101
    all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
    f1_per_class_per_threshold = torch.zeros((n_class, nb_score_thrs), device=device) if calc_best_score_thresholds else None
    best_score_threshold_per_cls = dict() if calc_best_score_thresholds else None
    for cls_i, cls in enumerate(unique_classes):
        cls_preds_idx, cls_targets_idx = (preds_cls == cls), (targets_cls == cls)
        cls_ap, cls_precision, cls_recall, cls_f1_per_threshold, cls_best_score_threshold = compute_detection_metrics_per_cls(
            preds_matched=preds_matched[cls_preds_idx],
            preds_to_ignore=preds_to_ignore[cls_preds_idx],
            preds_scores=preds_scores[cls_preds_idx],
            n_targets=cls_targets_idx.sum(),
            recall_thresholds=recall_thresholds,
            score_threshold=score_threshold,
            device=device,
            calc_best_score_thresholds=calc_best_score_thresholds,
            nb_score_thrs=nb_score_thrs,
        ap[cls_i, :] = cls_ap
        precision[cls_i, :] = cls_precision
        recall[cls_i, :] = cls_recall
        if calc_best_score_thresholds:
            f1_per_class_per_threshold[cls_i, :] = cls_f1_per_threshold
            best_score_threshold_per_cls[f"Best_score_threshold_cls_{int(cls)}"] = cls_best_score_threshold
    f1 = 2 * precision * recall / (precision + recall + 1e-16)
    if calc_best_score_thresholds:
        mean_f1_across_classes = torch.mean(f1_per_class_per_threshold, dim=0)
        best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_across_classes)]
    else:
        best_score_threshold = None
    return ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls
      

Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.

:param preds_matched:      Tensor of shape (num_predictions, n_iou_thresholds)
                                True when prediction (i) is matched with a target
                                with respect to the(j)th IoU threshold
:param preds_to_ignore     Tensor of shape (num_predictions, n_iou_thresholds)
                                True when prediction (i) is matched with a crowd target
                                with respect to the (j)th IoU threshold
:param preds_scores:       Tensor of shape (num_predictions), confidence score for every prediction
:param n_targets:          Number of target boxes of this class
:param recall_thresholds:  Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP
:param score_threshold:    Minimum confidence score to consider a prediction for the computation of
                                precision and recall (not MaP)
:param device:             Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class
:param nb_score_thrs:       Number of score thresholds to consider when calc_best_score_thresholds is True
:return:
    :ap, precision, recall:     Tensors of shape (nb_iou_thrs)
    :mean_f1_per_threshold:     Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
    :best_score_threshold:      torch.float if calc_best_score_thresholds is True else None
        Source code in V3_2/src/super_gradients/training/utils/detection_utils.py
1298
def compute_detection_metrics_per_cls(
    preds_matched: torch.Tensor,
    preds_to_ignore: torch.Tensor,
    preds_scores: torch.Tensor,
    n_targets: int,
    recall_thresholds: torch.Tensor,
    score_threshold: float,
    device: str,
    calc_best_score_thresholds: bool = False,
    nb_score_thrs: int = 101,
    Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.
        :param preds_matched:      Tensor of shape (num_predictions, n_iou_thresholds)
                                        True when prediction (i) is matched with a target
                                        with respect to the(j)th IoU threshold
        :param preds_to_ignore     Tensor of shape (num_predictions, n_iou_thresholds)
                                        True when prediction (i) is matched with a crowd target
                                        with respect to the (j)th IoU threshold
        :param preds_scores:       Tensor of shape (num_predictions), confidence score for every prediction
        :param n_targets:          Number of target boxes of this class
        :param recall_thresholds:  Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP
        :param score_threshold:    Minimum confidence score to consider a prediction for the computation of
                                        precision and recall (not MaP)
        :param device:             Device
        :param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class
        :param nb_score_thrs:       Number of score thresholds to consider when calc_best_score_thresholds is True
        :return:
            :ap, precision, recall:     Tensors of shape (nb_iou_thrs)
            :mean_f1_per_threshold:     Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
            :best_score_threshold:      torch.float if calc_best_score_thresholds is True else None
    nb_iou_thrs = preds_matched.shape[-1]
    mean_f1_per_threshold = torch.zeros(nb_score_thrs, device=device) if calc_best_score_thresholds else None
    best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device) if calc_best_score_thresholds else None
    tps = preds_matched
    fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore))
    if len(tps) == 0:
        return (
            torch.zeros(nb_iou_thrs, device=device),
            torch.zeros(nb_iou_thrs, device=device),
            torch.zeros(nb_iou_thrs, device=device),
            mean_f1_per_threshold,
            best_score_threshold,
    # Sort by decreasing score
    dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype
    sort_ind = torch.argsort(preds_scores.to(dtype), descending=True)
    tps = tps[sort_ind, :]
    fps = fps[sort_ind, :]
    preds_scores = preds_scores[sort_ind].contiguous()
    # Rolling sum over the predictions
    rolling_tps = torch.cumsum(tps, axis=0, dtype=torch.float)
    rolling_fps = torch.cumsum(fps, axis=0, dtype=torch.float)
    rolling_recalls = rolling_tps / n_targets
    rolling_precisions = rolling_tps / (rolling_tps + rolling_fps + torch.finfo(torch.float64).eps)
    # Reversed cummax to only have decreasing values
    rolling_precisions = rolling_precisions.flip(0).cummax(0).values.flip(0)
    # ==================
    # RECALL & PRECISION
    # We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i]
    # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
    # Note2: right=True due to negation
    lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True)
    if lowest_score_above_threshold == 0:  # Here score_threshold > preds_scores[0], so no pred is above the threshold
        recall = torch.zeros(nb_iou_thrs, device=device)
        precision = torch.zeros(nb_iou_thrs, device=device)  # the precision is not really defined when no pred but we need to give it a value
    else:
        recall = rolling_recalls[lowest_score_above_threshold - 1]
        precision = rolling_precisions[lowest_score_above_threshold - 1]
    # ==================
    # BEST CONFIDENCE SCORE THRESHOLD PER CLASS
    if calc_best_score_thresholds:
        all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
        # We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i]
        # Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
        lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True)
        # When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros
        rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0)
        rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0)
        # shape = (n_score_thresholds, nb_iou_thrs)
        recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds)
        precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds)
        # shape (n_score_thresholds, nb_iou_thrs)
        f1_per_threshold = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16)
        mean_f1_per_threshold = torch.mean(f1_per_threshold, dim=1)  # average over iou thresholds
        best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_per_threshold)]
    # ==================
    # AVERAGE PRECISION
    # shape = (nb_iou_thrs, n_recall_thresholds)
    recall_thresholds = recall_thresholds.view(1, -1).repeat(nb_iou_thrs, 1)
    # We want the index i so that: rolling_recalls[i-1] < recall_thresholds[k] <= rolling_recalls[i]
    # Note:  when recall_thresholds[k] > max(rolling_recalls), i = len(rolling_recalls)
    # Note2: we work with transpose (.T) to apply torch.searchsorted on first dim instead of the last one
    recall_threshold_idx = torch.searchsorted(rolling_recalls.T.contiguous(), recall_thresholds, right=False).T
    # When recall_thresholds[k] > max(rolling_recalls), rolling_precisions[i] is not defined, and we want precision = 0
    rolling_precisions = torch.cat((rolling_precisions, torch.zeros(1, nb_iou_thrs, device=device)), dim=0)
    # shape = (n_recall_thresholds, nb_iou_thrs)
    sampled_precision_points = torch.gather(input=rolling_precisions, index=recall_threshold_idx, dim=0)
    # Average over the recall_thresholds
    ap = sampled_precision_points.mean(0)
    return ap, precision, recall, mean_f1_per_threshold, best_score_threshold
      

Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score for a given image.

Parameters:

Description Default torch.Tensor

Tensor of shape (num_img_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size

required targets torch.Tensor

targets for this image of shape (num_img_targets, 6) format: (label, cx, cy, w, h, label) where cx,cy,w,h

required height

dimensions of the image

required width

dimensions of the image

required iou_thresholds torch.Tensor

Threshold to compute the mAP

required device required crowd_targets torch.Tensor

crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]

required top_k

Number of predictions to keep per class, ordered by confidence score

denormalize_targets

If True, denormalize the targets and crowd_targets

required return_on_cpu

If True, the output will be returned on "CPU", otherwise it will be returned on "device"

Tuple

:preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target

1068
def compute_img_detection_matching(
    preds: torch.Tensor,
    targets: torch.Tensor,
    crowd_targets: torch.Tensor,
    height: int,
    width: int,
    iou_thresholds: torch.Tensor,
    device: str,
    denormalize_targets: bool,
    top_k: int = 100,
    return_on_cpu: bool = True,
) -> Tuple:
    Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score
    for a given image.
    :param preds:           Tensor of shape (num_img_predictions, 6)
                            format:     (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
    :param targets:         targets for this image of shape (num_img_targets, 6)
                            format:     (label, cx, cy, w, h, label) where cx,cy,w,h
    :param height:          dimensions of the image
    :param width:           dimensions of the image
    :param iou_thresholds:  Threshold to compute the mAP
    :param device:
    :param crowd_targets:   crowd targets for all images of shape (total_num_crowd_targets, 6)
                            format:     (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
    :param top_k:           Number of predictions to keep per class, ordered by confidence score
    :param device:          Device
    :param denormalize_targets: If True, denormalize the targets and crowd_targets
    :param return_on_cpu:   If True, the output will be returned on "CPU", otherwise it will be returned on "device"
    :return:
        :preds_matched:     Tensor of shape (num_img_predictions, n_iou_thresholds)
                                True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
        :preds_to_ignore:   Tensor of shape (num_img_predictions, n_iou_thresholds)
                                True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
        :preds_scores:      Tensor of shape (num_img_predictions), confidence score for every prediction
        :preds_cls:         Tensor of shape (num_img_predictions), predicted class for every prediction
        :targets_cls:       Tensor of shape (num_img_targets), ground truth class for every target
    num_iou_thresholds = len(iou_thresholds)
    if preds is None or len(preds) == 0:
        if return_on_cpu:
            device = "cpu"
        preds_matched = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
        preds_to_ignore = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
        preds_scores = torch.tensor([], dtype=torch.float32, device=device)
        preds_cls = torch.tensor([], dtype=torch.float32, device=device)
        targets_cls = targets[:, 0].to(device=device)
        return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
    preds_matched = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
    targets_matched = torch.zeros(len(targets), num_iou_thresholds, dtype=torch.bool, device=device)
    preds_to_ignore = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
    preds_cls, preds_box, preds_scores = preds[:, -1], preds[:, 0:4], preds[:, 4]
    targets_cls, targets_box = targets[:, 0], targets[:, 1:5]
    crowd_targets_cls, crowd_target_box = crowd_targets[:, 0], crowd_targets[:, 1:5]
    # Ignore all but the predictions that were top_k for their class
    preds_idx_to_use = get_top_k_idx_per_cls(preds_scores, preds_cls, top_k)
    preds_to_ignore[:, :] = True
    preds_to_ignore[preds_idx_to_use] = False
    if len(targets) > 0 or len(crowd_targets) > 0:
        # CHANGE bboxes TO FIT THE IMAGE SIZE
        change_bbox_bounds_for_image_size(preds, (height, width))
        targets_box = cxcywh2xyxy(targets_box)
        crowd_target_box = cxcywh2xyxy(crowd_target_box)
        if denormalize_targets:
            targets_box[:, [0, 2]] *= width
            targets_box[:, [1, 3]] *= height
            crowd_target_box[:, [0, 2]] *= width
            crowd_target_box[:, [1, 3]] *= height
    if len(targets) > 0:
        # shape = (n_preds x n_targets)
        iou = box_iou(preds_box[preds_idx_to_use], targets_box)
        # Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
        # Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0
        cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
        iou[cls_mismatch] = 0
        # The matching priority is first detection confidence and then IoU value.
        # The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
        sorted_iou, target_sorted = iou.sort(descending=True, stable=True)
        # Only iterate over IoU values higher than min threshold to speed up the process
        for pred_selected_i, target_sorted_i in (sorted_iou > iou_thresholds[0]).nonzero(as_tuple=False):
            # pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
            pred_i = preds_idx_to_use[pred_selected_i]
            target_i = target_sorted[pred_selected_i, target_sorted_i]
            # Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
            is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > iou_thresholds
            # Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
            are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
            # Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
            are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)
            # For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
            # fill the matching placeholders with True
            targets_matched[target_i, are_candidates_good] = True
            preds_matched[pred_i, are_candidates_good] = True
            # When all the targets are matched with a prediction for every IoU Threshold, stop.
            if targets_matched.all():
                break
    # Crowd targets can be matched with many predictions.
    # Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.
    if len(crowd_targets) > 0:
        # shape = (n_preds_to_use x n_crowd_targets)
        ioa = crowd_ioa(preds_box[preds_idx_to_use], crowd_target_box)
        # Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
        # Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0
        cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)
        ioa[cls_mismatch] = 0
        # For each prediction, we keep it's highest score with any crowd target (of same class)
        # shape = (n_preds_to_use)
        best_ioa, _ = ioa.max(1)
        # If a prediction has IoA higher than threshold (with any target of same class), then there is a match
        # shape = (n_preds_to_use x iou_thresholds)
        is_matching_with_crowd = best_ioa.view(-1, 1) > iou_thresholds.view(1, -1)
        preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)
    if return_on_cpu:
        preds_matched = preds_matched.to("cpu")
        preds_to_ignore = preds_to_ignore.to("cpu")
        preds_scores = preds_scores.to("cpu")
        preds_cls = preds_cls.to("cpu")
        targets_cls = targets_cls.to("cpu")
    return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
      

Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2] :param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for boxes of a batch of images) :return: Converted bbox in same dimensions as the original

Source code in V3_2/src/super_gradients/training/utils/detection_utils.py 86
def convert_cxcywh_bbox_to_xyxy(input_bbox: torch.Tensor):
    Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2]
        :param input_bbox:  input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for
                            boxes of a batch of images)
        :return:            Converted bbox in same dimensions as the original
    need_squeeze = False
    # the input is always processed as a batch. in case it not a batch, it is unsqueezed, process and than squeeze back.
    if input_bbox.dim() < 3:
        need_squeeze = True
        input_bbox = input_bbox.unsqueeze(0)
    converted_bbox = torch.zeros_like(input_bbox) if isinstance(input_bbox, torch.Tensor) else np.zeros_like(input_bbox)
    converted_bbox[:, :, 0] = input_bbox[:, :, 0] - input_bbox[:, :, 2] / 2
    converted_bbox[:, :, 1] = input_bbox[:, :, 1] - input_bbox[:, :, 3] / 2
    converted_bbox[:, :, 2] = input_bbox[:, :, 0] + input_bbox[:, :, 2] / 2
    converted_bbox[:, :, 3] = input_bbox[:, :, 1] + input_bbox[:, :, 3] / 2
    # squeeze back if needed
    if need_squeeze:
        converted_bbox = converted_bbox[0]
    return converted_bbox
      

Return intersection-over-detection_area of boxes, used for crowd ground truths. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.

Parameters:

Description Default 852
def crowd_ioa(det_box: torch.Tensor, crowd_box: torch.Tensor) -> torch.Tensor:
    Return intersection-over-detection_area of boxes, used for crowd ground truths.
    Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
    :param det_box:     Tensor of shape [N, 4]
    :param crowd_box:   Tensor of shape [M, 4]
    :return: crowd_ioa, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoA values for every element in det_box and crowd_box
    det_area = compute_box_area(det_box.T)
    # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
    inter = (torch.min(det_box[:, None, 2:], crowd_box[:, 2:]) - torch.max(det_box[:, None, :2], crowd_box[:, :2])).clamp(0).prod(2)
    return inter / det_area[:, None]  # crowd_ioa = inter / det_area
624
def cxcywh2xyxy(bboxes):
    Transforms bboxes from centerized xy wh format to xyxy format
    :param bboxes: array, shaped (nboxes, 4)
    :return: modified bboxes
    bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
    bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
    bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1]
    bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0]
    return bboxes
54
def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int:
    """Get the label of a given target
    :param target_format:   Representation of the target (ex: LABEL_XYXY)
    :return:                Position of the class id in a bbox
                                ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
    format_split = target_format.value.split("_")
    if format_split[0] == "LABEL":
        return 0
    elif format_split[-1] == "LABEL":
        return -1
    else:
        raise NotImplementedError(f"No implementation to find index of LABEL in {target_format.value}")
657
def get_mosaic_coordinate(mosaic_index, xc, yc, w, h, input_h, input_w):
    Returns the mosaic coordinates of final mosaic image according to mosaic image index.
    :param mosaic_index: (int) mosaic image index
    :param xc: (int) center x coordinate of the entire mosaic grid.
    :param yc: (int) center y coordinate of the entire mosaic grid.
    :param w: (int) width of bbox
    :param h: (int) height of bbox
    :param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).
    :param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).
    :return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic
        output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.
    # index0 to top left part of image
    if mosaic_index == 0:
        x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
        small_coord = w - (x2 - x1), h - (y2 - y1), w, h
    # index1 to top right part of image
    elif mosaic_index == 1:
        x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
        small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
    # index2 to bottom left part of image
    elif mosaic_index == 2:
        x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
        small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
    # index2 to bottom right part of image
    elif mosaic_index == 3:
        x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h)  # noqa
        small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
    return (x1, y1, x2, y2), small_coord
1087
def get_top_k_idx_per_cls(preds_scores: torch.Tensor, preds_cls: torch.Tensor, top_k: int):
    """Get the indexes of all the top k predictions for every class
    :param preds_scores:   The confidence scores, vector of shape (n_pred)
    :param preds_cls:      The predicted class, vector of shape (n_pred)
    :param top_k:          Number of predictions to keep per class, ordered by confidence score
    :return top_k_idx:     Indexes of the top k predictions. length <= (k * n_unique_class)
    n_unique_cls = torch.max(preds_cls)
    mask = preds_cls.view(-1, 1) == torch.arange(n_unique_cls + 1, device=preds_scores.device).view(1, -1)
    preds_scores_per_cls = preds_scores.view(-1, 1) * mask
    sorted_scores_per_cls, sorting_idx = preds_scores_per_cls.sort(0, descending=True)
    idx_with_satisfying_scores = sorted_scores_per_cls[:top_k, :].nonzero(as_tuple=False)
    top_k_idx = sorting_idx[idx_with_satisfying_scores.split(1, dim=1)]
    return top_k_idx.view(-1)
      

Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf

Parameters:

Description Default

Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85] where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)

required conf_thres float

Threshold under which prediction are discarded

kernel

Type of kernel to use ['gaussian', 'linear']

'gaussian' sigma float

Sigma for the gaussian kernel

max_num_of_detections

Maximum number of boxes to output

357
def matrix_non_max_suppression(
    pred, conf_thres: float = 0.1, kernel: str = "gaussian", sigma: float = 3.0, max_num_of_detections: int = 500, class_agnostic_nms: bool = False
) -> List[torch.Tensor]:
    """Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf
    :param pred:        Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85]
                        where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)
    :param conf_thres:  Threshold under which prediction are discarded
    :param kernel:      Type of kernel to use ['gaussian', 'linear']
    :param sigma:       Sigma for the gaussian kernel
    :param max_num_of_detections: Maximum number of boxes to output
    :return: Detections list with shape (x1, y1, x2, y2, object_conf, class_conf, class)
    # MULTIPLY CONF BY CLASS CONF TO GET COMBINED CONFIDENCE
    class_conf, class_pred = pred[:, :, 5:].max(2)
    pred[:, :, 4] *= class_conf
    # BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2)
    pred[:, :, :4] = convert_cxcywh_bbox_to_xyxy(pred[:, :, :4])
    # DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred)
    pred = torch.cat((pred[:, :, :5], class_pred.unsqueeze(2)), 2)
    # SORT DETECTIONS BY DECREASING CONFIDENCE SCORES
    sort_ind = (-pred[:, :, 4]).argsort()
    pred = torch.stack([pred[i, sort_ind[i]] for i in range(pred.shape[0])])[:, 0:max_num_of_detections]
    ious = calc_bbox_iou_matrix(pred)
    ious = ious.triu(1)
    if not class_agnostic_nms:
        # CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER
        labels = pred[:, :, 5:]
        labeles_matrix = (labels == labels.transpose(2, 1)).float().triu(1)
        ious *= labeles_matrix
    ious_cmax, _ = ious.max(1)
    ious_cmax = ious_cmax.unsqueeze(2).repeat(1, 1, max_num_of_detections)
    if kernel == "gaussian":
        decay_matrix = torch.exp(-1 * sigma * (ious**2))
        compensate_matrix = torch.exp(-1 * sigma * (ious_cmax**2))
        decay, _ = (decay_matrix / compensate_matrix).min(dim=1)
    else:
        decay = (1 - ious) / (1 - ious_cmax)
        decay, _ = decay.min(dim=1)
    pred[:, :, 4] *= decay
    output = [pred[i, pred[i, :, 4] > conf_thres] for i in range(pred.shape[0])]
    return output
          

raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)

required conf_thres

below the confidence threshold - prediction are discarded

iou_thres

IoU threshold for the nms algorithm

multi_label_per_box

controls whether to decode multiple labels per box. True - each anchor can produce multiple labels of different classes that pass confidence threshold check (default). False - each anchor can produce only one label of the class with the highest score.

with_confidence

whether to multiply objectness score with class score. usually valid for Yolo models only.

False class_agnostic_nms

indicates how boxes of different classes will be treated during NMS True - NMS will be performed on all classes together. False - NMS will be performed on each class separately (default).

False 301
def non_max_suppression(
    prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False, class_agnostic_nms: bool = False
    Performs Non-Maximum Suppression (NMS) on inference results
    :param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
    :param conf_thres: below the confidence threshold - prediction are discarded
    :param iou_thres: IoU threshold for the nms algorithm
    :param multi_label_per_box: controls whether to decode multiple labels per box.
                                True - each anchor can produce multiple labels of different classes
                                       that pass confidence threshold check (default).
                                False - each anchor can produce only one label of the class with the highest score.
    :param with_confidence: whether to multiply objectness score with class score.
                            usually valid for Yolo models only.
    :param class_agnostic_nms: indicates how boxes of different classes will be treated during NMS
                               True - NMS will be performed on all classes together.
                               False - NMS will be performed on each class separately (default).
    :return: detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class)
    candidates_above_thres = prediction[..., 4] > conf_thres  # filter by confidence
    output = [None] * prediction.shape[0]
    for image_idx, pred in enumerate(prediction):
        pred = pred[candidates_above_thres[image_idx]]  # confident
        if not pred.shape[0]:  # If none remain process next image
            continue
        if with_confidence:
            pred[:, 5:] *= pred[:, 4:5]  # multiply objectness score with class score
        box = convert_cxcywh_bbox_to_xyxy(pred[:, :4])  # cxcywh to xyxy
        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label_per_box:  # try for all good confidence classes
            i, j = (pred[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            pred = torch.cat((box[i], pred[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = pred[:, 5:].max(1, keepdim=True)
            pred = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
        if not pred.shape[0]:  # If none remain process next image
            continue
        # Apply torch batched NMS algorithm
        boxes, scores, cls_idx = pred[:, :4], pred[:, 4], pred[:, 5]
        if class_agnostic_nms:
            idx_to_keep = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
        else:
            idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres)
        output[image_idx] = pred[idx_to_keep]
    return output
377
def undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
    :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
    :return:          images in a batch in cv2 format, BGR, (B, H, W, C)
    im_np = im_tensor.cpu().numpy()
    im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
    im_np *= 255.0
    return np.ascontiguousarray(im_np, dtype=np.uint8)
611
def xyxy2cxcywh(bboxes):
    Transforms bboxes from xyxy format to centerized xy wh format
    :param bboxes: array, shaped (nboxes, 4)
    :return: modified bboxes
    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
    return bboxes
          Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
414
class DDPNotSetupException(Exception):
    """Exception raised when DDP setup is required but was not done"""
    def __init__(self):
        self.message = (
            "Your environment was not setup correctly for DDP.\n"
            "Please run at the beginning of your script:\n"
            ">>> from super_gradients.training.utils.distributed_training_utils import setup_device'\n"
            ">>> from super_gradients.common.data_types.enum import MultiGPUMode\n"
            ">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
        super().__init__(self.message)
          

The effective batch size we want to calculate the batchnorm on. For example, if we are training a model on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192 (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus). If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken. param num_gpus: The number of gpus we are training on

required 143
@torch.no_grad()
def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
    :param model:                   The model being trained (ie: Trainer.net)
    :param loader:                  Training dataloader (ie: Trainer.train_loader)
    :param precise_bn_batch_size:   The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
                                    on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
                                    (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
                                    If precise_bn_batch_size is not provided in the training_params, the latter heuristic
                                    will be taken.
    param num_gpus:                 The number of gpus we are training on
    # Compute the number of minibatches to use
    num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
    num_iter = min(num_iter, len(loader))
    # Retrieve the BN layers
    bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
    # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
    running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
    running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
    # Remember momentum values
    momentums = [bn.momentum for bn in bns]
    # Set momentum to 1.0 to compute BN stats that only reflect the current batch
    for bn in bns:
        bn.momentum = 1.0
    # Average the BN stats for each BN layer over the batches
    for inputs, _labels in itertools.islice(loader, num_iter):
        model(inputs.cuda())
        for i, bn in enumerate(bns):
            running_means[i] += bn.running_mean / num_iter
            running_vars[i] += bn.running_var / num_iter
    # Sync BN stats across GPUs (no reduction if 1 GPU used)
    running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
    running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
    # Set BN stats and restore original momentum values
    for i, bn in enumerate(bns):
        bn.running_mean = running_means[i]
        bn.running_var = running_vars[i]
        bn.momentum = momentums[i]
      

This method performs a reduce operation on multiple nodes running distributed training It first sums all of the results and then divides the summation

Parameters:

Description Default 42
def distributed_all_reduce_tensor_average(tensor, n):
    This method performs a reduce operation on multiple nodes running distributed training
    It first sums all of the results and then divides the summation
    :param tensor:  The tensor to perform the reduce operation for
    :param n:  Number of nodes
    :return:   Averaged tensor from all of the nodes
    rt = tensor.clone()
    torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
    rt /= n
    return rt
        Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
400
def get_gpu_mem_utilization():
    """GPU memory managed by the caching allocator in bytes for a given device."""
    # Workaround to work on any torch version
    if hasattr(torch.cuda, "memory_reserved"):
        return torch.cuda.memory_reserved()
    else:
        return torch.cuda.memory_cached()
        Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
151
def get_local_rank():
    Returns the local rank if running in DDP, and 0 otherwise
    :return: local rank
    return dist.get_rank() if dist.is_initialized() else 0
        Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
171
def get_world_size() -> int:
    Returns the world size if running in DDP, and 1 otherwise
    :return: world size
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()
      

Initialize Distributed Data Parallel

Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU. Whatever learning rate and schedule you specify will be applied to the each GPU individually. Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the batch you specify times the number of GPUs. In the literature there are several "best practices" to set learning rates and schedules for large batch sizes.

Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 345
def initialize_ddp():
    Initialize Distributed Data Parallel
    Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
    Whatever learning rate and schedule you specify will be applied to the each GPU individually.
    Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
    batch you specify times the number of GPUs. In the literature there are several "best practices" to set
    learning rates and schedules for large batch sizes.
    if device_config.assigned_rank > 0:
        mute_current_process()
    logger.info("Distributed training starting...")
    if not torch.distributed.is_initialized():
        backend = "gloo" if os.name == "nt" else "nccl"
        torch.distributed.init_process_group(backend=backend, init_method="env://")
    torch.cuda.set_device(device_config.assigned_rank)
    if torch.distributed.get_rank() == 0:
        logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
    device_config.device = "cuda:%d" % device_config.assigned_rank
55
def reduce_results_tuple_for_ddp(validation_results_tuple, device):
    """Gather all validation tuples from the various devices and average them"""
    validation_results_list = list(validation_results_tuple)
    for i, validation_result in enumerate(validation_results_list):
        if torch.is_tensor(validation_result):
            validation_result = validation_result.clone().detach()
        else:
            validation_result = torch.tensor(validation_result)
        validation_results_list[i] = distributed_all_reduce_tensor_average(tensor=validation_result.to(device), n=torch.distributed.get_world_size())
    validation_results_tuple = tuple(validation_results_list)
    return validation_results_tuple
      

Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).

Parameters:

Description Default

How many gpu's you want to run the script on. If not specified, every available device will be used.

Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 390
@record
def restart_script_with_ddp(num_gpus: int = None):
    """Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).
    :param num_gpus: How many gpu's you want to run the script on. If not specified, every available device will be used.
    ddp_port = find_free_port()
    # Get the value fom recipe if specified, otherwise take all available devices.
    num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
    if num_gpus > torch.cuda.device_count():
        raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
    logger.info(
        "Launching DDP with:\n"
        f"   - ddp_port = {ddp_port}\n"
        f"   - num_gpus = {num_gpus}/{torch.cuda.device_count()} available\n"
        "-------------------------------------\n"
    config = LaunchConfig(
        nproc_per_node=num_gpus,
        min_nodes=1,
        max_nodes=1,
        run_id="sg_initiated",
        role="default",
        rdzv_endpoint=f"127.0.0.1:{ddp_port}",
        rdzv_backend="static",
        rdzv_configs={"rank": 0, "timeout": 900},
        rdzv_timeout=-1,
        max_restarts=0,
        monitor_interval=5,
        start_method="spawn",
        log_dir=None,
        redirects=Std.NONE,
        tee=Std.NONE,
        metrics_cfg={},
    elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
    # The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
    sys.exit(0)
      

Performs the scaled all_reduce operation on the provided tensors. The input tensors are modified in-place. Currently supports only the sum reduction operator. The reduced values are scaled by the inverse size of the process group (equivalent to num_gpus).

Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 94
def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
    Performs the scaled all_reduce operation on the provided tensors.
    The input tensors are modified in-place.
    Currently supports only the sum
    reduction operator.
    The reduced values are scaled by the inverse size of the
    process group (equivalent to num_gpus).
    # There is no need for reduction in the single-proc case
    if num_gpus == 1:
        return tensors
    # Queue the reductions
    reductions = []
    for tensor in tensors:
        reduction = torch.distributed.all_reduce(tensor, async_op=True)
        reductions.append(reduction)
    # Wait for reductions to finish
    for reduction in reductions:
        reduction.wait()
    # Scale the results
    for tensor in tensors:
        tensor.mul_(1.0 / num_gpus)
    return tensors
        Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
257
def setup_cpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
    :param multi_gpu:    DDP, DP, Off or AUTO
    :param num_gpus:     Number of GPU's to use.
    if multi_gpu not in (MultiGPUMode.OFF, MultiGPUMode.AUTO):
        raise ValueError(f"device='cpu' and multi_gpu={multi_gpu} are not compatible together.")
    if num_gpus not in (0, None):
        raise ValueError(f"device='cpu' and num_gpus={num_gpus} are not compatible together.")
    device_config.device = "cpu"
    device_config.multi_gpu = MultiGPUMode.OFF
          

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

device

The device you want to use ('cpu' or 'cuda') If you only set num_gpus, your device will be set up according to the following logic: - setup_device(num_gpus=0) => gpu_mode='OFF' and device='cpu' - setup_device(num_gpus=1) => gpu_mode='OFF' and device='gpu' - setup_device(num_gpus>=2) => gpu_mode='DDP' and device='gpu' - setup_device(num_gpus=-1) => gpu_mode='DDP' and device='gpu' and num_gpus=<N-AVAILABLE-GPUs>

'cuda' Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 242
@resolve_param("multi_gpu", TypeFactory(MultiGPUMode.dict()))
def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None, device: str = "cuda"):
    If required, launch ddp subprocesses.
    :param multi_gpu:   DDP, DP, Off or AUTO
    :param num_gpus:    Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    :param device:      The device you want to use ('cpu' or 'cuda')
    If you only set num_gpus, your device will be set up according to the following logic:
        - `setup_device(num_gpus=0)`  => `gpu_mode='OFF'` and `device='cpu'`
        - `setup_device(num_gpus=1)`  => `gpu_mode='OFF'` and `device='gpu'`
        - `setup_device(num_gpus>=2)` => `gpu_mode='DDP'` and `device='gpu'`
        - `setup_device(num_gpus=-1)` => `gpu_mode='DDP'` and `device='gpu'` and `num_gpus=<N-AVAILABLE-GPUs>`
    init_trainer()
    # When launching with torch.distributed.launch or torchrun, multi_gpu might not be set to DDP (since we are not using the recipe params)
    # To avoid any issue we force multi_gpu to be DDP if the current process is ddp subprocess. We also set num_gpus, device to run smoothly.
    if not is_launched_using_sg() and is_distributed():
        multi_gpu, num_gpus, device = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, None, "cuda"
    if device is None:
        device = "cuda"
    if device == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA device is not available on your device... Moving to CPU.")
        multi_gpu, num_gpus, device = MultiGPUMode.OFF, 0, "cpu"
    if device == "cpu":
        setup_cpu(multi_gpu, num_gpus)
    elif device == "cuda":
        setup_gpu(multi_gpu, num_gpus)
    else:
        raise ValueError(f"Only valid values for device are: 'cpu' and 'cuda'. Received: '{device}'")
          

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 278
def setup_gpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
    If required, launch ddp subprocesses.
    :param multi_gpu:    DDP, DP, Off or AUTO
    :param num_gpus:     Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    if num_gpus == 0:
        raise ValueError("device='cuda' and num_gpus=0 are not compatible together.")
    multi_gpu, num_gpus = _resolve_gpu_params(multi_gpu=multi_gpu, num_gpus=num_gpus)
    device_config.device = "cuda"
    device_config.multi_gpu = multi_gpu
    if is_distributed():
        initialize_ddp()
    elif multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
        restart_script_with_ddp(num_gpus=num_gpus)
          

Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.

Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py 205
def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
    """[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.
    :param gpu_mode:    DDP, DP, Off or AUTO
    :param num_gpus:    Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
    logger.warning("setup_gpu_mode is now deprecated in favor of setup_device")
    setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
        Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
196
@contextmanager
def wait_for_the_master(local_rank: int):
    Make all processes waiting for the master to do some task.
    if local_rank > 0:
        dist.barrier()
    yield
    if local_rank == 0:
        if not dist.is_available():
            return
        if not dist.is_initialized():
            return
        else:
            dist.barrier()
class EarlyStop(PhaseCallback):
    Callback to monitor a metric and stop training when it stops improving.
    Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping
    mode_dict = {"min": torch.lt, "max": torch.gt}
    supported_phases = (Phase.VALIDATION_EPOCH_END, Phase.TRAIN_EPOCH_END)
    def __init__(
        self,
        phase: Phase,
        monitor: str,
        mode: str = "min",
        min_delta: float = 0.0,
        patience: int = 3,
        check_finite: bool = True,
        threshold: Optional[float] = None,
        verbose: bool = False,
        strict: bool = True,
        :param phase: Callback phase event.
        :param monitor: name of the metric to be monitored.
        :param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity
           monitored has stopped decreasing and in 'max' mode it will stop when the quantity
           monitored has stopped increasing.
        :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
           change of less than `min_delta`, will count as no improvement.
        :param patience: number of checks with no improvement after which training will be stopped.
            One check happens after every phase event.
        :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
        :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min'
            stops training when below threshold, For mode 'max' stops training when above threshold.
        :param verbose: If `True` print logs.
        :param strict: whether to crash the training if `monitor` is not found in the metrics.
        super(EarlyStop, self).__init__(phase)
        if phase not in self.supported_phases:
            raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}")
        self.phase = phase
        self.monitor_key = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.mode = mode
        self.check_finite = check_finite
        self.threshold = threshold
        self.verbose = verbose
        self.strict = strict
        self.wait_count = 0
        if self.mode not in self.mode_dict:
            raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
        self.monitor_op = self.mode_dict[self.mode]
        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        torch_inf = torch.tensor(np.Inf)
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
    def _get_metric_value(self, metrics_dict):
        if self.monitor_key not in metrics_dict.keys():
            msg = f"Can't find EarlyStop monitor {self.monitor_key} in metrics_dict: {metrics_dict.keys()}"
            exception_cls = RuntimeError if self.strict else MissingMonitorKeyException
            raise exception_cls(msg)
        return metrics_dict[self.monitor_key]
    def _check_for_early_stop(self, current: torch.Tensor):
        should_stop = False
        # check if current value is Nan or inf
        if self.check_finite and not torch.isfinite(current):
            should_stop = True
            reason = (
                f"Monitored metric {self.monitor_key} = {current} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
        # check if current value reached threshold value
        elif self.threshold is not None and self.monitor_op(current, self.threshold):
            should_stop = True
            reason = "Stopping threshold reached:" f" {self.monitor_key} = {current} {self.monitor_op} {self.threshold}." " Signaling Trainer to stop."
        # check if current is an improvement of monitor_key metric.
        elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
            should_stop = False
            if torch.isfinite(self.best_score):
                reason = (
                    f"Metric {self.monitor_key} improved by {abs(self.best_score - current):.3f} >="
                    f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
            else:
                reason = f"Metric {self.monitor_key} improved. New best score: {current:.3f}"
            self.best_score = current
            self.wait_count = 0
        # no improvement in monitor_key metric, check if wait_count is bigger than patience.
        else:
            self.wait_count += 1
            reason = f"Monitored metric {self.monitor_key} did not improve in the last {self.wait_count} records."
            if self.wait_count >= self.patience:
                should_stop = True
                reason += f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
        return reason, should_stop
    def __call__(self, context: PhaseContext):
        try:
            current = self._get_metric_value(context.metrics_dict)
        except MissingMonitorKeyException as e:
            logger.warning(e)
            return
        if not isinstance(current, torch.Tensor):
            current = torch.tensor(current)
        reason, self.should_stop = self._check_for_early_stop(current)
        # log reason message, and signal early stop if should_stop=True.
        if self.should_stop:
            self._signal_early_stop(context, reason)
        elif self.verbose:
            logger.info(reason)
    def _signal_early_stop(self, context: PhaseContext, reason: str):
        logger.info(reason)
        context.update_context(stop_training=True)
          

one of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing.

'min' min_delta float

minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.

patience

number of checks with no improvement after which training will be stopped. One check happens after every phase event.

check_finite

When set True, stops training when the monitor becomes NaN or infinite.

threshold Optional[float]

Stop training immediately once the monitored quantity reaches this threshold. For mode 'min' stops training when below threshold, For mode 'max' stops training when above threshold.

verbose

If True print logs.

False strict

whether to crash the training if monitor is not found in the metrics.

patience: int = 3, check_finite: bool = True, threshold: Optional[float] = None, verbose: bool = False, strict: bool = True, :param phase: Callback phase event. :param monitor: name of the metric to be monitored. :param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing. :param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. :param patience: number of checks with no improvement after which training will be stopped. One check happens after every phase event. :param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. :param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min' stops training when below threshold, For mode 'max' stops training when above threshold. :param verbose: If `True` print logs. :param strict: whether to crash the training if `monitor` is not found in the metrics. super(EarlyStop, self).__init__(phase) if phase not in self.supported_phases: raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}") self.phase = phase self.monitor_key = monitor self.min_delta = min_delta self.patience = patience self.mode = mode self.check_finite = check_finite self.threshold = threshold self.verbose = verbose self.strict = strict self.wait_count = 0 if self.mode not in self.mode_dict: raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") self.monitor_op = self.mode_dict[self.mode] self.min_delta *= 1 if self.monitor_op == torch.gt else -1 torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf 151
class MissingMonitorKeyException(Exception):
    Exception raised for missing monitor key in metrics_dict.
      

Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers.

Source code in V3_2/src/super_gradients/training/utils/ema.py 186
class KDModelEMA(ModelEMA):
    """Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
        Init the EMA
        :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                     its final value. beta=15 is ~40% of the training process.
        # Only work on the student (we don't want to update and to have a duplicate of the teacher)
        super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)
        # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
        # with already the instantiated teacher, to have the final KD EMA
        self.ema = KDModule(
            arch_params=unwrap_model(kd_model).arch_params,
            student=self.ema,
            teacher=unwrap_model(kd_model).teacher,
            run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
                KDModule
          

KDModule, the training Knowledge distillation model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.

required decay float

the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)

required

the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.

required 186
def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
    Init the EMA
    :param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
                IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
    :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                  until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
    :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                 its final value. beta=15 is ~40% of the training process.
    # Only work on the student (we don't want to update and to have a duplicate of the teacher)
    super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)
    # Overwrite current ema attribute with combination of the student model EMA (current self.ema)
    # with already the instantiated teacher, to have the final KD EMA
    self.ema = KDModule(
        arch_params=unwrap_model(kd_model).arch_params,
        student=self.ema,
        teacher=unwrap_model(kd_model).teacher,
        run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
      

Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models Keep a moving average of everything in the model state_dict (parameters and buffers). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers.

Source code in V3_2/src/super_gradients/training/utils/ema.py 152
class ModelEMA:
    """Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
        Init the EMA
        :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
                    IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                    AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
        :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                      until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
        :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                     its final value. beta=15 is ~40% of the training process.
        # Create EMA
        model = unwrap_model(model)
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.decay_function = decay_function
        we hold a list of model attributes (not wights and biases) which we would like to include in each
        attribute update or exclude from each update. a SgModule declare these attribute using
        get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
        all non-private (not starting with '_') attributes will be updated (and only them).
        if isinstance(model, SgModule):
            self.include_attributes = model.get_include_attributes()
            self.exclude_attributes = model.get_exclude_attributes()
        else:
            warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
            self.include_attributes = []
            self.exclude_attributes = []
        for p in self.ema.parameters():
            p.requires_grad_(False)
    @classmethod
    def from_params(cls, model: nn.Module, decay_type: str = None, decay: float = None, **kwargs):
        if decay is None:
            logger.warning(
                "Parameter `decay` is not specified for EMA params. Please specify `decay` parameter explicitly in your config:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: exp\n"
                "  beta: 15\n"
                "Will default to decay: 0.9999\n"
                "In the next major release of SG this warning will become an error."
            decay = 0.9999
        if "exp_activation" in kwargs:
            logger.warning(
                "Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: exp # Equivalent to exp_activation: True\n"
                "  beta: 15\n"
                "\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: constant # Equivalent to exp_activation: False\n"
                "\n"
                "In the next major release of SG this warning will become an error."
            decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"
        if decay_type is None:
            logger.warning(
                "Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
                "ema: True\n"
                "ema_params: \n"
                "  decay: 0.9999\n"
                "  decay_type: constant|exp|threshold\n"
                "Will default to `exp` decay with beta = 15\n"
                "In the next major release of SG this warning will become an error."
            decay_type = "exp"
            if "beta" not in kwargs:
                kwargs["beta"] = 15
        try:
            decay_cls = EMA_DECAY_FUNCTIONS[decay_type]
        except KeyError:
            raise UnknownTypeException(decay_type, list(EMA_DECAY_FUNCTIONS.keys()))
        decay_function = decay_cls(**kwargs)
        return cls(model, decay, decay_function)
    def update(self, model, step: int, total_steps: int):
        Update the state of the EMA model.
        :param model: Current training model
        :param step: Current training step
        :param total_steps: Total training steps
        # Update EMA parameters
        model = unwrap_model(model)
        with torch.no_grad():
            decay = self.decay_function(self.decay, step, total_steps)
            for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
                if ema_v.dtype.is_floating_point:
                    ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())
    def update_attr(self, model):
        This function updates model attributes (not weight and biases) from original model to the ema model.
        attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
        model operation and need to be updated.
        If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
        attributes will be updated (and only them).
        :param model: the source model
        copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)
we hold a list of model attributes (not wights and biases) which we would like to include in each
attribute update or exclude from each update. a SgModule declare these attribute using
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
all non-private (not starting with '_') attributes will be updated (and only them).

nn.Module

Union[SgModule, nn.Module], the training model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.

required decay float

the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)

required

the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.

required 69
def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
    Init the EMA
    :param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
                IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
                AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
    :param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
                  until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
    :param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
                 its final value. beta=15 is ~40% of the training process.
    # Create EMA
    model = unwrap_model(model)
    self.ema = deepcopy(model)
    self.ema.eval()
    self.decay = decay
    self.decay_function = decay_function
    we hold a list of model attributes (not wights and biases) which we would like to include in each
    attribute update or exclude from each update. a SgModule declare these attribute using
    get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
    all non-private (not starting with '_') attributes will be updated (and only them).
    if isinstance(model, SgModule):
        self.include_attributes = model.get_include_attributes()
        self.exclude_attributes = model.get_exclude_attributes()
    else:
        warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
        self.include_attributes = []
        self.exclude_attributes = []
    for p in self.ema.parameters():
        p.requires_grad_(False)
141
def update(self, model, step: int, total_steps: int):
    Update the state of the EMA model.
    :param model: Current training model
    :param step: Current training step
    :param total_steps: Total training steps
    # Update EMA parameters
    model = unwrap_model(model)
    with torch.no_grad():
        decay = self.decay_function(self.decay, step, total_steps)
        for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
            if ema_v.dtype.is_floating_point:
                ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())
      

This function updates model attributes (not weight and biases) from original model to the ema model. attributes of the original model, such as anchors and grids (of detection models), may be crucial to the model operation and need to be updated. If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_') attributes will be updated (and only them).

Parameters:

Description Default 152
def update_attr(self, model):
    This function updates model attributes (not weight and biases) from original model to the ema model.
    attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
    model operation and need to be updated.
    If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
    attributes will be updated (and only them).
    :param model: the source model
    copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)
    def __call__(self, decay: float, step: int, total_steps: int) -> float:
        return decay
61
class ExpDecay(IDecayFunction):
    Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))
    def __init__(self, beta: float, **kwargs):
        self.beta = beta
    def __call__(self, decay: float, step, total_steps: int) -> float:
        x = step / total_steps
        return decay * (1 - np.exp(-x * self.beta))
      

Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress. Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete implementation.

Source code in V3_2/src/super_gradients/training/utils/ema_decay_schedules.py 23
class IDecayFunction:
    Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
    Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
    implementation.
    @abstractmethod
    def __call__(self, decay: float, step: int, total_steps: int) -> float:
        :param decay: The maximum decay value.
        :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
        :param total_steps:  Total number of training steps.
        :return: Computed decay value for a given step.
          

Current training step. The unit-range training percentage can be obtained by step / total_steps.

required total_steps

Total number of training steps.

required 23
@abstractmethod
def __call__(self, decay: float, step: int, total_steps: int) -> float:
    :param decay: The maximum decay value.
    :param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
    :param total_steps:  Total number of training steps.
    :return: Computed decay value for a given step.
47
class ThresholdDecay(IDecayFunction):
    Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))
    def __init__(self, **kwargs):
    def __call__(self, decay: float, step, total_steps: int) -> float:
        return np.minimum(decay, (1 + step) / (10 + step))
      

Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model

Parameters:

Description Default 25
def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
    Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
    :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed
    :param model: the target model
    :return: the number of fuses executed
    children = list(model.named_children())
    counter = 0
    for i in range(len(children) - 1):
        if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d):
            setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1]))
            if replace_bn_with_identity:
                setattr(model, children[i + 1][0], nn.Identity())
            else:
                delattr(model, children[i + 1][0])
            counter += 1
    for child_name, child in children:
        counter += fuse_conv_bn(child, replace_bn_with_identity)
    return counter
126
def get_input_output_shapes(batch_size: int, input_dims: Union[list, tuple], output_dims: Union[list, tuple]):
    Returns input/output shapes for single/multiple input/s output/s
    if isinstance(input_dims[0], list):
        input_shape = [i.size() for i in input_dims[0] if i is not None]
    else:
        input_shape = list(input_dims[0].size())
    input_shape[0] = batch_size
    if isinstance(output_dims, (list, tuple)):
        output_shape = [[-1] + list(o.size())[1:] for o in output_dims if o is not None]
    else:
        output_shape = list(output_dims.size())
        output_shape[0] = batch_size
    return input_shape, output_shape
      

return the model summary as a string The block(type) column represents the lines (layers) above :param dtypes: The input types (list of inputs types) :param high_verbosity: prints layer by layer information

Source code in V3_2/src/super_gradients/training/utils/get_model_stats.py 94
def get_model_stats(
    model: nn.Module,
    input_dims: Union[list, tuple],
    high_verbosity: bool = True,
    batch_size: int = 1,
    device: str = "cuda",  # noqa: C901
    dtypes=None,
    iterations: int = 100,
    return the model summary as a string
    The block(type) column represents the lines (layers) above
        :param dtypes:          The input types (list of inputs types)
        :param high_verbosity:  prints layer by layer information
    dtypes = dtypes or [torch.FloatTensor] * len(input_dims)
    def register_hook(module):
        add a hook (all the desirable actions) for every layer that is not nn.Sequential/nn.ModuleList
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)
            m_key = f"{class_name}-{module_idx + 1}"
            summary[m_key] = OrderedDict()
            # block_name refers to all layers that contains other layers
            if len(module._modules) != 0:
                summary[m_key]["block_name"] = class_name
            summary[m_key]["inference_time"] = np.round(timer.stop(), 3)
            timer.start()
            summary[m_key]["gpu_occupation"] = (round(torch.cuda.memory_allocated(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]
            summary[m_key]["gpu_cached_memory"] = (round(torch.cuda.memory_reserved(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]
            summary[m_key]["input_shape"], summary[m_key]["output_shape"] = get_input_output_shapes(batch_size=batch_size, input_dims=input, output_dims=output)
            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
                summary[m_key]["trainable"] = module.weight.requires_grad
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            summary[m_key]["nb_params"] = params
        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
            hooks.append(module.register_forward_hook(hook))
    # multiple inputs to the network
    if isinstance(input_dims, tuple):
        input_dims = [input_dims]
    x = [torch.rand(batch_size, *input_dim).type(dtype).to(device=device) for input_dim, dtype in zip(input_dims, dtypes)]
    summary_list = []
    with torch.no_grad():
        for i in range(iterations + 10):
            # create properties
            summary = OrderedDict()
            hooks = []
            # register hook
            model.apply(register_hook)
            timer = Timer(device=device)
            timer.start()
            # make a forward pass
            model(*x)
            # remove these hooks
            for h in hooks:
                h.remove()
            # we start counting from the 10th iteration for warmup
            if i >= 10:
                summary_list.append(summary)
    summary = _average_inference_time(summary_list=summary_list, summary=summary, divisor=iterations)
    return _convert_summary_dict_to_string(summary=summary, high_verbosity=high_verbosity, input_dims=input_dims, batch_size=batch_size, device=device)
151
def check_image_typing(image: ImageSource) -> bool:
    """Check if the given object respects typing of image.
    :param image: Image to check.
    :return: True if the object is an image, False otherwise.
    if isinstance(image, get_args(SingleImageSource)):
        return True
    elif isinstance(image, list):
        return all([isinstance(image_item, get_args(SingleImageSource)) for image_item in image])
    else:
        return False
      

Generator that loads images one at a time.

Supported types include: - str: A string representing either an image or an URL. - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - List: A list of images of any of the above types.

Parameters:

Description Default Union[List[ImageSource], ImageSource]

Single image or a list of images of supported types.

required 57
def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iterable[np.ndarray]:
    """Generator that loads images one at a time.
    Supported types include:
        - str:              A string representing either an image or an URL.
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - List:             A list of images of any of the above types.
    :param images:  Single image or a list of images of supported types.
    :return:        Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    if isinstance(images, str) and os.path.isdir(images):
        images_paths = list_images_in_folder(images)
        for image_path in images_paths:
            yield load_image(image=image_path)
    elif isinstance(images, (list, Iterator)):
        for image in images:
            yield load_image(image=image)
    else:
        yield load_image(image=images)
160
def is_image(filename: str) -> bool:
    """Check if the given file name refers to image.
    :param filename:    The filename to check.
    :return:            True if the file is an image, False otherwise.
    return filename.split(".")[-1].lower() in IMG_EXTENSIONS
127
def is_url(url: str) -> bool:
    """Check if the given string is a URL.
    :param url:  String to check.
    try:
        result = urlparse(url)
        return all([result.scheme, result.netloc, result.path])
    except Exception:
        return False
67
def list_images_in_folder(directory: str) -> List[str]:
    """List all the images in a directory.
    :param directory: The path to the directory containing the images.
    :return: A list of image file names.
    files = os.listdir(directory)
    images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))]
    return images_paths
      

Load a single image and return it as a numpy arrays.

Supported image types include: - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - str: A string representing either a local file path or a URL to an image

Parameters:

Description Default 92
def load_image(image: ImageSource) -> np.ndarray:
    """Load a single image and return it as a numpy arrays.
    Supported image types include:
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - str:              A string representing either a local file path or a URL to an image
    :param image: Single image of supported types.
    :return:      Image as numpy arrays. If loaded from string, the image will be returned as RGB.
    if isinstance(image, np.ndarray):
        return image
    elif isinstance(image, torch.Tensor):
        return image.numpy()
    elif isinstance(image, PIL.Image.Image):
        return load_np_image_from_pil(image)
    elif isinstance(image, str):
        image = load_pil_image_from_str(image_str=image)
        return load_np_image_from_pil(image)
    else:
        raise ValueError(f"Unsupported image type: {type(image)}")
      

Load a single image or a list of images and return them as a list of numpy arrays.

Supported types include: - str: A string representing either an image or an URL. - numpy.ndarray: A numpy array representing the image - torch.Tensor: A PyTorch tensor representing the image - PIL.Image.Image: A PIL Image object - List: A list of images of any of the above types.

Parameters:

Description Default Union[List[ImageSource], ImageSource]

Single image or a list of images of supported types.

required 33
def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]:
    """Load a single image or a list of images and return them as a list of numpy arrays.
    Supported types include:
        - str:              A string representing either an image or an URL.
        - numpy.ndarray:    A numpy array representing the image
        - torch.Tensor:     A PyTorch tensor representing the image
        - PIL.Image.Image:  A PIL Image object
        - List:             A list of images of any of the above types.
    :param images:  Single image or a list of images of supported types.
    :return:        List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
    return [image for image in generate_image_loader(images=images)]
97
def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
    """Convert a PIL image to numpy array in RGB format."""
    return np.asarray(image.convert("RGB"))
108
def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
    """Load an image based on a string (local file path or URL)."""
    if is_url(image_str):
        response = requests.get(image_str, stream=True)
        response.raise_for_status()
        return PIL.Image.open(io.BytesIO(response.content))
    else:
        return PIL.Image.open(image_str)
116
def save_image(image: np.ndarray, path: str) -> None:
    """Save a numpy array as an image.
    :param image:  Image to save, (H, W, C), RGB.
    :param path:   Path to save the image to.
    Image.fromarray(image).save(path)
138
def show_image(image: np.ndarray) -> None:
    """Show an image using matplotlib.
    :param image: Image to show in (H, W, C), RGB.
    plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
    plt.imshow(image, interpolation="nearest")
    plt.axis("off")
    plt.tight_layout()
    plt.show()
    """Class for calculating the FPS of a video stream."""
    def __init__(self, update_frequency: Optional[float] = None):
        """Create a new FPSCounter object.
        :param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                                 If None, the counter is updated every frame.
        self._update_frequency = update_frequency
        self._start_time = time.time()
        self._frame_count = 0
        self._fps = 0.0
    def _update_fps(self, elapsed_time, current_time) -> None:
        """Compute new value of FPS and reset the counter."""
        self._fps = self._frame_count / elapsed_time
        self._start_time = current_time
        self._frame_count = 0
    @property
    def fps(self) -> float:
        """Current FPS value."""
        self._frame_count += 1
        current_time, elapsed_time = time.time(), time.time() - self._start_time
        if self._update_frequency is None or elapsed_time > self._update_frequency:
            self._update_fps(elapsed_time=elapsed_time, current_time=current_time)
        return self._fps
                Optional[float]
          

Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.

99
def __init__(self, update_frequency: Optional[float] = None):
    """Create a new FPSCounter object.
    :param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                             If None, the counter is updated every frame.
    self._update_frequency = update_frequency
    self._start_time = time.time()
    self._frame_count = 0
    self._fps = 0.0
                Optional[Callable[[np.ndarray], np.ndarray]]
          

Function to apply to each frame before displaying it. If None, frames are displayed as is.

capture

ID of the video capture device to use. Default is cv2.CAP_ANY (which selects the first available device).

cv2.CAP_ANY fps_update_frequency Optional[float]

Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.

70
class WebcamStreaming:
    """Stream video from a webcam. Press 'q' to quit the streaming.
    :param window_name:          Name of the window to display the video stream.
    :param frame_processing_fn:  Function to apply to each frame before displaying it.
                                 If None, frames are displayed as is.
    :param capture:              ID of the video capture device to use.
                                 Default is cv2.CAP_ANY (which selects the first available device).
    :param fps_update_frequency: Minimum time (in seconds) between updates to the FPS counter.
                                 If None, the counter is updated every frame.
    def __init__(
        self,
        window_name: str = "",
        frame_processing_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
        capture: int = cv2.CAP_ANY,
        fps_update_frequency: Optional[float] = None,
        self.window_name = window_name
        self.frame_processing_fn = frame_processing_fn
        self.cap = cv2.VideoCapture(capture)
        if not self.cap.isOpened():
            raise ValueError("Could not open video capture device")
        self._fps_counter = FPSCounter(update_frequency=fps_update_frequency)
    def run(self) -> None:
        """Start streaming video from the webcam and displaying it in a window.
        Press 'q' to quit the streaming.
        while not self._stop():
            self._display_single_frame()
    def _display_single_frame(self) -> None:
        """Read a single frame from the video capture device, apply any specified frame processing,
        and display the resulting frame in the window.
        Also updates the FPS counter and displays it in the frame.
        _ret, frame = self.cap.read()
        if self.frame_processing_fn:
            frame = self.frame_processing_fn(frame)
        _write_fps_to_frame(frame, self.fps)
        cv2.imshow(self.window_name, frame)
    def _stop(self) -> bool:
        """Stopping condition for the streaming."""
        return cv2.waitKey(1) & 0xFF == ord("q")
    @property
    def fps(self) -> float:
        return self._fps_counter.fps
    def __del__(self):
        """Release the video capture device and close the window."""
        self.cap.release()
        cv2.destroyAllWindows()
70
def __del__(self):
    """Release the video capture device and close the window."""
    self.cap.release()
    cv2.destroyAllWindows()
      

Start streaming video from the webcam and displaying it in a window.

Press 'q' to quit the streaming.

Source code in V3_2/src/super_gradients/training/utils/media/stream.py 43
def run(self) -> None:
    """Start streaming video from the webcam and displaying it in a window.
    Press 'q' to quit the streaming.
    while not self._stop():
        self._display_single_frame()
188
def includes_video_extension(file_path: str) -> bool:
    """Check if a file includes a video extension.
    :param file_path:   Path to the video file.
    :return:            True if the file includes a video extension.
    return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
30
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
    """Open a video file and extract each frame into numpy array.
    :param file_path:   Path to the video file.
    :param max_frames:  Optional, maximum number of frames to extract.
    :return:
                - Frames representing the video, each in (H, W, C), RGB.
                - Frames per Second (FPS).
    cap = _open_video(file_path)
    frames = _extract_frames(cap, max_frames)
    fps = cap.get(cv2.CAP_PROP_FPS)
    cap.release()
    return frames, fps
                List[np.ndarray]
          

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required

Frames per second

required 91
def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
    """Save a video locally in .gif format.
    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
    frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
                List[np.ndarray]
          

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required

Frames per second

required 113
def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
    """Save a video locally in .mp4 format.
    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    video_height, video_width = _validate_frames(frames)
    video_writer = cv2.VideoWriter(
        output_path,
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (video_width, video_height),
    for frame in frames:
        video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
    video_writer.release()
      

Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.

Parameters:

Description Default List[np.ndarray]

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required

Frames per second

required 78
def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
    """Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
    :param output_path: Where the video will be saved
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    if not includes_video_extension(output_path):
        logger.info(f'Output path "{output_path}" does not have a video extension, and therefore will be saved as {output_path}.mp4')
        output_path += ".mp4"
    if check_is_gif(output_path):
        save_gif(output_path, frames, fps)
    else:
        save_mp4(output_path, frames, fps)
165
def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
    """Display a video from disk using OpenCV.
    :param video_path:   Path to the video file.
    :param window_name:  Name of the window to display the video
    cap = _open_video(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            # Display the frame
            cv2.imshow(window_name, frame)
            # Wait for the specified number of milliseconds before displaying the next frame
            if cv2.waitKey(int(1000 / fps)) & 0xFF == ord("q"):
                break
        else:
            break
    # Release the VideoCapture object and destroy the window
    cap.release()
    cv2.destroyAllWindows()
    cv2.waitKey(1)
                List[np.ndarray]
          

Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.

required float

Frames per second

required window_name

Name of the window to display the video

'Prediction' 180
def show_video_from_frames(frames: List[np.ndarray], fps: float, window_name: str = "Prediction") -> None:
    """Display a video from a list of frames using OpenCV.
    :param frames:      Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
    :param fps:         Frames per second
    :param window_name:  Name of the window to display the video
    for frame in frames:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        cv2.imshow(window_name, frame)
        cv2.waitKey(int(1000 / fps))
    cv2.destroyAllWindows()
    cv2.waitKey(1)
      

Wrapper function for initializing the optimizer :param net: the nn_module to build the optimizer for :param lr: initial learning rate :param training_params: training_parameters

Source code in V3_2/src/super_gradients/training/utils/optimizer_utils.py 122
def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimizer:
    Wrapper function for initializing the optimizer
        :param net: the nn_module to build the optimizer for
        :param lr: initial learning rate
        :param training_params: training_parameters
    if is_model_wrapped(net):
        raise ValueError("Argument net for build_optimizer must be an unwrapped model. " "Please use build_optimizer(unwrap_model(net), ...).")
    if isinstance(training_params.optimizer, str):
        optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
    else:
        optimizer_cls = training_params.optimizer
    optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls].copy() if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS.keys() else dict()
    optimizer_params.update(**training_params.optimizer_params)
    training_params.optimizer_params = optimizer_params
    weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
    # OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
    if hasattr(net, "initialize_param_groups"):
        # INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
        net_named_params = net.initialize_param_groups(lr, training_params)
    else:
        net_named_params = [{"named_params": net.named_parameters()}]
    if training_params.zero_weight_decay_on_bias_and_bn:
        optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net, net_named_params, weight_decay)
    else:
        # Overwrite groups to include params instead of named params
        for ind_group, param_group in enumerate(net_named_params):
            param_group["params"] = [param[1] for param in list(param_group["named_params"])]
            del param_group["named_params"]
            net_named_params[ind_group] = param_group
        optimizer_training_params = net_named_params
    # CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT
    optimizer = optimizer_cls(optimizer_training_params, lr=lr, **training_params.optimizer_params)
    return optimizer
      

separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format required by torch Optimizer classes. bias + BN with weight decay=0 and the rest with the given weight decay :param module: train net module. :param net_named_params: list of params groups, output of SgModule.initialize_param_groups :param weight_decay: value to set for the non BN and bias parameters

Source code in V3_2/src/super_gradients/training/utils/optimizer_utils.py 54
def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
    separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
     required by torch Optimizer classes.
    bias + BN with weight decay=0 and the rest with the given weight decay
        :param module: train net module.
        :param net_named_params: list of params groups, output of SgModule.initialize_param_groups
        :param weight_decay: value to set for the non BN and bias parameters
    # FIXME - replace usage of ids addresses to find batchnorm and biases params.
    #  This solution iterate 2 times over module parameters, find a way to iterate only one time.
    no_decay_ids = _get_no_decay_param_ids(module)
    # split param groups for optimizer
    optimizer_param_groups = []
    for param_group in net_named_params:
        no_decay_params = []
        decay_params = []
        for name, param in param_group["named_params"]:
            if id(param) in no_decay_ids:
                no_decay_params.append(param)
            else:
                decay_params.append(param)
        # append two param groups from the original param group, with and without weight decay.
        extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]}
        optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
        optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
    return optimizer_param_groups
      

This implementation is taken from timm's github: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py

PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb

This optimizer code was adapted from the following (starting with latest) * https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py * https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py * https://github.com/cybertronai/pytorch-lamb

Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.

In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.

Original copyrights for above sources are below.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

MIT License

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py

LAMB was proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes_.

Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability. (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) grad_averaging (bool, optional): whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True) max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) trust_clip (bool): enable LAMBC trust ratio clipping (default: False) always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 weight decay parameter (default: False)

.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ

Source code in V3_2/src/super_gradients/training/utils/optimizers/lamb.py 216
@register_optimizer(Optimizers.LAMB)
class Lamb(Optimizer):
    """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
    reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
    LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
        lr (float, optional): learning rate. (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its norm. (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability. (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        grad_averaging (bool, optional): whether apply (1-beta2) to grad when
            calculating running averages of gradient. (default: True)
        max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
        trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
        always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
            weight decay parameter (default: False)
    .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    .. _On the Convergence of Adam and Beyond:
        https://openreview.net/forum?id=ryQu7f-RZ
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-3,
        bias_correction: bool = True,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.01,
        grad_averaging: bool = True,
        max_grad_norm: float = 1.0,
        trust_clip: bool = False,
        always_adapt: bool = False,
        defaults = dict(
            lr=lr,
            bias_correction=bias_correction,
            betas=betas,
            eps=eps,
            weight_decay=weight_decay,
            grad_averaging=grad_averaging,
            max_grad_norm=max_grad_norm,
            trust_clip=trust_clip,
            always_adapt=always_adapt,
        super().__init__(params, defaults)
    @torch.no_grad()
    def step(self, closure: Optional[callable] = None) -> torch.Tensor:
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        device = self.param_groups[0]["params"][0].device
        one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
        global_grad_norm = torch.zeros(1, device=device)
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
                global_grad_norm.add_(grad.pow(2).sum())
        global_grad_norm = torch.sqrt(global_grad_norm)
        # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
        # scalar types properly https://github.com/pytorch/pytorch/issues/9190
        max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
        clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)
        for group in self.param_groups:
            bias_correction = 1 if group["bias_correction"] else 0
            beta1, beta2 = group["betas"]
            grad_averaging = 1 if group["grad_averaging"] else 0
            beta3 = 1 - beta1 if grad_averaging else 1.0
            # assume same step across group now to simplify things
            # per parameter step can be easily support by making it tensor, or pass list into kernel
            if "step" in group:
                group["step"] += 1
            else:
                group["step"] = 1
            if bias_correction:
                bias_correction1 = 1 - beta1 ** group["step"]
                bias_correction2 = 1 - beta2 ** group["step"]
            else:
                bias_correction1, bias_correction2 = 1.0, 1.0
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.div_(clip_global_grad_norm)
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient valuesa
                    state["exp_avg"] = torch.zeros_like(p)
                    # Exponential moving average of squared gradient values
                    state["exp_avg_sq"] = torch.zeros_like(p)
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=beta3)  # m_t
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)  # v_t
                denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
                update = (exp_avg / bias_correction1).div_(denom)
                weight_decay = group["weight_decay"]
                if weight_decay != 0:
                    update.add_(p, alpha=weight_decay)
                if weight_decay != 0 or group["always_adapt"]:
                    # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
                    # excluded from weight decay, unless always_adapt == True, then always enabled.
                    w_norm = p.norm(2.0)
                    g_norm = update.norm(2.0)
                    # FIXME nested where required since logical and/or not working in PT XLA
                    trust_ratio = torch.where(
                        w_norm > 0,
                        torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
                        one_tensor,
                    if group["trust_clip"]:
                        # LAMBC trust clipping, upper bound fixed at one
                        trust_ratio = torch.minimum(trust_ratio, one_tensor)
                    update.mul_(trust_ratio)
                p.add_(update, alpha=-group["lr"])
        return loss
      

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

Source code in V3_2/src/super_gradients/training/utils/optimizers/lamb.py 216
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()
    device = self.param_groups[0]["params"][0].device
    one_tensor = torch.tensor(1.0, device=device)  # because torch.where doesn't handle scalars correctly
    global_grad_norm = torch.zeros(1, device=device)
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad
            if grad.is_sparse:
                raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
            global_grad_norm.add_(grad.pow(2).sum())
    global_grad_norm = torch.sqrt(global_grad_norm)
    # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
    # scalar types properly https://github.com/pytorch/pytorch/issues/9190
    max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
    clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)
    for group in self.param_groups:
        bias_correction = 1 if group["bias_correction"] else 0
        beta1, beta2 = group["betas"]
        grad_averaging = 1 if group["grad_averaging"] else 0
        beta3 = 1 - beta1 if grad_averaging else 1.0
        # assume same step across group now to simplify things
        # per parameter step can be easily support by making it tensor, or pass list into kernel
        if "step" in group:
            group["step"] += 1
        else:
            group["step"] = 1
        if bias_correction:
            bias_correction1 = 1 - beta1 ** group["step"]
            bias_correction2 = 1 - beta2 ** group["step"]
        else:
            bias_correction1, bias_correction2 = 1.0, 1.0
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.div_(clip_global_grad_norm)
            state = self.state[p]
            # State initialization
            if len(state) == 0:
                # Exponential moving average of gradient valuesa
                state["exp_avg"] = torch.zeros_like(p)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(p)
            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
            # Decay the first and second moment running average coefficient
            exp_avg.mul_(beta1).add_(grad, alpha=beta3)  # m_t
            exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)  # v_t
            denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
            update = (exp_avg / bias_correction1).div_(denom)
            weight_decay = group["weight_decay"]
            if weight_decay != 0:
                update.add_(p, alpha=weight_decay)
            if weight_decay != 0 or group["always_adapt"]:
                # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
                # excluded from weight decay, unless always_adapt == True, then always enabled.
                w_norm = p.norm(2.0)
                g_norm = update.norm(2.0)
                # FIXME nested where required since logical and/or not working in PT XLA
                trust_ratio = torch.where(
                    w_norm > 0,
                    torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
                    one_tensor,
                if group["trust_clip"]:
                    # LAMBC trust clipping, upper bound fixed at one
                    trust_ratio = torch.minimum(trust_ratio, one_tensor)
                update.mul_(trust_ratio)
            p.add_(update, alpha=-group["lr"])
    return loss
81
@register_optimizer(Optimizers.LION)
class Lion(Optimizer):
    r"""Implements Lion algorithm.
    Generaly, it is recommended to divide lr used by AdamW by 10 and multiply the weight decay by 10.
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-4,
        betas: Tuple[float, float] = (0.9, 0.99),
        weight_decay: float = 0.0,
        Initialize the hyperparameters.
        :param params:          Iterable of parameters to optimize or dicts defining parameter groups
        :param lr:              Learning rate (default: 1e-4)
        :param betas:           Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
        :param weight_decay:    Weight decay coefficient (default: 0)
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)
    @torch.no_grad()
    def step(self, closure: Optional[callable] = None) -> torch.Tensor:
        Perform a single optimization step.
        :param closure: A closure that reevaluates the model and returns the loss.
        :return: Loss.
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                # Perform stepweight decay
                p.data.mul_(1 - group["lr"] * group["weight_decay"])
                grad = p.grad
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    # Exponential moving average of gradient values
                    state["exp_avg"] = torch.zeros_like(p)
                exp_avg = state["exp_avg"]
                beta1, beta2 = group["betas"]
                # Weight update
                update = exp_avg * beta1 + grad * (1 - beta1)
                p.add_(torch.sign(update), alpha=-group["lr"])
                # Decay the momentum running average coefficient
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
        return loss
                Union[Iterable[torch.Tensor], Iterable[dict]]
          

Iterable of parameters to optimize or dicts defining parameter groups

required float

Learning rate (default: 1e-4)

0.0001 betas Tuple[float, float]

Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))

(0.9, 0.99) weight_decay float

Weight decay coefficient (default: 0)

42
def __init__(
    self,
    params: Union[Iterable[torch.Tensor], Iterable[dict]],
    lr: float = 1e-4,
    betas: Tuple[float, float] = (0.9, 0.99),
    weight_decay: float = 0.0,
    Initialize the hyperparameters.
    :param params:          Iterable of parameters to optimize or dicts defining parameter groups
    :param lr:              Learning rate (default: 1e-4)
    :param betas:           Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
    :param weight_decay:    Weight decay coefficient (default: 0)
    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= betas[0] < 1.0:
        raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
    if not 0.0 <= betas[1] < 1.0:
        raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
    defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
    super().__init__(params, defaults)
81
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
    Perform a single optimization step.
    :param closure: A closure that reevaluates the model and returns the loss.
    :return: Loss.
    loss = None
    if closure is not None:
        with torch.enable_grad():
            loss = closure()
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            # Perform stepweight decay
            p.data.mul_(1 - group["lr"] * group["weight_decay"])
            grad = p.grad
            state = self.state[p]
            # State initialization
            if len(state) == 0:
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(p)
            exp_avg = state["exp_avg"]
            beta1, beta2 = group["betas"]
            # Weight update
            update = exp_avg * beta1 + grad * (1 - beta1)
            p.add_(torch.sign(update), alpha=-group["lr"])
            # Decay the momentum running average coefficient
            exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
    return loss
      

Implements RMSprop algorithm (TensorFlow style epsilon) NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt and a few other modifications to closer match Tensorflow for matching hyper-params. Noteworthy changes include: 1. Epsilon applied inside square-root 2. square_avg initialized to ones 3. LR scaling of update accumulated in momentum buffer Proposed by G. Hinton in his course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>. The centered version first appears in Generating Sequences With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>.

Source code in V3_2/src/super_gradients/training/utils/optimizers/rmsprop_tf.py 153
@register_optimizer(Optimizers.RMS_PROP_TF)
class RMSpropTF(Optimizer):
    """Implements RMSprop algorithm (TensorFlow style epsilon)
    NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
    and a few other modifications to closer match Tensorflow for matching hyper-params.
    Noteworthy changes include:
    1. Epsilon applied inside square-root
    2. square_avg initialized to ones
    3. LR scaling of update accumulated in momentum buffer
    Proposed by G. Hinton in his
    `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
    The centered version first appears in `Generating Sequences
    With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_."""
    def __init__(
        self,
        params: Union[Iterable[torch.Tensor], Iterable[dict]],
        lr: float = 1e-2,
        alpha: float = 0.9,
        eps: float = 1e-10,
        weight_decay: float = 0,
        momentum: float = 0.0,
        centered: bool = False,
        decoupled_decay: bool = False,
        lr_in_momentum: bool = True,
        """RMSprop optimizer that follows the tf's RMSprop characteristics
        :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
        :param lr (float, optional): learning rate
        :param momentum (float, optional): momentum factor
        :param alpha (float, optional): smoothing (decay) constant
        :param eps (float, optional): term added to the denominator to improve numerical stability
        :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
         estimation of its variance
        :param weight_decay (float, optional): weight decay (L2 penalty)
        :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
        :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
         defaults in Tensorflow
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= momentum:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if not 0.0 <= weight_decay:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
        if not 0.0 <= alpha:
            raise ValueError("Invalid alpha value: {}".format(alpha))
        defaults = dict(
            lr=lr,
            momentum=momentum,
            alpha=alpha,
            eps=eps,
            centered=centered,
            weight_decay=weight_decay,
            decoupled_decay=decoupled_decay,
            lr_in_momentum=lr_in_momentum,
        super(RMSpropTF, self).__init__(params, defaults)
    def __setstate__(self, state):
        super(RMSpropTF, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault("momentum", 0)
            group.setdefault("centered", False)
    def step(self, closure: Optional[callable] = None) -> torch.Tensor:  # noqa: C901
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("RMSprop does not support sparse gradients")
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["square_avg"] = torch.ones_like(p.data)  # PyTorch inits to zero
                    if group["momentum"] > 0:
                        state["momentum_buffer"] = torch.zeros_like(p.data)
                    if group["centered"]:
                        state["grad_avg"] = torch.zeros_like(p.data)
                square_avg = state["square_avg"]
                one_minus_alpha = 1.0 - group["alpha"]
                state["step"] += 1
                if group["weight_decay"] != 0:
                    if "decoupled_decay" in group and group["decoupled_decay"]:
                        p.data.add_(-group["weight_decay"], p.data)
                    else:
                        grad = grad.add(group["weight_decay"], p.data)
                # Tensorflow order of ops for updating squared avg
                square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
                # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)  # PyTorch original
                if group["centered"]:
                    grad_avg = state["grad_avg"]
                    grad_avg.add_(one_minus_alpha, grad - grad_avg)
                    # grad_avg.mul_(alpha).add_(1 - alpha, grad)  # PyTorch original
                    avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_()  # eps moved in sqrt
                else:
                    avg = square_avg.add(group["eps"]).sqrt_()  # eps moved in sqrt
                if group["momentum"] > 0:
                    buf = state["momentum_buffer"]
                    # Tensorflow accumulates the LR scaling in the momentum buffer
                    if "lr_in_momentum" in group and group["lr_in_momentum"]:
                        buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
                        p.data.add_(-buf)
                    else:
                        # PyTorch scales the param update by LR
                        buf.mul_(group["momentum"]).addcdiv_(grad, avg)
                        p.data.add_(-group["lr"], buf)
                else:
                    p.data.addcdiv_(-group["lr"], grad, avg)
        return loss
81
def __init__(
    self,
    params: Union[Iterable[torch.Tensor], Iterable[dict]],
    lr: float = 1e-2,
    alpha: float = 0.9,
    eps: float = 1e-10,
    weight_decay: float = 0,
    momentum: float = 0.0,
    centered: bool = False,
    decoupled_decay: bool = False,
    lr_in_momentum: bool = True,
    """RMSprop optimizer that follows the tf's RMSprop characteristics
    :param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
    :param lr (float, optional): learning rate
    :param momentum (float, optional): momentum factor
    :param alpha (float, optional): smoothing (decay) constant
    :param eps (float, optional): term added to the denominator to improve numerical stability
    :param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
     estimation of its variance
    :param weight_decay (float, optional): weight decay (L2 penalty)
    :param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
    :param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
     defaults in Tensorflow
    if not 0.0 <= lr:
        raise ValueError("Invalid learning rate: {}".format(lr))
    if not 0.0 <= eps:
        raise ValueError("Invalid epsilon value: {}".format(eps))
    if not 0.0 <= momentum:
        raise ValueError("Invalid momentum value: {}".format(momentum))
    if not 0.0 <= weight_decay:
        raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
    if not 0.0 <= alpha:
        raise ValueError("Invalid alpha value: {}".format(alpha))
    defaults = dict(
        lr=lr,
        momentum=momentum,
        alpha=alpha,
        eps=eps,
        centered=centered,
        weight_decay=weight_decay,
        decoupled_decay=decoupled_decay,
        lr_in_momentum=lr_in_momentum,
    super(RMSpropTF, self).__init__(params, defaults)
      

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

Source code in V3_2/src/super_gradients/training/utils/optimizers/rmsprop_tf.py 153
def step(self, closure: Optional[callable] = None) -> torch.Tensor:  # noqa: C901
    """Performs a single optimization step.
    Arguments:
        closure (callable, optional): A closure that reevaluates the model
            and returns the loss.
    loss = None
    if closure is not None:
        loss = closure()
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad.data
            if grad.is_sparse:
                raise RuntimeError("RMSprop does not support sparse gradients")
            state = self.state[p]
            # State initialization
            if len(state) == 0:
                state["step"] = 0
                state["square_avg"] = torch.ones_like(p.data)  # PyTorch inits to zero
                if group["momentum"] > 0:
                    state["momentum_buffer"] = torch.zeros_like(p.data)
                if group["centered"]:
                    state["grad_avg"] = torch.zeros_like(p.data)
            square_avg = state["square_avg"]
            one_minus_alpha = 1.0 - group["alpha"]
            state["step"] += 1
            if group["weight_decay"] != 0:
                if "decoupled_decay" in group and group["decoupled_decay"]:
                    p.data.add_(-group["weight_decay"], p.data)
                else:
                    grad = grad.add(group["weight_decay"], p.data)
            # Tensorflow order of ops for updating squared avg
            square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
            # square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)  # PyTorch original
            if group["centered"]:
                grad_avg = state["grad_avg"]
                grad_avg.add_(one_minus_alpha, grad - grad_avg)
                # grad_avg.mul_(alpha).add_(1 - alpha, grad)  # PyTorch original
                avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_()  # eps moved in sqrt
            else:
                avg = square_avg.add(group["eps"]).sqrt_()  # eps moved in sqrt
            if group["momentum"] > 0:
                buf = state["momentum_buffer"]
                # Tensorflow accumulates the LR scaling in the momentum buffer
                if "lr_in_momentum" in group and group["lr_in_momentum"]:
                    buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
                    p.data.add_(-buf)
                else:
                    # PyTorch scales the param update by LR
                    buf.mul_(group["momentum"]).addcdiv_(grad, avg)
                    p.data.add_(-group["lr"], buf)
            else:
                p.data.addcdiv_(-group["lr"], grad, avg)
    return loss
          Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
345
class DEKRPoseEstimationDecodeCallback(nn.Module):
    Class that implements decoding logic of DEKR's model predictions into poses.
    def __init__(
        self,
        output_stride: int,
        max_num_people: int,
        keypoint_threshold: float,
        nms_threshold: float,
        nms_num_threshold: int,
        apply_sigmoid: bool,
        min_confidence: float = 0.0,
        :param output_stride: Output stride of the model
        :param int max_num_people: Maximum number of decoded poses
        :param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
        :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                              Given in terms of a percentage of a square root of the area of the pose bounding box.
        :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
        :param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                              bound to [0..1] range and trained with logits (E.g focal loss)
        :param float min_confidence: Minimum confidence threshold for pose
        super().__init__()
        self.keypoint_threshold = keypoint_threshold
        self.max_num_people = max_num_people
        self.output_stride = output_stride
        self.nms_threshold = nms_threshold
        self.nms_num_threshold = nms_num_threshold
        self.apply_sigmoid = apply_sigmoid
        self.min_confidence = min_confidence
    @torch.no_grad()
    def forward(self, predictions: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
        :param predictions: Tuple (heatmap, offset):
            heatmap - [BatchSize, NumJoints+1,H,W]
            offset - [BatchSize, NumJoints*2,H,W]
        :return: Tuple
        all_poses = []
        all_scores = []
        heatmap, offset = predictions
        batch_size = len(heatmap)
        for i in range(batch_size):
            poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
            all_poses.append(poses)
            all_scores.append(scores)
        return all_poses, all_scores
    def decode_one_sized_batch(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        heatmap, offset = predictions
        posemap = _offset_to_pose(offset)  # [1, 2 * num_joints, H, W]
        if heatmap.size(0) != 1:
            raise RuntimeError("Batch size of 1 is required")
        if self.apply_sigmoid:
            heatmap = heatmap.sigmoid()
        heatmap_sum, poses_sum = aggregate_results(
            heatmap,
            posemap,
            pose_center_score_threshold=self.keypoint_threshold,
            max_num_people=self.max_num_people,
            output_stride=self.output_stride,
        poses, scores = pose_nms(
            heatmap_sum,
            poses_sum,
            max_num_people=self.max_num_people,
            nms_threshold=self.nms_threshold,
            nms_num_threshold=self.nms_num_threshold,
            pose_score_threshold=self.min_confidence,
        if len(poses) != len(scores):
            raise RuntimeError("Decoding error detected. Returned mismatching number of poses/scores")
        return poses, scores
                float
          

(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate

required nms_threshold float

The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.

required nms_num_threshold

Number of joints that must pass the NMS check for the pose to be considered as a valid one.

required apply_sigmoid

If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)

required min_confidence float

Minimum confidence threshold for pose

Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py 292
def __init__(
    self,
    output_stride: int,
    max_num_people: int,
    keypoint_threshold: float,
    nms_threshold: float,
    nms_num_threshold: int,
    apply_sigmoid: bool,
    min_confidence: float = 0.0,
    :param output_stride: Output stride of the model
    :param int max_num_people: Maximum number of decoded poses
    :param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
    :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                          Given in terms of a percentage of a square root of the area of the pose bounding box.
    :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
    :param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                          bound to [0..1] range and trained with logits (E.g focal loss)
    :param float min_confidence: Minimum confidence threshold for pose
    super().__init__()
    self.keypoint_threshold = keypoint_threshold
    self.max_num_people = max_num_people
    self.output_stride = output_stride
    self.nms_threshold = nms_threshold
    self.nms_num_threshold = nms_num_threshold
    self.apply_sigmoid = apply_sigmoid
    self.min_confidence = min_confidence
                Union[Tensor, Tuple[Tensor, Tensor]]
          

Tuple (heatmap, offset): heatmap - [BatchSize, NumJoints+1,H,W] offset - [BatchSize, NumJoints*2,H,W]

required Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py 313
@torch.no_grad()
def forward(self, predictions: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    :param predictions: Tuple (heatmap, offset):
        heatmap - [BatchSize, NumJoints+1,H,W]
        offset - [BatchSize, NumJoints*2,H,W]
    :return: Tuple
    all_poses = []
    all_scores = []
    heatmap, offset = predictions
    batch_size = len(heatmap)
    for i in range(batch_size):
        poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
        all_poses.append(poses)
        all_scores.append(scores)
    return all_poses, all_scores
      

Get initial pose proposals and aggregate the results of all scale. Not this implementation works only for batch size of 1.

Parameters:

Description Default float

(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate

required max_num_people

(int)

required Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py 255
def aggregate_results(
    heatmap: Tensor, posemap: Tensor, output_stride: int, pose_center_score_threshold: float, max_num_people: int
) -> Tuple[Tensor, List[Tensor]]:
    Get initial pose proposals and aggregate the results of all scale.
    Not this implementation works only for batch size of 1.
    :param heatmap: Heatmap at this scale (B, 1+num_joints, w, h)
    :param posemap: Posemap at this scale (B, 2*num_joints, w, h)
    :param output_stride: Ratio of input size / predictions size
    :param pose_center_score_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
    :param max_num_people: (int)
    :return:
        - heatmap_sum: Sum of the heatmaps (1, 1+num_joints, w, h)
        - poses (List): Gather of the pose proposals [B, (num_people, num_joints, 3)]
    poses = []
    h, w = heatmap[0].size(-1), heatmap[0].size(-2)
    heatmap_sum = _up_interpolate(heatmap, size=(int(output_stride * w), int(output_stride * h)))
    center_heatmap = heatmap[0, -1:]
    pose_ind, ctr_score = _get_maximum_from_heatmap(center_heatmap, pose_center_score_threshold=pose_center_score_threshold, max_num_people=max_num_people)
    posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2)
    pose = output_stride * posemap[pose_ind]
    ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None]
    poses.append(torch.cat([pose, ctr_score], dim=2))
    return heatmap_sum, poses
        Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
24
def get_locations(output_h: int, output_w: int, device):
    Generate location map (each pixel contains its own XY coordinate)
    :param output_h: Feature map height (rows)
    :param output_w: Feature map width (cols)
    :param device: Target device to put tensor on
    :return: [H * W, 2]
    shifts_x = torch.arange(0, output_w, step=1, dtype=torch.float32, device=device)
    shifts_y = torch.arange(0, output_h, step=1, dtype=torch.float32, device=device)
    shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
    shift_x = shift_x.reshape(-1)
    shift_y = shift_y.reshape(-1)
    locations = torch.stack((shift_x, shift_y), dim=1)
    return locations
        Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
41
def get_reg_poses(offset: Tensor, num_joints: int):
    Decode offset predictions into absolute locations.
    :param offset: Tensor of [num_joints*2,H,W] shape with offset predictions for each joint
    :param num_joints: Number of joints
    :return: [H * W, num_joints, 2]
    _, h, w = offset.shape
    offset = offset.permute(1, 2, 0).reshape(h * w, num_joints, 2)
    locations = get_locations(h, w, offset.device)
    locations = locations[:, None, :].expand(-1, num_joints, -1)
    poses = locations - offset
    return poses
                float
          

The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.

required nms_num_threshold

Number of joints that must pass the NMS check for the pose to be considered as a valid one.

required pose_score_threshold float

Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.

required Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py 222
def pose_nms(
    heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_num_threshold: int, pose_score_threshold: float
) -> Tuple[np.ndarray, np.ndarray]:
    NMS for the regressed poses results.
    :param Tensor heatmap_avg: Avg of the heatmaps at all scales (1, 1+num_joints, w, h)
    :param List poses: Gather of the pose proposals [(num_people, num_joints, 3)]
    :param int max_num_people: Maximum number of decoded poses
    :param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
                          Given in terms of a percentage of a square root of the area of the pose bounding box.
    :param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
    :param float pose_score_threshold: Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.
    :return Tuple of (poses, scores)
    assert len(poses) == 1
    pose_score = torch.cat([pose[:, :, 2:] for pose in poses], dim=0)
    pose_coord = torch.cat([pose[:, :, :2] for pose in poses], dim=0)
    num_people, num_joints, _ = pose_coord.shape
    if num_people == 0:
        return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)
    heatval = _get_heat_value(pose_coord, heatmap_avg[0])
    heat_score = (torch.sum(heatval, dim=1) / num_joints)[:, 0]
    pose_score = pose_score * heatval
    poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)
    keep_pose_inds = _nms_core(pose_coord, heat_score, nms_threshold=nms_threshold, nms_num_threshold=nms_num_threshold)
    poses = poses[keep_pose_inds]
    heat_score = heat_score[keep_pose_inds]
    if len(keep_pose_inds) > max_num_people:
        heat_score, topk_inds = torch.topk(heat_score, max_num_people)
        poses = poses[topk_inds]
    poses = poses.numpy()
    if len(poses):
        scores = poses[:, :, 2].mean(axis=1)
        mask = scores >= pose_score_threshold
        poses = poses[mask]
        scores = scores[mask]
    else:
        return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)
    return poses, scores
      

A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger

Parameters:

Description Default Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py 172
@register_callback(Callbacks.DEKR_VISUALIZATION)
class DEKRVisualizationCallback(PhaseCallback):
    A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
    :param phase:                   When to trigger the callback.
    :param prefix:                  Prefix to add to the log.
    :param mean:                    Mean to subtract from image.
    :param std:                     Standard deviation to subtract from image.
    :param apply_sigmoid:           Whether to apply sigmoid to the output.
    :param batch_idx:               Batch index to perform visualization for.
    :param keypoints_threshold:     Keypoint threshold to use for visualization.
    def __init__(
        self,
        phase: Phase,
        prefix: str,
        mean: List[float],
        std: List[float],
        apply_sigmoid: bool = False,
        batch_idx: int = 0,
        keypoints_threshold: float = 0.01,
        super(DEKRVisualizationCallback, self).__init__(phase)
        self.batch_idx = batch_idx
        self.prefix = prefix
        self.mean = np.array(list(map(float, mean))).reshape((1, 1, -1))
        self.std = np.array(list(map(float, std))).reshape((1, 1, -1))
        self.apply_sigmoid = apply_sigmoid
        self.keypoints_threshold = keypoints_threshold
    def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
        Reverse image normalization image_normalized (image / 255 - mean) / std
        :param image_normalized: normalized [3,H,W]
        :return:
        image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
        image = (image_normalized * self.std + self.mean) * 255
        image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
        return image
    @classmethod
    def visualize_heatmap(self, heatmap: Tensor, apply_sigmoid: bool, dsize, min_value=None, max_value=None, colormap=cv2.COLORMAP_JET):
        if apply_sigmoid:
            heatmap = heatmap.sigmoid()
        if min_value is None:
            min_value = heatmap.min().item()
        if max_value is None:
            max_value = heatmap.max().item()
        heatmap = heatmap.detach().cpu().numpy()
        real_min = heatmap.min()
        real_max = heatmap.max()
        heatmap = np.max(heatmap, axis=0)
        heatmap = (heatmap - min_value) / (1e-8 + max_value - min_value)
        heatmap = np.clip(heatmap, 0, 1)
        heatmap_8u = (heatmap * 255).astype(np.uint8)
        heatmap_bgr = cv2.applyColorMap(heatmap_8u, colormap)
        heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
        if dsize is not None:
            heatmap_rgb = cv2.resize(heatmap_rgb, dsize=dsize)
        cv2.putText(
            heatmap_rgb,
            f"min:{real_min:.3f}",
            (5, 15),
            fontFace=cv2.FONT_HERSHEY_PLAIN,
            color=(255, 255, 255),
            fontScale=0.8,
            thickness=1,
            lineType=cv2.LINE_AA,
        cv2.putText(
            heatmap_rgb,
            f"max:{real_max:.3f}",
            (5, heatmap_rgb.shape[0] - 10),
            cv2.FONT_HERSHEY_PLAIN,
            color=(255, 255, 255),
            fontScale=0.8,
            thickness=1,
            lineType=cv2.LINE_AA,
        return heatmap, heatmap_rgb
    @multi_process_safe
    def __call__(self, context: PhaseContext):
        if context.batch_idx == self.batch_idx:
            batch_imgs = self.visualize_batch(context.inputs, context.preds, context.target)
            batch_imgs = np.stack(batch_imgs)
            tag = self.prefix + str(self.batch_idx) + "_images"
            context.sg_logger.add_images(tag=tag, images=batch_imgs, global_step=context.epoch, data_format="NHWC")
    @torch.no_grad()
    def visualize_batch(self, inputs, predictions, targets):
        num_samples = len(inputs)
        batch_imgs = []
        gt_heatmap, mask, _, _ = targets
        # Check whether model also produce supervised output predictions
        if isinstance(predictions, tuple) and len(predictions) == 2 and torch.is_tensor(predictions[0]) and torch.is_tensor(predictions[1]):
            heatmap, _ = predictions
        else:
            (heatmap, _), (_, _) = predictions
        for i in range(num_samples):
            batch_imgs.append(self.visualize_sample(inputs[i], predicted_heatmap=heatmap[i], target_heatmap=gt_heatmap[i], target_mask=mask[i]))
        return batch_imgs
    def visualize_sample(self, input, predicted_heatmap, target_heatmap, target_mask):
        image_rgb = self.denormalize_image(input)
        dsize = image_rgb.shape[1], image_rgb.shape[0]
        half_size = dsize[0] // 2, dsize[1] // 2
        target_heatmap_f32, target_heatmap_rgb = self.visualize_heatmap(target_heatmap, apply_sigmoid=False, dsize=half_size)
        target_heatmap_f32 = cv2.resize(target_heatmap_f32, dsize=dsize)
        target_heatmap_f32 = np.expand_dims(target_heatmap_f32, -1)
        peaks_heatmap = _hierarchical_pool(predicted_heatmap)[0]
        peaks_heatmap = predicted_heatmap.eq(peaks_heatmap) & (predicted_heatmap > self.keypoints_threshold)
        peaks_heatmap = peaks_heatmap.sum(dim=0, keepdim=False) > 0
        # Apply masking with GT mask to suppress predictions on ignored areas of the image (where target_mask==0)
        flat_target_mask = target_mask.sum(dim=0, keepdim=False) > 0
        peaks_heatmap &= flat_target_mask
        peaks_heatmap = peaks_heatmap.detach().cpu().numpy().astype(np.uint8) * 255
        peaks_heatmap = cv2.applyColorMap(peaks_heatmap, cv2.COLORMAP_JET)
        peaks_heatmap = cv2.cvtColor(peaks_heatmap, cv2.COLOR_BGR2RGB)
        peaks_heatmap = cv2.resize(peaks_heatmap, dsize=half_size)
        _, predicted_heatmap_rgb = self.visualize_heatmap(
            predicted_heatmap, min_value=target_heatmap.min().item(), max_value=target_heatmap.max().item(), apply_sigmoid=self.apply_sigmoid, dsize=half_size
        image_heatmap_overlay = image_rgb * (1 - target_heatmap_f32) + target_heatmap_f32 * cv2.resize(target_heatmap_rgb, dsize=dsize)
        image_heatmap_overlay = image_heatmap_overlay.astype(np.uint8)
        _, target_mask_rgb = self.visualize_heatmap(target_mask, min_value=0, max_value=1, apply_sigmoid=False, dsize=half_size, colormap=cv2.COLORMAP_BONE)
        return np.hstack(
                image_heatmap_overlay,
                np.vstack([target_heatmap_rgb, predicted_heatmap_rgb]),
                np.vstack([target_mask_rgb, peaks_heatmap]),
        Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py
59
def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
    Reverse image normalization image_normalized (image / 255 - mean) / std
    :param image_normalized: normalized [3,H,W]
    :return:
    image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
    image = (image_normalized * self.std + self.mean) * 255
    image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
    return image
          Source code in V3_2/src/super_gradients/training/utils/pose_estimation/rescoring_callback.py
25
class RescoringPoseEstimationDecodeCallback:
    A special adapter callback to be used with PoseEstimationMetrics to use the outputs from rescoring model inside metric class.
    def __init__(self, apply_sigmoid: bool):
        :param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                              bound to [0..1] range and trained with logits (E.g focal loss)
        super().__init__()
        self.apply_sigmoid = apply_sigmoid
    def __call__(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """ """
        poses, scores = predictions
        if self.apply_sigmoid:
            scores = scores.sigmoid()
        return poses, scores.squeeze(-1)  # Pose Estimation Callback expects that scores don't have the dummy dimension
          

If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)

required Source code in V3_2/src/super_gradients/training/utils/pose_estimation/rescoring_callback.py 18
def __init__(self, apply_sigmoid: bool):
    :param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
                          bound to [0..1] range and trained with logits (E.g focal loss)
    super().__init__()
    self.apply_sigmoid = apply_sigmoid
      

Object wrapping an image and a detection model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 123
@dataclass
class ImagePoseEstimationPrediction(ImagePrediction):
    """Object wrapping an image and a detection model's prediction.
    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    image: np.ndarray
    prediction: PoseEstimationPrediction
    def draw(
        self,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> np.ndarray:
        """Draw the predicted bboxes on the image.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        :return:                Image with predicted bboxes. Note that this does not modify the original image.
        image = self.image.copy()
        for pred_i in np.argsort(self.prediction.scores):
            image = draw_skeleton(
                image=image,
                keypoints=self.prediction.poses[pred_i],
                score=self.prediction.scores[pred_i],
                show_confidence=show_confidence,
                edge_links=self.prediction.edge_links,
                edge_colors=edge_colors or self.prediction.edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
        return image
    def show(
        self,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Display the image with predicted bboxes.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        image = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        show_image(image)
    def save(
        self,
        output_path: str,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.
        :param output_path:     Path to the output video file.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
        save_image(image=image, path=output_path)
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 64
def draw(
    self,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> np.ndarray:
    """Draw the predicted bboxes on the image.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    :return:                Image with predicted bboxes. Note that this does not modify the original image.
    image = self.image.copy()
    for pred_i in np.argsort(self.prediction.scores):
        image = draw_skeleton(
            image=image,
            keypoints=self.prediction.poses[pred_i],
            score=self.prediction.scores[pred_i],
            show_confidence=show_confidence,
            edge_links=self.prediction.edge_links,
            edge_colors=edge_colors or self.prediction.edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
    return image
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 123
def save(
    self,
    output_path: str,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.
    :param output_path:     Path to the output video file.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
    save_image(image=image, path=output_path)
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 96
def show(
    self,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Display the image with predicted bboxes.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    image = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    show_image(image)
          Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
204
@dataclass
class ImagesPoseEstimationPrediction(ImagesPredictions):
    """Object wrapping the list of image detection predictions.
    :attr _images_prediction_lst:  List of the predictions results
    _images_prediction_lst: List[ImagePoseEstimationPrediction]
    def show(
        self,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Display the predicted bboxes on the images.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        for prediction in self._images_prediction_lst:
            prediction.show(
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
    def save(
        self,
        output_folder: str,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.
        :param output_folder:   Folder path, where the images will be saved.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)
        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(
                output_path=image_output_path,
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 204
def save(
    self,
    output_folder: str,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.
    :param output_folder:   Folder path, where the images will be saved.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)
    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(
            output_path=image_output_path,
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 165
def show(
    self,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Display the predicted bboxes on the images.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    for prediction in self._images_prediction_lst:
        prediction.show(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
      

Object wrapping the list of image detection predictions as a Video.

:attr _images_prediction_lst: List of the predictions results :att fps: Frames per second of the video

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 320
@dataclass
class VideoPoseEstimationPrediction(VideoPredictions):
    """Object wrapping the list of image detection predictions as a Video.
    :attr _images_prediction_lst:   List of the predictions results
    :att fps:                       Frames per second of the video
    _images_prediction_lst: List[ImagePoseEstimationPrediction]
    fps: int
    def draw(
        self,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> List[np.ndarray]:
        """Draw the predicted bboxes on the images.
        :param output_folder:   Folder path, where the images will be saved.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        :return:                List of images with predicted bboxes. Note that this does not modify the original image.
        frames_with_bbox = [
            result.draw(
                edge_colors=edge_colors,
                joint_thickness=joint_thickness,
                keypoint_colors=keypoint_colors,
                keypoint_radius=keypoint_radius,
                box_thickness=box_thickness,
                show_confidence=show_confidence,
            for result in self._images_prediction_lst
        return frames_with_bbox
    def show(
        self,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Display the predicted bboxes on the images.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        frames = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)
    def save(
        self,
        output_path: str,
        edge_colors=None,
        joint_thickness: int = 2,
        keypoint_colors=None,
        keypoint_radius: int = 5,
        box_thickness: int = 2,
        show_confidence: bool = False,
    ) -> None:
        """Save the predicted bboxes on the images.
        :param output_path:     Path to the output video file.
        :param edge_colors:    Optional list of tuples representing the colors for each joint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joint links in the skeleton.
        :param joint_thickness: Thickness of the joint links  (in pixels).
        :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                                If None, default colors are used.
                                If not None the length must be equal to the number of joints in the skeleton.
        :param keypoint_radius: Radius of the keypoints (in pixels).
        :param show_confidence: Whether to show confidence scores on the image.
        :param box_thickness:   Thickness of bounding boxes.
        frames = self.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        save_video(output_path=output_path, frames=frames, fps=self.fps)
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 254
def draw(
    self,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> List[np.ndarray]:
    """Draw the predicted bboxes on the images.
    :param output_folder:   Folder path, where the images will be saved.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    :return:                List of images with predicted bboxes. Note that this does not modify the original image.
    frames_with_bbox = [
        result.draw(
            edge_colors=edge_colors,
            joint_thickness=joint_thickness,
            keypoint_colors=keypoint_colors,
            keypoint_radius=keypoint_radius,
            box_thickness=box_thickness,
            show_confidence=show_confidence,
        for result in self._images_prediction_lst
    return frames_with_bbox
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 320
def save(
    self,
    output_path: str,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Save the predicted bboxes on the images.
    :param output_path:     Path to the output video file.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    frames = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    save_video(output_path=output_path, frames=frames, fps=self.fps)
          

Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.

joint_thickness

Thickness of the joint links (in pixels).

keypoint_colors

Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.

keypoint_radius

Radius of the keypoints (in pixels).

show_confidence

Whether to show confidence scores on the image.

False box_thickness

Thickness of bounding boxes.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py 286
def show(
    self,
    edge_colors=None,
    joint_thickness: int = 2,
    keypoint_colors=None,
    keypoint_radius: int = 5,
    box_thickness: int = 2,
    show_confidence: bool = False,
) -> None:
    """Display the predicted bboxes on the images.
    :param edge_colors:    Optional list of tuples representing the colors for each joint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joint links in the skeleton.
    :param joint_thickness: Thickness of the joint links  (in pixels).
    :param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
                            If None, default colors are used.
                            If not None the length must be equal to the number of joints in the skeleton.
    :param keypoint_radius: Radius of the keypoints (in pixels).
    :param show_confidence: Whether to show confidence scores on the image.
    :param box_thickness:   Thickness of bounding boxes.
    frames = self.draw(
        edge_colors=edge_colors,
        joint_thickness=joint_thickness,
        keypoint_colors=keypoint_colors,
        keypoint_radius=keypoint_radius,
        box_thickness=box_thickness,
        show_confidence=show_confidence,
    show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)
      

Object wrapping an image and a classification model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 91
@dataclass
class ImageClassificationPrediction(ImagePrediction):
    """Object wrapping an image and a classification model's prediction.
    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    image: np.ndarray
    prediction: ClassificationPrediction
    class_names: List[str]
    def draw(self, show_confidence: bool = True) -> np.ndarray:
        """Draw the predicted label on the image.
        :param show_confidence: Whether to show confidence scores on the image.
        :return:                Image with predicted label.
        image = self.image.copy()
        return draw_label(
            image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:]
    def show(self, show_confidence: bool = True) -> None:
        """Display the image with predicted label.
        :param show_confidence: Whether to show confidence scores on the image.
        # to do draw the prediction on the image
        image = self.draw(show_confidence=show_confidence)
        show_image(image)
    def save(
        self,
        output_path: str,
        show_confidence: bool = True,
    ) -> None:
        """Save the predicted label on the images.
        :param output_path:     Path to the output video file.
        :param show_confidence: Whether to show confidence scores on the image.
        image = self.draw(show_confidence=show_confidence)
        save_image(image=image, path=output_path)
69
def draw(self, show_confidence: bool = True) -> np.ndarray:
    """Draw the predicted label on the image.
    :param show_confidence: Whether to show confidence scores on the image.
    :return:                Image with predicted label.
    image = self.image.copy()
    return draw_label(
        image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:]
    """Save the predicted label on the images.
    :param output_path:     Path to the output video file.
    :param show_confidence: Whether to show confidence scores on the image.
    image = self.draw(show_confidence=show_confidence)
    save_image(image=image, path=output_path)
78
def show(self, show_confidence: bool = True) -> None:
    """Display the image with predicted label.
    :param show_confidence: Whether to show confidence scores on the image.
    # to do draw the prediction on the image
    image = self.draw(show_confidence=show_confidence)
    show_image(image)
      

Object wrapping an image and a detection model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 157
@dataclass
class ImageDetectionPrediction(ImagePrediction):
    """Object wrapping an image and a detection model's prediction.
    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    image: np.ndarray
    prediction: DetectionPrediction
    class_names: List[str]
    def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> np.ndarray:
        """Draw the predicted bboxes on the image.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :return:                Image with predicted bboxes. Note that this does not modify the original image.
        image = self.image.copy()
        color_mapping = color_mapping or generate_color_mapping(len(self.class_names))
        for pred_i in np.argsort(self.prediction.confidence):
            class_id = int(self.prediction.labels[pred_i])
            score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
            image = draw_bbox(
                image=image,
                title=f"{self.class_names[class_id]} {score}",
                color=color_mapping[class_id],
                box_thickness=box_thickness,
                x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
                y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
                x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
                y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
        return image
    def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Display the image with predicted bboxes.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
        show_image(image)
    def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Save the predicted bboxes on the images.
        :param output_path:     Path to the output video file.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
        save_image(image=image, path=output_path)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 134
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> np.ndarray:
    """Draw the predicted bboxes on the image.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :return:                Image with predicted bboxes. Note that this does not modify the original image.
    image = self.image.copy()
    color_mapping = color_mapping or generate_color_mapping(len(self.class_names))
    for pred_i in np.argsort(self.prediction.confidence):
        class_id = int(self.prediction.labels[pred_i])
        score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
        image = draw_bbox(
            image=image,
            title=f"{self.class_names[class_id]} {score}",
            color=color_mapping[class_id],
            box_thickness=box_thickness,
            x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
            y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
            x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
            y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
    return image
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 157
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Save the predicted bboxes on the images.
    :param output_path:     Path to the output video file.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
    save_image(image=image, path=output_path)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 145
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Display the image with predicted bboxes.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
    show_image(image)
      

Object wrapping an image and a model's prediction.

:attr image: Input image :attr predictions: Predictions of the model :attr class_names: List of the class names to predict

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 43
@dataclass
class ImagePrediction(ABC):
    """Object wrapping an image and a model's prediction.
    :attr image:        Input image
    :attr predictions:  Predictions of the model
    :attr class_names:  List of the class names to predict
    image: np.ndarray
    prediction: Prediction
    class_names: List[str]
    @abstractmethod
    def draw(self, *args, **kwargs) -> np.ndarray:
        """Draw the predictions on the image."""
    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the image."""
    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the image."""
33
@abstractmethod
def draw(self, *args, **kwargs) -> np.ndarray:
    """Draw the predictions on the image."""
43
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the image."""
38
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the image."""
          Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
238
@dataclass
class ImagesClassificationPrediction(ImagesPredictions):
    """Object wrapping the list of image classification predictions.
    :attr _images_prediction_lst:  List of the predictions results
    _images_prediction_lst: List[ImageClassificationPrediction]
    def show(self, show_confidence: bool = True) -> None:
        """Display the predicted labels on the images.
        :param show_confidence: Whether to show confidence scores on the image.
        for prediction in self._images_prediction_lst:
            prediction.show(show_confidence=show_confidence)
    def save(self, output_folder: str, show_confidence: bool = True) -> None:
        """Save the predicted label on the images.
        :param output_folder:     Folder path, where the images will be saved.
        :param show_confidence: Whether to show confidence scores on the image.
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)
        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(output_path=image_output_path, show_confidence=show_confidence)
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
238
def save(self, output_folder: str, show_confidence: bool = True) -> None:
    """Save the predicted label on the images.
    :param output_folder:     Folder path, where the images will be saved.
    :param show_confidence: Whether to show confidence scores on the image.
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)
    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(output_path=image_output_path, show_confidence=show_confidence)
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
225
def show(self, show_confidence: bool = True) -> None:
    """Display the predicted labels on the images.
    :param show_confidence: Whether to show confidence scores on the image.
    for prediction in self._images_prediction_lst:
        prediction.show(show_confidence=show_confidence)
          Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
277
@dataclass
class ImagesDetectionPrediction(ImagesPredictions):
    """Object wrapping the list of image detection predictions.
    :attr _images_prediction_lst:  List of the predictions results
    _images_prediction_lst: List[ImageDetectionPrediction]
    def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Display the predicted bboxes on the images.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        for prediction in self._images_prediction_lst:
            prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
    def save(
        self, output_folder: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None
    ) -> None:
        """Save the predicted bboxes on the images.
        :param output_folder:     Folder path, where the images will be saved.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        if output_folder:
            os.makedirs(output_folder, exist_ok=True)
        for i, prediction in enumerate(self._images_prediction_lst):
            image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
            prediction.save(output_path=image_output_path, box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 277
def save(
    self, output_folder: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None
) -> None:
    """Save the predicted bboxes on the images.
    :param output_folder:     Folder path, where the images will be saved.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    if output_folder:
        os.makedirs(output_folder, exist_ok=True)
    for i, prediction in enumerate(self._images_prediction_lst):
        image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
        prediction.save(output_path=image_output_path, box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 259
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Display the predicted bboxes on the images.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    for prediction in self._images_prediction_lst:
        prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
          Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
186
@dataclass
class ImagesPredictions(ABC):
    """Object wrapping the list of image predictions.
    :attr _images_prediction_lst: List of results of the run
    _images_prediction_lst: List[ImagePrediction]
    def __len__(self) -> int:
        return len(self._images_prediction_lst)
    def __getitem__(self, index: int) -> ImagePrediction:
        return self._images_prediction_lst[index]
    def __iter__(self) -> Iterator[ImagePrediction]:
        return iter(self._images_prediction_lst)
    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the images."""
    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the images."""
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
186
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the images."""
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
181
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the images."""
      

Object wrapping the list of image detection predictions as a Video.

:attr _images_prediction_lst: List of the predictions results :att fps: Frames per second of the video

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 326
@dataclass
class VideoDetectionPrediction(VideoPredictions):
    """Object wrapping the list of image detection predictions as a Video.
    :attr _images_prediction_lst:   List of the predictions results
    :att fps:                       Frames per second of the video
    _images_prediction_lst: List[ImageDetectionPrediction]
    fps: int
    def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> List[np.ndarray]:
        """Draw the predicted bboxes on the images.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        :return:                List of images with predicted bboxes. Note that this does not modify the original image.
        frames_with_bbox = [
            result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst
        return frames_with_bbox
    def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Display the predicted bboxes on the images.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
        show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
    def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
        """Save the predicted bboxes on the images.
        :param output_path:     Path to the output video file.
        :param box_thickness:   Thickness of bounding boxes.
        :param show_confidence: Whether to show confidence scores on the image.
        :param color_mapping:   List of tuples representing the colors for each class.
                                Default is None, which generates a default color mapping based on the number of class names.
        frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
        save_video(output_path=output_path, frames=frames, fps=self.fps)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 303
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> List[np.ndarray]:
    """Draw the predicted bboxes on the images.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    :return:                List of images with predicted bboxes. Note that this does not modify the original image.
    frames_with_bbox = [
        result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst
    return frames_with_bbox
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 326
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Save the predicted bboxes on the images.
    :param output_path:     Path to the output video file.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
    save_video(output_path=output_path, frames=frames, fps=self.fps)
                Optional[List[Tuple[int, int, int]]]
          

List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 314
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
    """Display the predicted bboxes on the images.
    :param box_thickness:   Thickness of bounding boxes.
    :param show_confidence: Whether to show confidence scores on the image.
    :param color_mapping:   List of tuples representing the colors for each class.
                            Default is None, which generates a default color mapping based on the number of class names.
    frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
    show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
      

Object wrapping the list of image predictions as a Video.

:attr _images_prediction_lst: List of results of the run :att fps: Frames per second of the video

Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py 208
@dataclass
class VideoPredictions(ImagesPredictions, ABC):
    """Object wrapping the list of image predictions as a Video.
    :attr _images_prediction_lst:   List of results of the run
    :att fps:                       Frames per second of the video
    _images_prediction_lst: List[ImagePrediction]
    fps: float
    @abstractmethod
    def show(self, *args, **kwargs) -> None:
        """Display the predictions on the video."""
    @abstractmethod
    def save(self, *args, **kwargs) -> None:
        """Save the predictions on the video."""
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
208
@abstractmethod
def save(self, *args, **kwargs) -> None:
    """Save the predictions on the video."""
        Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
203
@abstractmethod
def show(self, *args, **kwargs) -> None:
    """Display the predictions on the video."""
139
@dataclass
class ClassificationPrediction(Prediction):
    """Represents a Classification prediction"""
    confidence: float
    labels: int
    image_shape: Tuple[int, int]
    def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]):
        :param confidence:  Confidence scores for each bounding box
        :param labels:      Labels for each bounding box.
        :param image_shape: Shape of the image the prediction is made on, (H, W).
        self._validate_input(confidence, labels)
        self.confidence = confidence
        self.labels = labels
        self.image_shape = image_shape
    def _validate_input(self, confidence: np.ndarray, labels: np.ndarray) -> None:
        if not isinstance(confidence, float):
            raise ValueError(f"Argument confidence must be a numpy array, not {type(confidence)}")
        if not isinstance(labels, int):
            raise ValueError(f"Argument labels must be a numpy array, not {type(labels)}")
    def __len__(self):
        return len(self.labels)
130
def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]):
    :param confidence:  Confidence scores for each bounding box
    :param labels:      Labels for each bounding box.
    :param image_shape: Shape of the image the prediction is made on, (H, W).
    self._validate_input(confidence, labels)
    self.confidence = confidence
    self.labels = labels
    self.image_shape = image_shape
55
@dataclass
class DetectionPrediction(Prediction):
    """Represents a detection prediction, with bboxes represented in xyxy format."""
    bboxes_xyxy: np.ndarray
    confidence: np.ndarray
    labels: np.ndarray
    def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
        :param bboxes:      BBoxes in the format specified by bbox_format
        :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
        :param confidence:  Confidence scores for each bounding box
        :param labels:      Labels for each bounding box.
        :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
        self._validate_input(bboxes, confidence, labels)
        factory = BBoxFormatFactory()
        bboxes_xyxy = convert_bboxes(
            bboxes=bboxes,
            image_shape=image_shape,
            source_format=factory.get(bbox_format),
            target_format=factory.get("xyxy"),
            inplace=False,
        self.bboxes_xyxy = bboxes_xyxy
        self.confidence = confidence
        self.labels = labels
    def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None:
        n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0]
        if n_bboxes != n_confidences != n_labels:
            raise ValueError(
                f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})."
    def __len__(self):
        return len(self.bboxes_xyxy)
                Tuple[int, int]
          

Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format

required 45
def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
    :param bboxes:      BBoxes in the format specified by bbox_format
    :param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
    :param confidence:  Confidence scores for each bounding box
    :param labels:      Labels for each bounding box.
    :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
    self._validate_input(bboxes, confidence, labels)
    factory = BBoxFormatFactory()
    bboxes_xyxy = convert_bboxes(
        bboxes=bboxes,
        image_shape=image_shape,
        source_format=factory.get(bbox_format),
        target_format=factory.get("xyxy"),
        inplace=False,
    self.bboxes_xyxy = bboxes_xyxy
    self.confidence = confidence
    self.labels = labels
      

Represents a pose estimation prediction.

:attr poses: Numpy array of [Num Poses, Num Joints, 2] shape :attr scores: Numpy array of [Num Poses] shape

Source code in V3_2/src/super_gradients/training/utils/predict/predictions.py 108
@dataclass
class PoseEstimationPrediction(Prediction):
    """Represents a pose estimation prediction.
    :attr poses: Numpy array of [Num Poses, Num Joints, 2] shape
    :attr scores: Numpy array of [Num Poses] shape
    poses: np.ndarray
    scores: np.ndarray
    edge_links: np.ndarray
    edge_colors: np.ndarray
    keypoint_colors: np.ndarray
    image_shape: Tuple[int, int]
    def __init__(
        self,
        poses: np.ndarray,
        scores: np.ndarray,
        edge_links: np.ndarray,
        edge_colors: np.ndarray,
        keypoint_colors: np.ndarray,
        image_shape: Tuple[int, int],
        :param poses:
        :param scores:
        :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
        self._validate_input(poses, scores, edge_links, edge_colors, keypoint_colors)
        self.poses = poses
        self.scores = scores
        self.edge_links = edge_links
        self.edge_colors = edge_colors
        self.image_shape = image_shape
        self.keypoint_colors = keypoint_colors
    def _validate_input(self, poses: np.ndarray, scores: np.ndarray, edge_links, edge_colors, keypoint_colors) -> None:
        if not isinstance(poses, np.ndarray):
            raise ValueError(f"Argument poses must be a numpy array, not {type(poses)}")
        if not isinstance(scores, np.ndarray):
            raise ValueError(f"Argument scores must be a numpy array, not {type(scores)}")
        if not isinstance(keypoint_colors, np.ndarray):
            raise ValueError(f"Argument keypoint_colors must be a numpy array, not {type(keypoint_colors)}")
        if len(poses) != len(scores) != len(keypoint_colors):
            raise ValueError(f"The number of poses ({len(poses)}) does not match the number of scores ({len(scores)}).")
        if len(edge_links) != len(edge_colors):
            raise ValueError(f"The number of joint links ({len(edge_links)}) does not match the number of joint colors ({len(edge_colors)}).")
    def __len__(self):
        return len(self.poses)
                Tuple[int, int]
          

Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format

required edge_links: np.ndarray, edge_colors: np.ndarray, keypoint_colors: np.ndarray, image_shape: Tuple[int, int], :param poses: :param scores: :param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format self._validate_input(poses, scores, edge_links, edge_colors, keypoint_colors) self.poses = poses self.scores = scores self.edge_links = edge_links self.edge_colors = edge_colors self.image_shape = image_shape self.keypoint_colors = keypoint_colors

Quantization utilities

Methods are based on: https://github.com/NVIDIA/TensorRT/blob/51a4297753d3e12d0eed864be52400f429a6a94c/tools/pytorch-quantization/examples/torchvision/classification_flow.py#L385

(Licensed under the Apache License, Version 2.0)

149
class QuantizationCalibrator:
    def __init__(self, torch_hist: bool = True, verbose: bool = True) -> None:
        if _imported_pytorch_quantization_failure is not None:
            raise _imported_pytorch_quantization_failure
        super().__init__()
        self.verbose = verbose
        self.torch_hist = torch_hist
    def calibrate_model(
        self,
        model: torch.nn.Module,
        calib_data_loader: torch.utils.data.DataLoader,
        method: str = "percentile",
        num_calib_batches: int = 2,
        percentile: float = 99.99,
        Calibrates torch model with quantized modules.
        :param model:               torch.nn.Module, model to perfrom the calibration on.
        :param calib_data_loader:   torch.utils.data.DataLoader, data loader of the calibration dataset.
                                    Assumes that the first element of the tuple is the input image.
        :param method:              str, One of [percentile, mse, entropy, max].
                                    Statistics method for amax computation of the quantized modules
                                    (Default=percentile).
        :param num_calib_batches:   int, number of batches to collect the statistics from.
        :param percentile:          float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
                                    Discarded when other methods are used (Default=99.99).
        logging_level = logging.getLogger("absl").getEffectiveLevel()
        if not self.verbose:  # suppress pytorch-quantization spam
            logging.getLogger("absl").setLevel("ERROR")
        acceptable_methods = ["percentile", "mse", "entropy", "max"]
        if method in acceptable_methods:
            with torch.no_grad():
                device = next(model.parameters()).device
                self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
                # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
                # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
                if method == "precentile":
                    self._compute_amax(model, method="percentile", percentile=percentile)
                else:
                    self._compute_amax(model, method=method)
                model.to(device)
        else:
            raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")
        logging.getLogger("absl").setLevel(logging_level)
    def _collect_stats(self, model, data_loader, num_batches):
        """Feed data to the network and collect statistics"""
        local_rank = get_local_rank()
        world_size = get_world_size()
        device = next(model.parameters()).device
        # Enable calibrators
        self._enable_calibrators(model)
        # Feed data to the network for collecting stats
        for i, (image, *_) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
            if world_size > 1:
                all_batches = [torch.zeros_like(image, device=device) for _ in range(world_size)]
                all_gather(all_batches, image.to(device=device))
            else:
                all_batches = [image]
            for local_image in all_batches:
                model(local_image.to(device=device))
            if i >= num_batches:
                break
        # Disable calibrators
        self._disable_calibrators(model)
    def _disable_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    module.disable_calib()
                    module.enable_quant()
                else:
                    module.enable()
    def reset_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    module._calibrator.reset()  # release memory
    def _enable_calibrators(self, model):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    if isinstance(module._calibrator, calib.HistogramCalibrator):
                        module._calibrator._torch_hist = self.torch_hist  # TensorQuantizer does not expose it as API
                    module.disable_quant()
                    module.enable_calib()
                else:
                    module.disable()
    def _compute_amax(self, model, **kwargs):
        for name, module in model.named_modules():
            if isinstance(module, quant_nn.TensorQuantizer):
                if module._calibrator is not None:
                    if isinstance(module._calibrator, calib.MaxCalibrator):
                        module.load_calib_amax()
                    else:
                        module.load_calib_amax(**kwargs)
                if hasattr(module, "clip"):
                    module.init_learn_amax()
                if self.verbose:
                    print(f"{name:40}: {module}")
                torch.utils.data.DataLoader
          

torch.utils.data.DataLoader, data loader of the calibration dataset. Assumes that the first element of the tuple is the input image.

required method

str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules (Default=percentile).

'percentile' num_calib_batches

int, number of batches to collect the statistics from.

percentile float

float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).

99.99 self, model: torch.nn.Module, calib_data_loader: torch.utils.data.DataLoader, method: str = "percentile", num_calib_batches: int = 2, percentile: float = 99.99, Calibrates torch model with quantized modules. :param model: torch.nn.Module, model to perfrom the calibration on. :param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset. Assumes that the first element of the tuple is the input image. :param method: str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules (Default=percentile). :param num_calib_batches: int, number of batches to collect the statistics from. :param percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99). logging_level = logging.getLogger("absl").getEffectiveLevel() if not self.verbose: # suppress pytorch-quantization spam logging.getLogger("absl").setLevel("ERROR") acceptable_methods = ["percentile", "mse", "entropy", "max"] if method in acceptable_methods: with torch.no_grad(): device = next(model.parameters()).device self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches) # FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS, # SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT. if method == "precentile": self._compute_amax(model, method="percentile", percentile=percentile) else: self._compute_amax(model, method=method) model.to(device) else: raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}") logging.getLogger("absl").setLevel(logging_level)

This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized class, with relevant quant descriptors.

Example: self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)

Source code in V3_2/src/super_gradients/training/utils/quantization/core.py 165
class QuantizedMapping(nn.Module):
    This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized
    class, with relevant quant descriptors.
    Example:
        self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)
    def __init__(
        self,
        float_module: nn.Module,
        quantized_target_class: Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]],
        action=QuantizedMetadata.ReplacementAction.REPLACE,
        input_quant_descriptor: QuantDescriptor = None,
        weights_quant_descriptor: QuantDescriptor = None,
    ) -> None:
        super().__init__()
        self.float_module = float_module
        self.quantized_target_class = quantized_target_class
        self.action = action
        self.input_quant_descriptor = input_quant_descriptor
        self.weights_quant_descriptor = weights_quant_descriptor
        self.forward = float_module.forward
      

This dataclass is responsible for holding the information regarding float->quantized module relation. It can be both layer-grained and module-grained, e.g., module.backbone.conv1 -> QuantConv2d, nn.Linear -> QuantLinear, etc...

Parameters:

Description Default Union[str, Type]

Name of a specific layer (e.g., module.backbone.conv1), or a specific type (e.g., Conv2d) that will be later quantized

required quantized_target_class Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]

Quantized type that the source will be converted to

required action ReplacementAction

how to resolve the conversion, we either: - SKIP: skip it, - UNWRAP: unwrap the instance and work with the wrapped one (i.e., we wrap with a mapper), - REPLACE: replace source with an instance of the quantized type - REPLACE_AND_RECURE: replace source with an instance of the quantized type, then try to recursively quantize the child modules of that type - RECURE_AND_REPLACE: recursively quantize the child modules, then replace source with an instance of the quantized type

required input_quant_descriptor QuantDescriptor

Quantization descriptor for inputs (None will take the default one)

weights_quant_descriptor QuantDescriptor

Quantization descriptor for weights (None will take the default one)

class QuantizedMetadata: This dataclass is responsible for holding the information regarding float->quantized module relation. It can be both layer-grained and module-grained, e.g., `module.backbone.conv1 -> QuantConv2d`, `nn.Linear -> QuantLinear`, etc... :param float_source: Name of a specific layer (e.g., `module.backbone.conv1`), or a specific type (e.g., `Conv2d`) that will be later quantized :param quantized_target_class: Quantized type that the source will be converted to :param action: how to resolve the conversion, we either: - SKIP: skip it, - UNWRAP: unwrap the instance and work with the wrapped one (i.e., we wrap with a mapper), - REPLACE: replace source with an instance of the quantized type - REPLACE_AND_RECURE: replace source with an instance of the quantized type, then try to recursively quantize the child modules of that type - RECURE_AND_REPLACE: recursively quantize the child modules, then replace source with an instance of the quantized type :param input_quant_descriptor: Quantization descriptor for inputs (None will take the default one) :param weights_quant_descriptor: Quantization descriptor for weights (None will take the default one) class ReplacementAction(Enum): REPLACE = "replace" REPLACE_AND_RECURE = "replace_and_recure" RECURE_AND_REPLACE = "recure_and_replace" UNWRAP = "unwrap" SKIP = "skip" float_source: Union[str, Type] quantized_target_class: Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]] action: ReplacementAction input_quant_descriptor: QuantDescriptor = None # default is used if None weights_quant_descriptor: QuantDescriptor = None # default is used if None def __post_init__(self): if self.action in ( QuantizedMetadata.ReplacementAction.REPLACE, QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE, QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE, assert issubclass(self.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin))

A base class for user custom Quantized classes. Every Quantized class must inherit this mixin, which adds from_float class-method. NOTES: * the Quantized class may also inherit from the native QuantMixin or QuantInputMixin * quant descriptors (for inputs and weights) will be passed as kwargs. The module may ignore them if they are not necessary * the default implementation of from_float is inspecting the init args, and searching for corresponding properties from the float instance that is passed as argument, e.g., for __init__(self, a) the mechanism will look for float_instance.a and pass that value to the __init__ method

Source code in V3_2/src/super_gradients/training/utils/quantization/core.py 78
class SGQuantMixin(nn.Module):
    A base class for user custom Quantized classes.
    Every Quantized class must inherit this mixin, which adds `from_float` class-method.
    NOTES:
        * the Quantized class may also inherit from the native `QuantMixin` or `QuantInputMixin`
        * quant descriptors (for inputs and weights) will be passed as `kwargs`. The module may ignore them if they are
          not necessary
        * the default implementation of `from_float` is inspecting the __init__ args, and searching for corresponding
          properties from the float instance that is passed as argument, e.g., for `__init__(self, a)`
          the mechanism will look for `float_instance.a` and pass that value to the `__init__` method
    @classmethod
    def from_float(cls, float_instance, **kwargs):
        required_init_params = list(inspect.signature(cls.__init__).parameters)[1:]  # [0] is self
        # if cls.__init__ has explicit `quant_desc_input` or `quant_desc_weight` - we don't search the state of the
        # float module, because it would not contain this state. these values are injected by the framework
        ignore_init_args = {"quant_desc_input", "quant_desc_weight"}.intersection(set(required_init_params))
        # if cls.__init__ doesn't have neither **kwargs, nor `quant_desc_input` and `quant_desc_weight`,
        # we should also remove these keys from the passed kwargs and make sure there's nothing more!
        if "kwargs" not in required_init_params:
            for arg in ("quant_desc_input", "quant_desc_weight"):
                if arg in ignore_init_args:
                    continue
                kwargs.pop(arg, None)  # we ignore if not existing
        return _from_float(cls, float_instance, ignore_init_args, **kwargs)
      

This class wraps a float module instance, and defines that this instance will not be converted to quantized version

Example: self.my_block = SkipQuantization(MyBlock(4, n_classes))

Source code in V3_2/src/super_gradients/training/utils/quantization/core.py 92
class SkipQuantization(nn.Module):
    This class wraps a float module instance, and defines that this instance will not be converted to quantized version
    Example:
        self.my_block = SkipQuantization(MyBlock(4, n_classes))
    def __init__(self, module: nn.Module) -> None:
        super().__init__()
        self.float_module = module
        self.forward = module.forward
          

transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices

train

export model in training mode

False model torch.nn.Module

torch.nn.Module, model to export

required onnx_filename

str, target path for the onnx file,

required input_shape tuple

tuple, input shape (usually BCHW)

required deepcopy_model

Whether to export deepcopy(model). Necessary in case further training is performed and prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).

False 65
def export_quantized_module_to_onnx(
    model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, deepcopy_model=False, **kwargs
    Method for exporting onnx after QAT.
    :param to_cpu: transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices
    :param train: export model in training mode
    :param model: torch.nn.Module, model to export
    :param onnx_filename: str, target path for the onnx file,
    :param input_shape: tuple, input shape (usually BCHW)
    :param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and
     prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
    if _imported_pytorch_quantization_failure is not None:
        raise _imported_pytorch_quantization_failure
    if deepcopy_model:
        model = deepcopy(model)
    use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
    quant_nn.TensorQuantizer.use_fb_fake_quant = True
    # Export ONNX for multiple batch sizes
    logger.info("Creating ONNX file: " + onnx_filename)
    if train:
        training_mode = TrainingMode.TRAINING
        model.train()
    else:
        training_mode = TrainingMode.EVAL
        model.eval()
        if hasattr(model, "prep_model_for_conversion"):
            model.prep_model_for_conversion(**kwargs)
    # workaround when model.prep_model_for_conversion does reparametrization
    # and tensors get scattered to different devices
    if to_cpu:
        export_model = model.cpu()
    else:
        export_model = model
    dummy_input = torch.randn(input_shape, device=next(model.parameters()).device)
    torch.onnx.export(export_model, dummy_input, onnx_filename, verbose=False, opset_version=13, do_constant_folding=True, training=training_mode)
    # Restore functions of quant_nn back as expected
    quant_nn.TensorQuantizer.use_fb_fake_quant = use_fb_fake_quant_state
          Source code in V3_2/src/super_gradients/training/utils/quantization/selective_quantization_utils.py
326
class SelectiveQuantizer:
    :param custom_mappings:                             custom mappings that extend the default mappings with extra behaviour
    :param default_per_channel_quant_weights:           whether quant module weights should be per channel (default=True)
    :param default_quant_modules_calibrator_weights:    default calibrator method for weights (default='max')
    :param default_quant_modules_calibrator_inputs:     default calibrator method for inputs (default='histogram')
    :param default_learn_amax:                          EXPERIMENTAL! whether quant modules should have learnable amax (default=False)
    if _imported_pytorch_quantization_failure is not None:
        raise _imported_pytorch_quantization_failure
    mapping_instructions: Dict[Union[str, Type], QuantizedMetadata] = {
            float_type: QuantizedMetadata(
                float_source=float_type,
                quantized_target_class=quantized_target_class,
                action=QuantizedMetadata.ReplacementAction.REPLACE,
            for (float_type, quantized_target_class) in [
                (nn.Conv1d, quant_nn.QuantConv1d),
                (nn.Conv2d, quant_nn.QuantConv2d),
                (nn.Conv3d, quant_nn.QuantConv3d),
                (nn.ConvTranspose1d, quant_nn.QuantConvTranspose1d),
                (nn.ConvTranspose2d, quant_nn.QuantConvTranspose2d),
                (nn.ConvTranspose3d, quant_nn.QuantConvTranspose3d),
                (nn.Linear, quant_nn.Linear),
                (nn.LSTM, quant_nn.LSTM),
                (nn.LSTMCell, quant_nn.LSTMCell),
                (nn.AvgPool1d, quant_nn.QuantAvgPool1d),
                (nn.AvgPool2d, quant_nn.QuantAvgPool2d),
                (nn.AvgPool3d, quant_nn.QuantAvgPool3d),
                (nn.AdaptiveAvgPool1d, quant_nn.QuantAdaptiveAvgPool1d),
                (nn.AdaptiveAvgPool2d, quant_nn.QuantAdaptiveAvgPool2d),
                (nn.AdaptiveAvgPool3d, quant_nn.QuantAdaptiveAvgPool3d),
        SkipQuantization: QuantizedMetadata(float_source=SkipQuantization, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP),
    }  # DEFAULT MAPPING INSTRUCTIONS
    def __init__(
        self,
        custom_mappings: dict = None,
        default_quant_modules_calibrator_weights: str = "max",
        default_quant_modules_calibrator_inputs: str = "histogram",
        default_per_channel_quant_weights: bool = True,
        default_learn_amax: bool = False,
        verbose: bool = True,
    ) -> None:
        super().__init__()
        self.default_quant_modules_calibrator_weights = default_quant_modules_calibrator_weights
        self.default_quant_modules_calibrator_inputs = default_quant_modules_calibrator_inputs
        self.default_per_channel_quant_weights = default_per_channel_quant_weights
        self.default_learn_amax = default_learn_amax
        self.verbose = verbose
        self.mapping_instructions = self.mapping_instructions.copy()
        if custom_mappings is not None:
            self.mapping_instructions.update(custom_mappings)  # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED
    def _get_default_quant_descriptor(self, for_weights=False):
        methods = {"percentile": "histogram", "mse": "histogram", "entropy": "histogram", "histogram": "histogram", "max": "max"}
        if for_weights:
            axis = 0 if self.default_per_channel_quant_weights else None
            learn_amax = self.default_learn_amax
            if self.default_learn_amax and self.default_per_channel_quant_weights:
                logger.error("Learnable amax is suported only for per-tensor quantization. Disabling it for weights quantization!")
                learn_amax = False
            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_weights], axis=axis, learn_amax=learn_amax)
        else:
            # activations stay per-tensor by default
            return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_inputs], learn_amax=self.default_learn_amax)
    def register_skip_quantization(self, *, layer_names: Optional[Set[str]] = None):
        if layer_names is not None:
            self.mapping_instructions.update(
                    name: QuantizedMetadata(float_source=name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.SKIP)
                    for name in layer_names
    def register_quantization_mapping(
        self, *, layer_names: Set[str], quantized_target_class: Type[SGQuantMixin], input_quant_descriptor=None, weights_quant_descriptor=None
        self.mapping_instructions.update(
                name: QuantizedMetadata(
                    float_source=name,
                    quantized_target_class=quantized_target_class,
                    action=QuantizedMetadata.ReplacementAction.REPLACE,
                    input_quant_descriptor=input_quant_descriptor,
                    weights_quant_descriptor=weights_quant_descriptor,
                for name in layer_names
    def _preprocess_skips_and_custom_mappings(self, module: nn.Module, nesting: Tuple[str, ...] = ()):
        This pass is done to extract layer name and mapping instructions, so that we regard to per-layer processing.
        Relevant layer-specific mapping instructions are either `SkipQuantization` or `QuantizedMapping`, which are then
        being added to the mappings
        mapping_instructions = dict()
        for name, child_module in module.named_children():
            nested_name = ".".join(nesting + (name,))
            if isinstance(child_module, SkipQuantization):
                mapping_instructions[nested_name] = QuantizedMetadata(
                    float_source=nested_name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP
            if isinstance(child_module, QuantizedMapping):
                mapping_instructions[nested_name] = QuantizedMetadata(
                    float_source=nested_name,
                    quantized_target_class=child_module.quantized_target_class,
                    input_quant_descriptor=child_module.input_quant_descriptor,
                    weights_quant_descriptor=child_module.weights_quant_descriptor,
                    action=child_module.action,
            if isinstance(child_module, nn.Module):  # recursive call
                mapping_instructions.update(self._preprocess_skips_and_custom_mappings(child_module, nesting + (name,)))
        return mapping_instructions
    def _instantiate_quantized_from_float(self, float_module, metadata, preserve_state_dict):
        base_classes = (QuantMixin, QuantInputMixin, SGQuantMixin)
        if not issubclass(metadata.quantized_target_class, base_classes):
            raise AssertionError(
                f"Quantization suite for {type(float_module).__name__} is invalid. "
                f"{metadata.quantized_target_class.__name__} must inherit one of "
                f"{', '.join(map(lambda _: _.__name__, base_classes))}"
        # USE PROVIDED QUANT DESCRIPTORS, OR DEFAULT IF NONE PROVIDED
        quant_descriptors = dict()
        if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin)):
            quant_descriptors = {"quant_desc_input": metadata.input_quant_descriptor or self._get_default_quant_descriptor(for_weights=False)}
        if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin)):
            quant_descriptors.update({"quant_desc_weight": metadata.weights_quant_descriptor or self._get_default_quant_descriptor(for_weights=True)})
        if not hasattr(metadata.quantized_target_class, "from_float"):
            assert isinstance(metadata.quantized_target_class, SGQuantMixin), (
                f"{metadata.quantized_target_class.__name__} must inherit from " f"{SGQuantMixin.__name__}, so that it would include `from_float` class method"
        q_instance = metadata.quantized_target_class.from_float(float_module, **quant_descriptors)
        # MOVE TENSORS TO ORIGINAL DEVICE
        if len(list(float_module.parameters(recurse=False))) > 0:
            q_instance = q_instance.to(next(float_module.parameters(recurse=False)).device)
        elif len(list(float_module.buffers(recurse=False))):
            q_instance = q_instance.to(next(float_module.buffers(recurse=False)).device)
        # COPY STATE DICT IF NEEDED
        if preserve_state_dict:
            # quant state dict may have additional parameters for Clip and strict loading will fail
            # if we find at least one Clip module in q_instance, disable strict loading and hope for the best
            strict_load = True
            for k in q_instance.state_dict().keys():
                if "clip.clip_value_max" in k or "clip.clip_value_min" in k:
                    strict_load = False
                    logger.debug(
                        "Instantiating quant module in non-strict mode leaving Clip parameters non-initilaized. Use QuantizationCalibrator to initialize them."
                    break
            q_instance.load_state_dict(float_module.state_dict(), strict=strict_load)
        return q_instance
    def _maybe_quantize_one_layer(
        self,
        module: nn.Module,
        child_name: str,
        nesting: Tuple[str, ...],
        child_module: nn.Module,
        mapping_instructions: Dict[Union[str, Type], QuantizedMetadata],
        preserve_state_dict: bool,
    ) -> bool:
        Does the heavy lifting of (maybe) quantizing a layer: creates a quantized instance based on a float instance,
        and replaces it in the "parent" module
        :param module:                  the module we'd like to quantize a specific layer in
        :param child_name:              the attribute name of the layer in the module
        :param nesting:                 the current nesting we're in. Needed to find the appropriate key in the mappings
        :param child_module:            the instance of the float module we'd like to quantize
        :param mapping_instructions:    mapping instructions: how to quantize
        :param preserve_state_dict:     whether to copy the state dict from the float instance to the quantized instance
        :return: a boolean indicates if we found a match and should not continue recursively
        # if we don't have any instruction for the specific layer or the specific type - we continue
        # NOTE! IT IS IMPORTANT TO FIRST PROCESS THE NAME AND ONLY THEN THE TYPE
        if _imported_pytorch_quantization_failure is not None:
            raise _imported_pytorch_quantization_failure
        for candidate_key in (".".join(nesting + (child_name,)), type(child_module)):
            if candidate_key not in mapping_instructions:
                continue
            metadata: QuantizedMetadata = mapping_instructions[candidate_key]
            if metadata.action == QuantizedMetadata.ReplacementAction.SKIP:
                return True
            elif metadata.action == QuantizedMetadata.ReplacementAction.UNWRAP:
                assert isinstance(child_module, SkipQuantization)
                setattr(module, child_name, child_module.float_module)
                return True
            elif metadata.action in (
                QuantizedMetadata.ReplacementAction.REPLACE,
                QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
                QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
                if isinstance(child_module, QuantizedMapping):  # UNWRAP MAPPING
                    child_module = child_module.float_module
                q_instance: nn.Module = self._instantiate_quantized_from_float(
                    float_module=child_module, metadata=metadata, preserve_state_dict=preserve_state_dict
                # ACTUAL REPLACEMENT
                def replace():
                    setattr(module, child_name, q_instance)
                def recurse_quantize():
                    self._quantize_module_aux(
                        module=getattr(module, child_name),
                        mapping_instructions=mapping_instructions,
                        nesting=nesting + (child_name,),
                        preserve_state_dict=preserve_state_dict,
                if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
                    replace()
                elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
                    replace()
                    recurse_quantize()
                elif metadata.action == QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE:
                    recurse_quantize()
                    replace()
                return True
            else:
                raise NotImplementedError
        return False
    def quantize_module(self, module: nn.Module, *, preserve_state_dict=True):
        per_layer_mappings = self._preprocess_skips_and_custom_mappings(module)
        mapping_instructions = {
            **per_layer_mappings,
            **self.mapping_instructions,
        }  # we first regard the per layer mappings, and then override with the custom mappings in case there is overlap
        logging_level = logging.getLogger("absl").getEffectiveLevel()
        if not self.verbose:  # suppress pytorch-quantization spam
            logging.getLogger("absl").setLevel("ERROR")
        device = next(module.parameters()).device
        self._quantize_module_aux(mapping_instructions=mapping_instructions, module=module, nesting=(), preserve_state_dict=preserve_state_dict)
        module.to(device)
        logging.getLogger("absl").setLevel(logging_level)
    def _quantize_module_aux(self, mapping_instructions, module, nesting, preserve_state_dict):
        for name, child_module in module.named_children():
            found = self._maybe_quantize_one_layer(module, name, nesting, child_module, mapping_instructions, preserve_state_dict)
            # RECURSIVE CALL, to support module_list, sequential, custom (nested) modules
            if not found and isinstance(child_module, nn.Module):
                self._quantize_module_aux(mapping_instructions, child_module, nesting + (name,), preserve_state_dict)
        Source code in V3_2/src/super_gradients/training/utils/quantization/selective_quantization_utils.py
53
def register_quantized_module(
    float_source: Union[str, Type[nn.Module]],
    action: QuantizedMetadata.ReplacementAction = QuantizedMetadata.ReplacementAction.REPLACE,
    input_quant_descriptor: Optional[QuantDescriptor] = None,
    weights_quant_descriptor: Optional[QuantDescriptor] = None,
) -> Callable:
    Decorator used to register a Quantized module as a quantized version for Float module
    :param action:                      action to perform on the float_source
    :param float_source:                the float module type that is being registered
    :param input_quant_descriptor:      the input quantization descriptor
    :param weights_quant_descriptor:    the weight quantization descriptor
    def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
        if float_source in SelectiveQuantizer.mapping_instructions:
            metadata = SelectiveQuantizer.mapping_instructions[float_source]
            raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")
        SelectiveQuantizer.mapping_instructions.update(
                float_source: QuantizedMetadata(
                    float_source=float_source,
                    quantized_target_class=quant_module,
                    input_quant_descriptor=input_quant_descriptor,
                    weights_quant_descriptor=weights_quant_descriptor,
                    action=action,
        return quant_module  # this is required since the decorator assigns the result to the `quant_module`
    return decorator
      

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

Intended usage of this block is the following:

class ResNetBlock(nn.Module): def init(self, ..., drop_path_rate:float): self.drop_path = DropPath(drop_path_rate)

def forward(self, x): return x + self.drop_path(self.conv_bn_act(x))

Code taken from TIMM (https://github.com/rwightman/pytorch-image-models) Apache License 2.0

Source code in V3_2/src/super_gradients/training/utils/regularization_utils.py 52
class DropPath(nn.Module):
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    Intended usage of this block is the following:
    >>> class ResNetBlock(nn.Module):
    >>>   def __init__(self, ..., drop_path_rate:float):
    >>>     self.drop_path = DropPath(drop_path_rate)
    >>>   def forward(self, x):
    >>>     return x + self.drop_path(self.conv_bn_act(x))
    Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)
    Apache License 2.0
    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        :param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
        :param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
                              keep output mean & std in the same range as w/o drop path.
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep
    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        return drop_path(x, self.drop_prob, self.scale_by_keep)
    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"
                float
          

Probability of zeroing out individual vector (channel dimension) of each feature map

scale_by_keep

Whether to scale the output by the keep probability. Enable by default and helps to keep output mean & std in the same range as w/o drop path.

43
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
    :param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
    :param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
                          keep output mean & std in the same range as w/o drop path.
    super(DropPath, self).__init__()
    self.drop_prob = drop_prob
    self.scale_by_keep = scale_by_keep
14
def drop_path(x, drop_prob: float = 0.0, scale_by_keep: bool = True):
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor
136
class BinarySegmentationVisualization:
    @staticmethod
    def _visualize_image(image_np: np.ndarray, pred_mask: torch.Tensor, target_mask: torch.Tensor, image_scale: float, checkpoint_dir: str, image_name: str):
        pred_mask = pred_mask.copy()
        image_np = torch.from_numpy(np.moveaxis(image_np, -1, 0).astype(np.uint8))
        pred_mask = pred_mask[np.newaxis, :, :] > 0.5
        target_mask = target_mask[np.newaxis, :, :].astype(bool)
        tp_mask = np.logical_and(pred_mask, target_mask)
        fp_mask = np.logical_and(pred_mask, np.logical_not(target_mask))
        fn_mask = np.logical_and(np.logical_not(pred_mask), target_mask)
        overlay = torch.from_numpy(np.concatenate([tp_mask, fp_mask, fn_mask]))
        # SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING
        colors = ["green", "red", "blue"]
        res_image = draw_segmentation_masks(image_np, overlay, colors=colors).detach().numpy()
        res_image = np.concatenate([res_image[ch, :, :, np.newaxis] for ch in range(3)], 2)
        res_image = cv2.resize(res_image.astype(np.uint8), (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)
        if checkpoint_dir is None:
            return res_image
        else:
            cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), res_image)
    @staticmethod
    def visualize_batch(
        image_tensor: torch.Tensor,
        pred_mask: torch.Tensor,
        target_mask: torch.Tensor,
        batch_name: Union[int, str],
        checkpoint_dir: str = None,
        undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing,
        image_scale: float = 1.0,
        A helper function to visualize detections predicted by a network:
        saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
        Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
        :param image_tensor:            rgb images, (B, H, W, 3)
        :param pred_boxes:              boxes after NMS for each image in a batch, each (Num_boxes, 6),
                                        values on dim 1 are: x1, y1, x2, y2, confidence, class
        :param target_boxes:            (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
                                        (coordinates scaled to [0, 1])
        :param batch_name:              id of the current batch to use for image naming
        :param checkpoint_dir:          a path where images with boxes will be saved. if None, the result images will
                                        be returns as a list of numpy image arrays
        :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
        :param image_scale:             scale factor for output image
        image_np = undo_preprocessing_func(image_tensor.detach())
        pred_mask = torch.sigmoid(pred_mask[:, 0, :, :])  # comment out
        out_images = []
        for i in range(image_np.shape[0]):
            preds = pred_mask[i].detach().cpu().numpy()
            targets = target_mask[i].detach().cpu().numpy()
            image_name = "_".join([str(batch_name), str(i)])
            res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name)
            if res_image is not None:
                out_images.append(res_image)
        return out_images
visualize_batch(image_tensor, pred_mask, target_mask, batch_name, checkpoint_dir=None, undo_preprocessing_func=reverse_imagenet_preprocessing, image_scale=1.0)
      staticmethod
      

A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.

Parameters:

Description Default

boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class

required target_boxes

(Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1])

required batch_name Union[int, str]

id of the current batch to use for image naming

required checkpoint_dir

a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays

undo_preprocessing_func Callable[[torch.Tensor], np.ndarray]

a function to convert preprocessed images tensor into a batch of cv2-like images

reverse_imagenet_preprocessing image_scale float

scale factor for output image

image_tensor: torch.Tensor, pred_mask: torch.Tensor, target_mask: torch.Tensor, batch_name: Union[int, str], checkpoint_dir: str = None, undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing, image_scale: float = 1.0, A helper function to visualize detections predicted by a network: saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call. Colors are generated on the fly: uniformly sampled from color wheel to support all given classes. :param image_tensor: rgb images, (B, H, W, 3) :param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class :param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1]) :param batch_name: id of the current batch to use for image naming :param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays :param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images :param image_scale: scale factor for output image image_np = undo_preprocessing_func(image_tensor.detach()) pred_mask = torch.sigmoid(pred_mask[:, 0, :, :]) # comment out out_images = [] for i in range(image_np.shape[0]): preds = pred_mask[i].detach().cpu().numpy() targets = target_mask[i].detach().cpu().numpy() image_name = "_".join([str(batch_name), str(i)]) res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name) if res_image is not None: out_images.append(res_image) return out_images

kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1

required flatten_channels

Whether to apply logical_or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True.

191
def one_hot_to_binary_edge(x: torch.Tensor, kernel_size: int, flatten_channels: bool = True) -> torch.Tensor:
    Utils function to create edge feature maps.
    :param x: input tensor, must be one_hot tensor with shape [B, C, H, W]
    :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
        follows: `edge_width = kernel - 1`
    :param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is
        considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
        [B, 1, H, W]. Default is `True`.
    :return: one_hot edge torch.Tensor.
    if kernel_size < 0 or kernel_size % 2 == 0:
        raise ValueError(f"kernel size must be an odd positive values, such as [1, 3, 5, ..], found: {kernel_size}")
    _kernel = torch.ones(x.size(1), 1, kernel_size, kernel_size, dtype=torch.float32, device=x.device)
    padding = (kernel_size - 1) // 2
    # Use replicate padding to prevent class shifting and edge formation at the image boundaries.
    padded_x = F.pad(x.float(), mode="replicate", pad=[padding] * 4)
    # The binary edges feature map is created by subtracting dilated features from erosed features.
    # First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values.
    # The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by
    # (kernel_size - 1) / 2.
    dilation = torch.clamp(F.conv2d(padded_x, _kernel, groups=x.size(1)), 0, 1)
    # Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by
    # applying a dilation operation on the inverse of the one-hot features.
    erosion = 1 - torch.clamp(F.conv2d(1 - padded_x, _kernel, groups=x.size(1)), 0, 1)
    # Finally the edge features are the result of subtracting dilation by erosion.
    # i.e for a simple 1D one-hot input:    [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1]
    # Dilated features:                     [0, 0, 1, 1, 1, 1, 1, 0, 0]
    # Erosed inverse features:              [0, 0, 0, 0, 1, 0, 0, 0, 0]
    # Edge features: dilation - erosion:    [0, 0, 1, 1, 0, 1, 1, 0, 0]
    edge = dilation - erosion
    if flatten_channels:
        # use max operator across channels. Equivalent to logical or for input with binary values [0, 1].
        edge = edge.max(dim=1, keepdim=True)[0]
    return edge
68
def reverse_imagenet_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
    :param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
    :return:          images in a batch in cv2 format, BGR, (B, H, W, C)
    im_np = im_tensor.cpu().numpy()
    im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
    im_np *= np.array([[[0.229, 0.224, 0.225][::-1]]])
    im_np += np.array([[[0.485, 0.456, 0.406][::-1]]])
    im_np *= 255.0
    return np.ascontiguousarray(im_np, dtype=np.uint8)
          

num of classes in datasets excluding ignore label, this is the output channels of the one hot result.

required kernel_size

kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1

required flatten_channels

Whether to apply logical or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True.

208
def target_to_binary_edge(target: torch.Tensor, num_classes: int, kernel_size: int, ignore_index: int = None, flatten_channels: bool = True) -> torch.Tensor:
    Utils function to create edge feature maps from target.
    :param target: Class labels long tensor, with shape [N, H, W]
    :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
        result.
    :param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
        follows: `edge_width = kernel - 1`
    :param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is
        considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
        [B, 1, H, W]. Default is `True`.
    :return: one_hot edge torch.Tensor.
    one_hot = to_one_hot(target, num_classes=num_classes, ignore_index=ignore_index)
    return one_hot_to_binary_edge(one_hot, kernel_size=kernel_size, flatten_channels=flatten_channels)
          

num of classes in datasets excluding ignore label, this is the output channels of the one hot result.

required 55
def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None):
    Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.
    :param target: Class labels long tensor, with shape [N, H, W]
    :param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
        result.
    :return: one hot tensor with shape [N, num_classes, H, W]
    num_classes = num_classes if ignore_index is None else num_classes + 1
    one_hot = F.one_hot(target, num_classes).permute((0, 3, 1, 2))
    if ignore_index is not None:
        # remove ignore_index channel
        one_hot = torch.cat([one_hot[:, :ignore_index], one_hot[:, ignore_index + 1 :]], dim=1)
    return one_hot
      

Type of improvement compared to previous value, i.e. if the value is better, worse or the same.

Difference with "increase": If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement). For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.

Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py 75
class ImprovementType(Enum):
    """Type of improvement compared to previous value, i.e. if the value is better, worse or the same.
    Difference with "increase":
        If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
        For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
    IS_BETTER = "better"
    IS_WORSE = "worse"
    IS_SAME = "same"
    NONE = "none"
    def to_color(self) -> Union[str, None]:
        """Get the color representing the current improvement type"""
        if self == ImprovementType.IS_SAME:
            return "white"
        elif self == ImprovementType.IS_BETTER:
            return "green"
        elif self == ImprovementType.IS_WORSE:
            return "red"
        else:
            return None
75
def to_color(self) -> Union[str, None]:
    """Get the color representing the current improvement type"""
    if self == ImprovementType.IS_SAME:
        return "white"
    elif self == ImprovementType.IS_BETTER:
        return "green"
    elif self == ImprovementType.IS_WORSE:
        return "red"
    else:
        return None
      

Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.

Difference with "improvement": If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement). For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.

Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py 50
class IncreaseType(Enum):
    """Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.
    Difference with "improvement":
        If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
        For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
    NONE = "none"
    IS_GREATER = "greater"
    IS_SMALLER = "smaller"
    IS_EQUAL = "equal"
    def to_symbol(self) -> str:
        """Get the symbol representing the current increase type"""
        if self == IncreaseType.NONE:
            return ""
        elif self == IncreaseType.IS_GREATER:
            return "↗"
        elif self == IncreaseType.IS_SMALLER:
            return "↘"
        else:
            return "="
50
def to_symbol(self) -> str:
    """Get the symbol representing the current increase type"""
    if self == IncreaseType.NONE:
        return ""
    elif self == IncreaseType.IS_GREATER:
        return "↗"
    elif self == IncreaseType.IS_SMALLER:
        return "↘"
    else:
        return "="
      

Store a value and some indicators relative to its past iterations.

The value can be a metric/loss, and the iteration can be epochs/batch.

Parameters:

Description Default Optional[bool]

True, a greater value is considered better. ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4 ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4 None when unknown

current Optional[float]

Current value of the metric

previous Optional[float]

Value of the metric in previous iteration

Optional[float]

Value of the metric in best iteration (best according to greater_is_better)

change_from_previous Optional[float]

Change compared to previous iteration value

change_from_best Optional[float]

Change compared to best iteration value

148
@dataclass
class MonitoredValue:
    """Store a value and some indicators relative to its past iterations.
    The value can be a metric/loss, and the iteration can be epochs/batch.
    :param name:                    Name of the metric
    :param greater_is_better:       True, a greater value is considered better.
                                      ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4
                                      ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4
                                    None when unknown
    :param current:                 Current value of the metric
    :param previous:                Value of the metric in previous iteration
    :param best:                    Value of the metric in best iteration (best according to greater_is_better)
    :param change_from_previous:    Change compared to previous iteration value
    :param change_from_best:        Change compared to best iteration value
    name: str
    greater_is_better: Optional[bool] = None
    current: Optional[float] = None
    previous: Optional[float] = None
    best: Optional[float] = None
    change_from_previous: Optional[float] = None
    change_from_best: Optional[float] = None
    @property
    def has_increased_from_previous(self) -> IncreaseType:
        """Type of increase compared to previous value, i.e. if the value is greater, smaller or the same."""
        return self._get_increase_type(self.change_from_previous)
    @property
    def has_improved_from_previous(self) -> ImprovementType:
        """Type of improvement compared to previous value, i.e. if the value is better, worse or the same."""
        return self._get_improvement_type(delta=self.change_from_previous)
    @property
    def has_increased_from_best(self) -> IncreaseType:
        """Type of increase compared to best value, i.e. if the value is greater, smaller or the same."""
        return self._get_increase_type(self.change_from_best)
    @property
    def has_improved_from_best(self) -> ImprovementType:
        """Type of improvement compared to best value, i.e. if the value is better, worse or the same."""
        return self._get_improvement_type(delta=self.change_from_best)
    def _get_increase_type(self, delta: float) -> IncreaseType:
        """Type of increase, i.e. if the value is greater, smaller or the same."""
        if self.change_from_best is None:
            return IncreaseType.NONE
        if delta > 0:
            return IncreaseType.IS_GREATER
        elif delta < 0:
            return IncreaseType.IS_SMALLER
        else:
            return IncreaseType.IS_EQUAL
    def _get_improvement_type(self, delta: float) -> ImprovementType:
        """Type of improvement, i.e. if value is better, worse or the same."""
        if self.greater_is_better is None or self.change_from_best is None:
            return ImprovementType.NONE
        has_increased, has_decreased = delta > 0, delta < 0
        if has_increased and self.greater_is_better or has_decreased and not self.greater_is_better:
            return ImprovementType.IS_BETTER
        elif has_increased and not self.greater_is_better or has_decreased and self.greater_is_better:
            return ImprovementType.IS_WORSE
        else:
            return ImprovementType.IS_SAME
350
def add_log_to_file(filename, results_titles_list, results_values_list, epoch, max_epochs):
    """Add a message to the log file"""
    # -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes
    with open(filename, "a") as f:
        f.write("\nEpoch (%d/%d)  - " % (epoch, max_epochs))
        for result_title, result_value in zip(results_titles_list, results_values_list):
            if isinstance(result_value, torch.Tensor):
                result_value = result_value.item()
            f.write(result_title + ": " + str(result_value) + "\t")
                Dict[str, Dict[str, MonitoredValue]]
          

Dict of Dict. The first one represents the splut, and the second one a loss/metric.

required 254
def display_epoch_summary(epoch: int, n_digits: int, monitored_values_dict: Dict[str, Dict[str, MonitoredValue]]) -> None:
    """Display a summary of loss/metric of interest, for a given epoch.
    :param epoch: the number of epoch.
    :param n_digits: number of digits to display on screen for float values
    :param monitored_values_dict: Dict of Dict. The first one represents the splut, and the second one a loss/metric.
    def _format_to_str(val: Optional[float]) -> str:
        return str(round(val, n_digits)) if val is not None else "None"
    def _generate_tree(value_name: str, monitored_value: MonitoredValue) -> Tree:
        """Generate a tree that represents the stats of a given loss/metric."""
        current = _format_to_str(monitored_value.current)
        root_id = str(hash(f"{value_name} = {current}")) + str(random.random())
        tree = Tree()
        tree.create_node(tag=f"{value_name.capitalize()} = {current}", identifier=root_id)
        if monitored_value.previous is not None:
            previous = _format_to_str(monitored_value.previous)
            best = _format_to_str(monitored_value.best)
            change_from_previous = _format_to_str(monitored_value.change_from_previous)
            change_from_best = _format_to_str(monitored_value.change_from_best)
            diff_with_prev_colored = colored(
                text=f"{monitored_value.has_increased_from_previous.to_symbol()} {change_from_previous}",
                color=monitored_value.has_improved_from_previous.to_color(),
            diff_with_best_colored = colored(
                text=f"{monitored_value.has_increased_from_best.to_symbol()} {change_from_best}", color=monitored_value.has_improved_from_best.to_color()
            tree.create_node(tag=f"Epoch N-1      = {previous:6} ({diff_with_prev_colored:8})", identifier=f"0_previous_{root_id}", parent=root_id)
            tree.create_node(tag=f"Best until now = {best:6} ({diff_with_best_colored:8})", identifier=f"1_best_{root_id}", parent=root_id)
        return tree
    summary_tree = Tree()
    summary_tree.create_node(f"SUMMARY OF EPOCH {epoch}", "Summary")
    for split, monitored_values in monitored_values_dict.items():
        if len(monitored_values):
            split_tree = Tree()
            split_tree.create_node(split, split)
            for name, value in monitored_values.items():
                split_tree.paste(split, new_tree=_generate_tree(name, monitored_value=value))
            summary_tree.paste("Summary", split_tree)
    print("===========================================================")
    summary_tree.show(key=False)
    print("===========================================================")
446
def get_callable_param_names(obj: callable) -> Tuple[str]:
    """Get the param names of a given callable (function, class, ...)
    :param obj: Object to inspect
    :return: Param names of that object
    return tuple(inspect.signature(obj).parameters)
339
def init_summary_writer(tb_dir, checkpoint_loaded, user_prompt=False):
    """Remove previous tensorboard files from directory and launch a tensor board process"""
    # If the training is from scratch, Walk through destination folder and delete existing tensorboard logs
    user = ""
    if not checkpoint_loaded:
        for filename in os.listdir(tb_dir):
            if "events" in filename:
                if not user_prompt:
                    logger.debug('"{}" will not be deleted'.format(filename))
                    continue
                while True:
                    # Verify with user before deleting old tensorboard files
                    user = (
                        input('\nOLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:\n"{}"\n' "DO YOU WANT TO DELETE THEM? [y/n]".format(filename))
                        if (user != "n" or user != "y")
                        else user
                    if user == "y":
                        os.remove("{}/{}".format(tb_dir, filename))
                        print("DELETED: {}!".format(filename))
                        break
                    elif user == "n":
                        print('"{}" will not be deleted'.format(filename))
                        break
                    print("Unknown answer...")
    # Launch a tensorboard process
    return SummaryWriter(tb_dir)
      

launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them unless port is defined by the user :param checkpoints_dir_path: :param sleep_postpone: :param port: :return: tuple of tb process, port

Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py 308
def launch_tensorboard_process(checkpoints_dir_path: str, sleep_postpone: bool = True, port: int = None) -> Tuple[Process, int]:
    launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them
                                 unless port is defined by the user
        :param checkpoints_dir_path:
        :param sleep_postpone:
        :param port:
        :return: tuple of tb process, port
    logdir_path = str(Path(checkpoints_dir_path).parent.absolute())
    tb_cmd = "tensorboard --logdir=" + logdir_path + " --bind_all"
    if port is not None:
        tb_ports = [port]
    else:
        tb_ports = range(6006, 6016)
    for tb_port in tb_ports:
        if not try_port(tb_port):
            continue
        else:
            print("Starting Tensor-Board process on port: " + str(tb_port))
            tensor_board_process = Process(target=os.system, args=([tb_cmd + " --port=" + str(tb_port)]))
            tensor_board_process.daemon = True
            tensor_board_process.start()
            # LET THE TENSORBOARD PROCESS START
            if sleep_postpone:
                time.sleep(3)
            return tensor_board_process, tb_port
    # RETURNING IRRELEVANT VALUES
    print("Failed to initialize Tensor-Board process on port: " + ", ".join(map(str, tb_ports)))
    return None, -1
465
def log_main_training_params(
    multi_gpu: MultiGPUMode, num_gpus: int, batch_size: int, batch_accumulate: int, train_dataset_length: int, train_dataloader_len: int
    """Log training parameters"""
    msg = (
        "TRAINING PARAMETERS:\n"
        f"    - Mode:                         {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
        f"    - Number of GPUs:               {num_gpus if 'cuda' in device_config.device  else 0:<10} ({torch.cuda.device_count()} available on the machine)\n"
        f"    - Dataset size:                 {train_dataset_length:<10} (len(train_set))\n"
        f"    - Batch size per GPU:           {batch_size:<10} (batch_size)\n"
        f"    - Batch Accumulate:             {batch_accumulate:<10} (batch_accumulate)\n"
        f"    - Total batch size:             {num_gpus * batch_size:<10} (num_gpus * batch_size)\n"
        f"    - Effective Batch size:         {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n"
        f"    - Iterations per epoch:         {int(train_dataloader_len):<10} (len(train_loader))\n"
        f"    - Gradient updates per epoch:   {int(train_dataloader_len / batch_accumulate):<10} (len(train_loader) / batch_accumulate)\n"
    logger.info(msg)
    def log_exceptook(excepthook: Callable) -> Callable:
        """Wrapping function that logs exceptions that are not KeyboardInterrupt"""
        def handle_exception(exc_type, exc_value, exc_traceback):
            if not issubclass(exc_type, KeyboardInterrupt):
                logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
            excepthook(exc_type, exc_value, exc_traceback)
            return
        return handle_exception
    sys.excepthook = log_exceptook(sys.excepthook)
      

parse args from a config. unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature

Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py 438
def parse_args(cfg, arg_names: Union[Sequence[str], callable]) -> dict:
    parse args from a config.
    unlike get_param(), in this case only parameters that appear in the config will override default params from the function's signature
    if not isinstance(arg_names, Sequence):
        arg_names = get_callable_param_names(arg_names)
    kwargs_dict = {}
    for arg_name in arg_names:
        if hasattr(cfg, arg_name) and getattr(cfg, arg_name) is not None:
            kwargs_dict[arg_name] = getattr(cfg, arg_name)
    return kwargs_dict
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    is_port_available = False
    try:
        sock.bind(("localhost", port))
        is_port_available = True
    except Exception as ex:
        print("Port " + str(port) + " is in use" + str(ex))
    sock.close()
    return is_port_available
                Union[tuple, torch.Tensor]
          

(Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of the following formats: 1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2 2. tuple: (inputs, targets, additional_batch_items) where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through the phase context under the attribute additional_batch_item_i_name, using a phase callback.

required 401
def unpack_batch_items(batch_items: Union[tuple, torch.Tensor]):
    Adds support for unpacking batch items in train/validation loop.
    :param batch_items: (Union[tuple, torch.Tensor]) returned by the data loader, which is expected to be in one of
         the following formats:
            1. torch.Tensor or tuple, s.t inputs = batch_items[0], targets = batch_items[1] and len(batch_items) = 2
            2. tuple: (inputs, targets, additional_batch_items)
         where inputs are fed to the network, targets are their corresponding labels and additional_batch_items is a
         dictionary (format {additional_batch_item_i_name: additional_batch_item_i ...}) which can be accessed through
         the phase context under the attribute additional_batch_item_i_name, using a phase callback.
    :return: inputs, target, additional_batch_items
    additional_batch_items = {}
    if len(batch_items) == 2:
        inputs, target = batch_items
    elif len(batch_items) == 3:
        inputs, target, additional_batch_items = batch_items
    else:
        raise UnsupportedBatchItemsFormat(batch_items)
    return inputs, target, additional_batch_items
183
def update_monitored_value(previous_monitored_value: MonitoredValue, new_value: float) -> MonitoredValue:
    """Update the given ValueToMonitor object (could be a loss or a metric) with the new value
    :param previous_monitored_value: The stats about the value that is monitored throughout epochs.
    :param new_value: The value of the current epoch that will be used to update previous_monitored_value
    :return:
    previous_value, previous_best_value = previous_monitored_value.current, previous_monitored_value.best
    name, greater_is_better = previous_monitored_value.name, previous_monitored_value.greater_is_better
    if previous_best_value is None:
        previous_best_value = previous_value
    elif greater_is_better:
        previous_best_value = max(previous_value, previous_best_value)
    else:
        previous_best_value = min(previous_value, previous_best_value)
    if previous_value is None:
        change_from_previous = None
        change_from_best = None
    else:
        change_from_previous = new_value - previous_value
        change_from_best = new_value - previous_best_value
    return MonitoredValue(
        name=name,
        current=new_value,
        previous=previous_value,
        best=previous_best_value,
        change_from_previous=change_from_previous,
        change_from_best=change_from_best,
        greater_is_better=greater_is_better,
200
def update_monitored_values_dict(monitored_values_dict: Dict[str, MonitoredValue], new_values_dict: Dict[str, float]) -> Dict[str, MonitoredValue]:
    """Update the given ValueToMonitor object (could be a loss or a metric) with the new value
    :param monitored_values_dict: Dict mapping value names to their stats throughout epochs.
    :param new_values_dict: Dict mapping value names to their new (i.e. current epoch) value.
    :return: Updated monitored_values_dict
    relevant_keys = set(new_values_dict.keys()).intersection(monitored_values_dict.keys())
    for monitored_value_name in relevant_keys:
        previous_value = monitored_values_dict[monitored_value_name]
        monitored_values_dict[monitored_value_name] = update_monitored_value(
            new_value=new_values_dict[monitored_value_name],
            previous_monitored_value=previous_value,
    return monitored_values_dict
371
def write_hpms(writer, hpmstructs=[], special_conf={}):
    """Stores the training and dataset hyper params in the tensorboard file"""
    hpm_string = ""
    for hpm in hpmstructs:
        for key, val in hpm.__dict__.items():
            hpm_string += "{}: {}  \n  ".format(key, val)
    for key, val in special_conf.items():
        hpm_string += "{}: {}  \n  ".format(key, val)
    writer.add_text("Hyper_parameters", hpm_string)
    writer.flush()
      

Stores the training and validation loss and accuracy for current epoch in a tensorboard file

Source code in