憨厚的麻辣香锅 · 葛蔚-中国科学院大学-UCAS· 3 周前 · |
活泼的黄花菜 · 可怜的小舞被竹青她们戏耍,还好有唐三出来帮她 ...· 1 月前 · |
高大的柚子 · Load a CSV File with ...· 2 月前 · |
捣蛋的手套 · 【大前端】JavaScript ...· 3 月前 · |
朝气蓬勃的茶叶 · 如何用python模拟点击onclick() ...· 6 月前 · |
Returns activation class by its name from torch.nn namespace. This function support all modules available from torch.nn and also their lower-case aliases. On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).
act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01) act = act_cls()
44def get_builtin_activation_type(activation: Union[str, None], **kwargs) -> Type[nn.Module]:
Returns activation class by its name from torch.nn namespace. This function support all modules available from
torch.nn and also their lower-case aliases.
On top of that, it supports a few aliaes: leaky_relu (LeakyReLU), swish (silu).
>>> act_cls = get_activation_type("LeakyReLU", inplace=True, slope=0.01)
>>> act = act_cls()
:param activation: Activation function name (E.g. ReLU). If None - return nn.Identity
:param **kwargs : Extra arguments to pass to constructor during instantiation (E.g. inplace=True)
:returns : Type of the activation function that is ready to be instantiated
if activation is None:
activation_cls = nn.Identity
else:
lowercase_aliases: Dict[str, str] = dict((k.lower(), k) for k in torch.nn.__dict__.keys())
# Register additional aliases
lowercase_aliases["leaky_relu"] = "LeakyReLU" # LeakyRelu in snake_case
lowercase_aliases["swish"] = "SiLU" # Swish shich is equivalent to SiLU
lowercase_aliases["none"] = "Identity"
if activation in lowercase_aliases:
activation = lowercase_aliases[activation]
if activation not in torch.nn.__dict__:
raise KeyError(f"Requested activation function {activation} is not known")
activation_cls = torch.nn.__dict__[activation]
if len(kwargs):
activation_cls = partial(activation_cls, **kwargs)
return activation_cls
29
def batch_distance2bbox(points: Tensor, distance: Tensor, max_shapes: Optional[Tensor] = None) -> Tensor:
"""Decode distance prediction to bounding box for batch.
:param points: [B, ..., 2], "xy" format
:param distance: [B, ..., 4], "ltrb" format
:param max_shapes: [B, 2], "h,w" format, Shape of the image.
:return: Tensor: Decoded bboxes, "x1y1x2y2" format.
lt, rb = torch.split(distance, 2, dim=-1)
# while tensor add parameters, parameters should be better placed on the second place
x1y1 = -lt + points
x2y2 = rb + points
out_bbox = torch.cat([x1y1, x2y2], dim=-1)
if max_shapes is not None:
max_shapes = max_shapes.flip(-1).tile([1, 2])
delta_dim = out_bbox.ndim - max_shapes.ndim
for _ in range(delta_dim):
max_shapes.unsqueeze_(1)
out_bbox = torch.where(out_bbox < max_shapes, out_bbox, max_shapes)
out_bbox = torch.where(out_bbox > 0, out_bbox, torch.zeros_like(out_bbox))
return out_bbox
Base callback class with all the callback methods. Derived classes may override one or many of the available events
to receive callbacks when such events are triggered by the training loop.
The order of the events is as follows:
on_training_start(context) # called once before training starts, good for setting up the warmup LR
for epoch in range(epochs):
on_train_loader_start(context)
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context) # called after loss has been computed
on_train_batch_backward_end(context) # called after .backward() was called
on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients)
on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers)
on_train_batch_end(context)
on_train_loader_end(context)
on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context)
on_validation_loader_end(context)
on_validation_end_best_epoch(context)
on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context)
on_test_end(context)
on_training_end(context) # called once after training ends.
Correspondence mapping from the old callback API:
on_training_start(context) <-> Phase.PRE_TRAINING
for epoch in range(epochs):
on_train_loader_start(context) <-> Phase.TRAIN_EPOCH_START
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context)
on_train_batch_backward_end(context) <-> Phase.TRAIN_BATCH_END
on_train_batch_gradient_step_start(context)
on_train_batch_gradient_step_end(context) <-> Phase.TRAIN_BATCH_STEP
on_train_batch_end(context)
on_train_loader_end(context) <-> Phase.TRAIN_EPOCH_END
on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context) <-> Phase.VALIDATION_BATCH_END
on_validation_loader_end(context) <-> Phase.VALIDATION_EPOCH_END
on_validation_end_best_epoch(context) <-> Phase.VALIDATION_END_BEST_EPOCH
on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context) <-> Phase.TEST_BATCH_END
on_test_end(context) <-> Phase.TEST_END
on_training_end(context) <-> Phase.POST_TRAINING
Source code in V3_2/src/super_gradients/training/utils/callbacks/base_callbacks.py
355
class Callback:
Base callback class with all the callback methods. Derived classes may override one or many of the available events
to receive callbacks when such events are triggered by the training loop.
The order of the events is as follows:
on_training_start(context) # called once before training starts, good for setting up the warmup LR
for epoch in range(epochs):
on_train_loader_start(context)
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context) # called after loss has been computed
on_train_batch_backward_end(context) # called after .backward() was called
on_train_batch_gradient_step_start(context) # called before the optimizer step about to happen (gradient clipping, logging of gradients)
on_train_batch_gradient_step_end(context) # called after gradient step was done, good place to update LR (for step-based schedulers)
on_train_batch_end(context)
on_train_loader_end(context)
on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context)
on_validation_loader_end(context)
on_validation_end_best_epoch(context)
on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context)
on_test_end(context)
on_training_end(context) # called once after training ends.
Correspondence mapping from the old callback API:
on_training_start(context) <-> Phase.PRE_TRAINING
for epoch in range(epochs):
on_train_loader_start(context) <-> Phase.TRAIN_EPOCH_START
for batch in train_loader:
on_train_batch_start(context)
on_train_batch_loss_end(context)
on_train_batch_backward_end(context) <-> Phase.TRAIN_BATCH_END
on_train_batch_gradient_step_start(context)
on_train_batch_gradient_step_end(context) <-> Phase.TRAIN_BATCH_STEP
on_train_batch_end(context)
on_train_loader_end(context) <-> Phase.TRAIN_EPOCH_END
on_validation_loader_start(context)
for batch in validation_loader:
on_validation_batch_start(context)
on_validation_batch_end(context) <-> Phase.VALIDATION_BATCH_END
on_validation_loader_end(context) <-> Phase.VALIDATION_EPOCH_END
on_validation_end_best_epoch(context) <-> Phase.VALIDATION_END_BEST_EPOCH
on_test_start(context)
for batch in test_loader:
on_test_batch_start(context)
on_test_batch_end(context) <-> Phase.TEST_BATCH_END
on_test_end(context) <-> Phase.TEST_END
on_training_end(context) <-> Phase.POST_TRAINING
def on_training_start(self, context: PhaseContext) -> None:
Called once before start of the first epoch
At this point, the context argument is guaranteed to have the following attributes:
- optimizer
- net
- checkpoints_dir_path
- criterion
- sg_logger
- train_loader
- valid_loader
- training_params
- checkpoint_params
- architecture
- arch_params
- metric_to_watch
- device
- ema_model
The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
:param context:
:return:
def on_train_loader_start(self, context: PhaseContext) -> None:
Called each epoch at the start of train data loader (before getting the first batch).
At this point, the context argument is guaranteed to have the following attributes:
- epoch
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
:param context:
:return:
def on_train_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from data loader and moving it to target device.
This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).
At this point the context argument is guaranteed to have the following attributes:
- batch_idx
- inputs
- targets
- **additional_batch_items
:param context:
:return:
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
Called after model forward and loss computation has been done.
At this point the context argument is guaranteed to have the following attributes:
- preds
- loss_log_items
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.
:param context:
:return:
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
Called after loss.backward() method was called for a given batch
:param context:
:return:
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
Called before the graadient step is about to happen.
Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
:param context:
:return:
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
:param context:
:return:
def on_train_batch_end(self, context: PhaseContext) -> None:
Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
:param context:
:return:
def on_train_loader_end(self, context: PhaseContext) -> None:
Called each epoch at the end of train data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
:param context:
:return:
def on_validation_loader_start(self, context: PhaseContext) -> None:
Called each epoch at the start of validation data loader (before getting the first batch).
:param context:
:return:
def on_validation_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from validation loader and moving it to target device.
:param context:
:return:
def on_validation_batch_end(self, context: PhaseContext) -> None:
Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
:param context:
:return:
def on_validation_loader_end(self, context: PhaseContext) -> None:
Called each epoch at the end of validation data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
:param context:
:return:
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
Called each epoch after validation has been performed and the best metric has been achieved.
The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
:param context:
:return:
def on_test_loader_start(self, context: PhaseContext) -> None:
Called once at the start of test data loader (before getting the first batch).
:param context:
:return:
def on_test_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from test loader and moving it to target device.
:param context:
:return:
def on_test_batch_end(self, context: PhaseContext) -> None:
Called after all forward step have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
:param context:
:return:
def on_test_loader_end(self, context: PhaseContext) -> None:
Called once at the end of test data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TEST_END.
:param context:
:return:
def on_training_end(self, context: PhaseContext) -> None:
Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
The corresponding Phase enum value for this event is Phase.POST_TRAINING.
:param context:
:return:
Called after all forward step have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
Parameters:
Description
Default
337
def on_test_batch_end(self, context: PhaseContext) -> None:
Called after all forward step have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.TEST_BATCH_END.
:param context:
:return:
Called at each batch after getting batch of data from test loader and moving it to target device.
Parameters:
Description
Default
328
def on_test_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from test loader and moving it to target device.
:param context:
:return:
Called once at the end of test data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TEST_END.
Parameters:
Description
Default
346
def on_test_loader_end(self, context: PhaseContext) -> None:
Called once at the end of test data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TEST_END.
:param context:
:return:
320
def on_test_loader_start(self, context: PhaseContext) -> None:
Called once at the start of test data loader (before getting the first batch).
:param context:
:return:
229
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
Called after loss.backward() method was called for a given batch
:param context:
:return:
Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
Parameters:
Description
Default
257
def on_train_batch_end(self, context: PhaseContext) -> None:
Called after all forward/backward/optimizer steps have been performed for a given batch and there is nothing left to do.
:param context:
:return:
Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
Parameters:
Description
Default
247
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
Called after gradient step has been performed. Good place to update LR (for step-based schedulers)
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_STEP.
:param context:
:return:
Called before the graadient step is about to happen.
Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
Parameters:
Description
Default
238
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
Called before the graadient step is about to happen.
Good place to clip gradients (with respect to scaler), log gradients to data ratio, etc.
:param context:
:return:
Called after model forward and loss computation has been done.
At this point the context argument is guaranteed to have the following attributes:
- preds
- loss_log_items
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.
Parameters:
Description
Default
220
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
Called after model forward and loss computation has been done.
At this point the context argument is guaranteed to have the following attributes:
- preds
- loss_log_items
The corresponding Phase enum value for this event is Phase.TRAIN_BATCH_END.
:param context:
:return:
Called at each batch after getting batch of data from data loader and moving it to target device.
This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).
At this point the context argument is guaranteed to have the following attributes:
- batch_idx
- inputs
- targets
- **additional_batch_items
Parameters:
Description
Default
206
def on_train_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from data loader and moving it to target device.
This event triggered AFTER Trainer.pre_prediction_callback call (If it was defined).
At this point the context argument is guaranteed to have the following attributes:
- batch_idx
- inputs
- targets
- **additional_batch_items
:param context:
:return:
Called each epoch at the end of train data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
Parameters:
Description
Default
267
def on_train_loader_end(self, context: PhaseContext) -> None:
Called each epoch at the end of train data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_END.
:param context:
:return:
Called each epoch at the start of train data loader (before getting the first batch).
At this point, the context argument is guaranteed to have the following attributes:
- epoch
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
Parameters:
Description
Default
190
def on_train_loader_start(self, context: PhaseContext) -> None:
Called each epoch at the start of train data loader (before getting the first batch).
At this point, the context argument is guaranteed to have the following attributes:
- epoch
The corresponding Phase enum value for this event is Phase.TRAIN_EPOCH_START.
:param context:
:return:
Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
The corresponding Phase enum value for this event is Phase.POST_TRAINING.
Parameters:
Description
Default
355
def on_training_end(self, context: PhaseContext) -> None:
Called once after the training loop has finished (Due to reaching optimization criterion or because of an error.)
The corresponding Phase enum value for this event is Phase.POST_TRAINING.
:param context:
:return:
Called once before start of the first epoch
At this point, the context argument is guaranteed to have the following attributes:
- optimizer
- net
- checkpoints_dir_path
- criterion
- sg_logger
- train_loader
- valid_loader
- training_params
- checkpoint_params
- architecture
- arch_params
- metric_to_watch
- device
- ema_model
The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
Parameters:
Description
Default
179
def on_training_start(self, context: PhaseContext) -> None:
Called once before start of the first epoch
At this point, the context argument is guaranteed to have the following attributes:
- optimizer
- net
- checkpoints_dir_path
- criterion
- sg_logger
- train_loader
- valid_loader
- training_params
- checkpoint_params
- architecture
- arch_params
- metric_to_watch
- device
- ema_model
The corresponding Phase enum value for this event is Phase.PRE_TRAINING.
:param context:
:return:
Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
Parameters:
Description
Default
293
def on_validation_batch_end(self, context: PhaseContext) -> None:
Called after all forward step / loss / metric computation have been performed for a given batch and there is nothing left to do.
The corresponding Phase enum value for this event is Phase.VALIDATION_BATCH_END.
:param context:
:return:
Called at each batch after getting batch of data from validation loader and moving it to target device.
Parameters:
Description
Default
284
def on_validation_batch_start(self, context: PhaseContext) -> None:
Called at each batch after getting batch of data from validation loader and moving it to target device.
:param context:
:return:
Called each epoch after validation has been performed and the best metric has been achieved.
The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
Parameters:
Description
Default
311
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
Called each epoch after validation has been performed and the best metric has been achieved.
The corresponding Phase enum value for this event is Phase.VALIDATION_END_BEST_EPOCH.
:param context:
:return:
Called each epoch at the end of validation data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
Parameters:
Description
Default
302
def on_validation_loader_end(self, context: PhaseContext) -> None:
Called each epoch at the end of validation data loader (after processing the last batch).
The corresponding Phase enum value for this event is Phase.VALIDATION_EPOCH_END.
:param context:
:return:
Called each epoch at the start of validation data loader (before getting the first batch).
Parameters:
Description
Default
276
def on_validation_loader_start(self, context: PhaseContext) -> None:
Called each epoch at the start of validation data loader (before getting the first batch).
:param context:
:return:
def __init__(self, callbacks: List[Callback]):
# TODO: Add reordering of callbacks to make sure that they are called in the right order
# For instance, two callbacks may be dependent on each other, so the first one should be called first
# Example: Gradient Clipping & Gradient Logging callback. We first need to clip the gradients, and then log them
# So if user added them in wrong order we can guarantee their order would be correct.
# We can achieve this by adding a property to the callback to the callback indicating it's priority:
# Forward = 0
# Loss = 100
# Backward = 200
# Metrics = 300
# Scheduler = 400
# Logging = 500
# So ordering callbacks by their order would ensure than we first run all Forward-related callbacks (for a given event),
# Than backward, and only then - logging.
self.callbacks = callbacks
def on_training_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_training_start(context)
def on_train_loader_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_loader_start(context)
def on_train_batch_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_start(context)
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_loss_end(context)
def on_train_batch_backward_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_backward_end(context)
def on_train_batch_gradient_step_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_gradient_step_start(context)
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_gradient_step_end(context)
def on_train_batch_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_batch_end(context)
def on_validation_loader_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_validation_loader_start(context)
def on_validation_batch_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_validation_batch_start(context)
def on_validation_batch_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_validation_batch_end(context)
def on_validation_loader_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_validation_loader_end(context)
def on_train_loader_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_train_loader_end(context)
def on_training_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_training_end(context)
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_validation_end_best_epoch(context)
def on_test_loader_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_test_loader_start(context)
def on_test_batch_start(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_test_batch_start(context)
def on_test_batch_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_test_batch_end(context)
def on_test_loader_end(self, context: PhaseContext) -> None:
for callback in self.callbacks:
callback.on_test_loader_end(context)
Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead.
This callback supports receiving only a subset of events defined in Phase enum:
PRE_TRAINING = "PRE_TRAINING"
TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
TRAIN_BATCH_END = "TRAIN_BATCH_END"
TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
TRAIN_EPOCH_END = "TRAIN_EPOCH_END"
VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"
TEST_BATCH_END = "TEST_BATCH_END"
TEST_END = "TEST_END"
POST_TRAINING = "POST_TRAINING"
Source code in V3_2/src/super_gradients/training/utils/callbacks/base_callbacks.py
429
class PhaseCallback(Callback):
Kept here to keep backward compatibility with old code. New callbacks should use Callback class instead.
This callback supports receiving only a subset of events defined in Phase enum:
PRE_TRAINING = "PRE_TRAINING"
TRAIN_EPOCH_START = "TRAIN_EPOCH_START"
TRAIN_BATCH_END = "TRAIN_BATCH_END"
TRAIN_BATCH_STEP = "TRAIN_BATCH_STEP"
TRAIN_EPOCH_END = "TRAIN_EPOCH_END"
VALIDATION_BATCH_END = "VALIDATION_BATCH_END"
VALIDATION_EPOCH_END = "VALIDATION_EPOCH_END"
VALIDATION_END_BEST_EPOCH = "VALIDATION_END_BEST_EPOCH"
TEST_BATCH_END = "TEST_BATCH_END"
TEST_END = "TEST_END"
POST_TRAINING = "POST_TRAINING"
def __init__(self, phase: Phase):
self.phase = phase
def __call__(self, *args, **kwargs):
raise NotImplementedError
def __repr__(self) -> str:
return self.__class__.__name__
def on_training_start(self, context: PhaseContext) -> None:
if self.phase == Phase.PRE_TRAINING:
self(context)
def on_train_loader_start(self, context: PhaseContext) -> None:
if self.phase == Phase.TRAIN_EPOCH_START:
self(context)
def on_train_batch_loss_end(self, context: PhaseContext) -> None:
if self.phase == Phase.TRAIN_BATCH_END:
self(context)
def on_train_batch_gradient_step_end(self, context: PhaseContext) -> None:
if self.phase == Phase.TRAIN_BATCH_STEP:
self(context)
def on_train_loader_end(self, context: PhaseContext) -> None:
if self.phase == Phase.TRAIN_EPOCH_END:
self(context)
def on_validation_batch_end(self, context: PhaseContext) -> None:
if self.phase == Phase.VALIDATION_BATCH_END:
self(context)
def on_validation_loader_end(self, context: PhaseContext) -> None:
if self.phase == Phase.VALIDATION_EPOCH_END:
self(context)
def on_validation_end_best_epoch(self, context: PhaseContext) -> None:
if self.phase == Phase.VALIDATION_END_BEST_EPOCH:
self(context)
def on_test_batch_end(self, context: PhaseContext) -> None:
if self.phase == Phase.TEST_BATCH_END:
self(context)
def on_test_loader_end(self, context: PhaseContext) -> None:
if self.phase == Phase.TEST_END:
self(context)
def on_training_end(self, context: PhaseContext) -> None:
if self.phase == Phase.POST_TRAINING:
self(context)
88
class PhaseContext:
Represents the input for phase callbacks, and is constantly updated after callback calls.
def __init__(
self,
epoch=None,
batch_idx=None,
optimizer=None,
metrics_dict=None,
inputs=None,
preds=None,
target=None,
metrics_compute_fn=None,
loss_avg_meter=None,
loss_log_items=None,
criterion=None,
device=None,
experiment_name=None,
ckpt_dir=None,
net=None,
lr_warmup_epochs=None,
sg_logger=None,
train_loader=None,
valid_loader=None,
training_params=None,
ddp_silent_mode=None,
checkpoint_params=None,
architecture=None,
arch_params=None,
metric_to_watch=None,
valid_metrics=None,
ema_model=None,
self.epoch = epoch
self.batch_idx = batch_idx
self.optimizer = optimizer
self.inputs = inputs
self.preds = preds
self.target = target
self.metrics_dict = metrics_dict
self.metrics_compute_fn = metrics_compute_fn
self.loss_avg_meter = loss_avg_meter
self.loss_log_items = loss_log_items
self.criterion = criterion
self.device = device
self.stop_training = False
self.experiment_name = experiment_name
self.ckpt_dir = ckpt_dir
self.net = net
self.lr_warmup_epochs = lr_warmup_epochs
self.sg_logger = sg_logger
self.train_loader = train_loader
self.valid_loader = valid_loader
self.training_params = training_params
self.ddp_silent_mode = ddp_silent_mode
self.checkpoint_params = checkpoint_params
self.architecture = architecture
self.arch_params = arch_params
self.metric_to_watch = metric_to_watch
self.valid_metrics = valid_metrics
self.ema_model = ema_model
def update_context(self, **kwargs):
for attr, attr_val in kwargs.items():
setattr(self, attr, attr_val)
377
@register_lr_warmup(LRWarmups.LINEAR_BATCH_STEP)
class BatchStepLinearWarmupLRCallback(Callback):
LR scheduling callback for linear step warmup on each batch step.
LR climbs from warmup_initial_lr with to initial lr.
def __init__(
self,
warmup_initial_lr: float,
initial_lr: float,
train_loader_len: int,
update_param_groups: bool,
lr_warmup_steps: int,
training_params,
net,
**kwargs,
:param warmup_initial_lr: Starting learning rate
:param initial_lr: Target learning rate after warmup
:param train_loader_len: Length of train data loader
:param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
:param kwargs:
super(BatchStepLinearWarmupLRCallback, self).__init__()
if lr_warmup_steps > train_loader_len:
logger.warning(
f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)
self.lr = initial_lr
self.initial_lr = initial_lr
self.update_param_groups = update_param_groups
self.training_params = training_params
self.net = net
self.learning_rates = learning_rates
self.train_loader_len = train_loader_len
self.lr_warmup_steps = lr_warmup_steps
def on_train_batch_start(self, context: PhaseContext) -> None:
global_training_step = context.batch_idx + context.epoch * self.train_loader_len
if global_training_step < self.lr_warmup_steps:
self.lr = float(self.learning_rates[global_training_step])
self.update_lr(context.optimizer, context.epoch, context.batch_idx)
def update_lr(self, optimizer, epoch, batch_idx=None):
Same as in LRCallbackBase
:param optimizer:
:param epoch:
:param batch_idx:
:return:
if self.update_param_groups:
param_groups = unwrap_model(self.net).update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr
Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
required
kwargs
:param warmup_initial_lr: Starting learning rate
:param initial_lr: Target learning rate after warmup
:param train_loader_len: Length of train data loader
:param lr_warmup_steps: Optional. If passed, will use fixed number of warmup steps to warmup LR. Default is None.
:param kwargs:
super(BatchStepLinearWarmupLRCallback, self).__init__()
if lr_warmup_steps > train_loader_len:
logger.warning(
f"Number of warmup steps ({lr_warmup_steps}) is greater than number of steps in epoch ({train_loader_len}). "
f"Warmup steps will be capped to number of steps in epoch to avoid interfering with any pre-epoch LR schedulers."
lr_warmup_steps = min(lr_warmup_steps, train_loader_len)
learning_rates = np.linspace(start=warmup_initial_lr, stop=initial_lr, num=lr_warmup_steps, endpoint=True)
self.lr = initial_lr
self.initial_lr = initial_lr
self.update_param_groups = update_param_groups
self.training_params = training_params
self.net = net
self.learning_rates = learning_rates
self.train_loader_len = train_loader_len
self.lr_warmup_steps = lr_warmup_steps
377
def update_lr(self, optimizer, epoch, batch_idx=None):
Same as in LRCallbackBase
:param optimizer:
:param epoch:
:param batch_idx:
:return:
if self.update_param_groups:
param_groups = unwrap_model(self.net).update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr
A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
Parameters:
Description
Default
667
class BinarySegmentationVisualizationCallback(PhaseCallback):
A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
:param phase: When to trigger the callback.
:param freq: Frequency (in epochs) to perform this callback.
:param batch_idx: Batch index to perform visualization for.
:param last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).
def __init__(self, phase: Phase, freq: int, batch_idx: int = 0, last_img_idx_in_batch: int = -1):
super(BinarySegmentationVisualizationCallback, self).__init__(phase)
self.freq = freq
self.batch_idx = batch_idx
self.last_img_idx_in_batch = last_img_idx_in_batch
def __call__(self, context: PhaseContext):
if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
if isinstance(context.preds, tuple):
preds = context.preds[0].clone()
else:
preds = context.preds.clone()
batch_imgs = BinarySegmentationVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx)
batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
batch_imgs = np.stack(batch_imgs)
tag = "batch_" + str(self.batch_idx) + "_images"
context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")
490
@register_lr_scheduler(LRSchedulers.COSINE)
class CosineLRCallback(LRCallbackBase):
Hard coded step Cosine anealing learning rate scheduling.
def __init__(self, max_epochs, cosine_final_lr_ratio, **kwargs):
super(CosineLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
self.max_epochs = max_epochs
self.cosine_final_lr_ratio = cosine_final_lr_ratio
def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
current_iter = max(0, self.train_loader_len * effective_epoch + context.batch_idx - self.training_params.lr_warmup_steps)
max_iter = self.train_loader_len * effective_max_epochs - self.training_params.lr_warmup_steps
lr = self.compute_learning_rate(current_iter, max_iter, self.initial_lr, self.cosine_final_lr_ratio)
self.lr = float(lr)
self.update_lr(context.optimizer, context.epoch, context.batch_idx)
def is_lr_scheduling_enabled(self, context):
# Account of per-step warmup
if self.training_params.lr_warmup_steps > 0:
current_step = self.train_loader_len * context.epoch + context.batch_idx
return current_step >= self.training_params.lr_warmup_steps
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
@classmethod
def compute_learning_rate(cls, step: Union[float, np.ndarray], total_steps: float, initial_lr: float, final_lr_ratio: float):
# the cosine starts from initial_lr and reaches initial_lr * cosine_final_lr_ratio in last epoch
lr = 0.5 * initial_lr * (1.0 + np.cos(step / (total_steps + 1) * math.pi))
return lr * (1 - final_lr_ratio) + (initial_lr * final_lr_ratio)
221
@register_callback(Callbacks.DECI_LAB_UPLOAD)
class DeciLabUploadCallback(PhaseCallback):
Post-training callback for uploading and optimizing a model.
:param model_meta_data: Model's meta-data object. Type: ModelMetadata
:param optimization_request_form: Optimization request form object. Type: OptimizationRequestForm
:param ckpt_name: Checkpoint filename, inside the checkpoint directory.
def __init__(
self,
model_name: str,
input_dimensions: Sequence[int],
target_hardware_types: "Optional[List[str]]" = None,
target_batch_size: "Optional[int]" = None,
target_quantization_level: "Optional[str]" = None,
ckpt_name: str = "ckpt_best.pth",
**kwargs,
super().__init__(phase=Phase.POST_TRAINING)
self.input_dimensions = input_dimensions
self.model_name = model_name
self.target_hardware_types = target_hardware_types
self.target_batch_size = target_batch_size
self.target_quantization_level = target_quantization_level
self.ckpt_name = ckpt_name
self.platform_client = DeciClient()
@staticmethod
def log_optimization_failed():
logger.info("We couldn't finish your model optimization. Visit https://console.deci.ai for details")
def upload_model(self, model):
This function will upload the trained model to the Deci Lab
:param model: The resulting model from the training process
self.platform_client.upload_model(
model=model,
name=self.model_name,
input_dimensions=self.input_dimensions,
target_hardware_types=self.target_hardware_types,
target_batch_size=self.target_batch_size,
target_quantization_level=self.target_quantization_level,
def get_optimization_status(self, optimized_model_name: str):
This function will do fetch the optimized version of the trained model and check on its benchmark status.
The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
or log about the successful optimization - whichever happens first.
:param optimized_model_name: Optimized model name
:return: Whether or not the optimized model has been benchmarked
def handler(_signum, _frame):
logger.error("Process timed out. Visit https://console.deci.ai for details")
return False
signal.signal(signal.SIGALRM, handler)
signal.alarm(1800)
finished = False
while not finished:
if self.platform_client.is_model_benchmarking(name=optimized_model_name):
time.sleep(30)
else:
finished = True
signal.alarm(0)
return True
def __call__(self, context: PhaseContext) -> None:
This function will attempt to upload the trained model and schedule an optimization for it.
:param context: Training phase context
try:
model = copy.deepcopy(unwrap_model(context.net))
model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
model_state_dict = torch.load(model_state_dict_path)["net"]
model.load_state_dict(state_dict=model_state_dict)
model = model.cpu()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.input_dimensions)
self.upload_model(model=model)
model_name = self.model_name
logger.info(f"Successfully added {model_name} to the model repository")
optimized_model_name = f"{model_name}_1_1"
logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
success = self.get_optimization_status(optimized_model_name=optimized_model_name)
if success:
logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
else:
DeciLabUploadCallback.log_optimization_failed()
except Exception as ex:
DeciLabUploadCallback.log_optimization_failed()
logger.error(ex)
This function will attempt to upload the trained model and schedule an optimization for it.
Parameters:
Description
Default
221
def __call__(self, context: PhaseContext) -> None:
This function will attempt to upload the trained model and schedule an optimization for it.
:param context: Training phase context
try:
model = copy.deepcopy(unwrap_model(context.net))
model_state_dict_path = os.path.join(context.ckpt_dir, self.ckpt_name)
model_state_dict = torch.load(model_state_dict_path)["net"]
model.load_state_dict(state_dict=model_state_dict)
model = model.cpu()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.input_dimensions)
self.upload_model(model=model)
model_name = self.model_name
logger.info(f"Successfully added {model_name} to the model repository")
optimized_model_name = f"{model_name}_1_1"
logger.info("We'll wait for the scheduled optimization to finish. Please don't close this window")
success = self.get_optimization_status(optimized_model_name=optimized_model_name)
if success:
logger.info("Successfully finished your model optimization. Visit https://console.deci.ai for details")
else:
DeciLabUploadCallback.log_optimization_failed()
except Exception as ex:
DeciLabUploadCallback.log_optimization_failed()
logger.error(ex)
This function will do fetch the optimized version of the trained model and check on its benchmark status.
The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
or log about the successful optimization - whichever happens first.
Parameters:
Description
Default
190
def get_optimization_status(self, optimized_model_name: str):
This function will do fetch the optimized version of the trained model and check on its benchmark status.
The status will be checked against the server every 30 seconds and the process will timeout after 30 minutes
or log about the successful optimization - whichever happens first.
:param optimized_model_name: Optimized model name
:return: Whether or not the optimized model has been benchmarked
def handler(_signum, _frame):
logger.error("Process timed out. Visit https://console.deci.ai for details")
return False
signal.signal(signal.SIGALRM, handler)
signal.alarm(1800)
finished = False
while not finished:
if self.platform_client.is_model_benchmarking(name=optimized_model_name):
time.sleep(30)
else:
finished = True
signal.alarm(0)
return True
162
def upload_model(self, model):
This function will upload the trained model to the Deci Lab
:param model: The resulting model from the training process
self.platform_client.upload_model(
model=model,
name=self.model_name,
input_dimensions=self.input_dimensions,
target_hardware_types=self.target_hardware_types,
target_batch_size=self.target_batch_size,
target_quantization_level=self.target_quantization_level,
A callback that adds a visualization of a batch of detection predictions to context.sg_logger
Parameters:
Description
Default
638
@register_callback(Callbacks.DETECTION_VISUALIZATION_CALLBACK)
class DetectionVisualizationCallback(PhaseCallback):
A callback that adds a visualization of a batch of detection predictions to context.sg_logger
:param phase: When to trigger the callback.
:param freq: Frequency (in epochs) to perform this callback.
:param batch_idx: Batch index to perform visualization for.
:param classes: Class list of the dataset.
:param last_img_idx_in_batch: Last image index to add to log. (default=-1, will take entire batch).
def __init__(
self,
phase: Phase,
freq: int,
post_prediction_callback: DetectionPostPredictionCallback,
classes: list,
batch_idx: int = 0,
last_img_idx_in_batch: int = -1,
super(DetectionVisualizationCallback, self).__init__(phase)
self.freq = freq
self.post_prediction_callback = post_prediction_callback
self.batch_idx = batch_idx
self.classes = classes
self.last_img_idx_in_batch = last_img_idx_in_batch
def __call__(self, context: PhaseContext):
if context.epoch % self.freq == 0 and context.batch_idx == self.batch_idx:
# SOME CALCULATIONS ARE IN-PLACE IN NMS, SO CLONE THE PREDICTIONS
preds = (context.preds[0].clone(), None)
preds = self.post_prediction_callback(preds)
batch_imgs = DetectionVisualization.visualize_batch(context.inputs, preds, context.target, self.batch_idx, self.classes)
batch_imgs = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in batch_imgs]
batch_imgs = np.stack(batch_imgs)
tag = "batch_" + str(self.batch_idx) + "_images"
context.sg_logger.add_images(tag=tag, images=batch_imgs[: self.last_img_idx_in_batch], global_step=context.epoch, data_format="NHWC")
LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
initial_lr/(1+warmup_epochs).
Source code in V3_2/src/super_gradients/training/utils/callbacks/callbacks.py
293
@register_lr_warmup(LRWarmups.LINEAR_EPOCH_STEP)
class EpochStepWarmupLRCallback(LRCallbackBase):
LR scheduling callback for linear step warmup. This scheduler uses a whole epoch as single step.
LR climbs from warmup_initial_lr with even steps to initial lr. When warmup_initial_lr is None - LR climb starts from
initial_lr/(1+warmup_epochs).
def __init__(self, **kwargs):
super(EpochStepWarmupLRCallback, self).__init__(Phase.TRAIN_EPOCH_START, **kwargs)
self.warmup_initial_lr = self.training_params.warmup_initial_lr or self.initial_lr / (self.training_params.lr_warmup_epochs + 1)
self.warmup_step_size = (
(self.initial_lr - self.warmup_initial_lr) / self.training_params.lr_warmup_epochs if self.training_params.lr_warmup_epochs > 0 else 0
def perform_scheduling(self, context):
self.lr = self.warmup_initial_lr + context.epoch * self.warmup_step_size
self.update_lr(context.optimizer, context.epoch, None)
def is_lr_scheduling_enabled(self, context):
return self.training_params.lr_warmup_epochs > 0 and self.training_params.lr_warmup_epochs >= context.epoch
429
@register_lr_scheduler(LRSchedulers.EXP)
class ExponentialLRCallback(LRCallbackBase):
Exponential decay learning rate scheduling. Decays the learning rate by `lr_decay_factor` every epoch.
def __init__(self, lr_decay_factor: float, **kwargs):
super().__init__(phase=Phase.TRAIN_BATCH_STEP, **kwargs)
self.lr_decay_factor = lr_decay_factor
def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
current_iter = self.train_loader_len * effective_epoch + context.batch_idx
self.lr = self.initial_lr * self.lr_decay_factor ** (current_iter / self.train_loader_len)
self.update_lr(context.optimizer, context.epoch, context.batch_idx)
def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
520
@register_lr_scheduler(LRSchedulers.FUNCTION)
class FunctionLRCallback(LRCallbackBase):
Hard coded rate scheduling for user defined lr scheduling function.
@deprecated(version="3.2.0", reason="This callback is deprecated and will be removed in future versions.")
def __init__(self, max_epochs, lr_schedule_function, **kwargs):
super(FunctionLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
assert callable(lr_schedule_function), "self.lr_function must be callable"
self.lr_schedule_function = lr_schedule_function
self.max_epochs = max_epochs
def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
self.lr = self.lr_schedule_function(
initial_lr=self.initial_lr,
epoch=effective_epoch,
iter=context.batch_idx,
max_epoch=effective_max_epochs,
iters_per_epoch=self.train_loader_len,
self.update_lr(context.optimizer, context.epoch, context.batch_idx)
532
class IllegalLRSchedulerMetric(Exception):
"""Exception raised illegal combination of training parameters.
:param metric_name: Name of the metric that is not supported.
:param metrics_dict: Dictionary of metrics that are supported.
def __init__(self, metric_name: str, metrics_dict: dict):
self.message = "Illegal metric name: " + metric_name + ". Expected one of metics_dics keys: " + str(metrics_dict.keys())
super().__init__(self.message)
269
@register_callback(Callbacks.LR_CALLBACK_BASE)
class LRCallbackBase(PhaseCallback):
Base class for hard coded learning rate scheduling regimes, implemented as callbacks.
def __init__(self, phase, initial_lr, update_param_groups, train_loader_len, net, training_params, **kwargs):
super(LRCallbackBase, self).__init__(phase)
self.initial_lr = initial_lr
self.lr = initial_lr
self.update_param_groups = update_param_groups
self.train_loader_len = train_loader_len
self.net = net
self.training_params = training_params
def __call__(self, context: PhaseContext, **kwargs):
if self.is_lr_scheduling_enabled(context):
self.perform_scheduling(context)
def is_lr_scheduling_enabled(self, context: PhaseContext):
Predicate that controls whether to perform lr scheduling based on values in context.
:param context: PhaseContext: current phase's context.
:return: bool, whether to apply lr scheduling or not.
raise NotImplementedError
def perform_scheduling(self, context: PhaseContext):
Performs lr scheduling based on values in context.
:param context: PhaseContext: current phase's context.
raise NotImplementedError
def update_lr(self, optimizer, epoch, batch_idx=None):
if self.update_param_groups:
param_groups = unwrap_model(self.net).update_param_groups(
optimizer.param_groups, self.lr, epoch, batch_idx, self.training_params, self.train_loader_len
optimizer.param_groups = param_groups
else:
# UPDATE THE OPTIMIZERS PARAMETER
for param_group in optimizer.param_groups:
param_group["lr"] = self.lr
250
def is_lr_scheduling_enabled(self, context: PhaseContext):
Predicate that controls whether to perform lr scheduling based on values in context.
:param context: PhaseContext: current phase's context.
:return: bool, whether to apply lr scheduling or not.
raise NotImplementedError
258
def perform_scheduling(self, context: PhaseContext):
Performs lr scheduling based on values in context.
:param context: PhaseContext: current phase's context.
raise NotImplementedError
Learning rate scheduler callback.
When passing call a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).
Parameters:
Description
Default
torch.optim.lr_scheduler._LRScheduler
Learning rate scheduler to be called step() with.
required
metric_name
Metric name for ReduceLROnPlateau learning rate scheduler.
phase
Phase
Phase of when to trigger it.
required
563
@register_callback(Callbacks.LR_SCHEDULER)
class LRSchedulerCallback(PhaseCallback):
Learning rate scheduler callback.
When passing __call__ a metrics_dict, with a key=self.metric_name, the value of that metric will monitored
for ReduceLROnPlateau (i.e step(metrics_dict[self.metric_name]).
:param scheduler: Learning rate scheduler to be called step() with.
:param metric_name: Metric name for ReduceLROnPlateau learning rate scheduler.
:param phase: Phase of when to trigger it.
def __init__(self, scheduler: torch.optim.lr_scheduler._LRScheduler, phase: Phase, metric_name: str = None):
super(LRSchedulerCallback, self).__init__(phase)
self.scheduler = scheduler
self.metric_name = metric_name
def __call__(self, context: PhaseContext):
if context.lr_warmup_epochs <= context.epoch:
if self.metric_name and self.metric_name in context.metrics_dict.keys():
self.scheduler.step(context.metrics_dict[self.metric_name])
elif self.metric_name is None:
self.scheduler.step()
else:
raise IllegalLRSchedulerMetric(self.metric_name, context.metrics_dict)
def __repr__(self):
return "LRSchedulerCallback: " + repr(self.scheduler)
305
@register_lr_warmup(LRWarmups.LINEAR_STEP)
class LinearStepWarmupLRCallback(EpochStepWarmupLRCallback):
"""Deprecated, use EpochStepWarmupLRCallback instead"""
def __init__(self, **kwargs):
logger.warning(
f"Parameter {LRWarmups.LINEAR_STEP} has been made deprecated and will be removed in the next SG release. "
f"Please use `{LRWarmups.LINEAR_EPOCH_STEP}` instead."
super(LinearStepWarmupLRCallback, self).__init__(**kwargs)
Pre-training callback that verifies model conversion to onnx given specified conversion parameters.
The model is converted, then inference is applied with onnx runtime.
Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.
Parameters:
Description
Default
113
@register_callback(Callbacks.MODEL_CONVERSION_CHECK)
class ModelConversionCheckCallback(PhaseCallback):
Pre-training callback that verifies model conversion to onnx given specified conversion parameters.
The model is converted, then inference is applied with onnx runtime.
Use this callback with the same args as DeciPlatformCallback to prevent conversion fails at the end of training.
:param model_name: Model's name
:param input_dimensions: Model's input dimensions
:param primary_batch_size: Model's primary batch size
:param opset_version: (default=11)
:param do_constant_folding: (default=True)
:param dynamic_axes: (default={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
:param input_names: (default=["input"])
:param output_names: (default=["output"])
:param rtol: (default=1e-03)
:param atol: (default=1e-05)
def __init__(self, model_name: str, input_dimensions: Sequence[int], primary_batch_size: int, **kwargs):
super(ModelConversionCheckCallback, self).__init__(phase=Phase.PRE_TRAINING)
self.model_name = model_name
self.input_dimensions = input_dimensions
self.primary_batch_size = primary_batch_size
self.opset_version = kwargs.get("opset_version", 10)
self.do_constant_folding = kwargs.get("do_constant_folding", None) if kwargs.get("do_constant_folding", None) else True
self.input_names = kwargs.get("input_names") or ["input"]
self.output_names = kwargs.get("output_names") or ["output"]
self.dynamic_axes = kwargs.get("dynamic_axes") or {"input": {0: "batch_size"}, "output": {0: "batch_size"}}
self.rtol = kwargs.get("rtol", 1e-03)
self.atol = kwargs.get("atol", 1e-05)
def __call__(self, context: PhaseContext):
model = copy.deepcopy(unwrap_model(context.net))
model = model.cpu()
model.eval() # Put model into eval mode
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(input_size=self.input_dimensions)
x = torch.randn(self.primary_batch_size, *self.input_dimensions, requires_grad=False)
tmp_model_path = os.path.join(context.ckpt_dir, self.model_name + "_tmp.onnx")
with torch.no_grad():
torch_out = model(x)
torch.onnx.export(
model, # Model being run
x, # Model input (or a tuple for multiple inputs)
tmp_model_path, # Where to save the model (can be a file or file-like object)
export_params=True, # Store the trained parameter weights inside the model file
opset_version=self.opset_version,
do_constant_folding=self.do_constant_folding,
input_names=self.input_names,
output_names=self.output_names,
dynamic_axes=self.dynamic_axes,
onnx_model = onnx.load(tmp_model_path)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(tmp_model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
ort_outs = ort_session.run(None, ort_inputs)
# TODO: Ideally we don't want to check this but have the certainty of just calling torch_out.cpu()
if isinstance(torch_out, List) or isinstance(torch_out, tuple):
torch_out = torch_out[0]
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out.cpu().numpy(), ort_outs[0], rtol=self.rtol, atol=self.atol)
os.remove(tmp_model_path)
logger.info("Exported model has been tested with ONNXRuntime, and the result looks good!")
598
class PhaseContextTestCallback(PhaseCallback):
A callback that saves the phase context the for testing.
def __init__(self, phase: Phase):
super(PhaseContextTestCallback, self).__init__(phase)
self.context = None
def __call__(self, context: PhaseContext):
self.context = context
452
@register_lr_scheduler(LRSchedulers.POLY)
class PolyLRCallback(LRCallbackBase):
Hard coded polynomial decay learning rate scheduling (i.e at specific milestones).
def __init__(self, max_epochs, **kwargs):
super(PolyLRCallback, self).__init__(Phase.TRAIN_BATCH_STEP, **kwargs)
self.max_epochs = max_epochs
def perform_scheduling(self, context):
effective_epoch = context.epoch - self.training_params.lr_warmup_epochs
effective_max_epochs = self.max_epochs - self.training_params.lr_warmup_epochs - self.training_params.lr_cooldown_epochs
current_iter = (self.train_loader_len * effective_epoch + context.batch_idx) / self.training_params.batch_accumulate
max_iter = self.train_loader_len * effective_max_epochs / self.training_params.batch_accumulate
self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)
self.update_lr(context.optimizer, context.epoch, context.batch_idx)
def is_lr_scheduling_enabled(self, context):
post_warmup_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
return self.training_params.lr_warmup_epochs <= context.epoch < post_warmup_epochs
742
@register_callback(Callbacks.ROBOFLOW_RESULT_CALLBACK)
class RoboflowResultCallback(Callback):
"""Append the training results to a csv file. Be aware that this does not fully overwrite the existing file, just appends."""
def __init__(self, dataset_name: str, output_path: Optional[str] = None):
:param dataset_name: Name of the dataset that was used to train the model.
:param output_path: Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
self.dataset_name = dataset_name
self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")
if self.output_path is None:
raise ValueError("Output path must be specified")
super(RoboflowResultCallback, self).__init__()
@multi_process_safe
def on_training_end(self, context: PhaseContext):
with open(self.output_path, mode="a", newline="") as csv_file:
writer = csv.writer(csv_file)
mAP = context.metrics_dict["[email protected]:0.95"].item()
writer.writerow([self.dataset_name, mAP])
Optional[str]
Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
734
def __init__(self, dataset_name: str, output_path: Optional[str] = None):
:param dataset_name: Name of the dataset that was used to train the model.
:param output_path: Full path to the output csv file. By default, save at 'checkpoint_dir/results.csv'
self.dataset_name = dataset_name
self.output_path = output_path or os.path.join(get_project_checkpoints_dir_path(), "results.csv")
if self.output_path is None:
raise ValueError("Output path must be specified")
super(RoboflowResultCallback, self).__init__()
408
@register_lr_scheduler(LRSchedulers.STEP)
class StepLRCallback(LRCallbackBase):
Hard coded step learning rate scheduling (i.e at specific milestones).
def __init__(self, lr_updates, lr_decay_factor, step_lr_update_freq=None, **kwargs):
super(StepLRCallback, self).__init__(Phase.TRAIN_EPOCH_END, **kwargs)
if step_lr_update_freq and len(lr_updates):
raise ValueError("Only one of [lr_updates, step_lr_update_freq] should be passed to StepLRCallback constructor")
if step_lr_update_freq:
max_epochs = self.training_params.max_epochs - self.training_params.lr_cooldown_epochs
warmup_epochs = self.training_params.lr_warmup_epochs
lr_updates = [
int(np.ceil(step_lr_update_freq * x)) for x in range(1, max_epochs) if warmup_epochs <= int(np.ceil(step_lr_update_freq * x)) < max_epochs
elif self.training_params.lr_cooldown_epochs > 0:
logger.warning("Specific lr_updates were passed along with cooldown_epochs > 0," " cooldown will have no effect.")
self.lr_updates = lr_updates
self.lr_decay_factor = lr_decay_factor
def perform_scheduling(self, context):
num_updates_passed = [x for x in self.lr_updates if x <= context.epoch]
self.lr = self.initial_lr * self.lr_decay_factor ** len(num_updates_passed)
self.update_lr(context.optimizer, context.epoch, None)
def is_lr_scheduling_enabled(self, context):
return self.training_params.lr_warmup_epochs <= context.epoch
Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
Source code in V3_2/src/super_gradients/training/utils/callbacks/callbacks.py
757
class TestLRCallback(PhaseCallback):
Phase callback that collects the learning rates in lr_placeholder at the end of each epoch (used for testing). In
the case of multiple parameter groups (i.e multiple learning rates) the learning rate is collected from the first
one. The phase is VALIDATION_EPOCH_END to ensure all lr updates have been performed before calling this callback.
def __init__(self, lr_placeholder):
super(TestLRCallback, self).__init__(Phase.VALIDATION_EPOCH_END)
self.lr_placeholder = lr_placeholder
def __call__(self, context: PhaseContext):
self.lr_placeholder.append(context.optimizer.param_groups[0]["lr"])
TrainingStageSwitchCallback
A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
It does so by manipulating the objects inside the context.
Parameters:
Description
Default
695
class TrainingStageSwitchCallbackBase(PhaseCallback):
TrainingStageSwitchCallback
A phase callback that is called at a specific epoch (epoch start) to support multi-stage training.
It does so by manipulating the objects inside the context.
:param next_stage_start_epoch: Epoch idx to apply the stage change.
def __init__(self, next_stage_start_epoch: int):
super(TrainingStageSwitchCallbackBase, self).__init__(phase=Phase.TRAIN_EPOCH_START)
self.next_stage_start_epoch = next_stage_start_epoch
def __call__(self, context: PhaseContext):
if context.epoch == self.next_stage_start_epoch:
self.apply_stage_change(context)
def apply_stage_change(self, context: PhaseContext):
This method is called when the callback is fired on the next_stage_start_epoch,
and holds the stage change logic that should be applied to the context's objects.
:param context: PhaseContext, context of current phase
raise NotImplementedError
This method is called when the callback is fired on the next_stage_start_epoch,
and holds the stage change logic that should be applied to the context's objects.
Parameters:
Description
Default
695
def apply_stage_change(self, context: PhaseContext):
This method is called when the callback is fired on the next_stage_start_epoch,
and holds the stage change logic that should be applied to the context's objects.
:param context: PhaseContext, context of current phase
raise NotImplementedError
716
@register_callback(Callbacks.YOLOX_TRAINING_STAGE_SWITCH)
class YoloXTrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
YoloXTrainingStageSwitchCallback
Training stage switch for YoloX training.
Disables mosaic, and manipulates YoloX loss to use L1.
def __init__(self, next_stage_start_epoch: int = 285):
super(YoloXTrainingStageSwitchCallback, self).__init__(next_stage_start_epoch=next_stage_start_epoch)
def apply_stage_change(self, context: PhaseContext):
for transform in context.train_loader.dataset.transforms:
if hasattr(transform, "close"):
transform.close()
iter(context.train_loader)
context.criterion.use_l1 = True
Union[str, Mapping]
Union[str, Mapping], When str: Learning rate scheduling policy, one of ['step','poly','cosine','function']. 'step' refers to constant updates at epoch numbers passed through lr_updates
. Each update decays the learning rate by lr_decay_factor
. 'cosine' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983. The final learning rate ratio is controlled by cosine_final_lr_ratio
training parameter. 'poly' refers to the polynomial decrease: in each epoch iteration self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)
'function' refers to a user-defined learning rate scheduling function, that is passed through lr_schedule_function
. When Mapping, refers to a torch.optim.lr_scheduler.LRScheduler, following the below API: lr_mode = {LR_SCHEDULER_CLASS_NAME: {*LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX) Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step(). For instance, in order to: - Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END - Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...) https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using ReduceLROnPlateau. In any other case this kwarg is ignored. *LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init_.
required
train_loader
DataLoader
DataLoader, the Trainer.train_loader used for training.
required
torch.nn.Module
torch.nn.Module, the Trainer.net used for training.
required
training_params
Mapping
Mapping, Trainer.training_params.
required
update_param_groups
bool, Whether the Trainer.net has a specific way of updaitng its parameter group.
required
optimizer
torch.optim.Optimizer
The optimizer used for training. Will be passed to the LR callback's init (or the torch scheduler's init, depending on the lr_mode value as described above).
required
950
def create_lr_scheduler_callback(
lr_mode: Union[str, Mapping],
train_loader: DataLoader,
net: torch.nn.Module,
training_params: Mapping,
update_param_groups: bool,
optimizer: torch.optim.Optimizer,
) -> PhaseCallback:
Creates the phase callback in charge of LR scheduling, to be used by Trainer.
:param lr_mode: Union[str, Mapping],
When str:
Learning rate scheduling policy, one of ['step','poly','cosine','function'].
'step' refers to constant updates at epoch numbers passed through `lr_updates`. Each update decays the learning rate by `lr_decay_factor`.
'cosine' refers to the Cosine Anealing policy as mentioned in https://arxiv.org/abs/1608.03983.
The final learning rate ratio is controlled by `cosine_final_lr_ratio` training parameter.
'poly' refers to the polynomial decrease: in each epoch iteration `self.lr = self.initial_lr * pow((1.0 - (current_iter / max_iter)), 0.9)`
'function' refers to a user-defined learning rate scheduling function, that is passed through `lr_schedule_function`.
When Mapping, refers to a torch.optim.lr_scheduler._LRScheduler, following the below API:
lr_mode = {LR_SCHEDULER_CLASS_NAME: {**LR_SCHEDULER_KWARGS, "phase": XXX, "metric_name": XXX)
Where "phase" (of Phase type) controls when to call torch.optim.lr_scheduler._LRScheduler.step().
For instance, in order to:
- Update LR on each batch: Use phase: Phase.TRAIN_BATCH_END
- Update LR after each epoch: Use phase: Phase.TRAIN_EPOCH_END
The "metric_name" refers to the metric to watch (See docs for "metric_to_watch" in train(...)
https://docs.deci.ai/super-gradients/docstring/training/sg_trainer.html) when using
ReduceLROnPlateau. In any other case this kwarg is ignored.
**LR_SCHEDULER_KWARGS are simply passed to the torch scheduler's __init__.
:param train_loader: DataLoader, the Trainer.train_loader used for training.
:param net: torch.nn.Module, the Trainer.net used for training.
:param training_params: Mapping, Trainer.training_params.
:param update_param_groups:bool, Whether the Trainer.net has a specific way of updaitng its parameter group.
:param optimizer: The optimizer used for training. Will be passed to the LR callback's __init__
(or the torch scheduler's init, depending on the lr_mode value as described above).
:return: a PhaseCallback instance to be used by Trainer for LR scheduling.
if isinstance(lr_mode, str) and lr_mode in LR_SCHEDULERS_CLS_DICT:
sg_lr_callback_cls = LR_SCHEDULERS_CLS_DICT[lr_mode]
sg_lr_callback = sg_lr_callback_cls(
train_loader_len=len(train_loader),
net=net,
training_params=training_params,
update_param_groups=update_param_groups,
**training_params.to_dict(),
elif isinstance(lr_mode, Mapping) and list(lr_mode.keys())[0] in TORCH_LR_SCHEDULERS:
if update_param_groups:
logger.warning(
"The network's way of updataing (i.e update_param_groups) is not supported with native " "torch lr schedulers and will have no effect."
lr_scheduler_name = list(lr_mode.keys())[0]
torch_scheduler_params = {k: v for k, v in lr_mode[lr_scheduler_name].items() if k != "phase" and k != "metric_name"}
torch_scheduler_params["optimizer"] = optimizer
torch_scheduler = TORCH_LR_SCHEDULERS[lr_scheduler_name](**torch_scheduler_params)
if get_param(lr_mode[lr_scheduler_name], "phase") is None:
raise ValueError("Phase is required argument when working with torch schedulers.")
if lr_scheduler_name == "ReduceLROnPlateau" and get_param(lr_mode[lr_scheduler_name], "metric_name") is None:
raise ValueError("metric_name is required argument when working with ReduceLROnPlateau schedulers.")
sg_lr_callback = LRSchedulerCallback(
scheduler=torch_scheduler, phase=lr_mode[lr_scheduler_name]["phase"], metric_name=get_param(lr_mode[lr_scheduler_name], "metric_name")
else:
raise ValueError(f"Unknown lr_mode: {lr_mode}")
return sg_lr_callback
Source code in V3_2/src/super_gradients/training/utils/callbacks/ppyoloe_switch_callback.py
29
@register_callback(Callbacks.PPYOLOE_TRAINING_STAGE_SWITCH)
class PPYoloETrainingStageSwitchCallback(TrainingStageSwitchCallbackBase):
PPYoloETrainingStageSwitchCallback
Training stage switch for PPYolo training.
It changes static bbox assigner to a task aligned assigned after certain number of epochs passed
def __init__(
self,
static_assigner_end_epoch: int = 30,
super().__init__(next_stage_start_epoch=static_assigner_end_epoch)
def apply_stage_change(self, context: PhaseContext):
from super_gradients.training.losses import PPYoloELoss
if not isinstance(context.criterion, PPYoloELoss):
raise RuntimeError(
f"A criterion must be an instance of PPYoloELoss when using PPYoloETrainingStageSwitchCallback. " f"Got criterion {repr(context.criterion)}"
context.criterion.use_static_assigner = False
278
class MissingPretrainedWeightsException(Exception):
"""Exception raised by unsupported pretrianed model.
:param desc: explanation of the error
def __init__(self, desc):
self.message = "Missing pretrained wights: " + desc
super().__init__(self.message)
Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
:param model_state_dict: the model state_dict
:param source_ckpt: checkpoint dict
:param exclude optional list for excluded layers
:param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
that returns a desired weight for ckpt_val.
:return: renamed checkpoint dict (if possible)
Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py
180
def adapt_state_dict_to_fit_model_layer_names(model_state_dict: dict, source_ckpt: dict, exclude: list = [], solver: callable = None):
Given a model state dict and source checkpoints, the method tries to correct the keys in the model_state_dict to fit
the ckpt in order to properly load the weights into the model. If unsuccessful - returns None
:param model_state_dict: the model state_dict
:param source_ckpt: checkpoint dict
:param exclude optional list for excluded layers
:param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
that returns a desired weight for ckpt_val.
:return: renamed checkpoint dict (if possible)
if "net" in source_ckpt.keys():
source_ckpt = source_ckpt["net"]
model_state_dict_excluded = {k: v for k, v in model_state_dict.items() if not any(x in k for x in exclude)}
new_ckpt_dict = {}
for (ckpt_key, ckpt_val), (model_key, model_val) in zip(source_ckpt.items(), model_state_dict_excluded.items()):
if solver is not None:
ckpt_val = solver(ckpt_key, ckpt_val, model_key, model_val)
if ckpt_val.shape != model_val.shape:
raise ValueError(f"ckpt layer {ckpt_key} with shape {ckpt_val.shape} does not match {model_key}" f" with shape {model_val.shape} in the model")
new_ckpt_dict[model_key] = ckpt_val
return {"net": new_ckpt_dict}
Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
Parameters:
Description
Default
callable with signature (ckpt_key, ckpt_val, model_key, model_val) that returns a desired weight for ckpt_val.
75
def adaptive_load_state_dict(net: torch.nn.Module, state_dict: dict, strict: Union[bool, StrictLoad], solver=None):
Adaptively loads state_dict to net, by adapting the state_dict to net's layer names first.
:param net: (nn.Module) to load state_dict to
:param state_dict: (dict) Checkpoint state_dict
:param strict: (StrictLoad) key matching strictness
:param solver: callable with signature (ckpt_key, ckpt_val, model_key, model_val)
that returns a desired weight for ckpt_val.
:return:
state_dict = state_dict["net"] if "net" in state_dict else state_dict
# This is a backward compatibility fix for checkpoints that were saved with DataParallel/DistributedDataParallel wrapper
# and contains "module." prefix in all keys
# If all keys start with "module.", then we remove it.
if all([key.startswith("module.") for key in state_dict.keys()]):
state_dict = collections.OrderedDict([(key[7:], value) for key, value in state_dict.items()])
try:
strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
net.load_state_dict(state_dict, strict=strict_bool)
except (RuntimeError, ValueError, KeyError) as ex:
if strict == StrictLoad.NO_KEY_MATCHING:
adapted_state_dict = adapt_state_dict_to_fit_model_layer_names(net.state_dict(), state_dict, solver=solver)
net.load_state_dict(adapted_state_dict["net"], strict=True)
elif strict == StrictLoad.KEY_MATCHING:
transfer_weights(net, state_dict)
else:
raise_informative_runtime_error(net.state_dict(), state_dict, ex)
Copy the checkpoint from any supported source to a local destination path
:param local_ckpt_destination_dir: destination where the checkpoint will be saved to
:param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
:param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Modelull URL)
:param path_src: S3 / url
:param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
:return: Path to checkpoint
Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py
135
@explicit_params_validation(validation_type="None")
def copy_ckpt_to_local_folder(
local_ckpt_destination_dir: str,
ckpt_filename: str,
remote_ckpt_source_dir: str = None,
path_src: str = "local",
overwrite_local_ckpt: bool = False,
load_weights_only: bool = False,
Copy the checkpoint from any supported source to a local destination path
:param local_ckpt_destination_dir: destination where the checkpoint will be saved to
:param ckpt_filename: ckpt_best.pth Or ckpt_latest.pth
:param remote_ckpt_source_dir: Name of the source checkpoint to be loaded (S3 Model\full URL)
:param path_src: S3 / url
:param overwrite_local_ckpt: determines if checkpoint will be saved in destination dir or in a temp folder
:return: Path to checkpoint
ckpt_file_full_local_path = None
# IF NOT DEFINED - IT IS SET TO THE TARGET's FOLDER NAME
remote_ckpt_source_dir = local_ckpt_destination_dir if remote_ckpt_source_dir is None else remote_ckpt_source_dir
if not overwrite_local_ckpt:
# CREATE A TEMP FOLDER TO SAVE THE CHECKPOINT TO
download_ckpt_destination_dir = tempfile.gettempdir()
print(
"PLEASE NOTICE - YOU ARE IMPORTING A REMOTE CHECKPOINT WITH overwrite_local_checkpoint = False "
"-> IT WILL BE REDIRECTED TO A TEMP FOLDER AND DELETED ON MACHINE RESTART"
else:
# SAVE THE CHECKPOINT TO MODEL's FOLDER
download_ckpt_destination_dir = pkg_resources.resource_filename("checkpoints", local_ckpt_destination_dir)
if path_src.startswith("s3"):
model_checkpoints_data_interface = ADNNModelRepositoryDataInterfaces(data_connection_location=path_src)
# DOWNLOAD THE FILE FROM S3 TO THE DESTINATION FOLDER
ckpt_file_full_local_path = model_checkpoints_data_interface.load_remote_checkpoints_file(
ckpt_source_remote_dir=remote_ckpt_source_dir,
ckpt_destination_local_dir=download_ckpt_destination_dir,
ckpt_file_name=ckpt_filename,
overwrite_local_checkpoints_file=overwrite_local_ckpt,
if not load_weights_only:
# COPY LOG FILES FROM THE REMOTE DIRECTORY TO THE LOCAL ONE ONLY IF LOADING THE CURRENT MODELs CKPT
model_checkpoints_data_interface.load_all_remote_log_files(
model_name=remote_ckpt_source_dir, model_checkpoint_local_dir=download_ckpt_destination_dir
if path_src == "url":
ckpt_file_full_local_path = download_ckpt_destination_dir + os.path.sep + ckpt_filename
# DOWNLOAD THE FILE FROM URL TO THE DESTINATION FOLDER
with wait_for_the_master(get_local_rank()):
download_url_to_file(remote_ckpt_source_dir, ckpt_file_full_local_path, progress=True)
return ckpt_file_full_local_path
Whether to call set_dataset_processing_params on "processing_params" entry inside the checkpoint file (default=False).
False
ckpt_local_path: str,
load_backbone: bool = False,
strict: Union[str, StrictLoad] = StrictLoad.NO_KEY_MATCHING,
load_weights_only: bool = False,
load_ema_as_net: bool = False,
load_processing_params: bool = False,
Loads the state dict in ckpt_local_path to net and returns the checkpoint's state dict.
:param load_ema_as_net: Will load the EMA inside the checkpoint file to the network when set
:param ckpt_local_path: local path to the checkpoint file
:param load_backbone: whether to load the checkpoint as a backbone
:param net: network to load the checkpoint to
:param strict:
:param load_weights_only: Whether to ignore all other entries other then "net".
:param load_processing_params: Whether to call set_dataset_processing_params on "processing_params" entry inside the
checkpoint file (default=False).
:return:
if isinstance(strict, str):
strict = StrictLoad(strict)
net = unwrap_model(net)
if load_backbone and not hasattr(net, "backbone"):
raise ValueError("No backbone attribute in net - Can't load backbone weights")
# LOAD THE LOCAL CHECKPOINT PATH INTO A state_dict OBJECT
checkpoint = read_ckpt_state_dict(ckpt_path=ckpt_local_path)
if load_ema_as_net:
if "ema_net" not in checkpoint.keys():
raise ValueError("Can't load ema network- no EMA network stored in checkpoint file")
else:
checkpoint["net"] = checkpoint["ema_net"]
# LOAD THE CHECKPOINTS WEIGHTS TO THE MODEL
if load_backbone:
adaptive_load_state_dict(net.backbone, checkpoint, strict)
else:
adaptive_load_state_dict(net, checkpoint, strict)
message_suffix = " checkpoint." if not load_ema_as_net else " EMA checkpoint."
message_model = "model" if not load_backbone else "model's backbone"
logger.info("Successfully loaded " + message_model + " weights from " + ckpt_local_path + message_suffix)
if (isinstance(net, HasPredict)) and load_processing_params:
if "processing_params" not in checkpoint.keys():
raise ValueError("Can't load processing params - could not find any stored in checkpoint file.")
try:
net.set_dataset_processing_params(**checkpoint["processing_params"])
except Exception as e:
logger.warning(
f"Could not set preprocessing pipeline from the checkpoint dataset: {e}. Before calling"
"predict make sure to call set_dataset_processing_params."
if load_weights_only or load_backbone:
# DISCARD ALL THE DATA STORED IN CHECKPOINT OTHER THAN THE WEIGHTS
[checkpoint.pop(key) for key in list(checkpoint.keys()) if key != "net"]
return checkpoint
330
def load_pretrained_weights(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: name for the pretrianed weights (i.e imagenet)
:return: None
from super_gradients.common.object_names import Models
model_url_key = architecture + "_" + str(pretrained_weights)
if model_url_key not in MODEL_URLS.keys():
raise MissingPretrainedWeightsException(model_url_key)
url = MODEL_URLS[model_url_key]
if architecture in {Models.YOLO_NAS_S, Models.YOLO_NAS_M, Models.YOLO_NAS_L}:
logger.info(
"License Notification: YOLO-NAS pre-trained weights are subjected to the specific license terms and conditions detailed in \n"
"https://github.com/Deci-AI/super-gradients/blob/master/LICENSE.YOLONAS.md\n"
"By downloading the pre-trained weight files you agree to comply with these terms."
unique_filename = url.split("https://sghub.deci.ai/models/")[1].replace("/", "_").replace(" ", "_")
map_location = torch.device("cpu")
with wait_for_the_master(get_local_rank()):
pretrained_state_dict = load_state_dict_from_url(url=url, map_location=map_location, file_name=unique_filename)
_load_weights(architecture, model, pretrained_state_dict)
353
def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pretrained_weights: str):
Loads pretrained weights from the MODEL_URLS dictionary to model
:param architecture: name of the model's architecture
:param model: model to load pretrinaed weights for
:param pretrained_weights: path tp pretrained weights
:return: None
map_location = torch.device("cpu")
pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location)
_load_weights(architecture, model, pretrained_state_dict)
Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
Source code in V3_2/src/super_gradients/training/utils/checkpoint_utils.py
199
def raise_informative_runtime_error(state_dict, checkpoint, exception_msg):
Given a model state dict and source checkpoints, the method calls "adapt_state_dict_to_fit_model_layer_names"
and enhances the exception_msg if loading the checkpoint_dict via the conversion method is possible
try:
new_ckpt_dict = adapt_state_dict_to_fit_model_layer_names(state_dict, checkpoint)
temp_file = tempfile.NamedTemporaryFile().name + ".pt"
torch.save(new_ckpt_dict, temp_file)
exception_msg = (
f"\n{'=' * 200}\n{str(exception_msg)} \nconvert ckpt via the utils.adapt_state_dict_to_fit_"
f"model_layer_names method\na converted checkpoint file was saved in the path {temp_file}\n{'=' * 200}"
except ValueError as ex: # IN CASE adapt_state_dict_to_fit_model_layer_names WAS UNSUCCESSFUL
exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
finally:
raise RuntimeError(exception_msg)
156
def read_ckpt_state_dict(ckpt_path: str, device="cpu") -> Mapping[str, torch.Tensor]:
Reads a checkpoint state dict from a given path or url
:param ckpt_path: Checkpoint path or url
:param device: Target device where tensors should be loaded
:return: Checkpoint state dict object
if ckpt_path.startswith("https://"):
with wait_for_the_master(get_local_rank()):
state_dict = load_state_dict_from_url(ckpt_path, progress=False, map_location=device)
return state_dict
else:
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Incorrect Checkpoint path: {ckpt_path} (This should be an absolute path)")
state_dict = torch.load(ckpt_path, map_location=device)
return state_dict
Copy weights from model_state_dict
to model
, skipping layers that are incompatible (Having different shape).
This method is helpful if you are doing some model surgery and want to load
part of the model weights into different model.
This function will go over all the layers in model_state_dict
and will try to find a matching layer in model
and
copy the weights into it. If shape will not match, the layer will be skipped.
Parameters:
Description
Default
44
def transfer_weights(model: nn.Module, model_state_dict: Mapping[str, Tensor]) -> None:
Copy weights from `model_state_dict` to `model`, skipping layers that are incompatible (Having different shape).
This method is helpful if you are doing some model surgery and want to load
part of the model weights into different model.
This function will go over all the layers in `model_state_dict` and will try to find a matching layer in `model` and
copy the weights into it. If shape will not match, the layer will be skipped.
:param model: Model to load weights into
:param model_state_dict: Model state dict to load weights from
:return: None
for name, value in model_state_dict.items():
try:
model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False)
except RuntimeError:
Implements access counting mechanism for configuration settings (dicts/lists).
It is achieved by wrapping underlying config and override getitem, getattr methods to catch read operations
and increments access counter for each property.
Source code in V3_2/src/super_gradients/training/utils/config_utils.py
77
class AccessCounterMixin:
Implements access counting mechanism for configuration settings (dicts/lists).
It is achieved by wrapping underlying config and override __getitem__, __getattr__ methods to catch read operations
and increments access counter for each property.
_access_counter: Mapping[str, int]
_prefix: str # Prefix string
def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
Return an attribute value optionally wrapped as access counter adapter to trace read counts.
:param value: Attribute value
:param key: Attribute name
:param count_usage: Whether increment usage count for given attribute. Default is True.
:return: wrapped value
key_with_prefix = self._prefix + str(key)
if count_usage:
self._access_counter[key_with_prefix] += 1
if isinstance(value, Mapping):
return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
if isinstance(value, Iterable) and not isinstance(value, str):
return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
return value
@property
def access_counter(self):
return self._access_counter
@abc.abstractmethod
def get_all_params(self) -> Set[str]:
raise NotImplementedError()
def get_used_params(self) -> Set[str]:
used_params = {k for (k, v) in self._access_counter.items() if v > 0}
return used_params
def get_unused_params(self) -> Set[str]:
unused_params = self.get_all_params() - self.get_used_params()
return unused_params
def __copy__(self):
cls = self.__class__
result = cls.__new__(cls)
result.__dict__.update(self.__dict__)
return result
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, deepcopy(v, memo))
return result
Return an attribute value optionally wrapped as access counter adapter to trace read counts.
Parameters:
Description
Default
47
def maybe_wrap_as_counter(self, value, key, count_usage: bool = True):
Return an attribute value optionally wrapped as access counter adapter to trace read counts.
:param value: Attribute value
:param key: Attribute name
:param count_usage: Whether increment usage count for given attribute. Default is True.
:return: wrapped value
key_with_prefix = self._prefix + str(key)
if count_usage:
self._access_counter[key_with_prefix] += 1
if isinstance(value, Mapping):
return AccessCounterDict(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
if isinstance(value, Iterable) and not isinstance(value, str):
return AccessCounterList(value, access_counter=self._access_counter, prefix=key_with_prefix + ".")
return value
A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
this check is to ensure there were no typo or outdated configuration parameters.
It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception.
Example usage:
from super_gradients.training.utils import raise_if_unused_params
with raise_if_unused_params(some_config) as some_config:
do_something_with_config(some_config)
243
def raise_if_unused_params(config: Union[HpmStruct, DictConfig, ListConfig, Mapping, list, tuple]) -> ConfigInspector:
A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
this check is to ensure there were no typo or outdated configuration parameters.
It at least one of config parameters was not used, this function will raise an UnusedConfigParamException exception.
Example usage:
>>> from super_gradients.training.utils import raise_if_unused_params
>>> with raise_if_unused_params(some_config) as some_config:
>>> do_something_with_config(some_config)
:param config: A config to check
:return: An instance of ConfigInspector
if isinstance(config, HpmStruct):
wrapper_cls = AccessCounterHpmStruct
elif isinstance(config, (Mapping, DictConfig)):
wrapper_cls = AccessCounterDict
elif isinstance(config, (list, tuple, ListConfig)):
wrapper_cls = AccessCounterList
else:
raise RuntimeError(f"Unsupported type. Root configuration object must be a mapping or list. Got type {type(config)}")
return ConfigInspector(wrapper_cls(config), unused_params_action="raise")
A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
this check is to ensure there were no typo or outdated configuration parameters.
It at least one of config parameters was not used, this function will emit warning.
Example usage:
from super_gradients.training.utils import warn_if_unused_params
with warn_if_unused_params(some_config) as some_config:
do_something_with_config(some_config)
271
def warn_if_unused_params(config):
A helper function to check whether all confuration parameters were used on given block of code. Motivation to have
this check is to ensure there were no typo or outdated configuration parameters.
It at least one of config parameters was not used, this function will emit warning.
Example usage:
>>> from super_gradients.training.utils import warn_if_unused_params
>>> with warn_if_unused_params(some_config) as some_config:
>>> do_something_with_config(some_config)
:param config: A config to check
:return: An instance of ConfigInspector
if isinstance(config, HpmStruct):
wrapper_cls = AccessCounterHpmStruct
elif isinstance(config, (Mapping, DictConfig)):
wrapper_cls = AccessCounterDict
elif isinstance(config, (list, tuple, ListConfig)):
wrapper_cls = AccessCounterList
else:
raise RuntimeError("Unsupported type. Root configuration object must be a mapping or list.")
return ConfigInspector(wrapper_cls(config), unused_params_action="warn")
from super_gradients.training.utils.deprecated_utils import wrap_with_warning
from super_gradients.training.utils.callbacks import EpochStepWarmupLRCallback, BatchStepLinearWarmupLRCallback
LR_WARMUP_CLS_DICT = {
"linear": wrap_with_warning(
EpochStepWarmupLRCallback,
message=f"Parameter linear
has been made deprecated and will be removed in the next SG release. Please use linear_epoch
instead",
'linear_epoch`': EpochStepWarmupLRCallback,
32
def wrap_with_warning(cls: Callable, message: str) -> Any:
Emits a warning when target class of function is called.
>>> from super_gradients.training.utils.deprecated_utils import wrap_with_warning
>>> from super_gradients.training.utils.callbacks import EpochStepWarmupLRCallback, BatchStepLinearWarmupLRCallback
>>> LR_WARMUP_CLS_DICT = {
>>> "linear": wrap_with_warning(
>>> EpochStepWarmupLRCallback,
>>> message=f"Parameter `linear` has been made deprecated and will be removed in the next SG release. Please use `linear_epoch` instead",
>>> 'linear_epoch`': EpochStepWarmupLRCallback,
:param cls: A class or function to wrap
:param message: A message to emit when this class is called
:return: A factory method that returns wrapped class
def _inner_fn(*args, **kwargs):
logger.warning(message)
return cls(*args, **kwargs)
return _inner_fn
598
class Anchors(nn.Module):
A wrapper function to hold the anchors used by detection models such as Yolo
def __init__(self, anchors_list: List[List], strides: List[int]):
:param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
the width and height of the anchors of a specific detection layer.
i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
The width and height are in pixels (not relative to image size)
:param strides: a list containing the stride of the layers from which the detection heads are fed.
i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
super().__init__()
self.__anchors_list = anchors_list
self.__strides = strides
self._check_all_lists(anchors_list)
self._check_all_len_equal_and_even(anchors_list)
self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
@staticmethod
def _check_all_lists(anchors: list) -> bool:
for a in anchors:
if not isinstance(a, (list, ListConfig)):
raise RuntimeError("All objects of anchors_list must be lists")
@staticmethod
def _check_all_len_equal_and_even(anchors: list) -> bool:
len_of_first = len(anchors[0])
for a in anchors:
if len(a) % 2 == 1 or len(a) != len_of_first:
raise RuntimeError("All objects of anchors_list must be of the same even length")
@property
def stride(self) -> nn.Parameter:
return self._stride
@property
def anchors(self) -> nn.Parameter:
return self._anchors
@property
def anchor_grid(self) -> nn.Parameter:
return self._anchor_grid
@property
def detection_layers_num(self) -> int:
return self._anchors.shape[0]
@property
def num_anchors(self) -> int:
return self._anchors.shape[1]
def __repr__(self):
return f"anchors_list: {self.__anchors_list} strides: {self.__strides}"
List[List]
of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds the width and height of the anchors of a specific detection layer. i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each The width and height are in pixels (not relative to image size)
required
strides
List[int]
a list containing the stride of the layers from which the detection heads are fed. i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
required
562
def __init__(self, anchors_list: List[List], strides: List[int]):
:param anchors_list: of the shape [[w1,h1,w2,h2,w3,h3], [w4,h4,w5,h5,w6,h6] .... where each sublist holds
the width and height of the anchors of a specific detection layer.
i.e. for a model with 3 detection layers, each containing 5 anchors the format will be a of 3 sublists of 10 numbers each
The width and height are in pixels (not relative to image size)
:param strides: a list containing the stride of the layers from which the detection heads are fed.
i.e. if the firs detection head is connected to the backbone after the input dimensions were reduces by 8, the first number will be 8
super().__init__()
self.__anchors_list = anchors_list
self.__strides = strides
self._check_all_lists(anchors_list)
self._check_all_len_equal_and_even(anchors_list)
self._stride = nn.Parameter(torch.Tensor(strides).float(), requires_grad=False)
anchors = torch.Tensor(anchors_list).float().view(len(anchors_list), -1, 2)
self._anchors = nn.Parameter(anchors / self._stride.view(-1, 1, 1), requires_grad=False)
self._anchor_grid = nn.Parameter(anchors.clone().view(len(anchors_list), 1, -1, 1, 1, 2), requires_grad=False)
826
@register_collate_function()
class CrowdDetectionCollateFN(DetectionCollateFN):
Collate function for Yolox training with additional_batch_items that includes crowd targets
def __init__(self):
super().__init__()
self.expected_item_names = ("image", "targets", "crowd_targets")
def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
807
class CrowdDetectionPPYoloECollateFN(PPYoloECollateFN):
Collate function for Yolox training with additional_batch_items that includes crowd targets
def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
super().__init__(random_resize_sizes, random_resize_modes)
self.expected_item_names = ("image", "targets", "crowd_targets")
def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
if self.random_resize_sizes is not None:
data = self.random_resize(data)
try:
images_batch, labels_batch, crowd_labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
return self._format_images(images_batch), self._format_targets(labels_batch), {"crowd_targets": self._format_targets(crowd_labels_batch)}
687
class DatasetItemsException(Exception):
def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
:param data_sample: item(s) returned by a dataset
:param collate_type: type of the collate that caused the exception
:param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
collate_type_name = collate_type.__name__
num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
super().__init__(error_msg)
Tuple
tuple of names of items that are expected by the collate to be returned from the dataset
required
687
def __init__(self, data_sample: Tuple, collate_type: Type, expected_item_names: Tuple):
:param data_sample: item(s) returned by a dataset
:param collate_type: type of the collate that caused the exception
:param expected_item_names: tuple of names of items that are expected by the collate to be returned from the dataset
collate_type_name = collate_type.__name__
num_sample_items = len(data_sample) if isinstance(data_sample, tuple) else 1
error_msg = f"`{collate_type_name}` only supports Datasets that return a tuple {expected_item_names}, but got a tuple of len={num_sample_items}"
super().__init__(error_msg)
self.expected_item_names = ("image", "targets")
def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
try:
images_batch, labels_batch = list(zip(*data))
except (ValueError, TypeError):
raise DatasetItemsException(data_sample=data[0], collate_type=type(self), expected_item_names=self.expected_item_names)
return self._format_images(images_batch), self._format_targets(labels_batch)
def _format_images(self, images_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
images_batch = [torch.tensor(img) for img in images_batch]
images_batch_stack = torch.stack(images_batch, 0)
if images_batch_stack.shape[3] == 3:
images_batch_stack = torch.moveaxis(images_batch_stack, -1, 1).float()
return images_batch_stack
def _format_targets(self, labels_batch: List[Union[torch.Tensor, np.array]]) -> torch.Tensor:
Stack a batch id column to targets and concatenate
:param labels_batch: a list of targets per image (each of arbitrary length)
:return: one tensor of targets of all imahes of shape [N, 6], where N is the total number of targets in a batch
and the 1st column is batch item index
labels_batch = [torch.tensor(labels) for labels in labels_batch]
labels_batch_indexed = []
for i, labels in enumerate(labels_batch):
batch_column = labels.new_ones((labels.shape[0], 1)) * i
labels = torch.cat((batch_column, labels), dim=-1)
labels_batch_indexed.append(labels)
return torch.cat(labels_batch_indexed, 0)
194
class DetectionPostPredictionCallback(ABC, nn.Module):
def __init__(self) -> None:
super().__init__()
@abstractmethod
def forward(self, x, device: str):
:param x: the output of your model
:param device: the device to move all output tensors into
:return: a list with length batch_size, each item in the list is a detections
with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
raise NotImplementedError
:param x: the output of your model
:param device: the device to move all output tensors into
:return: a list with length batch_size, each item in the list is a detections
with shape: nx6 (x1, y1, x2, y2, confidence, class) where x and y are in range [0,1]
raise NotImplementedError
Enum class for the different detection output formats
When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).
For example:
LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]
Source code in V3_2/src/super_gradients/training/utils/detection_utils.py
39
class DetectionTargetsFormat(Enum):
Enum class for the different detection output formats
When NORMALIZED is not specified- the type refers to unnormalized image coordinates (of the bboxes).
For example:
LABEL_NORMALIZED_XYXY means [class_idx,x1,y1,x2,y2]
LABEL_XYXY = "LABEL_XYXY"
XYXY_LABEL = "XYXY_LABEL"
LABEL_NORMALIZED_XYXY = "LABEL_NORMALIZED_XYXY"
NORMALIZED_XYXY_LABEL = "NORMALIZED_XYXY_LABEL"
LABEL_CXCYWH = "LABEL_CXCYWH"
CXCYWH_LABEL = "CXCYWH_LABEL"
LABEL_NORMALIZED_CXCYWH = "LABEL_NORMALIZED_CXCYWH"
NORMALIZED_CXCYWH_LABEL = "NORMALIZED_CXCYWH_LABEL"
534
class DetectionVisualization:
@staticmethod
def _generate_color_mapping(num_classes: int) -> List[Tuple[int]]:
Generate a unique BGR color for each class
return generate_color_mapping(num_classes=num_classes)
@staticmethod
def _draw_box_title(
color_mapping: List[Tuple[int]],
class_names: List[str],
box_thickness: int,
image_np: np.ndarray,
x1: int,
y1: int,
x2: int,
y2: int,
class_id: int,
pred_conf: float = None,
is_target: bool = False,
color = color_mapping[class_id]
class_name = class_names[class_id]
if is_target:
title = f"[GT] {class_name}"
else:
title = f'[Pred] {class_name} {str(round(pred_conf, 2)) if pred_conf is not None else ""}'
image_np = draw_bbox(image=image_np, title=title, x1=x1, y1=y1, x2=x2, y2=y2, box_thickness=box_thickness, color=color)
return image_np
@staticmethod
def _visualize_image(
image_np: np.ndarray,
pred_boxes: np.ndarray,
target_boxes: np.ndarray,
class_names: List[str],
box_thickness: int,
gt_alpha: float,
image_scale: float,
checkpoint_dir: str,
image_name: str,
image_np = cv2.resize(image_np, (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)
color_mapping = DetectionVisualization._generate_color_mapping(len(class_names))
# Draw predictions
pred_boxes[:, :4] *= image_scale
for box in pred_boxes:
image_np = DetectionVisualization._draw_box_title(
color_mapping, class_names, box_thickness, image_np, *box[:4].astype(int), class_id=int(box[5]), pred_conf=box[4]
# Draw ground truths
target_boxes_image = np.zeros_like(image_np, np.uint8)
for box in target_boxes:
target_boxes_image = DetectionVisualization._draw_box_title(
color_mapping, class_names, box_thickness, target_boxes_image, *box[2:], class_id=box[1], is_target=True
# Transparent overlay of ground truth boxes
mask = target_boxes_image.astype(bool)
image_np[mask] = cv2.addWeighted(image_np, 1 - gt_alpha, target_boxes_image, gt_alpha, 0)[mask]
if checkpoint_dir is None:
return image_np
else:
pathlib.Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), image_np)
@staticmethod
def _scaled_ccwh_to_xyxy(target_boxes: np.ndarray, h: int, w: int, image_scale: float) -> np.ndarray:
Modifies target_boxes inplace
:param target_boxes: (c1, c2, w, h) boxes in [0, 1] range
:param h: image height
:param w: image width
:param image_scale: desired scale for the boxes w.r.t. w and h
:return: targets in (x1, y1, x2, y2) format
in range [0, w * self.image_scale] [0, h * self.image_scale]
# unscale
target_boxes[:, 2:] *= np.array([[w, h, w, h]])
# x1 = c1 - w // 2; y1 = c2 - h // 2
target_boxes[:, 2] -= target_boxes[:, 4] // 2
target_boxes[:, 3] -= target_boxes[:, 5] // 2
# x2 = w + x1; y2 = h + y1
target_boxes[:, 4] += target_boxes[:, 2]
target_boxes[:, 5] += target_boxes[:, 3]
target_boxes[:, 2:] *= image_scale
target_boxes = target_boxes.astype(int)
return target_boxes
@staticmethod
def visualize_batch(
image_tensor: torch.Tensor,
pred_boxes: List[torch.Tensor],
target_boxes: torch.Tensor,
batch_name: Union[int, str],
class_names: List[str],
checkpoint_dir: str = None,
undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing,
box_thickness: int = 2,
image_scale: float = 1.0,
gt_alpha: float = 0.4,
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
Adjustable:
* Ground truth box transparency;
* Box width;
* Image size (larger or smaller than what's provided)
:param image_tensor: rgb images, (B, H, W, 3)
:param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),
values on dim 1 are: x1, y1, x2, y2, confidence, class
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
(coordinates scaled to [0, 1])
:param batch_name: id of the current batch to use for image naming
:param class_names: names of all classes, each on its own index
:param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will
be returns as a list of numpy image arrays
:param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
:param box_thickness: box line thickness in px
:param image_scale: scale of an image w.r.t. given image size,
e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
:param gt_alpha: a value in [0., 1.] transparency on ground truth boxes,
0 for invisible, 1 for fully opaque
image_np = undo_preprocessing_func(image_tensor.detach())
targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale)
out_images = []
for i in range(image_np.shape[0]):
preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6))
targets_cur = targets[targets[:, 0] == i]
image_name = "_".join([str(batch_name), str(i)])
res_image = DetectionVisualization._visualize_image(
image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale, checkpoint_dir, image_name
if res_image is not None:
out_images.append(res_image)
return out_images
visualize_batch(image_tensor, pred_boxes, target_boxes, batch_name, class_names, checkpoint_dir=None, undo_preprocessing_func=undo_image_preprocessing, box_thickness=2, image_scale=1.0, gt_alpha=0.4)
staticmethod
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
Adjustable:
* Ground truth box transparency;
* Box width;
* Image size (larger or smaller than what's provided)
Parameters:
Description
Default
List[torch.Tensor]
boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class
required
target_boxes
torch.Tensor
(Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1])
required
batch_name
Union[int, str]
id of the current batch to use for image naming
required
class_names
List[str]
names of all classes, each on its own index
required
checkpoint_dir
a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays
undo_preprocessing_func
Callable[[torch.Tensor], np.ndarray]
a function to convert preprocessed images tensor into a batch of cv2-like images
undo_image_preprocessing
box_thickness
box line thickness in px
image_scale
float
scale of an image w.r.t. given image size, e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
gt_alpha
float
a value in [0., 1.] transparency on ground truth boxes, 0 for invisible, 1 for fully opaque
def visualize_batch(
image_tensor: torch.Tensor,
pred_boxes: List[torch.Tensor],
target_boxes: torch.Tensor,
batch_name: Union[int, str],
class_names: List[str],
checkpoint_dir: str = None,
undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = undo_image_preprocessing,
box_thickness: int = 2,
image_scale: float = 1.0,
gt_alpha: float = 0.4,
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
Adjustable:
* Ground truth box transparency;
* Box width;
* Image size (larger or smaller than what's provided)
:param image_tensor: rgb images, (B, H, W, 3)
:param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),
values on dim 1 are: x1, y1, x2, y2, confidence, class
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
(coordinates scaled to [0, 1])
:param batch_name: id of the current batch to use for image naming
:param class_names: names of all classes, each on its own index
:param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will
be returns as a list of numpy image arrays
:param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
:param box_thickness: box line thickness in px
:param image_scale: scale of an image w.r.t. given image size,
e.g. incoming images are (320x320), use scale = 2. to preview in (640x640)
:param gt_alpha: a value in [0., 1.] transparency on ground truth boxes,
0 for invisible, 1 for fully opaque
image_np = undo_preprocessing_func(image_tensor.detach())
targets = DetectionVisualization._scaled_ccwh_to_xyxy(target_boxes.detach().cpu().numpy(), *image_np.shape[1:3], image_scale)
out_images = []
for i in range(image_np.shape[0]):
preds = pred_boxes[i].detach().cpu().numpy() if pred_boxes[i] is not None else np.empty((0, 6))
targets_cur = targets[targets[:, 0] == i]
image_name = "_".join([str(batch_name), str(i)])
res_image = DetectionVisualization._visualize_image(
image_np[i], preds, targets_cur, class_names, box_thickness, gt_alpha, image_scale, checkpoint_dir, image_name
if res_image is not None:
out_images.append(res_image)
return out_images
220
class IouThreshold(tuple, Enum):
MAP_05 = (0.5, 0.5)
MAP_05_TO_095 = (0.5, 0.95)
def is_range(self):
return self[0] != self[1]
def to_tensor(self):
if self.is_range():
return self.from_bounds(self[0], self[1], step=0.05)
else:
return torch.tensor([self[0]])
@classmethod
def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
Create a tensor with values from low (including) to high (including) with a given step size.
:param low: Lower bound
:param high: Upper bound
:param step: Step size
:return: Tensor of [low, low + step, low + 2 * step, ..., high]
n_iou_thresh = int(round((high - low) / step)) + 1
return torch.linspace(low, high, n_iou_thresh)
Create a tensor with values from low (including) to high (including) with a given step size.
Parameters:
Description
Default
220
@classmethod
def from_bounds(cls, low: float, high: float, step: float = 0.05) -> torch.Tensor:
Create a tensor with values from low (including) to high (including) with a given step size.
:param low: Lower bound
:param high: Upper bound
:param step: Step size
:return: Tensor of [low, low + step, low + 2 * step, ..., high]
n_iou_thresh = int(round((high - low) / step)) + 1
return torch.linspace(low, high, n_iou_thresh)
366
class NMS_Type(str, Enum):
Type of non max suppression algorithm that can be used for post processing detection
ITERATIVE = "iterative"
MATRIX = "matrix"
785
class PPYoloECollateFN(DetectionCollateFN):
Collate function for PPYoloE training
def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
:param random_resize_sizes: (rows, cols)
super().__init__()
self.random_resize_sizes = random_resize_sizes
self.random_resize_modes = random_resize_modes
def __repr__(self):
return f"PPYoloECollateFN(random_resize_sizes={self.random_resize_sizes}, random_resize_modes={self.random_resize_modes})"
def __str__(self):
return self.__repr__()
def __call__(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
if self.random_resize_sizes is not None:
data = self.random_resize(data)
return super().__call__(data)
def random_resize(self, batch):
target_size = random.choice(self.random_resize_sizes)
interpolation = random.choice(self.random_resize_modes)
batch = [self.random_resize_sample(sample, target_size, interpolation) for sample in batch]
return batch
def random_resize_sample(self, sample, target_size, interpolation):
if len(sample) == 2:
image, targets = sample # TARGETS ARE IN LABEL_CXCYWH
with_crowd = False
elif len(sample) == 3:
image, targets, crowd_targets = sample
with_crowd = True
else:
raise DatasetItemsException(data_sample=sample, collate_type=type(self), expected_item_names=self.expected_item_names)
dsize = int(target_size), int(target_size)
scale_factors = target_size / image.shape[0], target_size / image.shape[1]
image = cv2.resize(
image,
dsize=dsize,
interpolation=interpolation,
sy, sx = scale_factors
targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
if with_crowd:
crowd_targets[:, 1:5] *= np.array([[sx, sy, sx, sy]], dtype=targets.dtype)
return image, targets, crowd_targets
return image, targets
741
def __init__(self, random_resize_sizes: Union[List[int], None] = None, random_resize_modes: Union[List[int], None] = None):
:param random_resize_sizes: (rows, cols)
super().__init__()
self.random_resize_sizes = random_resize_sizes
self.random_resize_modes = random_resize_modes
674
def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
Adjusts the bbox annotations of rescaled, padded image.
:param bbox: (np.array) bbox to modify.
:param scale_ratio: (float) scale ratio between rescale output image and original one.
:param padw: (int) width padding size.
:param padh: (int) height padding size.
:param w_max: (int) width border.
:param h_max: (int) height border
:return: modified bbox (np.array)
bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
return bbox
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Parameters:
Description
Default
242
def box_iou(box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
:param box1: Tensor of shape [N, 4]
:param box2: Tensor of shape [M, 4]
:return: iou, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
def box_area(box):
# box = 4xn
return (box[2] - box[0]) * (box[3] - box[1])
area1 = box_area(box1.T)
area2 = box_area(box2.T)
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
return inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
torch.Tensor
a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where each box format is [x1,y1,x2,y2]
required
171
def calc_bbox_iou_matrix(pred: torch.Tensor):
calculate iou for every pair of boxes in the boxes vector
:param pred: a 3-dimensional tensor containing all boxes for a batch of images [N, num_boxes, 4], where
each box format is [x1,y1,x2,y2]
:return: a 3-dimensional matrix where M_i_j_k is the iou of box j and box k of the i'th image in the batch
box = pred[:, :, :4] #
b1_x1, b1_y1 = box[:, :, 0].unsqueeze(1), box[:, :, 1].unsqueeze(1)
b1_x2, b1_y2 = box[:, :, 2].unsqueeze(1), box[:, :, 3].unsqueeze(1)
b2_x1 = b1_x1.transpose(2, 1)
b2_x2 = b1_x2.transpose(2, 1)
b2_y1 = b1_y1.transpose(2, 1)
b2_y2 = b1_y2.transpose(2, 1)
intersection_area = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
# Union Area
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
union_area = (w1 * h1 + 1e-16) + w2 * h2 - intersection_area
ious = intersection_area / union_area
return ious
calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2
Parameters:
Description
Default
147
def calculate_bbox_iou_matrix(box1, box2, x1y1x2y2=True, GIoU: bool = False, DIoU=False, CIoU=False, eps=1e-9):
calculate iou matrix containing the iou of every couple iuo(i,j) where i is in box1 and j is in box2
:param box1: a 2D tensor of boxes (shape N x 4)
:param box2: a 2D tensor of boxes (shape M x 4)
:param x1y1x2y2: boxes format is x1y1x2y2 (True) or xywh where xy is the center (False)
:return: a 2D iou matrix (shape NxM)
if box1.dim() > 1:
box1 = box1.T
# Get the coordinates of bounding boxes
if x1y1x2y2: # x1, y1, x2, y2 = box1
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
else: # x, y, w, h = box1
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
b1_x1, b1_y1, b1_x2, b1_y2 = b1_x1.unsqueeze(1), b1_y1.unsqueeze(1), b1_x2.unsqueeze(1), b1_y2.unsqueeze(1)
return _iou(CIoU, DIoU, GIoU, b1_x1, b1_x2, b1_y1, b1_y2, b2_x1, b2_x2, b2_y1, b2_y2, eps)
836
def compute_box_area(box: torch.Tensor) -> torch.Tensor:
Compute the area of one or many boxes.
:param box: One or many boxes, shape = (4, ?), each box in format (x1, y1, x2, y2)
:return: Area of every box, shape = (1, ?)
# box = 4xn
return (box[2] - box[0]) * (box[3] - box[1])
Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.
Parameters:
Description
Default
List[torch.Tensor]
list (of length batch_size) of Tensors of shape (num_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
required
targets
torch.Tensor
targets for all images of shape (total_num_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
required
height
dimensions of the image
required
width
dimensions of the image
required
iou_thresholds
torch.Tensor
Threshold to compute the mAP
required
device
Device
required
crowd_targets
Optional[torch.Tensor]
crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
top_k
Number of predictions to keep per class, ordered by confidence score
denormalize_targets
If True, denormalize the targets and crowd_targets
required
return_on_cpu
If True, the output will be returned on "CPU", otherwise it will be returned on "device"
List[Tuple]
list of the following tensors, for every image: :preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target
919
def compute_detection_matching(
output: List[torch.Tensor],
targets: torch.Tensor,
height: int,
width: int,
iou_thresholds: torch.Tensor,
denormalize_targets: bool,
device: str,
crowd_targets: Optional[torch.Tensor] = None,
top_k: int = 100,
return_on_cpu: bool = True,
) -> List[Tuple]:
Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score.
:param output: list (of length batch_size) of Tensors of shape (num_predictions, 6)
format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
:param targets: targets for all images of shape (total_num_targets, 6)
format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
:param height: dimensions of the image
:param width: dimensions of the image
:param iou_thresholds: Threshold to compute the mAP
:param device: Device
:param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)
format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
:param top_k: Number of predictions to keep per class, ordered by confidence score
:param denormalize_targets: If True, denormalize the targets and crowd_targets
:param return_on_cpu: If True, the output will be returned on "CPU", otherwise it will be returned on "device"
:return: list of the following tensors, for every image:
:preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)
True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
:preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)
True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
:preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction
:preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction
:targets_cls: Tensor of shape (num_img_targets), ground truth class for every target
output = map(lambda tensor: None if tensor is None else tensor.to(device), output)
targets, iou_thresholds = targets.to(device), iou_thresholds.to(device)
# If crowd_targets is not provided, we patch it with an empty tensor
crowd_targets = torch.zeros(size=(0, 6), device=device) if crowd_targets is None else crowd_targets.to(device)
batch_metrics = []
for img_i, img_preds in enumerate(output):
# If img_preds is None (not prediction for this image), we patch it with an empty tensor
img_preds = img_preds if img_preds is not None else torch.zeros(size=(0, 6), device=device)
img_targets = targets[targets[:, 0] == img_i, 1:]
img_crowd_targets = crowd_targets[crowd_targets[:, 0] == img_i, 1:]
img_matching_tensors = compute_img_detection_matching(
preds=img_preds,
targets=img_targets,
crowd_targets=img_crowd_targets,
denormalize_targets=denormalize_targets,
height=height,
width=width,
device=device,
iou_thresholds=iou_thresholds,
top_k=top_k,
return_on_cpu=return_on_cpu,
batch_metrics.append(img_matching_tensors)
return batch_metrics
Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.
Parameters:
Description
Default
torch.Tensor
Tensor of shape (num_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
required
preds_scores
torch.Tensor
Tensor of shape (num_predictions), confidence score for every prediction
required
preds_cls
torch.Tensor
Tensor of shape (num_predictions), predicted class for every prediction
required
targets_cls
torch.Tensor
Tensor of shape (num_targets), ground truth class for every target box to be detected
required
recall_thresholds
Optional[torch.Tensor]
Recall thresholds used to compute MaP.
score_threshold
Optional[float]
Minimum confidence score to consider a prediction for the computation of precision, recall and f1 (not MaP)
device
Device
required
calc_best_score_thresholds
If True, the best confidence score threshold is computed for each class
False
Tuple
:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs) :unique_classes: Vector with all unique target classes :best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds is True else None :best_score_threshold_per_cls: dict that stores the best score threshold for each class , if calc_best_score_thresholds is True else None
1170
def compute_detection_metrics(
preds_matched: torch.Tensor,
preds_to_ignore: torch.Tensor,
preds_scores: torch.Tensor,
preds_cls: torch.Tensor,
targets_cls: torch.Tensor,
device: str,
recall_thresholds: Optional[torch.Tensor] = None,
score_threshold: Optional[float] = 0.1,
calc_best_score_thresholds: bool = False,
) -> Tuple:
Compute the list of precision, recall, MaP and f1 for every recall IoU threshold and for every class.
:param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
:param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
:param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction
:param preds_cls: Tensor of shape (num_predictions), predicted class for every prediction
:param targets_cls: Tensor of shape (num_targets), ground truth class for every target box to be detected
:param recall_thresholds: Recall thresholds used to compute MaP.
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision, recall and f1 (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for each class
:return:
:ap, precision, recall, f1: Tensors of shape (n_class, nb_iou_thrs)
:unique_classes: Vector with all unique target classes
:best_score_threshold: torch.float with the best overall score threshold if calc_best_score_thresholds
is True else None
:best_score_threshold_per_cls: dict that stores the best score threshold for each class , if
calc_best_score_thresholds is True else None
preds_matched, preds_to_ignore = preds_matched.to(device), preds_to_ignore.to(device)
preds_scores, preds_cls, targets_cls = preds_scores.to(device), preds_cls.to(device), targets_cls.to(device)
recall_thresholds = torch.linspace(0, 1, 101, device=device) if recall_thresholds is None else recall_thresholds.to(device)
unique_classes = torch.unique(targets_cls).long()
n_class, nb_iou_thrs = len(unique_classes), preds_matched.shape[-1]
ap = torch.zeros((n_class, nb_iou_thrs), device=device)
precision = torch.zeros((n_class, nb_iou_thrs), device=device)
recall = torch.zeros((n_class, nb_iou_thrs), device=device)
nb_score_thrs = 101
all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
f1_per_class_per_threshold = torch.zeros((n_class, nb_score_thrs), device=device) if calc_best_score_thresholds else None
best_score_threshold_per_cls = dict() if calc_best_score_thresholds else None
for cls_i, cls in enumerate(unique_classes):
cls_preds_idx, cls_targets_idx = (preds_cls == cls), (targets_cls == cls)
cls_ap, cls_precision, cls_recall, cls_f1_per_threshold, cls_best_score_threshold = compute_detection_metrics_per_cls(
preds_matched=preds_matched[cls_preds_idx],
preds_to_ignore=preds_to_ignore[cls_preds_idx],
preds_scores=preds_scores[cls_preds_idx],
n_targets=cls_targets_idx.sum(),
recall_thresholds=recall_thresholds,
score_threshold=score_threshold,
device=device,
calc_best_score_thresholds=calc_best_score_thresholds,
nb_score_thrs=nb_score_thrs,
ap[cls_i, :] = cls_ap
precision[cls_i, :] = cls_precision
recall[cls_i, :] = cls_recall
if calc_best_score_thresholds:
f1_per_class_per_threshold[cls_i, :] = cls_f1_per_threshold
best_score_threshold_per_cls[f"Best_score_threshold_cls_{int(cls)}"] = cls_best_score_threshold
f1 = 2 * precision * recall / (precision + recall + 1e-16)
if calc_best_score_thresholds:
mean_f1_across_classes = torch.mean(f1_per_class_per_threshold, dim=0)
best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_across_classes)]
else:
best_score_threshold = None
return ap, precision, recall, f1, unique_classes, best_score_threshold, best_score_threshold_per_cls
Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.
:param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a target
with respect to the(j)th IoU threshold
:param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a crowd target
with respect to the (j)th IoU threshold
:param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction
:param n_targets: Number of target boxes of this class
:param recall_thresholds: Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision and recall (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class
:param nb_score_thrs: Number of score thresholds to consider when calc_best_score_thresholds is True
:return:
:ap, precision, recall: Tensors of shape (nb_iou_thrs)
:mean_f1_per_threshold: Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
:best_score_threshold: torch.float if calc_best_score_thresholds is True else None
Source code in V3_2/src/super_gradients/training/utils/detection_utils.py
1298
def compute_detection_metrics_per_cls(
preds_matched: torch.Tensor,
preds_to_ignore: torch.Tensor,
preds_scores: torch.Tensor,
n_targets: int,
recall_thresholds: torch.Tensor,
score_threshold: float,
device: str,
calc_best_score_thresholds: bool = False,
nb_score_thrs: int = 101,
Compute the list of precision, recall and MaP of a given class for every recall IoU threshold.
:param preds_matched: Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a target
with respect to the(j)th IoU threshold
:param preds_to_ignore Tensor of shape (num_predictions, n_iou_thresholds)
True when prediction (i) is matched with a crowd target
with respect to the (j)th IoU threshold
:param preds_scores: Tensor of shape (num_predictions), confidence score for every prediction
:param n_targets: Number of target boxes of this class
:param recall_thresholds: Tensor of shape (max_n_rec_thresh) list of recall thresholds used to compute MaP
:param score_threshold: Minimum confidence score to consider a prediction for the computation of
precision and recall (not MaP)
:param device: Device
:param calc_best_score_thresholds: If True, the best confidence score threshold is computed for this class
:param nb_score_thrs: Number of score thresholds to consider when calc_best_score_thresholds is True
:return:
:ap, precision, recall: Tensors of shape (nb_iou_thrs)
:mean_f1_per_threshold: Tensor of shape (nb_score_thresholds) if calc_best_score_thresholds is True else None
:best_score_threshold: torch.float if calc_best_score_thresholds is True else None
nb_iou_thrs = preds_matched.shape[-1]
mean_f1_per_threshold = torch.zeros(nb_score_thrs, device=device) if calc_best_score_thresholds else None
best_score_threshold = torch.tensor(0.0, dtype=torch.float, device=device) if calc_best_score_thresholds else None
tps = preds_matched
fps = torch.logical_and(torch.logical_not(preds_matched), torch.logical_not(preds_to_ignore))
if len(tps) == 0:
return (
torch.zeros(nb_iou_thrs, device=device),
torch.zeros(nb_iou_thrs, device=device),
torch.zeros(nb_iou_thrs, device=device),
mean_f1_per_threshold,
best_score_threshold,
# Sort by decreasing score
dtype = torch.uint8 if preds_scores.is_cuda and preds_scores.dtype is torch.bool else preds_scores.dtype
sort_ind = torch.argsort(preds_scores.to(dtype), descending=True)
tps = tps[sort_ind, :]
fps = fps[sort_ind, :]
preds_scores = preds_scores[sort_ind].contiguous()
# Rolling sum over the predictions
rolling_tps = torch.cumsum(tps, axis=0, dtype=torch.float)
rolling_fps = torch.cumsum(fps, axis=0, dtype=torch.float)
rolling_recalls = rolling_tps / n_targets
rolling_precisions = rolling_tps / (rolling_tps + rolling_fps + torch.finfo(torch.float64).eps)
# Reversed cummax to only have decreasing values
rolling_precisions = rolling_precisions.flip(0).cummax(0).values.flip(0)
# ==================
# RECALL & PRECISION
# We want the rolling precision/recall at index i so that: preds_scores[i-1] >= score_threshold > preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
# Note2: right=True due to negation
lowest_score_above_threshold = torch.searchsorted(-preds_scores, -score_threshold, right=True)
if lowest_score_above_threshold == 0: # Here score_threshold > preds_scores[0], so no pred is above the threshold
recall = torch.zeros(nb_iou_thrs, device=device)
precision = torch.zeros(nb_iou_thrs, device=device) # the precision is not really defined when no pred but we need to give it a value
else:
recall = rolling_recalls[lowest_score_above_threshold - 1]
precision = rolling_precisions[lowest_score_above_threshold - 1]
# ==================
# BEST CONFIDENCE SCORE THRESHOLD PER CLASS
if calc_best_score_thresholds:
all_score_thresholds = torch.linspace(0, 1, nb_score_thrs, device=device)
# We want the rolling precision/recall at index i so that: preds_scores[i-1] > score_threshold >= preds_scores[i]
# Note: torch.searchsorted works on increasing sequence and preds_scores is decreasing, so we work with "-"
lowest_scores_above_thresholds = torch.searchsorted(-preds_scores, -all_score_thresholds, right=True)
# When score_threshold > preds_scores[0], then no pred is above the threshold, so we pad with zeros
rolling_recalls_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_recalls), dim=0)
rolling_precisions_padded = torch.cat((torch.zeros(1, nb_iou_thrs, device=device), rolling_precisions), dim=0)
# shape = (n_score_thresholds, nb_iou_thrs)
recalls_per_threshold = torch.index_select(input=rolling_recalls_padded, dim=0, index=lowest_scores_above_thresholds)
precisions_per_threshold = torch.index_select(input=rolling_precisions_padded, dim=0, index=lowest_scores_above_thresholds)
# shape (n_score_thresholds, nb_iou_thrs)
f1_per_threshold = 2 * recalls_per_threshold * precisions_per_threshold / (recalls_per_threshold + precisions_per_threshold + 1e-16)
mean_f1_per_threshold = torch.mean(f1_per_threshold, dim=1) # average over iou thresholds
best_score_threshold = all_score_thresholds[torch.argmax(mean_f1_per_threshold)]
# ==================
# AVERAGE PRECISION
# shape = (nb_iou_thrs, n_recall_thresholds)
recall_thresholds = recall_thresholds.view(1, -1).repeat(nb_iou_thrs, 1)
# We want the index i so that: rolling_recalls[i-1] < recall_thresholds[k] <= rolling_recalls[i]
# Note: when recall_thresholds[k] > max(rolling_recalls), i = len(rolling_recalls)
# Note2: we work with transpose (.T) to apply torch.searchsorted on first dim instead of the last one
recall_threshold_idx = torch.searchsorted(rolling_recalls.T.contiguous(), recall_thresholds, right=False).T
# When recall_thresholds[k] > max(rolling_recalls), rolling_precisions[i] is not defined, and we want precision = 0
rolling_precisions = torch.cat((rolling_precisions, torch.zeros(1, nb_iou_thrs, device=device)), dim=0)
# shape = (n_recall_thresholds, nb_iou_thrs)
sampled_precision_points = torch.gather(input=rolling_precisions, index=recall_threshold_idx, dim=0)
# Average over the recall_thresholds
ap = sampled_precision_points.mean(0)
return ap, precision, recall, mean_f1_per_threshold, best_score_threshold
Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score
for a given image.
Parameters:
Description
Default
torch.Tensor
Tensor of shape (num_img_predictions, 6) format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
required
targets
torch.Tensor
targets for this image of shape (num_img_targets, 6) format: (label, cx, cy, w, h, label) where cx,cy,w,h
required
height
dimensions of the image
required
width
dimensions of the image
required
iou_thresholds
torch.Tensor
Threshold to compute the mAP
required
device
required
crowd_targets
torch.Tensor
crowd targets for all images of shape (total_num_crowd_targets, 6) format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
required
top_k
Number of predictions to keep per class, ordered by confidence score
denormalize_targets
If True, denormalize the targets and crowd_targets
required
return_on_cpu
If True, the output will be returned on "CPU", otherwise it will be returned on "device"
Tuple
:preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a target with respect to the (j)th IoU threshold :preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds) True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold :preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction :preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction :targets_cls: Tensor of shape (num_img_targets), ground truth class for every target
1068
def compute_img_detection_matching(
preds: torch.Tensor,
targets: torch.Tensor,
crowd_targets: torch.Tensor,
height: int,
width: int,
iou_thresholds: torch.Tensor,
device: str,
denormalize_targets: bool,
top_k: int = 100,
return_on_cpu: bool = True,
) -> Tuple:
Match predictions (NMS output) and the targets (ground truth) with respect to IoU and confidence score
for a given image.
:param preds: Tensor of shape (num_img_predictions, 6)
format: (x1, y1, x2, y2, confidence, class_label) where x1,y1,x2,y2 are according to image size
:param targets: targets for this image of shape (num_img_targets, 6)
format: (label, cx, cy, w, h, label) where cx,cy,w,h
:param height: dimensions of the image
:param width: dimensions of the image
:param iou_thresholds: Threshold to compute the mAP
:param device:
:param crowd_targets: crowd targets for all images of shape (total_num_crowd_targets, 6)
format: (index, x, y, w, h, label) where x,y,w,h are in range [0,1]
:param top_k: Number of predictions to keep per class, ordered by confidence score
:param device: Device
:param denormalize_targets: If True, denormalize the targets and crowd_targets
:param return_on_cpu: If True, the output will be returned on "CPU", otherwise it will be returned on "device"
:return:
:preds_matched: Tensor of shape (num_img_predictions, n_iou_thresholds)
True when prediction (i) is matched with a target with respect to the (j)th IoU threshold
:preds_to_ignore: Tensor of shape (num_img_predictions, n_iou_thresholds)
True when prediction (i) is matched with a crowd target with respect to the (j)th IoU threshold
:preds_scores: Tensor of shape (num_img_predictions), confidence score for every prediction
:preds_cls: Tensor of shape (num_img_predictions), predicted class for every prediction
:targets_cls: Tensor of shape (num_img_targets), ground truth class for every target
num_iou_thresholds = len(iou_thresholds)
if preds is None or len(preds) == 0:
if return_on_cpu:
device = "cpu"
preds_matched = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
preds_to_ignore = torch.zeros((0, num_iou_thresholds), dtype=torch.bool, device=device)
preds_scores = torch.tensor([], dtype=torch.float32, device=device)
preds_cls = torch.tensor([], dtype=torch.float32, device=device)
targets_cls = targets[:, 0].to(device=device)
return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
preds_matched = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
targets_matched = torch.zeros(len(targets), num_iou_thresholds, dtype=torch.bool, device=device)
preds_to_ignore = torch.zeros(len(preds), num_iou_thresholds, dtype=torch.bool, device=device)
preds_cls, preds_box, preds_scores = preds[:, -1], preds[:, 0:4], preds[:, 4]
targets_cls, targets_box = targets[:, 0], targets[:, 1:5]
crowd_targets_cls, crowd_target_box = crowd_targets[:, 0], crowd_targets[:, 1:5]
# Ignore all but the predictions that were top_k for their class
preds_idx_to_use = get_top_k_idx_per_cls(preds_scores, preds_cls, top_k)
preds_to_ignore[:, :] = True
preds_to_ignore[preds_idx_to_use] = False
if len(targets) > 0 or len(crowd_targets) > 0:
# CHANGE bboxes TO FIT THE IMAGE SIZE
change_bbox_bounds_for_image_size(preds, (height, width))
targets_box = cxcywh2xyxy(targets_box)
crowd_target_box = cxcywh2xyxy(crowd_target_box)
if denormalize_targets:
targets_box[:, [0, 2]] *= width
targets_box[:, [1, 3]] *= height
crowd_target_box[:, [0, 2]] *= width
crowd_target_box[:, [1, 3]] *= height
if len(targets) > 0:
# shape = (n_preds x n_targets)
iou = box_iou(preds_box[preds_idx_to_use], targets_box)
# Fill IoU values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
# Filling with 0 is equivalent to ignore these values since with want IoU > iou_threshold > 0
cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != targets_cls.view(1, -1)
iou[cls_mismatch] = 0
# The matching priority is first detection confidence and then IoU value.
# The detection is already sorted by confidence in NMS, so here for each prediction we order the targets by iou.
sorted_iou, target_sorted = iou.sort(descending=True, stable=True)
# Only iterate over IoU values higher than min threshold to speed up the process
for pred_selected_i, target_sorted_i in (sorted_iou > iou_thresholds[0]).nonzero(as_tuple=False):
# pred_selected_i and target_sorted_i are relative to filters/sorting, so we extract their absolute indexes
pred_i = preds_idx_to_use[pred_selected_i]
target_i = target_sorted[pred_selected_i, target_sorted_i]
# Vector[j], True when IoU(pred_i, target_i) is above the (j)th threshold
is_iou_above_threshold = sorted_iou[pred_selected_i, target_sorted_i] > iou_thresholds
# Vector[j], True when both pred_i and target_i are not matched yet for the (j)th threshold
are_candidates_free = torch.logical_and(~preds_matched[pred_i, :], ~targets_matched[target_i, :])
# Vector[j], True when (pred_i, target_i) can be matched for the (j)th threshold
are_candidates_good = torch.logical_and(is_iou_above_threshold, are_candidates_free)
# For every threshold (j) where target_i and pred_i can be matched together ( are_candidates_good[j]==True )
# fill the matching placeholders with True
targets_matched[target_i, are_candidates_good] = True
preds_matched[pred_i, are_candidates_good] = True
# When all the targets are matched with a prediction for every IoU Threshold, stop.
if targets_matched.all():
break
# Crowd targets can be matched with many predictions.
# Therefore, for every prediction we just need to check if it has IoA large enough with any crowd target.
if len(crowd_targets) > 0:
# shape = (n_preds_to_use x n_crowd_targets)
ioa = crowd_ioa(preds_box[preds_idx_to_use], crowd_target_box)
# Fill IoA values at index (i, j) with 0 when the prediction (i) and target(j) are of different class
# Filling with 0 is equivalent to ignore these values since with want IoA > threshold > 0
cls_mismatch = preds_cls[preds_idx_to_use].view(-1, 1) != crowd_targets_cls.view(1, -1)
ioa[cls_mismatch] = 0
# For each prediction, we keep it's highest score with any crowd target (of same class)
# shape = (n_preds_to_use)
best_ioa, _ = ioa.max(1)
# If a prediction has IoA higher than threshold (with any target of same class), then there is a match
# shape = (n_preds_to_use x iou_thresholds)
is_matching_with_crowd = best_ioa.view(-1, 1) > iou_thresholds.view(1, -1)
preds_to_ignore[preds_idx_to_use] = torch.logical_or(preds_to_ignore[preds_idx_to_use], is_matching_with_crowd)
if return_on_cpu:
preds_matched = preds_matched.to("cpu")
preds_to_ignore = preds_to_ignore.to("cpu")
preds_scores = preds_scores.to("cpu")
preds_cls = preds_cls.to("cpu")
targets_cls = targets_cls.to("cpu")
return preds_matched, preds_to_ignore, preds_scores, preds_cls, targets_cls
Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2]
:param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for
boxes of a batch of images)
:return: Converted bbox in same dimensions as the original
Source code in V3_2/src/super_gradients/training/utils/detection_utils.py
86
def convert_cxcywh_bbox_to_xyxy(input_bbox: torch.Tensor):
Converts bounding box format from [cx, cy, w, h] to [x1, y1, x2, y2]
:param input_bbox: input bbox either 2-dimensional (for all boxes of a single image) or 3-dimensional (for
boxes of a batch of images)
:return: Converted bbox in same dimensions as the original
need_squeeze = False
# the input is always processed as a batch. in case it not a batch, it is unsqueezed, process and than squeeze back.
if input_bbox.dim() < 3:
need_squeeze = True
input_bbox = input_bbox.unsqueeze(0)
converted_bbox = torch.zeros_like(input_bbox) if isinstance(input_bbox, torch.Tensor) else np.zeros_like(input_bbox)
converted_bbox[:, :, 0] = input_bbox[:, :, 0] - input_bbox[:, :, 2] / 2
converted_bbox[:, :, 1] = input_bbox[:, :, 1] - input_bbox[:, :, 3] / 2
converted_bbox[:, :, 2] = input_bbox[:, :, 0] + input_bbox[:, :, 2] / 2
converted_bbox[:, :, 3] = input_bbox[:, :, 1] + input_bbox[:, :, 3] / 2
# squeeze back if needed
if need_squeeze:
converted_bbox = converted_bbox[0]
return converted_bbox
Return intersection-over-detection_area of boxes, used for crowd ground truths.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Parameters:
Description
Default
852
def crowd_ioa(det_box: torch.Tensor, crowd_box: torch.Tensor) -> torch.Tensor:
Return intersection-over-detection_area of boxes, used for crowd ground truths.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
:param det_box: Tensor of shape [N, 4]
:param crowd_box: Tensor of shape [M, 4]
:return: crowd_ioa, Tensor of shape [N, M]: the NxM matrix containing the pairwise IoA values for every element in det_box and crowd_box
det_area = compute_box_area(det_box.T)
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
inter = (torch.min(det_box[:, None, 2:], crowd_box[:, 2:]) - torch.max(det_box[:, None, :2], crowd_box[:, :2])).clamp(0).prod(2)
return inter / det_area[:, None] # crowd_ioa = inter / det_area
624
def cxcywh2xyxy(bboxes):
Transforms bboxes from centerized xy wh format to xyxy format
:param bboxes: array, shaped (nboxes, 4)
:return: modified bboxes
bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1]
bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0]
return bboxes
54
def get_cls_posx_in_target(target_format: DetectionTargetsFormat) -> int:
"""Get the label of a given target
:param target_format: Representation of the target (ex: LABEL_XYXY)
:return: Position of the class id in a bbox
ex: 0 if bbox of format label_xyxy | -1 if bbox of format xyxy_label
format_split = target_format.value.split("_")
if format_split[0] == "LABEL":
return 0
elif format_split[-1] == "LABEL":
return -1
else:
raise NotImplementedError(f"No implementation to find index of LABEL in {target_format.value}")
657
def get_mosaic_coordinate(mosaic_index, xc, yc, w, h, input_h, input_w):
Returns the mosaic coordinates of final mosaic image according to mosaic image index.
:param mosaic_index: (int) mosaic image index
:param xc: (int) center x coordinate of the entire mosaic grid.
:param yc: (int) center y coordinate of the entire mosaic grid.
:param w: (int) width of bbox
:param h: (int) height of bbox
:param input_h: (int) image input height (should be 1/2 of the final mosaic output image height).
:param input_w: (int) image input width (should be 1/2 of the final mosaic output image width).
:return: (x1, y1, x2, y2), (x1s, y1s, x2s, y2s) where (x1, y1, x2, y2) are the coordinates in the final mosaic
output image, and (x1s, y1s, x2s, y2s) are the coordinates in the placed image.
# index0 to top left part of image
if mosaic_index == 0:
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
small_coord = w - (x2 - x1), h - (y2 - y1), w, h
# index1 to top right part of image
elif mosaic_index == 1:
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
# index2 to bottom left part of image
elif mosaic_index == 2:
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
# index2 to bottom right part of image
elif mosaic_index == 3:
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
return (x1, y1, x2, y2), small_coord
1087
def get_top_k_idx_per_cls(preds_scores: torch.Tensor, preds_cls: torch.Tensor, top_k: int):
"""Get the indexes of all the top k predictions for every class
:param preds_scores: The confidence scores, vector of shape (n_pred)
:param preds_cls: The predicted class, vector of shape (n_pred)
:param top_k: Number of predictions to keep per class, ordered by confidence score
:return top_k_idx: Indexes of the top k predictions. length <= (k * n_unique_class)
n_unique_cls = torch.max(preds_cls)
mask = preds_cls.view(-1, 1) == torch.arange(n_unique_cls + 1, device=preds_scores.device).view(1, -1)
preds_scores_per_cls = preds_scores.view(-1, 1) * mask
sorted_scores_per_cls, sorting_idx = preds_scores_per_cls.sort(0, descending=True)
idx_with_satisfying_scores = sorted_scores_per_cls[:top_k, :].nonzero(as_tuple=False)
top_k_idx = sorting_idx[idx_with_satisfying_scores.split(1, dim=1)]
return top_k_idx.view(-1)
Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf
Parameters:
Description
Default
Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85] where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)
required
conf_thres
float
Threshold under which prediction are discarded
kernel
Type of kernel to use ['gaussian', 'linear']
'gaussian'
sigma
float
Sigma for the gaussian kernel
max_num_of_detections
Maximum number of boxes to output
357
def matrix_non_max_suppression(
pred, conf_thres: float = 0.1, kernel: str = "gaussian", sigma: float = 3.0, max_num_of_detections: int = 500, class_agnostic_nms: bool = False
) -> List[torch.Tensor]:
"""Performs Matrix Non-Maximum Suppression (NMS) on inference results https://arxiv.org/pdf/1912.04488.pdf
:param pred: Raw model prediction (in test mode) - a Tensor of shape [batch, num_predictions, 85]
where each item format is (x, y, w, h, object_conf, class_conf, ... 80 classes score ...)
:param conf_thres: Threshold under which prediction are discarded
:param kernel: Type of kernel to use ['gaussian', 'linear']
:param sigma: Sigma for the gaussian kernel
:param max_num_of_detections: Maximum number of boxes to output
:return: Detections list with shape (x1, y1, x2, y2, object_conf, class_conf, class)
# MULTIPLY CONF BY CLASS CONF TO GET COMBINED CONFIDENCE
class_conf, class_pred = pred[:, :, 5:].max(2)
pred[:, :, 4] *= class_conf
# BOX (CENTER X, CENTER Y, WIDTH, HEIGHT) TO (X1, Y1, X2, Y2)
pred[:, :, :4] = convert_cxcywh_bbox_to_xyxy(pred[:, :, :4])
# DETECTIONS ORDERED AS (x1y1x2y2, obj_conf, class_conf, class_pred)
pred = torch.cat((pred[:, :, :5], class_pred.unsqueeze(2)), 2)
# SORT DETECTIONS BY DECREASING CONFIDENCE SCORES
sort_ind = (-pred[:, :, 4]).argsort()
pred = torch.stack([pred[i, sort_ind[i]] for i in range(pred.shape[0])])[:, 0:max_num_of_detections]
ious = calc_bbox_iou_matrix(pred)
ious = ious.triu(1)
if not class_agnostic_nms:
# CREATE A LABELS MASK, WE WANT ONLY BOXES WITH THE SAME LABEL TO AFFECT EACH OTHER
labels = pred[:, :, 5:]
labeles_matrix = (labels == labels.transpose(2, 1)).float().triu(1)
ious *= labeles_matrix
ious_cmax, _ = ious.max(1)
ious_cmax = ious_cmax.unsqueeze(2).repeat(1, 1, max_num_of_detections)
if kernel == "gaussian":
decay_matrix = torch.exp(-1 * sigma * (ious**2))
compensate_matrix = torch.exp(-1 * sigma * (ious_cmax**2))
decay, _ = (decay_matrix / compensate_matrix).min(dim=1)
else:
decay = (1 - ious) / (1 - ious_cmax)
decay, _ = decay.min(dim=1)
pred[:, :, 4] *= decay
output = [pred[i, pred[i, :, 4] > conf_thres] for i in range(pred.shape[0])]
return output
raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
required
conf_thres
below the confidence threshold - prediction are discarded
iou_thres
IoU threshold for the nms algorithm
multi_label_per_box
controls whether to decode multiple labels per box. True - each anchor can produce multiple labels of different classes that pass confidence threshold check (default). False - each anchor can produce only one label of the class with the highest score.
with_confidence
whether to multiply objectness score with class score. usually valid for Yolo models only.
False
class_agnostic_nms
indicates how boxes of different classes will be treated during NMS True - NMS will be performed on all classes together. False - NMS will be performed on each class separately (default).
False
301
def non_max_suppression(
prediction, conf_thres=0.1, iou_thres=0.6, multi_label_per_box: bool = True, with_confidence: bool = False, class_agnostic_nms: bool = False
Performs Non-Maximum Suppression (NMS) on inference results
:param prediction: raw model prediction. Should be a list of Tensors of shape (cx, cy, w, h, confidence, cls0, cls1, ...)
:param conf_thres: below the confidence threshold - prediction are discarded
:param iou_thres: IoU threshold for the nms algorithm
:param multi_label_per_box: controls whether to decode multiple labels per box.
True - each anchor can produce multiple labels of different classes
that pass confidence threshold check (default).
False - each anchor can produce only one label of the class with the highest score.
:param with_confidence: whether to multiply objectness score with class score.
usually valid for Yolo models only.
:param class_agnostic_nms: indicates how boxes of different classes will be treated during NMS
True - NMS will be performed on all classes together.
False - NMS will be performed on each class separately (default).
:return: detections with shape nx6 (x1, y1, x2, y2, object_conf, class_conf, class)
candidates_above_thres = prediction[..., 4] > conf_thres # filter by confidence
output = [None] * prediction.shape[0]
for image_idx, pred in enumerate(prediction):
pred = pred[candidates_above_thres[image_idx]] # confident
if not pred.shape[0]: # If none remain process next image
continue
if with_confidence:
pred[:, 5:] *= pred[:, 4:5] # multiply objectness score with class score
box = convert_cxcywh_bbox_to_xyxy(pred[:, :4]) # cxcywh to xyxy
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label_per_box: # try for all good confidence classes
i, j = (pred[:, 5:] > conf_thres).nonzero(as_tuple=False).T
pred = torch.cat((box[i], pred[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = pred[:, 5:].max(1, keepdim=True)
pred = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
if not pred.shape[0]: # If none remain process next image
continue
# Apply torch batched NMS algorithm
boxes, scores, cls_idx = pred[:, :4], pred[:, 4], pred[:, 5]
if class_agnostic_nms:
idx_to_keep = torchvision.ops.boxes.nms(boxes, scores, iou_thres)
else:
idx_to_keep = torchvision.ops.boxes.batched_nms(boxes, scores, cls_idx, iou_thres)
output[image_idx] = pred[idx_to_keep]
return output
377
def undo_image_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
:param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
:return: images in a batch in cv2 format, BGR, (B, H, W, C)
im_np = im_tensor.cpu().numpy()
im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
im_np *= 255.0
return np.ascontiguousarray(im_np, dtype=np.uint8)
611
def xyxy2cxcywh(bboxes):
Transforms bboxes from xyxy format to centerized xy wh format
:param bboxes: array, shaped (nboxes, 4)
:return: modified bboxes
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
return bboxes
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
414
class DDPNotSetupException(Exception):
"""Exception raised when DDP setup is required but was not done"""
def __init__(self):
self.message = (
"Your environment was not setup correctly for DDP.\n"
"Please run at the beginning of your script:\n"
">>> from super_gradients.training.utils.distributed_training_utils import setup_device'\n"
">>> from super_gradients.common.data_types.enum import MultiGPUMode\n"
">>> setup_device(multi_gpu=MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, num_gpus=...)"
super().__init__(self.message)
The effective batch size we want to calculate the batchnorm on. For example, if we are training a model on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192 (ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus). If precise_bn_batch_size is not provided in the training_params, the latter heuristic will be taken. param num_gpus: The number of gpus we are training on
required
143
@torch.no_grad()
def compute_precise_bn_stats(model: nn.Module, loader: torch.utils.data.DataLoader, precise_bn_batch_size: int, num_gpus: int):
:param model: The model being trained (ie: Trainer.net)
:param loader: Training dataloader (ie: Trainer.train_loader)
:param precise_bn_batch_size: The effective batch size we want to calculate the batchnorm on. For example, if we are training a model
on 8 gpus, with a batch of 128 on each gpu, a good rule of thumb would be to give it 8192
(ie: effective_batch_size * num_gpus = batch_per_gpu * num_gpus * num_gpus).
If precise_bn_batch_size is not provided in the training_params, the latter heuristic
will be taken.
param num_gpus: The number of gpus we are training on
# Compute the number of minibatches to use
num_iter = int(precise_bn_batch_size / (loader.batch_size * num_gpus)) if precise_bn_batch_size else num_gpus
num_iter = min(num_iter, len(loader))
# Retrieve the BN layers
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
# Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
# Remember momentum values
momentums = [bn.momentum for bn in bns]
# Set momentum to 1.0 to compute BN stats that only reflect the current batch
for bn in bns:
bn.momentum = 1.0
# Average the BN stats for each BN layer over the batches
for inputs, _labels in itertools.islice(loader, num_iter):
model(inputs.cuda())
for i, bn in enumerate(bns):
running_means[i] += bn.running_mean / num_iter
running_vars[i] += bn.running_var / num_iter
# Sync BN stats across GPUs (no reduction if 1 GPU used)
running_means = scaled_all_reduce(running_means, num_gpus=num_gpus)
running_vars = scaled_all_reduce(running_vars, num_gpus=num_gpus)
# Set BN stats and restore original momentum values
for i, bn in enumerate(bns):
bn.running_mean = running_means[i]
bn.running_var = running_vars[i]
bn.momentum = momentums[i]
This method performs a reduce operation on multiple nodes running distributed training
It first sums all of the results and then divides the summation
Parameters:
Description
Default
42
def distributed_all_reduce_tensor_average(tensor, n):
This method performs a reduce operation on multiple nodes running distributed training
It first sums all of the results and then divides the summation
:param tensor: The tensor to perform the reduce operation for
:param n: Number of nodes
:return: Averaged tensor from all of the nodes
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
rt /= n
return rt
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
400
def get_gpu_mem_utilization():
"""GPU memory managed by the caching allocator in bytes for a given device."""
# Workaround to work on any torch version
if hasattr(torch.cuda, "memory_reserved"):
return torch.cuda.memory_reserved()
else:
return torch.cuda.memory_cached()
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
151
def get_local_rank():
Returns the local rank if running in DDP, and 0 otherwise
:return: local rank
return dist.get_rank() if dist.is_initialized() else 0
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
171
def get_world_size() -> int:
Returns the world size if running in DDP, and 1 otherwise
:return: world size
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
Initialize Distributed Data Parallel
Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
Whatever learning rate and schedule you specify will be applied to the each GPU individually.
Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
batch you specify times the number of GPUs. In the literature there are several "best practices" to set
learning rates and schedules for large batch sizes.
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
345
def initialize_ddp():
Initialize Distributed Data Parallel
Important note: (1) in distributed training it is customary to specify learning rates and batch sizes per GPU.
Whatever learning rate and schedule you specify will be applied to the each GPU individually.
Since gradients are passed and summed (reduced) from all to all GPUs, the effective batch size is the
batch you specify times the number of GPUs. In the literature there are several "best practices" to set
learning rates and schedules for large batch sizes.
if device_config.assigned_rank > 0:
mute_current_process()
logger.info("Distributed training starting...")
if not torch.distributed.is_initialized():
backend = "gloo" if os.name == "nt" else "nccl"
torch.distributed.init_process_group(backend=backend, init_method="env://")
torch.cuda.set_device(device_config.assigned_rank)
if torch.distributed.get_rank() == 0:
logger.info(f"Training in distributed mode... with {str(torch.distributed.get_world_size())} GPUs")
device_config.device = "cuda:%d" % device_config.assigned_rank
55
def reduce_results_tuple_for_ddp(validation_results_tuple, device):
"""Gather all validation tuples from the various devices and average them"""
validation_results_list = list(validation_results_tuple)
for i, validation_result in enumerate(validation_results_list):
if torch.is_tensor(validation_result):
validation_result = validation_result.clone().detach()
else:
validation_result = torch.tensor(validation_result)
validation_results_list[i] = distributed_all_reduce_tensor_average(tensor=validation_result.to(device), n=torch.distributed.get_world_size())
validation_results_tuple = tuple(validation_results_list)
return validation_results_tuple
Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).
Parameters:
Description
Default
How many gpu's you want to run the script on. If not specified, every available device will be used.
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
390
@record
def restart_script_with_ddp(num_gpus: int = None):
"""Launch the same script as the one that was launched (i.e. the command used to start the current process is re-used) but on subprocesses (i.e. with DDP).
:param num_gpus: How many gpu's you want to run the script on. If not specified, every available device will be used.
ddp_port = find_free_port()
# Get the value fom recipe if specified, otherwise take all available devices.
num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
if num_gpus > torch.cuda.device_count():
raise ValueError(f"You specified num_gpus={num_gpus} but only {torch.cuda.device_count()} GPU's are available")
logger.info(
"Launching DDP with:\n"
f" - ddp_port = {ddp_port}\n"
f" - num_gpus = {num_gpus}/{torch.cuda.device_count()} available\n"
"-------------------------------------\n"
config = LaunchConfig(
nproc_per_node=num_gpus,
min_nodes=1,
max_nodes=1,
run_id="sg_initiated",
role="default",
rdzv_endpoint=f"127.0.0.1:{ddp_port}",
rdzv_backend="static",
rdzv_configs={"rank": 0, "timeout": 900},
rdzv_timeout=-1,
max_restarts=0,
monitor_interval=5,
start_method="spawn",
log_dir=None,
redirects=Std.NONE,
tee=Std.NONE,
metrics_cfg={},
elastic_launch(config=config, entrypoint=sys.executable)(*sys.argv, *EXTRA_ARGS)
# The code below should actually never be reached as the process will be in a loop inside elastic_launch until any subprocess crashes.
sys.exit(0)
Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place.
Currently supports only the sum
reduction operator.
The reduced values are scaled by the inverse size of the
process group (equivalent to num_gpus).
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
94
def scaled_all_reduce(tensors: torch.Tensor, num_gpus: int):
Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place.
Currently supports only the sum
reduction operator.
The reduced values are scaled by the inverse size of the
process group (equivalent to num_gpus).
# There is no need for reduction in the single-proc case
if num_gpus == 1:
return tensors
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / num_gpus)
return tensors
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
257
def setup_cpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
:param multi_gpu: DDP, DP, Off or AUTO
:param num_gpus: Number of GPU's to use.
if multi_gpu not in (MultiGPUMode.OFF, MultiGPUMode.AUTO):
raise ValueError(f"device='cpu' and multi_gpu={multi_gpu} are not compatible together.")
if num_gpus not in (0, None):
raise ValueError(f"device='cpu' and num_gpus={num_gpus} are not compatible together.")
device_config.device = "cpu"
device_config.multi_gpu = MultiGPUMode.OFF
Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
device
The device you want to use ('cpu' or 'cuda') If you only set num_gpus, your device will be set up according to the following logic: - setup_device(num_gpus=0)
=> gpu_mode='OFF'
and device='cpu'
- setup_device(num_gpus=1)
=> gpu_mode='OFF'
and device='gpu'
- setup_device(num_gpus>=2)
=> gpu_mode='DDP'
and device='gpu'
- setup_device(num_gpus=-1)
=> gpu_mode='DDP'
and device='gpu'
and num_gpus=<N-AVAILABLE-GPUs>
'cuda'
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
242
@resolve_param("multi_gpu", TypeFactory(MultiGPUMode.dict()))
def setup_device(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None, device: str = "cuda"):
If required, launch ddp subprocesses.
:param multi_gpu: DDP, DP, Off or AUTO
:param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
:param device: The device you want to use ('cpu' or 'cuda')
If you only set num_gpus, your device will be set up according to the following logic:
- `setup_device(num_gpus=0)` => `gpu_mode='OFF'` and `device='cpu'`
- `setup_device(num_gpus=1)` => `gpu_mode='OFF'` and `device='gpu'`
- `setup_device(num_gpus>=2)` => `gpu_mode='DDP'` and `device='gpu'`
- `setup_device(num_gpus=-1)` => `gpu_mode='DDP'` and `device='gpu'` and `num_gpus=<N-AVAILABLE-GPUs>`
init_trainer()
# When launching with torch.distributed.launch or torchrun, multi_gpu might not be set to DDP (since we are not using the recipe params)
# To avoid any issue we force multi_gpu to be DDP if the current process is ddp subprocess. We also set num_gpus, device to run smoothly.
if not is_launched_using_sg() and is_distributed():
multi_gpu, num_gpus, device = MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, None, "cuda"
if device is None:
device = "cuda"
if device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA device is not available on your device... Moving to CPU.")
multi_gpu, num_gpus, device = MultiGPUMode.OFF, 0, "cpu"
if device == "cpu":
setup_cpu(multi_gpu, num_gpus)
elif device == "cuda":
setup_gpu(multi_gpu, num_gpus)
else:
raise ValueError(f"Only valid values for device are: 'cpu' and 'cuda'. Received: '{device}'")
Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
278
def setup_gpu(multi_gpu: MultiGPUMode = MultiGPUMode.AUTO, num_gpus: int = None):
If required, launch ddp subprocesses.
:param multi_gpu: DDP, DP, Off or AUTO
:param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
if num_gpus == 0:
raise ValueError("device='cuda' and num_gpus=0 are not compatible together.")
multi_gpu, num_gpus = _resolve_gpu_params(multi_gpu=multi_gpu, num_gpus=num_gpus)
device_config.device = "cuda"
device_config.multi_gpu = multi_gpu
if is_distributed():
initialize_ddp()
elif multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
restart_script_with_ddp(num_gpus=num_gpus)
Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
205
def setup_gpu_mode(gpu_mode: MultiGPUMode = MultiGPUMode.OFF, num_gpus: int = None):
"""[DEPRECATED in favor of setup_device] If required, launch ddp subprocesses.
:param gpu_mode: DDP, DP, Off or AUTO
:param num_gpus: Number of GPU's to use. When None, use all available devices on DDP or only one device on DP/OFF.
logger.warning("setup_gpu_mode is now deprecated in favor of setup_device")
setup_device(multi_gpu=gpu_mode, num_gpus=num_gpus)
Source code in V3_2/src/super_gradients/training/utils/distributed_training_utils.py
196
@contextmanager
def wait_for_the_master(local_rank: int):
Make all processes waiting for the master to do some task.
if local_rank > 0:
dist.barrier()
yield
if local_rank == 0:
if not dist.is_available():
return
if not dist.is_initialized():
return
else:
dist.barrier()
class EarlyStop(PhaseCallback):
Callback to monitor a metric and stop training when it stops improving.
Inspired by pytorch_lightning.callbacks.early_stopping and tf.keras.callbacks.EarlyStopping
mode_dict = {"min": torch.lt, "max": torch.gt}
supported_phases = (Phase.VALIDATION_EPOCH_END, Phase.TRAIN_EPOCH_END)
def __init__(
self,
phase: Phase,
monitor: str,
mode: str = "min",
min_delta: float = 0.0,
patience: int = 3,
check_finite: bool = True,
threshold: Optional[float] = None,
verbose: bool = False,
strict: bool = True,
:param phase: Callback phase event.
:param monitor: name of the metric to be monitored.
:param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity
monitored has stopped decreasing and in 'max' mode it will stop when the quantity
monitored has stopped increasing.
:param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no improvement.
:param patience: number of checks with no improvement after which training will be stopped.
One check happens after every phase event.
:param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
:param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min'
stops training when below threshold, For mode 'max' stops training when above threshold.
:param verbose: If `True` print logs.
:param strict: whether to crash the training if `monitor` is not found in the metrics.
super(EarlyStop, self).__init__(phase)
if phase not in self.supported_phases:
raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}")
self.phase = phase
self.monitor_key = monitor
self.min_delta = min_delta
self.patience = patience
self.mode = mode
self.check_finite = check_finite
self.threshold = threshold
self.verbose = verbose
self.strict = strict
self.wait_count = 0
if self.mode not in self.mode_dict:
raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
self.monitor_op = self.mode_dict[self.mode]
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
def _get_metric_value(self, metrics_dict):
if self.monitor_key not in metrics_dict.keys():
msg = f"Can't find EarlyStop monitor {self.monitor_key} in metrics_dict: {metrics_dict.keys()}"
exception_cls = RuntimeError if self.strict else MissingMonitorKeyException
raise exception_cls(msg)
return metrics_dict[self.monitor_key]
def _check_for_early_stop(self, current: torch.Tensor):
should_stop = False
# check if current value is Nan or inf
if self.check_finite and not torch.isfinite(current):
should_stop = True
reason = (
f"Monitored metric {self.monitor_key} = {current} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
# check if current value reached threshold value
elif self.threshold is not None and self.monitor_op(current, self.threshold):
should_stop = True
reason = "Stopping threshold reached:" f" {self.monitor_key} = {current} {self.monitor_op} {self.threshold}." " Signaling Trainer to stop."
# check if current is an improvement of monitor_key metric.
elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
should_stop = False
if torch.isfinite(self.best_score):
reason = (
f"Metric {self.monitor_key} improved by {abs(self.best_score - current):.3f} >="
f" min_delta = {abs(self.min_delta)}. New best score: {current:.3f}"
else:
reason = f"Metric {self.monitor_key} improved. New best score: {current:.3f}"
self.best_score = current
self.wait_count = 0
# no improvement in monitor_key metric, check if wait_count is bigger than patience.
else:
self.wait_count += 1
reason = f"Monitored metric {self.monitor_key} did not improve in the last {self.wait_count} records."
if self.wait_count >= self.patience:
should_stop = True
reason += f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
return reason, should_stop
def __call__(self, context: PhaseContext):
try:
current = self._get_metric_value(context.metrics_dict)
except MissingMonitorKeyException as e:
logger.warning(e)
return
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)
reason, self.should_stop = self._check_for_early_stop(current)
# log reason message, and signal early stop if should_stop=True.
if self.should_stop:
self._signal_early_stop(context, reason)
elif self.verbose:
logger.info(reason)
def _signal_early_stop(self, context: PhaseContext, reason: str):
logger.info(reason)
context.update_context(stop_training=True)
one of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing.
'min'
min_delta
float
minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta
, will count as no improvement.
patience
number of checks with no improvement after which training will be stopped. One check happens after every phase event.
check_finite
When set True
, stops training when the monitor becomes NaN or infinite.
threshold
Optional[float]
Stop training immediately once the monitored quantity reaches this threshold. For mode 'min' stops training when below threshold, For mode 'max' stops training when above threshold.
verbose
If True
print logs.
False
strict
whether to crash the training if monitor
is not found in the metrics.
patience: int = 3,
check_finite: bool = True,
threshold: Optional[float] = None,
verbose: bool = False,
strict: bool = True,
:param phase: Callback phase event.
:param monitor: name of the metric to be monitored.
:param mode: one of 'min', 'max'. In 'min' mode, training will stop when the quantity
monitored has stopped decreasing and in 'max' mode it will stop when the quantity
monitored has stopped increasing.
:param min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no improvement.
:param patience: number of checks with no improvement after which training will be stopped.
One check happens after every phase event.
:param check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
:param threshold: Stop training immediately once the monitored quantity reaches this threshold. For mode 'min'
stops training when below threshold, For mode 'max' stops training when above threshold.
:param verbose: If `True` print logs.
:param strict: whether to crash the training if `monitor` is not found in the metrics.
super(EarlyStop, self).__init__(phase)
if phase not in self.supported_phases:
raise ValueError(f"EarlyStop doesn't support phase: {phase}, " f"excepted {', '.join([str(x) for x in self.supported_phases])}")
self.phase = phase
self.monitor_key = monitor
self.min_delta = min_delta
self.patience = patience
self.mode = mode
self.check_finite = check_finite
self.threshold = threshold
self.verbose = verbose
self.strict = strict
self.wait_count = 0
if self.mode not in self.mode_dict:
raise Exception(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
self.monitor_op = self.mode_dict[self.mode]
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
151
class MissingMonitorKeyException(Exception):
Exception raised for missing monitor key in metrics_dict.
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
Source code in V3_2/src/super_gradients/training/utils/ema.py
186
class KDModelEMA(ModelEMA):
"""Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
Init the EMA
:param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
# Only work on the student (we don't want to update and to have a duplicate of the teacher)
super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)
# Overwrite current ema attribute with combination of the student model EMA (current self.ema)
# with already the instantiated teacher, to have the final KD EMA
self.ema = KDModule(
arch_params=unwrap_model(kd_model).arch_params,
student=self.ema,
teacher=unwrap_model(kd_model).teacher,
run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
KDModule
KDModule, the training Knowledge distillation model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
required
decay
float
the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
required
the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.
required
186
def __init__(self, kd_model: KDModule, decay: float, decay_function: IDecayFunction):
Init the EMA
:param kd_model: KDModule, the training Knowledge distillation model to construct the EMA model by
IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
# Only work on the student (we don't want to update and to have a duplicate of the teacher)
super().__init__(model=unwrap_model(kd_model).student, decay=decay, decay_function=decay_function)
# Overwrite current ema attribute with combination of the student model EMA (current self.ema)
# with already the instantiated teacher, to have the final KD EMA
self.ema = KDModule(
arch_params=unwrap_model(kd_model).arch_params,
student=self.ema,
teacher=unwrap_model(kd_model).teacher,
run_teacher_on_eval=unwrap_model(kd_model).run_teacher_on_eval,
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
Source code in V3_2/src/super_gradients/training/utils/ema.py
152
class ModelEMA:
"""Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
Init the EMA
:param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
# Create EMA
model = unwrap_model(model)
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.decay_function = decay_function
we hold a list of model attributes (not wights and biases) which we would like to include in each
attribute update or exclude from each update. a SgModule declare these attribute using
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
all non-private (not starting with '_') attributes will be updated (and only them).
if isinstance(model, SgModule):
self.include_attributes = model.get_include_attributes()
self.exclude_attributes = model.get_exclude_attributes()
else:
warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
self.include_attributes = []
self.exclude_attributes = []
for p in self.ema.parameters():
p.requires_grad_(False)
@classmethod
def from_params(cls, model: nn.Module, decay_type: str = None, decay: float = None, **kwargs):
if decay is None:
logger.warning(
"Parameter `decay` is not specified for EMA params. Please specify `decay` parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp\n"
" beta: 15\n"
"Will default to decay: 0.9999\n"
"In the next major release of SG this warning will become an error."
decay = 0.9999
if "exp_activation" in kwargs:
logger.warning(
"Parameter `exp_activation` is deprecated for EMA model. Please update your config to use decay_type: str (constant|exp|threshold) instead:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: exp # Equivalent to exp_activation: True\n"
" beta: 15\n"
"\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: constant # Equivalent to exp_activation: False\n"
"\n"
"In the next major release of SG this warning will become an error."
decay_type = "exp" if bool(kwargs.pop("exp_activation")) else "constant"
if decay_type is None:
logger.warning(
"Parameter decay_type is not specified for EMA model. Please specify decay_type parameter explicitly in your config:\n"
"ema: True\n"
"ema_params: \n"
" decay: 0.9999\n"
" decay_type: constant|exp|threshold\n"
"Will default to `exp` decay with beta = 15\n"
"In the next major release of SG this warning will become an error."
decay_type = "exp"
if "beta" not in kwargs:
kwargs["beta"] = 15
try:
decay_cls = EMA_DECAY_FUNCTIONS[decay_type]
except KeyError:
raise UnknownTypeException(decay_type, list(EMA_DECAY_FUNCTIONS.keys()))
decay_function = decay_cls(**kwargs)
return cls(model, decay, decay_function)
def update(self, model, step: int, total_steps: int):
Update the state of the EMA model.
:param model: Current training model
:param step: Current training step
:param total_steps: Total training steps
# Update EMA parameters
model = unwrap_model(model)
with torch.no_grad():
decay = self.decay_function(self.decay, step, total_steps)
for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
if ema_v.dtype.is_floating_point:
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())
def update_attr(self, model):
This function updates model attributes (not weight and biases) from original model to the ema model.
attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
model operation and need to be updated.
If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
attributes will be updated (and only them).
:param model: the source model
copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)
we hold a list of model attributes (not wights and biases) which we would like to include in each
attribute update or exclude from each update. a SgModule declare these attribute using
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
all non-private (not starting with '_') attributes will be updated (and only them).
nn.Module
Union[SgModule, nn.Module], the training model to construct the EMA model by IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
required
decay
float
the maximum decay value. as the training process advances, the decay will climb towards this value until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
required
the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to its final value. beta=15 is ~40% of the training process.
required
69
def __init__(self, model: nn.Module, decay: float, decay_function: IDecayFunction):
Init the EMA
:param model: Union[SgModule, nn.Module], the training model to construct the EMA model by
IMPORTANT: WHEN THE APPLICATION OF EMA ONLY ON A SUBSET OF ATTRIBUTES IS DESIRED, WRAP THE NN.MODULE
AS SgModule AND OVERWRITE get_include_attributes() AND get_exclude_attributes() AS DESIRED.
:param decay: the maximum decay value. as the training process advances, the decay will climb towards this value
until the EMA_t+1 = EMA_t * decay + TRAINING_MODEL * (1- decay)
:param beta: the exponent coefficient. The higher the beta, the sooner in the training the decay will saturate to
its final value. beta=15 is ~40% of the training process.
# Create EMA
model = unwrap_model(model)
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.decay_function = decay_function
we hold a list of model attributes (not wights and biases) which we would like to include in each
attribute update or exclude from each update. a SgModule declare these attribute using
get_include_attributes and get_exclude_attributes functions. for a nn.Module which is not a SgModule
all non-private (not starting with '_') attributes will be updated (and only them).
if isinstance(model, SgModule):
self.include_attributes = model.get_include_attributes()
self.exclude_attributes = model.get_exclude_attributes()
else:
warnings.warn("Warning: EMA should be used with SgModule instance. All attributes of the model will be " "included in EMA")
self.include_attributes = []
self.exclude_attributes = []
for p in self.ema.parameters():
p.requires_grad_(False)
141
def update(self, model, step: int, total_steps: int):
Update the state of the EMA model.
:param model: Current training model
:param step: Current training step
:param total_steps: Total training steps
# Update EMA parameters
model = unwrap_model(model)
with torch.no_grad():
decay = self.decay_function(self.decay, step, total_steps)
for ema_v, model_v in zip(self.ema.state_dict().values(), model.state_dict().values()):
if ema_v.dtype.is_floating_point:
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v.detach())
This function updates model attributes (not weight and biases) from original model to the ema model.
attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
model operation and need to be updated.
If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
attributes will be updated (and only them).
Parameters:
Description
Default
152
def update_attr(self, model):
This function updates model attributes (not weight and biases) from original model to the ema model.
attributes of the original model, such as anchors and grids (of detection models), may be crucial to the
model operation and need to be updated.
If include_attributes and exclude_attributes lists were not defined, all non-private (not starting with '_')
attributes will be updated (and only them).
:param model: the source model
copy_attr(self.ema, unwrap_model(model), self.include_attributes, self.exclude_attributes)
def __call__(self, decay: float, step: int, total_steps: int) -> float:
return decay
61
class ExpDecay(IDecayFunction):
Gradually increase EMA decay from 0.1 to the maximum value using following formula: decay * (1 - math.exp(-x * self.beta))
def __init__(self, beta: float, **kwargs):
self.beta = beta
def __call__(self, decay: float, step, total_steps: int) -> float:
x = step / total_steps
return decay * (1 - np.exp(-x * self.beta))
Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
implementation.
Source code in V3_2/src/super_gradients/training/utils/ema_decay_schedules.py
23
class IDecayFunction:
Interface for EMA decay schedule. The decay schedule is a function of the maximum decay value and training progress.
Usually it gradually increase EMA from to the maximum value. The exact ramp-up schedule is defined by the concrete
implementation.
@abstractmethod
def __call__(self, decay: float, step: int, total_steps: int) -> float:
:param decay: The maximum decay value.
:param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
:param total_steps: Total number of training steps.
:return: Computed decay value for a given step.
Current training step. The unit-range training percentage can be obtained by step / total_steps
.
required
total_steps
Total number of training steps.
required
23
@abstractmethod
def __call__(self, decay: float, step: int, total_steps: int) -> float:
:param decay: The maximum decay value.
:param step: Current training step. The unit-range training percentage can be obtained by `step / total_steps`.
:param total_steps: Total number of training steps.
:return: Computed decay value for a given step.
47
class ThresholdDecay(IDecayFunction):
Gradually increase EMA decay from 0.1 to the maximum value using following formula: min(decay, (1 + step) / (10 + step))
def __init__(self, **kwargs):
def __call__(self, decay: float, step, total_steps: int) -> float:
return np.minimum(decay, (1 + step) / (10 + step))
Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
Parameters:
Description
Default
25
def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
:param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed
:param model: the target model
:return: the number of fuses executed
children = list(model.named_children())
counter = 0
for i in range(len(children) - 1):
if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d):
setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1]))
if replace_bn_with_identity:
setattr(model, children[i + 1][0], nn.Identity())
else:
delattr(model, children[i + 1][0])
counter += 1
for child_name, child in children:
counter += fuse_conv_bn(child, replace_bn_with_identity)
return counter
126
def get_input_output_shapes(batch_size: int, input_dims: Union[list, tuple], output_dims: Union[list, tuple]):
Returns input/output shapes for single/multiple input/s output/s
if isinstance(input_dims[0], list):
input_shape = [i.size() for i in input_dims[0] if i is not None]
else:
input_shape = list(input_dims[0].size())
input_shape[0] = batch_size
if isinstance(output_dims, (list, tuple)):
output_shape = [[-1] + list(o.size())[1:] for o in output_dims if o is not None]
else:
output_shape = list(output_dims.size())
output_shape[0] = batch_size
return input_shape, output_shape
return the model summary as a string
The block(type) column represents the lines (layers) above
:param dtypes: The input types (list of inputs types)
:param high_verbosity: prints layer by layer information
Source code in V3_2/src/super_gradients/training/utils/get_model_stats.py
94
def get_model_stats(
model: nn.Module,
input_dims: Union[list, tuple],
high_verbosity: bool = True,
batch_size: int = 1,
device: str = "cuda", # noqa: C901
dtypes=None,
iterations: int = 100,
return the model summary as a string
The block(type) column represents the lines (layers) above
:param dtypes: The input types (list of inputs types)
:param high_verbosity: prints layer by layer information
dtypes = dtypes or [torch.FloatTensor] * len(input_dims)
def register_hook(module):
add a hook (all the desirable actions) for every layer that is not nn.Sequential/nn.ModuleList
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
module_idx = len(summary)
m_key = f"{class_name}-{module_idx + 1}"
summary[m_key] = OrderedDict()
# block_name refers to all layers that contains other layers
if len(module._modules) != 0:
summary[m_key]["block_name"] = class_name
summary[m_key]["inference_time"] = np.round(timer.stop(), 3)
timer.start()
summary[m_key]["gpu_occupation"] = (round(torch.cuda.memory_allocated(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]
summary[m_key]["gpu_cached_memory"] = (round(torch.cuda.memory_reserved(0) / 1024**3, 2), "GB") if torch.cuda.is_available() else [0]
summary[m_key]["input_shape"], summary[m_key]["output_shape"] = get_input_output_shapes(batch_size=batch_size, input_dims=input, output_dims=output)
params = 0
if hasattr(module, "weight") and hasattr(module.weight, "size"):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
summary[m_key]["trainable"] = module.weight.requires_grad
if hasattr(module, "bias") and hasattr(module.bias, "size"):
params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]["nb_params"] = params
if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
hooks.append(module.register_forward_hook(hook))
# multiple inputs to the network
if isinstance(input_dims, tuple):
input_dims = [input_dims]
x = [torch.rand(batch_size, *input_dim).type(dtype).to(device=device) for input_dim, dtype in zip(input_dims, dtypes)]
summary_list = []
with torch.no_grad():
for i in range(iterations + 10):
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
timer = Timer(device=device)
timer.start()
# make a forward pass
model(*x)
# remove these hooks
for h in hooks:
h.remove()
# we start counting from the 10th iteration for warmup
if i >= 10:
summary_list.append(summary)
summary = _average_inference_time(summary_list=summary_list, summary=summary, divisor=iterations)
return _convert_summary_dict_to_string(summary=summary, high_verbosity=high_verbosity, input_dims=input_dims, batch_size=batch_size, device=device)
151
def check_image_typing(image: ImageSource) -> bool:
"""Check if the given object respects typing of image.
:param image: Image to check.
:return: True if the object is an image, False otherwise.
if isinstance(image, get_args(SingleImageSource)):
return True
elif isinstance(image, list):
return all([isinstance(image_item, get_args(SingleImageSource)) for image_item in image])
else:
return False
Generator that loads images one at a time.
Supported types include:
- str: A string representing either an image or an URL.
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- List: A list of images of any of the above types.
Parameters:
Description
Default
Union[List[ImageSource], ImageSource]
Single image or a list of images of supported types.
required
57
def generate_image_loader(images: Union[List[ImageSource], ImageSource]) -> Iterable[np.ndarray]:
"""Generator that loads images one at a time.
Supported types include:
- str: A string representing either an image or an URL.
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- List: A list of images of any of the above types.
:param images: Single image or a list of images of supported types.
:return: Generator of images as numpy arrays. If loaded from string, the image will be returned as RGB.
if isinstance(images, str) and os.path.isdir(images):
images_paths = list_images_in_folder(images)
for image_path in images_paths:
yield load_image(image=image_path)
elif isinstance(images, (list, Iterator)):
for image in images:
yield load_image(image=image)
else:
yield load_image(image=images)
160
def is_image(filename: str) -> bool:
"""Check if the given file name refers to image.
:param filename: The filename to check.
:return: True if the file is an image, False otherwise.
return filename.split(".")[-1].lower() in IMG_EXTENSIONS
127
def is_url(url: str) -> bool:
"""Check if the given string is a URL.
:param url: String to check.
try:
result = urlparse(url)
return all([result.scheme, result.netloc, result.path])
except Exception:
return False
67
def list_images_in_folder(directory: str) -> List[str]:
"""List all the images in a directory.
:param directory: The path to the directory containing the images.
:return: A list of image file names.
files = os.listdir(directory)
images_paths = [os.path.join(directory, f) for f in files if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".gif"))]
return images_paths
Load a single image and return it as a numpy arrays.
Supported image types include:
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- str: A string representing either a local file path or a URL to an image
Parameters:
Description
Default
92
def load_image(image: ImageSource) -> np.ndarray:
"""Load a single image and return it as a numpy arrays.
Supported image types include:
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- str: A string representing either a local file path or a URL to an image
:param image: Single image of supported types.
:return: Image as numpy arrays. If loaded from string, the image will be returned as RGB.
if isinstance(image, np.ndarray):
return image
elif isinstance(image, torch.Tensor):
return image.numpy()
elif isinstance(image, PIL.Image.Image):
return load_np_image_from_pil(image)
elif isinstance(image, str):
image = load_pil_image_from_str(image_str=image)
return load_np_image_from_pil(image)
else:
raise ValueError(f"Unsupported image type: {type(image)}")
Load a single image or a list of images and return them as a list of numpy arrays.
Supported types include:
- str: A string representing either an image or an URL.
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- List: A list of images of any of the above types.
Parameters:
Description
Default
Union[List[ImageSource], ImageSource]
Single image or a list of images of supported types.
required
33
def load_images(images: Union[List[ImageSource], ImageSource]) -> List[np.ndarray]:
"""Load a single image or a list of images and return them as a list of numpy arrays.
Supported types include:
- str: A string representing either an image or an URL.
- numpy.ndarray: A numpy array representing the image
- torch.Tensor: A PyTorch tensor representing the image
- PIL.Image.Image: A PIL Image object
- List: A list of images of any of the above types.
:param images: Single image or a list of images of supported types.
:return: List of images as numpy arrays. If loaded from string, the image will be returned as RGB.
return [image for image in generate_image_loader(images=images)]
97
def load_np_image_from_pil(image: PIL.Image.Image) -> np.ndarray:
"""Convert a PIL image to numpy array in RGB format."""
return np.asarray(image.convert("RGB"))
108
def load_pil_image_from_str(image_str: str) -> PIL.Image.Image:
"""Load an image based on a string (local file path or URL)."""
if is_url(image_str):
response = requests.get(image_str, stream=True)
response.raise_for_status()
return PIL.Image.open(io.BytesIO(response.content))
else:
return PIL.Image.open(image_str)
116
def save_image(image: np.ndarray, path: str) -> None:
"""Save a numpy array as an image.
:param image: Image to save, (H, W, C), RGB.
:param path: Path to save the image to.
Image.fromarray(image).save(path)
138
def show_image(image: np.ndarray) -> None:
"""Show an image using matplotlib.
:param image: Image to show in (H, W, C), RGB.
plt.figure(figsize=(image.shape[1] / 100.0, image.shape[0] / 100.0), dpi=100)
plt.imshow(image, interpolation="nearest")
plt.axis("off")
plt.tight_layout()
plt.show()
"""Class for calculating the FPS of a video stream."""
def __init__(self, update_frequency: Optional[float] = None):
"""Create a new FPSCounter object.
:param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
If None, the counter is updated every frame.
self._update_frequency = update_frequency
self._start_time = time.time()
self._frame_count = 0
self._fps = 0.0
def _update_fps(self, elapsed_time, current_time) -> None:
"""Compute new value of FPS and reset the counter."""
self._fps = self._frame_count / elapsed_time
self._start_time = current_time
self._frame_count = 0
@property
def fps(self) -> float:
"""Current FPS value."""
self._frame_count += 1
current_time, elapsed_time = time.time(), time.time() - self._start_time
if self._update_frequency is None or elapsed_time > self._update_frequency:
self._update_fps(elapsed_time=elapsed_time, current_time=current_time)
return self._fps
Optional[float]
Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.
99
def __init__(self, update_frequency: Optional[float] = None):
"""Create a new FPSCounter object.
:param update_frequency: Minimum time (in seconds) between updates to the FPS counter.
If None, the counter is updated every frame.
self._update_frequency = update_frequency
self._start_time = time.time()
self._frame_count = 0
self._fps = 0.0
Optional[Callable[[np.ndarray], np.ndarray]]
Function to apply to each frame before displaying it. If None, frames are displayed as is.
capture
ID of the video capture device to use. Default is cv2.CAP_ANY (which selects the first available device).
cv2.CAP_ANY
fps_update_frequency
Optional[float]
Minimum time (in seconds) between updates to the FPS counter. If None, the counter is updated every frame.
70
class WebcamStreaming:
"""Stream video from a webcam. Press 'q' to quit the streaming.
:param window_name: Name of the window to display the video stream.
:param frame_processing_fn: Function to apply to each frame before displaying it.
If None, frames are displayed as is.
:param capture: ID of the video capture device to use.
Default is cv2.CAP_ANY (which selects the first available device).
:param fps_update_frequency: Minimum time (in seconds) between updates to the FPS counter.
If None, the counter is updated every frame.
def __init__(
self,
window_name: str = "",
frame_processing_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
capture: int = cv2.CAP_ANY,
fps_update_frequency: Optional[float] = None,
self.window_name = window_name
self.frame_processing_fn = frame_processing_fn
self.cap = cv2.VideoCapture(capture)
if not self.cap.isOpened():
raise ValueError("Could not open video capture device")
self._fps_counter = FPSCounter(update_frequency=fps_update_frequency)
def run(self) -> None:
"""Start streaming video from the webcam and displaying it in a window.
Press 'q' to quit the streaming.
while not self._stop():
self._display_single_frame()
def _display_single_frame(self) -> None:
"""Read a single frame from the video capture device, apply any specified frame processing,
and display the resulting frame in the window.
Also updates the FPS counter and displays it in the frame.
_ret, frame = self.cap.read()
if self.frame_processing_fn:
frame = self.frame_processing_fn(frame)
_write_fps_to_frame(frame, self.fps)
cv2.imshow(self.window_name, frame)
def _stop(self) -> bool:
"""Stopping condition for the streaming."""
return cv2.waitKey(1) & 0xFF == ord("q")
@property
def fps(self) -> float:
return self._fps_counter.fps
def __del__(self):
"""Release the video capture device and close the window."""
self.cap.release()
cv2.destroyAllWindows()
70
def __del__(self):
"""Release the video capture device and close the window."""
self.cap.release()
cv2.destroyAllWindows()
Start streaming video from the webcam and displaying it in a window.
Press 'q' to quit the streaming.
Source code in V3_2/src/super_gradients/training/utils/media/stream.py
43
def run(self) -> None:
"""Start streaming video from the webcam and displaying it in a window.
Press 'q' to quit the streaming.
while not self._stop():
self._display_single_frame()
188
def includes_video_extension(file_path: str) -> bool:
"""Check if a file includes a video extension.
:param file_path: Path to the video file.
:return: True if the file includes a video extension.
return isinstance(file_path, str) and file_path.lower().endswith(VIDEO_EXTENSIONS)
30
def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
"""Open a video file and extract each frame into numpy array.
:param file_path: Path to the video file.
:param max_frames: Optional, maximum number of frames to extract.
:return:
- Frames representing the video, each in (H, W, C), RGB.
- Frames per Second (FPS).
cap = _open_video(file_path)
frames = _extract_frames(cap, max_frames)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
return frames, fps
List[np.ndarray]
Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
required
Frames per second
required
91
def save_gif(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .gif format.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
frames_pil = [PIL.Image.fromarray(frame) for frame in frames]
frames_pil[0].save(output_path, save_all=True, append_images=frames_pil[1:], duration=int(1000 / fps), loop=0)
List[np.ndarray]
Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
required
Frames per second
required
113
def save_mp4(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally in .mp4 format.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
video_height, video_width = _validate_frames(frames)
video_writer = cv2.VideoWriter(
output_path,
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(video_width, video_height),
for frame in frames:
video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
video_writer.release()
Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
Parameters:
Description
Default
List[np.ndarray]
Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
required
Frames per second
required
78
def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
"""Save a video locally. Depending on the extension, the video will be saved as a .mp4 file or as a .gif file.
:param output_path: Where the video will be saved
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
if not includes_video_extension(output_path):
logger.info(f'Output path "{output_path}" does not have a video extension, and therefore will be saved as {output_path}.mp4')
output_path += ".mp4"
if check_is_gif(output_path):
save_gif(output_path, frames, fps)
else:
save_mp4(output_path, frames, fps)
165
def show_video_from_disk(video_path: str, window_name: str = "Prediction"):
"""Display a video from disk using OpenCV.
:param video_path: Path to the video file.
:param window_name: Name of the window to display the video
cap = _open_video(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
while cap.isOpened():
ret, frame = cap.read()
if ret:
# Display the frame
cv2.imshow(window_name, frame)
# Wait for the specified number of milliseconds before displaying the next frame
if cv2.waitKey(int(1000 / fps)) & 0xFF == ord("q"):
break
else:
break
# Release the VideoCapture object and destroy the window
cap.release()
cv2.destroyAllWindows()
cv2.waitKey(1)
List[np.ndarray]
Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
required
float
Frames per second
required
window_name
Name of the window to display the video
'Prediction'
180
def show_video_from_frames(frames: List[np.ndarray], fps: float, window_name: str = "Prediction") -> None:
"""Display a video from a list of frames using OpenCV.
:param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
:param fps: Frames per second
:param window_name: Name of the window to display the video
for frame in frames:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
cv2.imshow(window_name, frame)
cv2.waitKey(int(1000 / fps))
cv2.destroyAllWindows()
cv2.waitKey(1)
Wrapper function for initializing the optimizer
:param net: the nn_module to build the optimizer for
:param lr: initial learning rate
:param training_params: training_parameters
Source code in V3_2/src/super_gradients/training/utils/optimizer_utils.py
122
def build_optimizer(net: nn.Module, lr: float, training_params) -> optim.Optimizer:
Wrapper function for initializing the optimizer
:param net: the nn_module to build the optimizer for
:param lr: initial learning rate
:param training_params: training_parameters
if is_model_wrapped(net):
raise ValueError("Argument net for build_optimizer must be an unwrapped model. " "Please use build_optimizer(unwrap_model(net), ...).")
if isinstance(training_params.optimizer, str):
optimizer_cls = OptimizersTypeFactory().get(training_params.optimizer)
else:
optimizer_cls = training_params.optimizer
optimizer_params = OPTIMIZERS_DEFAULT_PARAMS[optimizer_cls].copy() if optimizer_cls in OPTIMIZERS_DEFAULT_PARAMS.keys() else dict()
optimizer_params.update(**training_params.optimizer_params)
training_params.optimizer_params = optimizer_params
weight_decay = get_param(training_params.optimizer_params, "weight_decay", 0.0)
# OPTIMIZER PARAM GROUPS ARE SET USING DEFAULT OR MODEL SPECIFIC INIT
if hasattr(net, "initialize_param_groups"):
# INITIALIZE_PARAM_GROUPS MUST RETURN A LIST OF DICTS WITH 'named_params' AND OPTIMIZER's ATTRIBUTES PER GROUP
net_named_params = net.initialize_param_groups(lr, training_params)
else:
net_named_params = [{"named_params": net.named_parameters()}]
if training_params.zero_weight_decay_on_bias_and_bn:
optimizer_training_params = separate_zero_wd_params_groups_for_optimizer(net, net_named_params, weight_decay)
else:
# Overwrite groups to include params instead of named params
for ind_group, param_group in enumerate(net_named_params):
param_group["params"] = [param[1] for param in list(param_group["named_params"])]
del param_group["named_params"]
net_named_params[ind_group] = param_group
optimizer_training_params = net_named_params
# CREATE AN OPTIMIZER OBJECT AND INITIALIZE IT
optimizer = optimizer_cls(optimizer_training_params, lr=lr, **training_params.optimizer_params)
return optimizer
separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
required by torch Optimizer classes.
bias + BN with weight decay=0 and the rest with the given weight decay
:param module: train net module.
:param net_named_params: list of params groups, output of SgModule.initialize_param_groups
:param weight_decay: value to set for the non BN and bias parameters
Source code in V3_2/src/super_gradients/training/utils/optimizer_utils.py
54
def separate_zero_wd_params_groups_for_optimizer(module: nn.Module, net_named_params, weight_decay: float):
separate param groups for batchnorm and biases and others with weight decay. return list of param groups in format
required by torch Optimizer classes.
bias + BN with weight decay=0 and the rest with the given weight decay
:param module: train net module.
:param net_named_params: list of params groups, output of SgModule.initialize_param_groups
:param weight_decay: value to set for the non BN and bias parameters
# FIXME - replace usage of ids addresses to find batchnorm and biases params.
# This solution iterate 2 times over module parameters, find a way to iterate only one time.
no_decay_ids = _get_no_decay_param_ids(module)
# split param groups for optimizer
optimizer_param_groups = []
for param_group in net_named_params:
no_decay_params = []
decay_params = []
for name, param in param_group["named_params"]:
if id(param) in no_decay_ids:
no_decay_params.append(param)
else:
decay_params.append(param)
# append two param groups from the original param group, with and without weight decay.
extra_optim_params = {key: param_group[key] for key in param_group if key not in ["named_params", "weight_decay"]}
optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
return optimizer_param_groups
This implementation is taken from timm's github:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
This optimizer code was adapted from the following (starting with latest)
* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
* https://github.com/cybertronai/pytorch-lamb
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
Original copyrights for above sources are below.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes
_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
Source code in V3_2/src/super_gradients/training/utils/optimizers/lamb.py
216
@register_optimizer(Optimizers.LAMB)
class Lamb(Optimizer):
"""Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-3,
bias_correction: bool = True,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.01,
grad_averaging: bool = True,
max_grad_norm: float = 1.0,
trust_clip: bool = False,
always_adapt: bool = False,
defaults = dict(
lr=lr,
bias_correction=bias_correction,
betas=betas,
eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
max_grad_norm=max_grad_norm,
trust_clip=trust_clip,
always_adapt=always_adapt,
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)
for group in self.param_groups:
bias_correction = 1 if group["bias_correction"] else 0
beta1, beta2 = group["betas"]
grad_averaging = 1 if group["grad_averaging"] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if "step" in group:
group["step"] += 1
else:
group["step"] = 1
if bias_correction:
bias_correction1 = 1 - beta1 ** group["step"]
bias_correction2 = 1 - beta2 ** group["step"]
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group["weight_decay"]
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group["always_adapt"]:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
# excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
if group["trust_clip"]:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group["lr"])
return loss
Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Source code in V3_2/src/super_gradients/training/utils/optimizers/lamb.py
216
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
device = self.param_groups[0]["params"][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
global_grad_norm = torch.zeros(1, device=device)
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Lamb does not support sparse gradients, consider SparseAdam instad.")
global_grad_norm.add_(grad.pow(2).sum())
global_grad_norm = torch.sqrt(global_grad_norm)
# FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
# scalar types properly https://github.com/pytorch/pytorch/issues/9190
max_grad_norm = torch.tensor(self.defaults["max_grad_norm"], device=device)
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, global_grad_norm / max_grad_norm, one_tensor)
for group in self.param_groups:
bias_correction = 1 if group["bias_correction"] else 0
beta1, beta2 = group["betas"]
grad_averaging = 1 if group["grad_averaging"] else 0
beta3 = 1 - beta1 if grad_averaging else 1.0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if "step" in group:
group["step"] += 1
else:
group["step"] = 1
if bias_correction:
bias_correction1 = 1 - beta1 ** group["step"]
bias_correction2 = 1 - beta2 ** group["step"]
else:
bias_correction1, bias_correction2 = 1.0, 1.0
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.div_(clip_global_grad_norm)
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient valuesa
state["exp_avg"] = torch.zeros_like(p)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group["eps"])
update = (exp_avg / bias_correction1).div_(denom)
weight_decay = group["weight_decay"]
if weight_decay != 0:
update.add_(p, alpha=weight_decay)
if weight_decay != 0 or group["always_adapt"]:
# Layer-wise LR adaptation. By default, skip adaptation on parameters that are
# excluded from weight decay, unless always_adapt == True, then always enabled.
w_norm = p.norm(2.0)
g_norm = update.norm(2.0)
# FIXME nested where required since logical and/or not working in PT XLA
trust_ratio = torch.where(
w_norm > 0,
torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
one_tensor,
if group["trust_clip"]:
# LAMBC trust clipping, upper bound fixed at one
trust_ratio = torch.minimum(trust_ratio, one_tensor)
update.mul_(trust_ratio)
p.add_(update, alpha=-group["lr"])
return loss
81
@register_optimizer(Optimizers.LION)
class Lion(Optimizer):
r"""Implements Lion algorithm.
Generaly, it is recommended to divide lr used by AdamW by 10 and multiply the weight decay by 10.
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
Initialize the hyperparameters.
:param params: Iterable of parameters to optimize or dicts defining parameter groups
:param lr: Learning rate (default: 1e-4)
:param betas: Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
:param weight_decay: Weight decay coefficient (default: 0)
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
Perform a single optimization step.
:param closure: A closure that reevaluates the model and returns the loss.
:return: Loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform stepweight decay
p.data.mul_(1 - group["lr"] * group["weight_decay"])
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
exp_avg = state["exp_avg"]
beta1, beta2 = group["betas"]
# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
p.add_(torch.sign(update), alpha=-group["lr"])
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
return loss
Union[Iterable[torch.Tensor], Iterable[dict]]
Iterable of parameters to optimize or dicts defining parameter groups
required
float
Learning rate (default: 1e-4)
0.0001
betas
Tuple[float, float]
Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
(0.9, 0.99)
weight_decay
float
Weight decay coefficient (default: 0)
42
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
Initialize the hyperparameters.
:param params: Iterable of parameters to optimize or dicts defining parameter groups
:param lr: Learning rate (default: 1e-4)
:param betas: Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
:param weight_decay: Weight decay coefficient (default: 0)
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
81
@torch.no_grad()
def step(self, closure: Optional[callable] = None) -> torch.Tensor:
Perform a single optimization step.
:param closure: A closure that reevaluates the model and returns the loss.
:return: Loss.
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# Perform stepweight decay
p.data.mul_(1 - group["lr"] * group["weight_decay"])
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p)
exp_avg = state["exp_avg"]
beta1, beta2 = group["betas"]
# Weight update
update = exp_avg * beta1 + grad * (1 - beta1)
p.add_(torch.sign(update), alpha=-group["lr"])
# Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
return loss
Implements RMSprop algorithm (TensorFlow style epsilon)
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
and a few other modifications to closer match Tensorflow for matching hyper-params.
Noteworthy changes include:
1. Epsilon applied inside square-root
2. square_avg initialized to ones
3. LR scaling of update accumulated in momentum buffer
Proposed by G. Hinton in his
course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>
.
The centered version first appears in Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>
.
Source code in V3_2/src/super_gradients/training/utils/optimizers/rmsprop_tf.py
153
@register_optimizer(Optimizers.RMS_PROP_TF)
class RMSpropTF(Optimizer):
"""Implements RMSprop algorithm (TensorFlow style epsilon)
NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
and a few other modifications to closer match Tensorflow for matching hyper-params.
Noteworthy changes include:
1. Epsilon applied inside square-root
2. square_avg initialized to ones
3. LR scaling of update accumulated in momentum buffer
Proposed by G. Hinton in his
`course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_.
The centered version first appears in `Generating Sequences
With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_."""
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-2,
alpha: float = 0.9,
eps: float = 1e-10,
weight_decay: float = 0,
momentum: float = 0.0,
centered: bool = False,
decoupled_decay: bool = False,
lr_in_momentum: bool = True,
"""RMSprop optimizer that follows the tf's RMSprop characteristics
:param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
:param lr (float, optional): learning rate
:param momentum (float, optional): momentum factor
:param alpha (float, optional): smoothing (decay) constant
:param eps (float, optional): term added to the denominator to improve numerical stability
:param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
estimation of its variance
:param weight_decay (float, optional): weight decay (L2 penalty)
:param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
:param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
defaults in Tensorflow
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(
lr=lr,
momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum,
super(RMSpropTF, self).__init__(params, defaults)
def __setstate__(self, state):
super(RMSpropTF, self).__setstate__(state)
for group in self.param_groups:
group.setdefault("momentum", 0)
group.setdefault("centered", False)
def step(self, closure: Optional[callable] = None) -> torch.Tensor: # noqa: C901
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("RMSprop does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["square_avg"] = torch.ones_like(p.data) # PyTorch inits to zero
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(p.data)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(p.data)
square_avg = state["square_avg"]
one_minus_alpha = 1.0 - group["alpha"]
state["step"] += 1
if group["weight_decay"] != 0:
if "decoupled_decay" in group and group["decoupled_decay"]:
p.data.add_(-group["weight_decay"], p.data)
else:
grad = grad.add(group["weight_decay"], p.data)
# Tensorflow order of ops for updating squared avg
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
if group["centered"]:
grad_avg = state["grad_avg"]
grad_avg.add_(one_minus_alpha, grad - grad_avg)
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_() # eps moved in sqrt
else:
avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
if group["momentum"] > 0:
buf = state["momentum_buffer"]
# Tensorflow accumulates the LR scaling in the momentum buffer
if "lr_in_momentum" in group and group["lr_in_momentum"]:
buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
p.data.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
p.data.add_(-group["lr"], buf)
else:
p.data.addcdiv_(-group["lr"], grad, avg)
return loss
81
def __init__(
self,
params: Union[Iterable[torch.Tensor], Iterable[dict]],
lr: float = 1e-2,
alpha: float = 0.9,
eps: float = 1e-10,
weight_decay: float = 0,
momentum: float = 0.0,
centered: bool = False,
decoupled_decay: bool = False,
lr_in_momentum: bool = True,
"""RMSprop optimizer that follows the tf's RMSprop characteristics
:param params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
:param lr (float, optional): learning rate
:param momentum (float, optional): momentum factor
:param alpha (float, optional): smoothing (decay) constant
:param eps (float, optional): term added to the denominator to improve numerical stability
:param centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an
estimation of its variance
:param weight_decay (float, optional): weight decay (L2 penalty)
:param decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
:param lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer update as per
defaults in Tensorflow
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))
defaults = dict(
lr=lr,
momentum=momentum,
alpha=alpha,
eps=eps,
centered=centered,
weight_decay=weight_decay,
decoupled_decay=decoupled_decay,
lr_in_momentum=lr_in_momentum,
super(RMSpropTF, self).__init__(params, defaults)
Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Source code in V3_2/src/super_gradients/training/utils/optimizers/rmsprop_tf.py
153
def step(self, closure: Optional[callable] = None) -> torch.Tensor: # noqa: C901
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError("RMSprop does not support sparse gradients")
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["square_avg"] = torch.ones_like(p.data) # PyTorch inits to zero
if group["momentum"] > 0:
state["momentum_buffer"] = torch.zeros_like(p.data)
if group["centered"]:
state["grad_avg"] = torch.zeros_like(p.data)
square_avg = state["square_avg"]
one_minus_alpha = 1.0 - group["alpha"]
state["step"] += 1
if group["weight_decay"] != 0:
if "decoupled_decay" in group and group["decoupled_decay"]:
p.data.add_(-group["weight_decay"], p.data)
else:
grad = grad.add(group["weight_decay"], p.data)
# Tensorflow order of ops for updating squared avg
square_avg.add_(one_minus_alpha, grad.pow(2) - square_avg)
# square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad) # PyTorch original
if group["centered"]:
grad_avg = state["grad_avg"]
grad_avg.add_(one_minus_alpha, grad - grad_avg)
# grad_avg.mul_(alpha).add_(1 - alpha, grad) # PyTorch original
avg = square_avg.addcmul(-1, grad_avg, grad_avg).add(group["eps"]).sqrt_() # eps moved in sqrt
else:
avg = square_avg.add(group["eps"]).sqrt_() # eps moved in sqrt
if group["momentum"] > 0:
buf = state["momentum_buffer"]
# Tensorflow accumulates the LR scaling in the momentum buffer
if "lr_in_momentum" in group and group["lr_in_momentum"]:
buf.mul_(group["momentum"]).addcdiv_(group["lr"], grad, avg)
p.data.add_(-buf)
else:
# PyTorch scales the param update by LR
buf.mul_(group["momentum"]).addcdiv_(grad, avg)
p.data.add_(-group["lr"], buf)
else:
p.data.addcdiv_(-group["lr"], grad, avg)
return loss
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
345
class DEKRPoseEstimationDecodeCallback(nn.Module):
Class that implements decoding logic of DEKR's model predictions into poses.
def __init__(
self,
output_stride: int,
max_num_people: int,
keypoint_threshold: float,
nms_threshold: float,
nms_num_threshold: int,
apply_sigmoid: bool,
min_confidence: float = 0.0,
:param output_stride: Output stride of the model
:param int max_num_people: Maximum number of decoded poses
:param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
:param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
Given in terms of a percentage of a square root of the area of the pose bounding box.
:param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
:param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
bound to [0..1] range and trained with logits (E.g focal loss)
:param float min_confidence: Minimum confidence threshold for pose
super().__init__()
self.keypoint_threshold = keypoint_threshold
self.max_num_people = max_num_people
self.output_stride = output_stride
self.nms_threshold = nms_threshold
self.nms_num_threshold = nms_num_threshold
self.apply_sigmoid = apply_sigmoid
self.min_confidence = min_confidence
@torch.no_grad()
def forward(self, predictions: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
:param predictions: Tuple (heatmap, offset):
heatmap - [BatchSize, NumJoints+1,H,W]
offset - [BatchSize, NumJoints*2,H,W]
:return: Tuple
all_poses = []
all_scores = []
heatmap, offset = predictions
batch_size = len(heatmap)
for i in range(batch_size):
poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
all_poses.append(poses)
all_scores.append(scores)
return all_poses, all_scores
def decode_one_sized_batch(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
heatmap, offset = predictions
posemap = _offset_to_pose(offset) # [1, 2 * num_joints, H, W]
if heatmap.size(0) != 1:
raise RuntimeError("Batch size of 1 is required")
if self.apply_sigmoid:
heatmap = heatmap.sigmoid()
heatmap_sum, poses_sum = aggregate_results(
heatmap,
posemap,
pose_center_score_threshold=self.keypoint_threshold,
max_num_people=self.max_num_people,
output_stride=self.output_stride,
poses, scores = pose_nms(
heatmap_sum,
poses_sum,
max_num_people=self.max_num_people,
nms_threshold=self.nms_threshold,
nms_num_threshold=self.nms_num_threshold,
pose_score_threshold=self.min_confidence,
if len(poses) != len(scores):
raise RuntimeError("Decoding error detected. Returned mismatching number of poses/scores")
return poses, scores
float
(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
required
nms_threshold
float
The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.
required
nms_num_threshold
Number of joints that must pass the NMS check for the pose to be considered as a valid one.
required
apply_sigmoid
If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)
required
min_confidence
float
Minimum confidence threshold for pose
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
292
def __init__(
self,
output_stride: int,
max_num_people: int,
keypoint_threshold: float,
nms_threshold: float,
nms_num_threshold: int,
apply_sigmoid: bool,
min_confidence: float = 0.0,
:param output_stride: Output stride of the model
:param int max_num_people: Maximum number of decoded poses
:param float keypoint_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
:param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
Given in terms of a percentage of a square root of the area of the pose bounding box.
:param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
:param bool apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
bound to [0..1] range and trained with logits (E.g focal loss)
:param float min_confidence: Minimum confidence threshold for pose
super().__init__()
self.keypoint_threshold = keypoint_threshold
self.max_num_people = max_num_people
self.output_stride = output_stride
self.nms_threshold = nms_threshold
self.nms_num_threshold = nms_num_threshold
self.apply_sigmoid = apply_sigmoid
self.min_confidence = min_confidence
Union[Tensor, Tuple[Tensor, Tensor]]
Tuple (heatmap, offset): heatmap - [BatchSize, NumJoints+1,H,W] offset - [BatchSize, NumJoints*2,H,W]
required
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
313
@torch.no_grad()
def forward(self, predictions: Union[Tensor, Tuple[Tensor, Tensor]]) -> Tuple[List[np.ndarray], List[np.ndarray]]:
:param predictions: Tuple (heatmap, offset):
heatmap - [BatchSize, NumJoints+1,H,W]
offset - [BatchSize, NumJoints*2,H,W]
:return: Tuple
all_poses = []
all_scores = []
heatmap, offset = predictions
batch_size = len(heatmap)
for i in range(batch_size):
poses, scores = self.decode_one_sized_batch(predictions=(heatmap[i : i + 1], offset[i : i + 1]))
all_poses.append(poses)
all_scores.append(scores)
return all_poses, all_scores
Get initial pose proposals and aggregate the results of all scale.
Not this implementation works only for batch size of 1.
Parameters:
Description
Default
float
(float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
required
max_num_people
(int)
required
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
255
def aggregate_results(
heatmap: Tensor, posemap: Tensor, output_stride: int, pose_center_score_threshold: float, max_num_people: int
) -> Tuple[Tensor, List[Tensor]]:
Get initial pose proposals and aggregate the results of all scale.
Not this implementation works only for batch size of 1.
:param heatmap: Heatmap at this scale (B, 1+num_joints, w, h)
:param posemap: Posemap at this scale (B, 2*num_joints, w, h)
:param output_stride: Ratio of input size / predictions size
:param pose_center_score_threshold: (float) A minimum score of a pose center keypoint for pose to be considered as a potential candidate
:param max_num_people: (int)
:return:
- heatmap_sum: Sum of the heatmaps (1, 1+num_joints, w, h)
- poses (List): Gather of the pose proposals [B, (num_people, num_joints, 3)]
poses = []
h, w = heatmap[0].size(-1), heatmap[0].size(-2)
heatmap_sum = _up_interpolate(heatmap, size=(int(output_stride * w), int(output_stride * h)))
center_heatmap = heatmap[0, -1:]
pose_ind, ctr_score = _get_maximum_from_heatmap(center_heatmap, pose_center_score_threshold=pose_center_score_threshold, max_num_people=max_num_people)
posemap = posemap[0].permute(1, 2, 0).view(h * w, -1, 2)
pose = output_stride * posemap[pose_ind]
ctr_score = ctr_score[:, None].expand(-1, pose.shape[-2])[:, :, None]
poses.append(torch.cat([pose, ctr_score], dim=2))
return heatmap_sum, poses
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
24
def get_locations(output_h: int, output_w: int, device):
Generate location map (each pixel contains its own XY coordinate)
:param output_h: Feature map height (rows)
:param output_w: Feature map width (cols)
:param device: Target device to put tensor on
:return: [H * W, 2]
shifts_x = torch.arange(0, output_w, step=1, dtype=torch.float32, device=device)
shifts_y = torch.arange(0, output_h, step=1, dtype=torch.float32, device=device)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1)
return locations
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
41
def get_reg_poses(offset: Tensor, num_joints: int):
Decode offset predictions into absolute locations.
:param offset: Tensor of [num_joints*2,H,W] shape with offset predictions for each joint
:param num_joints: Number of joints
:return: [H * W, num_joints, 2]
_, h, w = offset.shape
offset = offset.permute(1, 2, 0).reshape(h * w, num_joints, 2)
locations = get_locations(h, w, offset.device)
locations = locations[:, None, :].expand(-1, num_joints, -1)
poses = locations - offset
return poses
float
The maximum distance between two joints for them to be considered as belonging to the same pose. Given in terms of a percentage of a square root of the area of the pose bounding box.
required
nms_num_threshold
Number of joints that must pass the NMS check for the pose to be considered as a valid one.
required
pose_score_threshold
float
Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.
required
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_decode_callbacks.py
222
def pose_nms(
heatmap_avg, poses, max_num_people: int, nms_threshold: float, nms_num_threshold: int, pose_score_threshold: float
) -> Tuple[np.ndarray, np.ndarray]:
NMS for the regressed poses results.
:param Tensor heatmap_avg: Avg of the heatmaps at all scales (1, 1+num_joints, w, h)
:param List poses: Gather of the pose proposals [(num_people, num_joints, 3)]
:param int max_num_people: Maximum number of decoded poses
:param float nms_threshold: The maximum distance between two joints for them to be considered as belonging to the same pose.
Given in terms of a percentage of a square root of the area of the pose bounding box.
:param int nms_num_threshold: Number of joints that must pass the NMS check for the pose to be considered as a valid one.
:param float pose_score_threshold: Minimum confidence threshold for pose. Pose with confidence lower than this threshold will be discarded.
:return Tuple of (poses, scores)
assert len(poses) == 1
pose_score = torch.cat([pose[:, :, 2:] for pose in poses], dim=0)
pose_coord = torch.cat([pose[:, :, :2] for pose in poses], dim=0)
num_people, num_joints, _ = pose_coord.shape
if num_people == 0:
return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)
heatval = _get_heat_value(pose_coord, heatmap_avg[0])
heat_score = (torch.sum(heatval, dim=1) / num_joints)[:, 0]
pose_score = pose_score * heatval
poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)
keep_pose_inds = _nms_core(pose_coord, heat_score, nms_threshold=nms_threshold, nms_num_threshold=nms_num_threshold)
poses = poses[keep_pose_inds]
heat_score = heat_score[keep_pose_inds]
if len(keep_pose_inds) > max_num_people:
heat_score, topk_inds = torch.topk(heat_score, max_num_people)
poses = poses[topk_inds]
poses = poses.numpy()
if len(poses):
scores = poses[:, :, 2].mean(axis=1)
mask = scores >= pose_score_threshold
poses = poses[mask]
scores = scores[mask]
else:
return np.zeros((0, num_joints, 3), dtype=np.float32), np.zeros((0,), dtype=np.float32)
return poses, scores
A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
Parameters:
Description
Default
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py
172
@register_callback(Callbacks.DEKR_VISUALIZATION)
class DEKRVisualizationCallback(PhaseCallback):
A callback that adds a visualization of a batch of segmentation predictions to context.sg_logger
:param phase: When to trigger the callback.
:param prefix: Prefix to add to the log.
:param mean: Mean to subtract from image.
:param std: Standard deviation to subtract from image.
:param apply_sigmoid: Whether to apply sigmoid to the output.
:param batch_idx: Batch index to perform visualization for.
:param keypoints_threshold: Keypoint threshold to use for visualization.
def __init__(
self,
phase: Phase,
prefix: str,
mean: List[float],
std: List[float],
apply_sigmoid: bool = False,
batch_idx: int = 0,
keypoints_threshold: float = 0.01,
super(DEKRVisualizationCallback, self).__init__(phase)
self.batch_idx = batch_idx
self.prefix = prefix
self.mean = np.array(list(map(float, mean))).reshape((1, 1, -1))
self.std = np.array(list(map(float, std))).reshape((1, 1, -1))
self.apply_sigmoid = apply_sigmoid
self.keypoints_threshold = keypoints_threshold
def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
Reverse image normalization image_normalized (image / 255 - mean) / std
:param image_normalized: normalized [3,H,W]
:return:
image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
image = (image_normalized * self.std + self.mean) * 255
image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
return image
@classmethod
def visualize_heatmap(self, heatmap: Tensor, apply_sigmoid: bool, dsize, min_value=None, max_value=None, colormap=cv2.COLORMAP_JET):
if apply_sigmoid:
heatmap = heatmap.sigmoid()
if min_value is None:
min_value = heatmap.min().item()
if max_value is None:
max_value = heatmap.max().item()
heatmap = heatmap.detach().cpu().numpy()
real_min = heatmap.min()
real_max = heatmap.max()
heatmap = np.max(heatmap, axis=0)
heatmap = (heatmap - min_value) / (1e-8 + max_value - min_value)
heatmap = np.clip(heatmap, 0, 1)
heatmap_8u = (heatmap * 255).astype(np.uint8)
heatmap_bgr = cv2.applyColorMap(heatmap_8u, colormap)
heatmap_rgb = cv2.cvtColor(heatmap_bgr, cv2.COLOR_BGR2RGB)
if dsize is not None:
heatmap_rgb = cv2.resize(heatmap_rgb, dsize=dsize)
cv2.putText(
heatmap_rgb,
f"min:{real_min:.3f}",
(5, 15),
fontFace=cv2.FONT_HERSHEY_PLAIN,
color=(255, 255, 255),
fontScale=0.8,
thickness=1,
lineType=cv2.LINE_AA,
cv2.putText(
heatmap_rgb,
f"max:{real_max:.3f}",
(5, heatmap_rgb.shape[0] - 10),
cv2.FONT_HERSHEY_PLAIN,
color=(255, 255, 255),
fontScale=0.8,
thickness=1,
lineType=cv2.LINE_AA,
return heatmap, heatmap_rgb
@multi_process_safe
def __call__(self, context: PhaseContext):
if context.batch_idx == self.batch_idx:
batch_imgs = self.visualize_batch(context.inputs, context.preds, context.target)
batch_imgs = np.stack(batch_imgs)
tag = self.prefix + str(self.batch_idx) + "_images"
context.sg_logger.add_images(tag=tag, images=batch_imgs, global_step=context.epoch, data_format="NHWC")
@torch.no_grad()
def visualize_batch(self, inputs, predictions, targets):
num_samples = len(inputs)
batch_imgs = []
gt_heatmap, mask, _, _ = targets
# Check whether model also produce supervised output predictions
if isinstance(predictions, tuple) and len(predictions) == 2 and torch.is_tensor(predictions[0]) and torch.is_tensor(predictions[1]):
heatmap, _ = predictions
else:
(heatmap, _), (_, _) = predictions
for i in range(num_samples):
batch_imgs.append(self.visualize_sample(inputs[i], predicted_heatmap=heatmap[i], target_heatmap=gt_heatmap[i], target_mask=mask[i]))
return batch_imgs
def visualize_sample(self, input, predicted_heatmap, target_heatmap, target_mask):
image_rgb = self.denormalize_image(input)
dsize = image_rgb.shape[1], image_rgb.shape[0]
half_size = dsize[0] // 2, dsize[1] // 2
target_heatmap_f32, target_heatmap_rgb = self.visualize_heatmap(target_heatmap, apply_sigmoid=False, dsize=half_size)
target_heatmap_f32 = cv2.resize(target_heatmap_f32, dsize=dsize)
target_heatmap_f32 = np.expand_dims(target_heatmap_f32, -1)
peaks_heatmap = _hierarchical_pool(predicted_heatmap)[0]
peaks_heatmap = predicted_heatmap.eq(peaks_heatmap) & (predicted_heatmap > self.keypoints_threshold)
peaks_heatmap = peaks_heatmap.sum(dim=0, keepdim=False) > 0
# Apply masking with GT mask to suppress predictions on ignored areas of the image (where target_mask==0)
flat_target_mask = target_mask.sum(dim=0, keepdim=False) > 0
peaks_heatmap &= flat_target_mask
peaks_heatmap = peaks_heatmap.detach().cpu().numpy().astype(np.uint8) * 255
peaks_heatmap = cv2.applyColorMap(peaks_heatmap, cv2.COLORMAP_JET)
peaks_heatmap = cv2.cvtColor(peaks_heatmap, cv2.COLOR_BGR2RGB)
peaks_heatmap = cv2.resize(peaks_heatmap, dsize=half_size)
_, predicted_heatmap_rgb = self.visualize_heatmap(
predicted_heatmap, min_value=target_heatmap.min().item(), max_value=target_heatmap.max().item(), apply_sigmoid=self.apply_sigmoid, dsize=half_size
image_heatmap_overlay = image_rgb * (1 - target_heatmap_f32) + target_heatmap_f32 * cv2.resize(target_heatmap_rgb, dsize=dsize)
image_heatmap_overlay = image_heatmap_overlay.astype(np.uint8)
_, target_mask_rgb = self.visualize_heatmap(target_mask, min_value=0, max_value=1, apply_sigmoid=False, dsize=half_size, colormap=cv2.COLORMAP_BONE)
return np.hstack(
image_heatmap_overlay,
np.vstack([target_heatmap_rgb, predicted_heatmap_rgb]),
np.vstack([target_mask_rgb, peaks_heatmap]),
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/dekr_visualization_callbacks.py
59
def denormalize_image(self, image_normalized: Tensor) -> np.ndarray:
Reverse image normalization image_normalized (image / 255 - mean) / std
:param image_normalized: normalized [3,H,W]
:return:
image_normalized = torch.moveaxis(image_normalized, 0, -1).detach().cpu().numpy()
image = (image_normalized * self.std + self.mean) * 255
image = np.clip(image, 0, 255).astype(np.uint8)[..., ::-1]
return image
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/rescoring_callback.py
25
class RescoringPoseEstimationDecodeCallback:
A special adapter callback to be used with PoseEstimationMetrics to use the outputs from rescoring model inside metric class.
def __init__(self, apply_sigmoid: bool):
:param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
bound to [0..1] range and trained with logits (E.g focal loss)
super().__init__()
self.apply_sigmoid = apply_sigmoid
def __call__(self, predictions: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
""" """
poses, scores = predictions
if self.apply_sigmoid:
scores = scores.sigmoid()
return poses, scores.squeeze(-1) # Pose Estimation Callback expects that scores don't have the dummy dimension
If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not bound to [0..1] range and trained with logits (E.g focal loss)
required
Source code in V3_2/src/super_gradients/training/utils/pose_estimation/rescoring_callback.py
18
def __init__(self, apply_sigmoid: bool):
:param apply_sigmoid: If True, apply the sigmoid activation on heatmap. This is needed when heatmap is not
bound to [0..1] range and trained with logits (E.g focal loss)
super().__init__()
self.apply_sigmoid = apply_sigmoid
Object wrapping an image and a detection model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
123
@dataclass
class ImagePoseEstimationPrediction(ImagePrediction):
"""Object wrapping an image and a detection model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
image: np.ndarray
prediction: PoseEstimationPrediction
def draw(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> np.ndarray:
"""Draw the predicted bboxes on the image.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
:return: Image with predicted bboxes. Note that this does not modify the original image.
image = self.image.copy()
for pred_i in np.argsort(self.prediction.scores):
image = draw_skeleton(
image=image,
keypoints=self.prediction.poses[pred_i],
score=self.prediction.scores[pred_i],
show_confidence=show_confidence,
edge_links=self.prediction.edge_links,
edge_colors=edge_colors or self.prediction.edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
return image
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the image with predicted bboxes.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
image = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
show_image(image)
def save(
self,
output_path: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
save_image(image=image, path=output_path)
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
64
def draw(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> np.ndarray:
"""Draw the predicted bboxes on the image.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
:return: Image with predicted bboxes. Note that this does not modify the original image.
image = self.image.copy()
for pred_i in np.argsort(self.prediction.scores):
image = draw_skeleton(
image=image,
keypoints=self.prediction.poses[pred_i],
score=self.prediction.scores[pred_i],
show_confidence=show_confidence,
edge_links=self.prediction.edge_links,
edge_colors=edge_colors or self.prediction.edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors or self.prediction.keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
return image
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
123
def save(
self,
output_path: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence)
save_image(image=image, path=output_path)
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
96
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the image with predicted bboxes.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
image = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
show_image(image)
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
204
@dataclass
class ImagesPoseEstimationPrediction(ImagesPredictions):
"""Object wrapping the list of image detection predictions.
:attr _images_prediction_lst: List of the predictions results
_images_prediction_lst: List[ImagePoseEstimationPrediction]
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the predicted bboxes on the images.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
for prediction in self._images_prediction_lst:
prediction.show(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
def save(
self,
output_folder: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(
output_path=image_output_path,
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
204
def save(
self,
output_folder: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(
output_path=image_output_path,
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
165
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the predicted bboxes on the images.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
for prediction in self._images_prediction_lst:
prediction.show(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
Object wrapping the list of image detection predictions as a Video.
:attr _images_prediction_lst: List of the predictions results
:att fps: Frames per second of the video
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
320
@dataclass
class VideoPoseEstimationPrediction(VideoPredictions):
"""Object wrapping the list of image detection predictions as a Video.
:attr _images_prediction_lst: List of the predictions results
:att fps: Frames per second of the video
_images_prediction_lst: List[ImagePoseEstimationPrediction]
fps: int
def draw(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> List[np.ndarray]:
"""Draw the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
frames_with_bbox = [
result.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
for result in self._images_prediction_lst
return frames_with_bbox
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the predicted bboxes on the images.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
frames = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)
def save(
self,
output_path: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
frames = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
save_video(output_path=output_path, frames=frames, fps=self.fps)
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
254
def draw(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> List[np.ndarray]:
"""Draw the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
frames_with_bbox = [
result.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
for result in self._images_prediction_lst
return frames_with_bbox
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
320
def save(
self,
output_path: str,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
frames = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
save_video(output_path=output_path, frames=frames, fps=self.fps)
Optional list of tuples representing the colors for each joint. If None, default colors are used. If not None the length must be equal to the number of joint links in the skeleton.
joint_thickness
Thickness of the joint links (in pixels).
keypoint_colors
Optional list of tuples representing the colors for each keypoint. If None, default colors are used. If not None the length must be equal to the number of joints in the skeleton.
keypoint_radius
Radius of the keypoints (in pixels).
show_confidence
Whether to show confidence scores on the image.
False
box_thickness
Thickness of bounding boxes.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_pose_estimation_results.py
286
def show(
self,
edge_colors=None,
joint_thickness: int = 2,
keypoint_colors=None,
keypoint_radius: int = 5,
box_thickness: int = 2,
show_confidence: bool = False,
) -> None:
"""Display the predicted bboxes on the images.
:param edge_colors: Optional list of tuples representing the colors for each joint.
If None, default colors are used.
If not None the length must be equal to the number of joint links in the skeleton.
:param joint_thickness: Thickness of the joint links (in pixels).
:param keypoint_colors: Optional list of tuples representing the colors for each keypoint.
If None, default colors are used.
If not None the length must be equal to the number of joints in the skeleton.
:param keypoint_radius: Radius of the keypoints (in pixels).
:param show_confidence: Whether to show confidence scores on the image.
:param box_thickness: Thickness of bounding boxes.
frames = self.draw(
edge_colors=edge_colors,
joint_thickness=joint_thickness,
keypoint_colors=keypoint_colors,
keypoint_radius=keypoint_radius,
box_thickness=box_thickness,
show_confidence=show_confidence,
show_video_from_frames(window_name="Pose Estimation", frames=frames, fps=self.fps)
Object wrapping an image and a classification model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
91
@dataclass
class ImageClassificationPrediction(ImagePrediction):
"""Object wrapping an image and a classification model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
image: np.ndarray
prediction: ClassificationPrediction
class_names: List[str]
def draw(self, show_confidence: bool = True) -> np.ndarray:
"""Draw the predicted label on the image.
:param show_confidence: Whether to show confidence scores on the image.
:return: Image with predicted label.
image = self.image.copy()
return draw_label(
image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:]
def show(self, show_confidence: bool = True) -> None:
"""Display the image with predicted label.
:param show_confidence: Whether to show confidence scores on the image.
# to do draw the prediction on the image
image = self.draw(show_confidence=show_confidence)
show_image(image)
def save(
self,
output_path: str,
show_confidence: bool = True,
) -> None:
"""Save the predicted label on the images.
:param output_path: Path to the output video file.
:param show_confidence: Whether to show confidence scores on the image.
image = self.draw(show_confidence=show_confidence)
save_image(image=image, path=output_path)
69
def draw(self, show_confidence: bool = True) -> np.ndarray:
"""Draw the predicted label on the image.
:param show_confidence: Whether to show confidence scores on the image.
:return: Image with predicted label.
image = self.image.copy()
return draw_label(
image=image, label=self.class_names[self.prediction.labels], confidence=str(self.prediction.confidence), image_shape=self.prediction.image_shape[1:]
"""Save the predicted label on the images.
:param output_path: Path to the output video file.
:param show_confidence: Whether to show confidence scores on the image.
image = self.draw(show_confidence=show_confidence)
save_image(image=image, path=output_path)
78
def show(self, show_confidence: bool = True) -> None:
"""Display the image with predicted label.
:param show_confidence: Whether to show confidence scores on the image.
# to do draw the prediction on the image
image = self.draw(show_confidence=show_confidence)
show_image(image)
Object wrapping an image and a detection model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
157
@dataclass
class ImageDetectionPrediction(ImagePrediction):
"""Object wrapping an image and a detection model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
image: np.ndarray
prediction: DetectionPrediction
class_names: List[str]
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> np.ndarray:
"""Draw the predicted bboxes on the image.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:return: Image with predicted bboxes. Note that this does not modify the original image.
image = self.image.copy()
color_mapping = color_mapping or generate_color_mapping(len(self.class_names))
for pred_i in np.argsort(self.prediction.confidence):
class_id = int(self.prediction.labels[pred_i])
score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
image = draw_bbox(
image=image,
title=f"{self.class_names[class_id]} {score}",
color=color_mapping[class_id],
box_thickness=box_thickness,
x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
return image
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the image with predicted bboxes.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
show_image(image)
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
save_image(image=image, path=output_path)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
134
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> np.ndarray:
"""Draw the predicted bboxes on the image.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:return: Image with predicted bboxes. Note that this does not modify the original image.
image = self.image.copy()
color_mapping = color_mapping or generate_color_mapping(len(self.class_names))
for pred_i in np.argsort(self.prediction.confidence):
class_id = int(self.prediction.labels[pred_i])
score = "" if not show_confidence else str(round(self.prediction.confidence[pred_i], 2))
image = draw_bbox(
image=image,
title=f"{self.class_names[class_id]} {score}",
color=color_mapping[class_id],
box_thickness=box_thickness,
x1=int(self.prediction.bboxes_xyxy[pred_i, 0]),
y1=int(self.prediction.bboxes_xyxy[pred_i, 1]),
x2=int(self.prediction.bboxes_xyxy[pred_i, 2]),
y2=int(self.prediction.bboxes_xyxy[pred_i, 3]),
return image
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
157
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
save_image(image=image, path=output_path)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
145
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the image with predicted bboxes.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
image = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
show_image(image)
Object wrapping an image and a model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
43
@dataclass
class ImagePrediction(ABC):
"""Object wrapping an image and a model's prediction.
:attr image: Input image
:attr predictions: Predictions of the model
:attr class_names: List of the class names to predict
image: np.ndarray
prediction: Prediction
class_names: List[str]
@abstractmethod
def draw(self, *args, **kwargs) -> np.ndarray:
"""Draw the predictions on the image."""
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the image."""
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the image."""
33
@abstractmethod
def draw(self, *args, **kwargs) -> np.ndarray:
"""Draw the predictions on the image."""
43
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the image."""
38
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the image."""
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
238
@dataclass
class ImagesClassificationPrediction(ImagesPredictions):
"""Object wrapping the list of image classification predictions.
:attr _images_prediction_lst: List of the predictions results
_images_prediction_lst: List[ImageClassificationPrediction]
def show(self, show_confidence: bool = True) -> None:
"""Display the predicted labels on the images.
:param show_confidence: Whether to show confidence scores on the image.
for prediction in self._images_prediction_lst:
prediction.show(show_confidence=show_confidence)
def save(self, output_folder: str, show_confidence: bool = True) -> None:
"""Save the predicted label on the images.
:param output_folder: Folder path, where the images will be saved.
:param show_confidence: Whether to show confidence scores on the image.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(output_path=image_output_path, show_confidence=show_confidence)
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
238
def save(self, output_folder: str, show_confidence: bool = True) -> None:
"""Save the predicted label on the images.
:param output_folder: Folder path, where the images will be saved.
:param show_confidence: Whether to show confidence scores on the image.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(output_path=image_output_path, show_confidence=show_confidence)
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
225
def show(self, show_confidence: bool = True) -> None:
"""Display the predicted labels on the images.
:param show_confidence: Whether to show confidence scores on the image.
for prediction in self._images_prediction_lst:
prediction.show(show_confidence=show_confidence)
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
277
@dataclass
class ImagesDetectionPrediction(ImagesPredictions):
"""Object wrapping the list of image detection predictions.
:attr _images_prediction_lst: List of the predictions results
_images_prediction_lst: List[ImageDetectionPrediction]
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
for prediction in self._images_prediction_lst:
prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
def save(
self, output_folder: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None
) -> None:
"""Save the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(output_path=image_output_path, box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
277
def save(
self, output_folder: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None
) -> None:
"""Save the predicted bboxes on the images.
:param output_folder: Folder path, where the images will be saved.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
if output_folder:
os.makedirs(output_folder, exist_ok=True)
for i, prediction in enumerate(self._images_prediction_lst):
image_output_path = os.path.join(output_folder, f"pred_{i}.jpg")
prediction.save(output_path=image_output_path, box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
259
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
for prediction in self._images_prediction_lst:
prediction.show(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
186
@dataclass
class ImagesPredictions(ABC):
"""Object wrapping the list of image predictions.
:attr _images_prediction_lst: List of results of the run
_images_prediction_lst: List[ImagePrediction]
def __len__(self) -> int:
return len(self._images_prediction_lst)
def __getitem__(self, index: int) -> ImagePrediction:
return self._images_prediction_lst[index]
def __iter__(self) -> Iterator[ImagePrediction]:
return iter(self._images_prediction_lst)
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the images."""
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the images."""
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
186
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the images."""
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
181
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the images."""
Object wrapping the list of image detection predictions as a Video.
:attr _images_prediction_lst: List of the predictions results
:att fps: Frames per second of the video
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
326
@dataclass
class VideoDetectionPrediction(VideoPredictions):
"""Object wrapping the list of image detection predictions as a Video.
:attr _images_prediction_lst: List of the predictions results
:att fps: Frames per second of the video
_images_prediction_lst: List[ImageDetectionPrediction]
fps: int
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> List[np.ndarray]:
"""Draw the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
frames_with_bbox = [
result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst
return frames_with_bbox
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
save_video(output_path=output_path, frames=frames, fps=self.fps)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
303
def draw(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> List[np.ndarray]:
"""Draw the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
:return: List of images with predicted bboxes. Note that this does not modify the original image.
frames_with_bbox = [
result.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping) for result in self._images_prediction_lst
return frames_with_bbox
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
326
def save(self, output_path: str, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Save the predicted bboxes on the images.
:param output_path: Path to the output video file.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
save_video(output_path=output_path, frames=frames, fps=self.fps)
Optional[List[Tuple[int, int, int]]]
List of tuples representing the colors for each class. Default is None, which generates a default color mapping based on the number of class names.
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
314
def show(self, box_thickness: int = 2, show_confidence: bool = True, color_mapping: Optional[List[Tuple[int, int, int]]] = None) -> None:
"""Display the predicted bboxes on the images.
:param box_thickness: Thickness of bounding boxes.
:param show_confidence: Whether to show confidence scores on the image.
:param color_mapping: List of tuples representing the colors for each class.
Default is None, which generates a default color mapping based on the number of class names.
frames = self.draw(box_thickness=box_thickness, show_confidence=show_confidence, color_mapping=color_mapping)
show_video_from_frames(window_name="Detection", frames=frames, fps=self.fps)
Object wrapping the list of image predictions as a Video.
:attr _images_prediction_lst: List of results of the run
:att fps: Frames per second of the video
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
208
@dataclass
class VideoPredictions(ImagesPredictions, ABC):
"""Object wrapping the list of image predictions as a Video.
:attr _images_prediction_lst: List of results of the run
:att fps: Frames per second of the video
_images_prediction_lst: List[ImagePrediction]
fps: float
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the video."""
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the video."""
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
208
@abstractmethod
def save(self, *args, **kwargs) -> None:
"""Save the predictions on the video."""
Source code in V3_2/src/super_gradients/training/utils/predict/prediction_results.py
203
@abstractmethod
def show(self, *args, **kwargs) -> None:
"""Display the predictions on the video."""
139
@dataclass
class ClassificationPrediction(Prediction):
"""Represents a Classification prediction"""
confidence: float
labels: int
image_shape: Tuple[int, int]
def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]):
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W).
self._validate_input(confidence, labels)
self.confidence = confidence
self.labels = labels
self.image_shape = image_shape
def _validate_input(self, confidence: np.ndarray, labels: np.ndarray) -> None:
if not isinstance(confidence, float):
raise ValueError(f"Argument confidence must be a numpy array, not {type(confidence)}")
if not isinstance(labels, int):
raise ValueError(f"Argument labels must be a numpy array, not {type(labels)}")
def __len__(self):
return len(self.labels)
130
def __init__(self, confidence: float, labels: int, image_shape: Optional[Tuple[int, int]]):
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W).
self._validate_input(confidence, labels)
self.confidence = confidence
self.labels = labels
self.image_shape = image_shape
55
@dataclass
class DetectionPrediction(Prediction):
"""Represents a detection prediction, with bboxes represented in xyxy format."""
bboxes_xyxy: np.ndarray
confidence: np.ndarray
labels: np.ndarray
def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
:param bboxes: BBoxes in the format specified by bbox_format
:param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
self._validate_input(bboxes, confidence, labels)
factory = BBoxFormatFactory()
bboxes_xyxy = convert_bboxes(
bboxes=bboxes,
image_shape=image_shape,
source_format=factory.get(bbox_format),
target_format=factory.get("xyxy"),
inplace=False,
self.bboxes_xyxy = bboxes_xyxy
self.confidence = confidence
self.labels = labels
def _validate_input(self, bboxes: np.ndarray, confidence: np.ndarray, labels: np.ndarray) -> None:
n_bboxes, n_confidences, n_labels = bboxes.shape[0], confidence.shape[0], labels.shape[0]
if n_bboxes != n_confidences != n_labels:
raise ValueError(
f"The number of bounding boxes ({n_bboxes}) does not match the number of confidence scores ({n_confidences}) and labels ({n_labels})."
def __len__(self):
return len(self.bboxes_xyxy)
Tuple[int, int]
Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
required
45
def __init__(self, bboxes: np.ndarray, bbox_format: str, confidence: np.ndarray, labels: np.ndarray, image_shape: Tuple[int, int]):
:param bboxes: BBoxes in the format specified by bbox_format
:param bbox_format: BBoxes format that can be a string ("xyxy", "cxywh", ...)
:param confidence: Confidence scores for each bounding box
:param labels: Labels for each bounding box.
:param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
self._validate_input(bboxes, confidence, labels)
factory = BBoxFormatFactory()
bboxes_xyxy = convert_bboxes(
bboxes=bboxes,
image_shape=image_shape,
source_format=factory.get(bbox_format),
target_format=factory.get("xyxy"),
inplace=False,
self.bboxes_xyxy = bboxes_xyxy
self.confidence = confidence
self.labels = labels
Represents a pose estimation prediction.
:attr poses: Numpy array of [Num Poses, Num Joints, 2] shape
:attr scores: Numpy array of [Num Poses] shape
Source code in V3_2/src/super_gradients/training/utils/predict/predictions.py
108
@dataclass
class PoseEstimationPrediction(Prediction):
"""Represents a pose estimation prediction.
:attr poses: Numpy array of [Num Poses, Num Joints, 2] shape
:attr scores: Numpy array of [Num Poses] shape
poses: np.ndarray
scores: np.ndarray
edge_links: np.ndarray
edge_colors: np.ndarray
keypoint_colors: np.ndarray
image_shape: Tuple[int, int]
def __init__(
self,
poses: np.ndarray,
scores: np.ndarray,
edge_links: np.ndarray,
edge_colors: np.ndarray,
keypoint_colors: np.ndarray,
image_shape: Tuple[int, int],
:param poses:
:param scores:
:param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
self._validate_input(poses, scores, edge_links, edge_colors, keypoint_colors)
self.poses = poses
self.scores = scores
self.edge_links = edge_links
self.edge_colors = edge_colors
self.image_shape = image_shape
self.keypoint_colors = keypoint_colors
def _validate_input(self, poses: np.ndarray, scores: np.ndarray, edge_links, edge_colors, keypoint_colors) -> None:
if not isinstance(poses, np.ndarray):
raise ValueError(f"Argument poses must be a numpy array, not {type(poses)}")
if not isinstance(scores, np.ndarray):
raise ValueError(f"Argument scores must be a numpy array, not {type(scores)}")
if not isinstance(keypoint_colors, np.ndarray):
raise ValueError(f"Argument keypoint_colors must be a numpy array, not {type(keypoint_colors)}")
if len(poses) != len(scores) != len(keypoint_colors):
raise ValueError(f"The number of poses ({len(poses)}) does not match the number of scores ({len(scores)}).")
if len(edge_links) != len(edge_colors):
raise ValueError(f"The number of joint links ({len(edge_links)}) does not match the number of joint colors ({len(edge_colors)}).")
def __len__(self):
return len(self.poses)
Tuple[int, int]
Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
required
edge_links: np.ndarray,
edge_colors: np.ndarray,
keypoint_colors: np.ndarray,
image_shape: Tuple[int, int],
:param poses:
:param scores:
:param image_shape: Shape of the image the prediction is made on, (H, W). This is used to convert bboxes to xyxy format
self._validate_input(poses, scores, edge_links, edge_colors, keypoint_colors)
self.poses = poses
self.scores = scores
self.edge_links = edge_links
self.edge_colors = edge_colors
self.image_shape = image_shape
self.keypoint_colors = keypoint_colors
Quantization utilities
Methods are based on:
https://github.com/NVIDIA/TensorRT/blob/51a4297753d3e12d0eed864be52400f429a6a94c/tools/pytorch-quantization/examples/torchvision/classification_flow.py#L385
(Licensed under the Apache License, Version 2.0)
149
class QuantizationCalibrator:
def __init__(self, torch_hist: bool = True, verbose: bool = True) -> None:
if _imported_pytorch_quantization_failure is not None:
raise _imported_pytorch_quantization_failure
super().__init__()
self.verbose = verbose
self.torch_hist = torch_hist
def calibrate_model(
self,
model: torch.nn.Module,
calib_data_loader: torch.utils.data.DataLoader,
method: str = "percentile",
num_calib_batches: int = 2,
percentile: float = 99.99,
Calibrates torch model with quantized modules.
:param model: torch.nn.Module, model to perfrom the calibration on.
:param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset.
Assumes that the first element of the tuple is the input image.
:param method: str, One of [percentile, mse, entropy, max].
Statistics method for amax computation of the quantized modules
(Default=percentile).
:param num_calib_batches: int, number of batches to collect the statistics from.
:param percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
Discarded when other methods are used (Default=99.99).
logging_level = logging.getLogger("absl").getEffectiveLevel()
if not self.verbose: # suppress pytorch-quantization spam
logging.getLogger("absl").setLevel("ERROR")
acceptable_methods = ["percentile", "mse", "entropy", "max"]
if method in acceptable_methods:
with torch.no_grad():
device = next(model.parameters()).device
self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
# FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
# SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
if method == "precentile":
self._compute_amax(model, method="percentile", percentile=percentile)
else:
self._compute_amax(model, method=method)
model.to(device)
else:
raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")
logging.getLogger("absl").setLevel(logging_level)
def _collect_stats(self, model, data_loader, num_batches):
"""Feed data to the network and collect statistics"""
local_rank = get_local_rank()
world_size = get_world_size()
device = next(model.parameters()).device
# Enable calibrators
self._enable_calibrators(model)
# Feed data to the network for collecting stats
for i, (image, *_) in tqdm(enumerate(data_loader), total=num_batches, disable=local_rank > 0):
if world_size > 1:
all_batches = [torch.zeros_like(image, device=device) for _ in range(world_size)]
all_gather(all_batches, image.to(device=device))
else:
all_batches = [image]
for local_image in all_batches:
model(local_image.to(device=device))
if i >= num_batches:
break
# Disable calibrators
self._disable_calibrators(model)
def _disable_calibrators(self, model):
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_calib()
module.enable_quant()
else:
module.enable()
def reset_calibrators(self, model):
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module._calibrator.reset() # release memory
def _enable_calibrators(self, model):
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.HistogramCalibrator):
module._calibrator._torch_hist = self.torch_hist # TensorQuantizer does not expose it as API
module.disable_quant()
module.enable_calib()
else:
module.disable()
def _compute_amax(self, model, **kwargs):
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
if hasattr(module, "clip"):
module.init_learn_amax()
if self.verbose:
print(f"{name:40}: {module}")
torch.utils.data.DataLoader
torch.utils.data.DataLoader, data loader of the calibration dataset. Assumes that the first element of the tuple is the input image.
required
method
str, One of [percentile, mse, entropy, max]. Statistics method for amax computation of the quantized modules (Default=percentile).
'percentile'
num_calib_batches
int, number of batches to collect the statistics from.
percentile
float
float, percentile value to use when SgModel,quant_modules_calib_method='percentile'. Discarded when other methods are used (Default=99.99).
99.99
self,
model: torch.nn.Module,
calib_data_loader: torch.utils.data.DataLoader,
method: str = "percentile",
num_calib_batches: int = 2,
percentile: float = 99.99,
Calibrates torch model with quantized modules.
:param model: torch.nn.Module, model to perfrom the calibration on.
:param calib_data_loader: torch.utils.data.DataLoader, data loader of the calibration dataset.
Assumes that the first element of the tuple is the input image.
:param method: str, One of [percentile, mse, entropy, max].
Statistics method for amax computation of the quantized modules
(Default=percentile).
:param num_calib_batches: int, number of batches to collect the statistics from.
:param percentile: float, percentile value to use when SgModel,quant_modules_calib_method='percentile'.
Discarded when other methods are used (Default=99.99).
logging_level = logging.getLogger("absl").getEffectiveLevel()
if not self.verbose: # suppress pytorch-quantization spam
logging.getLogger("absl").setLevel("ERROR")
acceptable_methods = ["percentile", "mse", "entropy", "max"]
if method in acceptable_methods:
with torch.no_grad():
device = next(model.parameters()).device
self._collect_stats(model, calib_data_loader, num_batches=num_calib_batches)
# FOR PERCENTILE WE MUST PASS PERCENTILE VALUE THROUGH KWARGS,
# SO IT WOULD BE PASSED TO module.load_calib_amax(**kwargs), AND IN OTHER METHODS WE MUST NOT PASS IT.
if method == "precentile":
self._compute_amax(model, method="percentile", percentile=percentile)
else:
self._compute_amax(model, method=method)
model.to(device)
else:
raise ValueError(f"Unsupported quantization calibration method, " f"expected one of: {'.'.join(acceptable_methods)}, however, received: {method}")
logging.getLogger("absl").setLevel(logging_level)
This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized
class, with relevant quant descriptors.
Example:
self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)
Source code in V3_2/src/super_gradients/training/utils/quantization/core.py
165
class QuantizedMapping(nn.Module):
This class wraps a float module instance, and defines a mapping from this instance to the corresponding quantized
class, with relevant quant descriptors.
Example:
self.my_block = QuantizedMapping(float_module=MyBlock(4, n_classes), quantized_target_class=MyQuantizedBlock)
def __init__(
self,
float_module: nn.Module,
quantized_target_class: Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]],
action=QuantizedMetadata.ReplacementAction.REPLACE,
input_quant_descriptor: QuantDescriptor = None,
weights_quant_descriptor: QuantDescriptor = None,
) -> None:
super().__init__()
self.float_module = float_module
self.quantized_target_class = quantized_target_class
self.action = action
self.input_quant_descriptor = input_quant_descriptor
self.weights_quant_descriptor = weights_quant_descriptor
self.forward = float_module.forward
This dataclass is responsible for holding the information regarding float->quantized module relation.
It can be both layer-grained and module-grained, e.g.,
module.backbone.conv1 -> QuantConv2d
, nn.Linear -> QuantLinear
, etc...
Parameters:
Description
Default
Union[str, Type]
Name of a specific layer (e.g., module.backbone.conv1
), or a specific type (e.g., Conv2d
) that will be later quantized
required
quantized_target_class
Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]
Quantized type that the source will be converted to
required
action
ReplacementAction
how to resolve the conversion, we either: - SKIP: skip it, - UNWRAP: unwrap the instance and work with the wrapped one (i.e., we wrap with a mapper), - REPLACE: replace source with an instance of the quantized type - REPLACE_AND_RECURE: replace source with an instance of the quantized type, then try to recursively quantize the child modules of that type - RECURE_AND_REPLACE: recursively quantize the child modules, then replace source with an instance of the quantized type
required
input_quant_descriptor
QuantDescriptor
Quantization descriptor for inputs (None will take the default one)
weights_quant_descriptor
QuantDescriptor
Quantization descriptor for weights (None will take the default one)
class QuantizedMetadata:
This dataclass is responsible for holding the information regarding float->quantized module relation.
It can be both layer-grained and module-grained, e.g.,
`module.backbone.conv1 -> QuantConv2d`, `nn.Linear -> QuantLinear`, etc...
:param float_source: Name of a specific layer (e.g., `module.backbone.conv1`),
or a specific type (e.g., `Conv2d`) that will be later quantized
:param quantized_target_class: Quantized type that the source will be converted to
:param action: how to resolve the conversion, we either:
- SKIP: skip it,
- UNWRAP: unwrap the instance and work with the wrapped one
(i.e., we wrap with a mapper),
- REPLACE: replace source with an instance of the
quantized type
- REPLACE_AND_RECURE: replace source with an instance of the
quantized type, then try to recursively quantize the child modules of that type
- RECURE_AND_REPLACE: recursively quantize the child modules, then
replace source with an instance of the quantized type
:param input_quant_descriptor: Quantization descriptor for inputs (None will take the default one)
:param weights_quant_descriptor: Quantization descriptor for weights (None will take the default one)
class ReplacementAction(Enum):
REPLACE = "replace"
REPLACE_AND_RECURE = "replace_and_recure"
RECURE_AND_REPLACE = "recure_and_replace"
UNWRAP = "unwrap"
SKIP = "skip"
float_source: Union[str, Type]
quantized_target_class: Optional[Union[Type[QuantMixin], Type[QuantInputMixin], Type[SGQuantMixin]]]
action: ReplacementAction
input_quant_descriptor: QuantDescriptor = None # default is used if None
weights_quant_descriptor: QuantDescriptor = None # default is used if None
def __post_init__(self):
if self.action in (
QuantizedMetadata.ReplacementAction.REPLACE,
QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
assert issubclass(self.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin))
A base class for user custom Quantized classes.
Every Quantized class must inherit this mixin, which adds from_float
class-method.
NOTES:
* the Quantized class may also inherit from the native QuantMixin
or QuantInputMixin
* quant descriptors (for inputs and weights) will be passed as kwargs
. The module may ignore them if they are
not necessary
* the default implementation of from_float
is inspecting the init args, and searching for corresponding
properties from the float instance that is passed as argument, e.g., for __init__(self, a)
the mechanism will look for float_instance.a
and pass that value to the __init__
method
Source code in V3_2/src/super_gradients/training/utils/quantization/core.py
78
class SGQuantMixin(nn.Module):
A base class for user custom Quantized classes.
Every Quantized class must inherit this mixin, which adds `from_float` class-method.
NOTES:
* the Quantized class may also inherit from the native `QuantMixin` or `QuantInputMixin`
* quant descriptors (for inputs and weights) will be passed as `kwargs`. The module may ignore them if they are
not necessary
* the default implementation of `from_float` is inspecting the __init__ args, and searching for corresponding
properties from the float instance that is passed as argument, e.g., for `__init__(self, a)`
the mechanism will look for `float_instance.a` and pass that value to the `__init__` method
@classmethod
def from_float(cls, float_instance, **kwargs):
required_init_params = list(inspect.signature(cls.__init__).parameters)[1:] # [0] is self
# if cls.__init__ has explicit `quant_desc_input` or `quant_desc_weight` - we don't search the state of the
# float module, because it would not contain this state. these values are injected by the framework
ignore_init_args = {"quant_desc_input", "quant_desc_weight"}.intersection(set(required_init_params))
# if cls.__init__ doesn't have neither **kwargs, nor `quant_desc_input` and `quant_desc_weight`,
# we should also remove these keys from the passed kwargs and make sure there's nothing more!
if "kwargs" not in required_init_params:
for arg in ("quant_desc_input", "quant_desc_weight"):
if arg in ignore_init_args:
continue
kwargs.pop(arg, None) # we ignore if not existing
return _from_float(cls, float_instance, ignore_init_args, **kwargs)
This class wraps a float module instance, and defines that this instance will not be converted to quantized version
Example:
self.my_block = SkipQuantization(MyBlock(4, n_classes))
Source code in V3_2/src/super_gradients/training/utils/quantization/core.py
92
class SkipQuantization(nn.Module):
This class wraps a float module instance, and defines that this instance will not be converted to quantized version
Example:
self.my_block = SkipQuantization(MyBlock(4, n_classes))
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.float_module = module
self.forward = module.forward
transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices
train
export model in training mode
False
model
torch.nn.Module
torch.nn.Module, model to export
required
onnx_filename
str, target path for the onnx file,
required
input_shape
tuple
tuple, input shape (usually BCHW)
required
deepcopy_model
Whether to export deepcopy(model). Necessary in case further training is performed and prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
False
65
def export_quantized_module_to_onnx(
model: torch.nn.Module, onnx_filename: str, input_shape: tuple, train: bool = False, to_cpu: bool = True, deepcopy_model=False, **kwargs
Method for exporting onnx after QAT.
:param to_cpu: transfer model to CPU before converting to ONNX, dirty workaround when model's tensors are on different devices
:param train: export model in training mode
:param model: torch.nn.Module, model to export
:param onnx_filename: str, target path for the onnx file,
:param input_shape: tuple, input shape (usually BCHW)
:param deepcopy_model: Whether to export deepcopy(model). Necessary in case further training is performed and
prep_model_for_conversion makes the network un-trainable (i.e RepVGG blocks).
if _imported_pytorch_quantization_failure is not None:
raise _imported_pytorch_quantization_failure
if deepcopy_model:
model = deepcopy(model)
use_fb_fake_quant_state = quant_nn.TensorQuantizer.use_fb_fake_quant
quant_nn.TensorQuantizer.use_fb_fake_quant = True
# Export ONNX for multiple batch sizes
logger.info("Creating ONNX file: " + onnx_filename)
if train:
training_mode = TrainingMode.TRAINING
model.train()
else:
training_mode = TrainingMode.EVAL
model.eval()
if hasattr(model, "prep_model_for_conversion"):
model.prep_model_for_conversion(**kwargs)
# workaround when model.prep_model_for_conversion does reparametrization
# and tensors get scattered to different devices
if to_cpu:
export_model = model.cpu()
else:
export_model = model
dummy_input = torch.randn(input_shape, device=next(model.parameters()).device)
torch.onnx.export(export_model, dummy_input, onnx_filename, verbose=False, opset_version=13, do_constant_folding=True, training=training_mode)
# Restore functions of quant_nn back as expected
quant_nn.TensorQuantizer.use_fb_fake_quant = use_fb_fake_quant_state
Source code in V3_2/src/super_gradients/training/utils/quantization/selective_quantization_utils.py
326
class SelectiveQuantizer:
:param custom_mappings: custom mappings that extend the default mappings with extra behaviour
:param default_per_channel_quant_weights: whether quant module weights should be per channel (default=True)
:param default_quant_modules_calibrator_weights: default calibrator method for weights (default='max')
:param default_quant_modules_calibrator_inputs: default calibrator method for inputs (default='histogram')
:param default_learn_amax: EXPERIMENTAL! whether quant modules should have learnable amax (default=False)
if _imported_pytorch_quantization_failure is not None:
raise _imported_pytorch_quantization_failure
mapping_instructions: Dict[Union[str, Type], QuantizedMetadata] = {
float_type: QuantizedMetadata(
float_source=float_type,
quantized_target_class=quantized_target_class,
action=QuantizedMetadata.ReplacementAction.REPLACE,
for (float_type, quantized_target_class) in [
(nn.Conv1d, quant_nn.QuantConv1d),
(nn.Conv2d, quant_nn.QuantConv2d),
(nn.Conv3d, quant_nn.QuantConv3d),
(nn.ConvTranspose1d, quant_nn.QuantConvTranspose1d),
(nn.ConvTranspose2d, quant_nn.QuantConvTranspose2d),
(nn.ConvTranspose3d, quant_nn.QuantConvTranspose3d),
(nn.Linear, quant_nn.Linear),
(nn.LSTM, quant_nn.LSTM),
(nn.LSTMCell, quant_nn.LSTMCell),
(nn.AvgPool1d, quant_nn.QuantAvgPool1d),
(nn.AvgPool2d, quant_nn.QuantAvgPool2d),
(nn.AvgPool3d, quant_nn.QuantAvgPool3d),
(nn.AdaptiveAvgPool1d, quant_nn.QuantAdaptiveAvgPool1d),
(nn.AdaptiveAvgPool2d, quant_nn.QuantAdaptiveAvgPool2d),
(nn.AdaptiveAvgPool3d, quant_nn.QuantAdaptiveAvgPool3d),
SkipQuantization: QuantizedMetadata(float_source=SkipQuantization, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP),
} # DEFAULT MAPPING INSTRUCTIONS
def __init__(
self,
custom_mappings: dict = None,
default_quant_modules_calibrator_weights: str = "max",
default_quant_modules_calibrator_inputs: str = "histogram",
default_per_channel_quant_weights: bool = True,
default_learn_amax: bool = False,
verbose: bool = True,
) -> None:
super().__init__()
self.default_quant_modules_calibrator_weights = default_quant_modules_calibrator_weights
self.default_quant_modules_calibrator_inputs = default_quant_modules_calibrator_inputs
self.default_per_channel_quant_weights = default_per_channel_quant_weights
self.default_learn_amax = default_learn_amax
self.verbose = verbose
self.mapping_instructions = self.mapping_instructions.copy()
if custom_mappings is not None:
self.mapping_instructions.update(custom_mappings) # OVERRIDE DEFAULT WITH CUSTOM. CUSTOM IS PRIORITIZED
def _get_default_quant_descriptor(self, for_weights=False):
methods = {"percentile": "histogram", "mse": "histogram", "entropy": "histogram", "histogram": "histogram", "max": "max"}
if for_weights:
axis = 0 if self.default_per_channel_quant_weights else None
learn_amax = self.default_learn_amax
if self.default_learn_amax and self.default_per_channel_quant_weights:
logger.error("Learnable amax is suported only for per-tensor quantization. Disabling it for weights quantization!")
learn_amax = False
return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_weights], axis=axis, learn_amax=learn_amax)
else:
# activations stay per-tensor by default
return QuantDescriptor(calib_method=methods[self.default_quant_modules_calibrator_inputs], learn_amax=self.default_learn_amax)
def register_skip_quantization(self, *, layer_names: Optional[Set[str]] = None):
if layer_names is not None:
self.mapping_instructions.update(
name: QuantizedMetadata(float_source=name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.SKIP)
for name in layer_names
def register_quantization_mapping(
self, *, layer_names: Set[str], quantized_target_class: Type[SGQuantMixin], input_quant_descriptor=None, weights_quant_descriptor=None
self.mapping_instructions.update(
name: QuantizedMetadata(
float_source=name,
quantized_target_class=quantized_target_class,
action=QuantizedMetadata.ReplacementAction.REPLACE,
input_quant_descriptor=input_quant_descriptor,
weights_quant_descriptor=weights_quant_descriptor,
for name in layer_names
def _preprocess_skips_and_custom_mappings(self, module: nn.Module, nesting: Tuple[str, ...] = ()):
This pass is done to extract layer name and mapping instructions, so that we regard to per-layer processing.
Relevant layer-specific mapping instructions are either `SkipQuantization` or `QuantizedMapping`, which are then
being added to the mappings
mapping_instructions = dict()
for name, child_module in module.named_children():
nested_name = ".".join(nesting + (name,))
if isinstance(child_module, SkipQuantization):
mapping_instructions[nested_name] = QuantizedMetadata(
float_source=nested_name, quantized_target_class=None, action=QuantizedMetadata.ReplacementAction.UNWRAP
if isinstance(child_module, QuantizedMapping):
mapping_instructions[nested_name] = QuantizedMetadata(
float_source=nested_name,
quantized_target_class=child_module.quantized_target_class,
input_quant_descriptor=child_module.input_quant_descriptor,
weights_quant_descriptor=child_module.weights_quant_descriptor,
action=child_module.action,
if isinstance(child_module, nn.Module): # recursive call
mapping_instructions.update(self._preprocess_skips_and_custom_mappings(child_module, nesting + (name,)))
return mapping_instructions
def _instantiate_quantized_from_float(self, float_module, metadata, preserve_state_dict):
base_classes = (QuantMixin, QuantInputMixin, SGQuantMixin)
if not issubclass(metadata.quantized_target_class, base_classes):
raise AssertionError(
f"Quantization suite for {type(float_module).__name__} is invalid. "
f"{metadata.quantized_target_class.__name__} must inherit one of "
f"{', '.join(map(lambda _: _.__name__, base_classes))}"
# USE PROVIDED QUANT DESCRIPTORS, OR DEFAULT IF NONE PROVIDED
quant_descriptors = dict()
if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin, QuantInputMixin)):
quant_descriptors = {"quant_desc_input": metadata.input_quant_descriptor or self._get_default_quant_descriptor(for_weights=False)}
if issubclass(metadata.quantized_target_class, (SGQuantMixin, QuantMixin)):
quant_descriptors.update({"quant_desc_weight": metadata.weights_quant_descriptor or self._get_default_quant_descriptor(for_weights=True)})
if not hasattr(metadata.quantized_target_class, "from_float"):
assert isinstance(metadata.quantized_target_class, SGQuantMixin), (
f"{metadata.quantized_target_class.__name__} must inherit from " f"{SGQuantMixin.__name__}, so that it would include `from_float` class method"
q_instance = metadata.quantized_target_class.from_float(float_module, **quant_descriptors)
# MOVE TENSORS TO ORIGINAL DEVICE
if len(list(float_module.parameters(recurse=False))) > 0:
q_instance = q_instance.to(next(float_module.parameters(recurse=False)).device)
elif len(list(float_module.buffers(recurse=False))):
q_instance = q_instance.to(next(float_module.buffers(recurse=False)).device)
# COPY STATE DICT IF NEEDED
if preserve_state_dict:
# quant state dict may have additional parameters for Clip and strict loading will fail
# if we find at least one Clip module in q_instance, disable strict loading and hope for the best
strict_load = True
for k in q_instance.state_dict().keys():
if "clip.clip_value_max" in k or "clip.clip_value_min" in k:
strict_load = False
logger.debug(
"Instantiating quant module in non-strict mode leaving Clip parameters non-initilaized. Use QuantizationCalibrator to initialize them."
break
q_instance.load_state_dict(float_module.state_dict(), strict=strict_load)
return q_instance
def _maybe_quantize_one_layer(
self,
module: nn.Module,
child_name: str,
nesting: Tuple[str, ...],
child_module: nn.Module,
mapping_instructions: Dict[Union[str, Type], QuantizedMetadata],
preserve_state_dict: bool,
) -> bool:
Does the heavy lifting of (maybe) quantizing a layer: creates a quantized instance based on a float instance,
and replaces it in the "parent" module
:param module: the module we'd like to quantize a specific layer in
:param child_name: the attribute name of the layer in the module
:param nesting: the current nesting we're in. Needed to find the appropriate key in the mappings
:param child_module: the instance of the float module we'd like to quantize
:param mapping_instructions: mapping instructions: how to quantize
:param preserve_state_dict: whether to copy the state dict from the float instance to the quantized instance
:return: a boolean indicates if we found a match and should not continue recursively
# if we don't have any instruction for the specific layer or the specific type - we continue
# NOTE! IT IS IMPORTANT TO FIRST PROCESS THE NAME AND ONLY THEN THE TYPE
if _imported_pytorch_quantization_failure is not None:
raise _imported_pytorch_quantization_failure
for candidate_key in (".".join(nesting + (child_name,)), type(child_module)):
if candidate_key not in mapping_instructions:
continue
metadata: QuantizedMetadata = mapping_instructions[candidate_key]
if metadata.action == QuantizedMetadata.ReplacementAction.SKIP:
return True
elif metadata.action == QuantizedMetadata.ReplacementAction.UNWRAP:
assert isinstance(child_module, SkipQuantization)
setattr(module, child_name, child_module.float_module)
return True
elif metadata.action in (
QuantizedMetadata.ReplacementAction.REPLACE,
QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE,
QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE,
if isinstance(child_module, QuantizedMapping): # UNWRAP MAPPING
child_module = child_module.float_module
q_instance: nn.Module = self._instantiate_quantized_from_float(
float_module=child_module, metadata=metadata, preserve_state_dict=preserve_state_dict
# ACTUAL REPLACEMENT
def replace():
setattr(module, child_name, q_instance)
def recurse_quantize():
self._quantize_module_aux(
module=getattr(module, child_name),
mapping_instructions=mapping_instructions,
nesting=nesting + (child_name,),
preserve_state_dict=preserve_state_dict,
if metadata.action == QuantizedMetadata.ReplacementAction.REPLACE:
replace()
elif metadata.action == QuantizedMetadata.ReplacementAction.REPLACE_AND_RECURE:
replace()
recurse_quantize()
elif metadata.action == QuantizedMetadata.ReplacementAction.RECURE_AND_REPLACE:
recurse_quantize()
replace()
return True
else:
raise NotImplementedError
return False
def quantize_module(self, module: nn.Module, *, preserve_state_dict=True):
per_layer_mappings = self._preprocess_skips_and_custom_mappings(module)
mapping_instructions = {
**per_layer_mappings,
**self.mapping_instructions,
} # we first regard the per layer mappings, and then override with the custom mappings in case there is overlap
logging_level = logging.getLogger("absl").getEffectiveLevel()
if not self.verbose: # suppress pytorch-quantization spam
logging.getLogger("absl").setLevel("ERROR")
device = next(module.parameters()).device
self._quantize_module_aux(mapping_instructions=mapping_instructions, module=module, nesting=(), preserve_state_dict=preserve_state_dict)
module.to(device)
logging.getLogger("absl").setLevel(logging_level)
def _quantize_module_aux(self, mapping_instructions, module, nesting, preserve_state_dict):
for name, child_module in module.named_children():
found = self._maybe_quantize_one_layer(module, name, nesting, child_module, mapping_instructions, preserve_state_dict)
# RECURSIVE CALL, to support module_list, sequential, custom (nested) modules
if not found and isinstance(child_module, nn.Module):
self._quantize_module_aux(mapping_instructions, child_module, nesting + (name,), preserve_state_dict)
Source code in V3_2/src/super_gradients/training/utils/quantization/selective_quantization_utils.py
53
def register_quantized_module(
float_source: Union[str, Type[nn.Module]],
action: QuantizedMetadata.ReplacementAction = QuantizedMetadata.ReplacementAction.REPLACE,
input_quant_descriptor: Optional[QuantDescriptor] = None,
weights_quant_descriptor: Optional[QuantDescriptor] = None,
) -> Callable:
Decorator used to register a Quantized module as a quantized version for Float module
:param action: action to perform on the float_source
:param float_source: the float module type that is being registered
:param input_quant_descriptor: the input quantization descriptor
:param weights_quant_descriptor: the weight quantization descriptor
def decorator(quant_module: Type[SGQuantMixin]) -> Type[SGQuantMixin]:
if float_source in SelectiveQuantizer.mapping_instructions:
metadata = SelectiveQuantizer.mapping_instructions[float_source]
raise ValueError(f"`{float_source}` is already registered with following metadata {metadata}")
SelectiveQuantizer.mapping_instructions.update(
float_source: QuantizedMetadata(
float_source=float_source,
quantized_target_class=quant_module,
input_quant_descriptor=input_quant_descriptor,
weights_quant_descriptor=weights_quant_descriptor,
action=action,
return quant_module # this is required since the decorator assigns the result to the `quant_module`
return decorator
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Intended usage of this block is the following:
class ResNetBlock(nn.Module):
def init(self, ..., drop_path_rate:float):
self.drop_path = DropPath(drop_path_rate)
def forward(self, x):
return x + self.drop_path(self.conv_bn_act(x))
Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)
Apache License 2.0
Source code in V3_2/src/super_gradients/training/utils/regularization_utils.py
52
class DropPath(nn.Module):
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Intended usage of this block is the following:
>>> class ResNetBlock(nn.Module):
>>> def __init__(self, ..., drop_path_rate:float):
>>> self.drop_path = DropPath(drop_path_rate)
>>> def forward(self, x):
>>> return x + self.drop_path(self.conv_bn_act(x))
Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)
Apache License 2.0
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
:param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
:param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
keep output mean & std in the same range as w/o drop path.
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
if self.drop_prob == 0.0 or not self.training:
return x
return drop_path(x, self.drop_prob, self.scale_by_keep)
def extra_repr(self):
return f"drop_prob={round(self.drop_prob,3):0.3f}"
float
Probability of zeroing out individual vector (channel dimension) of each feature map
scale_by_keep
Whether to scale the output by the keep probability. Enable by default and helps to keep output mean & std in the same range as w/o drop path.
43
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
:param drop_prob: Probability of zeroing out individual vector (channel dimension) of each feature map
:param scale_by_keep: Whether to scale the output by the keep probability. Enable by default and helps to
keep output mean & std in the same range as w/o drop path.
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
14
def drop_path(x, drop_prob: float = 0.0, scale_by_keep: bool = True):
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
136
class BinarySegmentationVisualization:
@staticmethod
def _visualize_image(image_np: np.ndarray, pred_mask: torch.Tensor, target_mask: torch.Tensor, image_scale: float, checkpoint_dir: str, image_name: str):
pred_mask = pred_mask.copy()
image_np = torch.from_numpy(np.moveaxis(image_np, -1, 0).astype(np.uint8))
pred_mask = pred_mask[np.newaxis, :, :] > 0.5
target_mask = target_mask[np.newaxis, :, :].astype(bool)
tp_mask = np.logical_and(pred_mask, target_mask)
fp_mask = np.logical_and(pred_mask, np.logical_not(target_mask))
fn_mask = np.logical_and(np.logical_not(pred_mask), target_mask)
overlay = torch.from_numpy(np.concatenate([tp_mask, fp_mask, fn_mask]))
# SWITCH BETWEEN BLUE AND RED IF WE SAVE THE IMAGE ON THE DISC AS OTHERWISE WE CHANGE CHANNEL ORDERING
colors = ["green", "red", "blue"]
res_image = draw_segmentation_masks(image_np, overlay, colors=colors).detach().numpy()
res_image = np.concatenate([res_image[ch, :, :, np.newaxis] for ch in range(3)], 2)
res_image = cv2.resize(res_image.astype(np.uint8), (0, 0), fx=image_scale, fy=image_scale, interpolation=cv2.INTER_NEAREST)
if checkpoint_dir is None:
return res_image
else:
cv2.imwrite(os.path.join(checkpoint_dir, str(image_name) + ".jpg"), res_image)
@staticmethod
def visualize_batch(
image_tensor: torch.Tensor,
pred_mask: torch.Tensor,
target_mask: torch.Tensor,
batch_name: Union[int, str],
checkpoint_dir: str = None,
undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing,
image_scale: float = 1.0,
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
:param image_tensor: rgb images, (B, H, W, 3)
:param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),
values on dim 1 are: x1, y1, x2, y2, confidence, class
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
(coordinates scaled to [0, 1])
:param batch_name: id of the current batch to use for image naming
:param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will
be returns as a list of numpy image arrays
:param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
:param image_scale: scale factor for output image
image_np = undo_preprocessing_func(image_tensor.detach())
pred_mask = torch.sigmoid(pred_mask[:, 0, :, :]) # comment out
out_images = []
for i in range(image_np.shape[0]):
preds = pred_mask[i].detach().cpu().numpy()
targets = target_mask[i].detach().cpu().numpy()
image_name = "_".join([str(batch_name), str(i)])
res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name)
if res_image is not None:
out_images.append(res_image)
return out_images
visualize_batch(image_tensor, pred_mask, target_mask, batch_name, checkpoint_dir=None, undo_preprocessing_func=reverse_imagenet_preprocessing, image_scale=1.0)
staticmethod
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
Parameters:
Description
Default
boxes after NMS for each image in a batch, each (Num_boxes, 6), values on dim 1 are: x1, y1, x2, y2, confidence, class
required
target_boxes
(Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h (coordinates scaled to [0, 1])
required
batch_name
Union[int, str]
id of the current batch to use for image naming
required
checkpoint_dir
a path where images with boxes will be saved. if None, the result images will be returns as a list of numpy image arrays
undo_preprocessing_func
Callable[[torch.Tensor], np.ndarray]
a function to convert preprocessed images tensor into a batch of cv2-like images
reverse_imagenet_preprocessing
image_scale
float
scale factor for output image
image_tensor: torch.Tensor,
pred_mask: torch.Tensor,
target_mask: torch.Tensor,
batch_name: Union[int, str],
checkpoint_dir: str = None,
undo_preprocessing_func: Callable[[torch.Tensor], np.ndarray] = reverse_imagenet_preprocessing,
image_scale: float = 1.0,
A helper function to visualize detections predicted by a network:
saves images into a given path with a name that is {batch_name}_{imade_idx_in_the_batch}.jpg, one batch per call.
Colors are generated on the fly: uniformly sampled from color wheel to support all given classes.
:param image_tensor: rgb images, (B, H, W, 3)
:param pred_boxes: boxes after NMS for each image in a batch, each (Num_boxes, 6),
values on dim 1 are: x1, y1, x2, y2, confidence, class
:param target_boxes: (Num_targets, 6), values on dim 1 are: image id in a batch, class, x y w h
(coordinates scaled to [0, 1])
:param batch_name: id of the current batch to use for image naming
:param checkpoint_dir: a path where images with boxes will be saved. if None, the result images will
be returns as a list of numpy image arrays
:param undo_preprocessing_func: a function to convert preprocessed images tensor into a batch of cv2-like images
:param image_scale: scale factor for output image
image_np = undo_preprocessing_func(image_tensor.detach())
pred_mask = torch.sigmoid(pred_mask[:, 0, :, :]) # comment out
out_images = []
for i in range(image_np.shape[0]):
preds = pred_mask[i].detach().cpu().numpy()
targets = target_mask[i].detach().cpu().numpy()
image_name = "_".join([str(batch_name), str(i)])
res_image = BinarySegmentationVisualization._visualize_image(image_np[i], preds, targets, image_scale, checkpoint_dir, image_name)
if res_image is not None:
out_images.append(res_image)
return out_images
kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1
required
flatten_channels
Whether to apply logical_or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False
the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True
.
191
def one_hot_to_binary_edge(x: torch.Tensor, kernel_size: int, flatten_channels: bool = True) -> torch.Tensor:
Utils function to create edge feature maps.
:param x: input tensor, must be one_hot tensor with shape [B, C, H, W]
:param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
follows: `edge_width = kernel - 1`
:param flatten_channels: Whether to apply logical_or across channels dimension, if at least one pixel class is
considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
[B, 1, H, W]. Default is `True`.
:return: one_hot edge torch.Tensor.
if kernel_size < 0 or kernel_size % 2 == 0:
raise ValueError(f"kernel size must be an odd positive values, such as [1, 3, 5, ..], found: {kernel_size}")
_kernel = torch.ones(x.size(1), 1, kernel_size, kernel_size, dtype=torch.float32, device=x.device)
padding = (kernel_size - 1) // 2
# Use replicate padding to prevent class shifting and edge formation at the image boundaries.
padded_x = F.pad(x.float(), mode="replicate", pad=[padding] * 4)
# The binary edges feature map is created by subtracting dilated features from erosed features.
# First the positive one value masks are expanded (dilation) by applying a sliding window filter of one values.
# The resulted output is then clamped to binary format to [0, 1], this way the one-hot boundaries are expanded by
# (kernel_size - 1) / 2.
dilation = torch.clamp(F.conv2d(padded_x, _kernel, groups=x.size(1)), 0, 1)
# Similar to dilation, erosion (can be seen as inverse of dilation) is applied to contract the one-hot features by
# applying a dilation operation on the inverse of the one-hot features.
erosion = 1 - torch.clamp(F.conv2d(1 - padded_x, _kernel, groups=x.size(1)), 0, 1)
# Finally the edge features are the result of subtracting dilation by erosion.
# i.e for a simple 1D one-hot input: [0, 0, 0, 1, 1, 1, 0, 0, 0], using sliding kernel with size 3: [1, 1, 1]
# Dilated features: [0, 0, 1, 1, 1, 1, 1, 0, 0]
# Erosed inverse features: [0, 0, 0, 0, 1, 0, 0, 0, 0]
# Edge features: dilation - erosion: [0, 0, 1, 1, 0, 1, 1, 0, 0]
edge = dilation - erosion
if flatten_channels:
# use max operator across channels. Equivalent to logical or for input with binary values [0, 1].
edge = edge.max(dim=1, keepdim=True)[0]
return edge
68
def reverse_imagenet_preprocessing(im_tensor: torch.Tensor) -> np.ndarray:
:param im_tensor: images in a batch after preprocessing for inference, RGB, (B, C, H, W)
:return: images in a batch in cv2 format, BGR, (B, H, W, C)
im_np = im_tensor.cpu().numpy()
im_np = im_np[:, ::-1, :, :].transpose(0, 2, 3, 1)
im_np *= np.array([[[0.229, 0.224, 0.225][::-1]]])
im_np += np.array([[[0.485, 0.456, 0.406][::-1]]])
im_np *= 255.0
return np.ascontiguousarray(im_np, dtype=np.uint8)
num of classes in datasets excluding ignore label, this is the output channels of the one hot result.
required
kernel_size
kernel size of dilation erosion convolutions. The result edge widths depends on this argument as follows: edge_width = kernel - 1
required
flatten_channels
Whether to apply logical or across channels dimension, if at least one pixel class is considered as edge pixel flatten value is 1. If set as False
the output tensor shape is [B, C, H, W], else [B, 1, H, W]. Default is True
.
208
def target_to_binary_edge(target: torch.Tensor, num_classes: int, kernel_size: int, ignore_index: int = None, flatten_channels: bool = True) -> torch.Tensor:
Utils function to create edge feature maps from target.
:param target: Class labels long tensor, with shape [N, H, W]
:param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
result.
:param kernel_size: kernel size of dilation erosion convolutions. The result edge widths depends on this argument as
follows: `edge_width = kernel - 1`
:param flatten_channels: Whether to apply logical or across channels dimension, if at least one pixel class is
considered as edge pixel flatten value is 1. If set as `False` the output tensor shape is [B, C, H, W], else
[B, 1, H, W]. Default is `True`.
:return: one_hot edge torch.Tensor.
one_hot = to_one_hot(target, num_classes=num_classes, ignore_index=ignore_index)
return one_hot_to_binary_edge(one_hot, kernel_size=kernel_size, flatten_channels=flatten_channels)
num of classes in datasets excluding ignore label, this is the output channels of the one hot result.
required
55
def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None):
Target label to one_hot tensor. labels and ignore_index must be consecutive numbers.
:param target: Class labels long tensor, with shape [N, H, W]
:param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
result.
:return: one hot tensor with shape [N, num_classes, H, W]
num_classes = num_classes if ignore_index is None else num_classes + 1
one_hot = F.one_hot(target, num_classes).permute((0, 3, 1, 2))
if ignore_index is not None:
# remove ignore_index channel
one_hot = torch.cat([one_hot[:, :ignore_index], one_hot[:, ignore_index + 1 :]], dim=1)
return one_hot
Type of improvement compared to previous value, i.e. if the value is better, worse or the same.
Difference with "increase":
If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py
75
class ImprovementType(Enum):
"""Type of improvement compared to previous value, i.e. if the value is better, worse or the same.
Difference with "increase":
If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
IS_BETTER = "better"
IS_WORSE = "worse"
IS_SAME = "same"
NONE = "none"
def to_color(self) -> Union[str, None]:
"""Get the color representing the current improvement type"""
if self == ImprovementType.IS_SAME:
return "white"
elif self == ImprovementType.IS_BETTER:
return "green"
elif self == ImprovementType.IS_WORSE:
return "red"
else:
return None
75
def to_color(self) -> Union[str, None]:
"""Get the color representing the current improvement type"""
if self == ImprovementType.IS_SAME:
return "white"
elif self == ImprovementType.IS_BETTER:
return "green"
elif self == ImprovementType.IS_WORSE:
return "red"
else:
return None
Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.
Difference with "improvement":
If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
Source code in V3_2/src/super_gradients/training/utils/sg_trainer_utils.py
50
class IncreaseType(Enum):
"""Type of increase compared to previous value, i.e. if the value is greater, smaller or the same.
Difference with "improvement":
If a loss goes from 1 to 0.5, the value is smaller (decreased), but the result is better (improvement).
For accuracy from 1 to 0.5, the value is smaller, but this time the result decreased, because greater is better.
NONE = "none"
IS_GREATER = "greater"
IS_SMALLER = "smaller"
IS_EQUAL = "equal"
def to_symbol(self) -> str:
"""Get the symbol representing the current increase type"""
if self == IncreaseType.NONE:
return ""
elif self == IncreaseType.IS_GREATER:
return "↗"
elif self == IncreaseType.IS_SMALLER:
return "↘"
else:
return "="
50
def to_symbol(self) -> str:
"""Get the symbol representing the current increase type"""
if self == IncreaseType.NONE:
return ""
elif self == IncreaseType.IS_GREATER:
return "↗"
elif self == IncreaseType.IS_SMALLER:
return "↘"
else:
return "="
Store a value and some indicators relative to its past iterations.
The value can be a metric/loss, and the iteration can be epochs/batch.
Parameters:
Description
Default
Optional[bool]
True, a greater value is considered better. ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4 ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4 None when unknown
current
Optional[float]
Current value of the metric
previous
Optional[float]
Value of the metric in previous iteration
Optional[float]
Value of the metric in best iteration (best according to greater_is_better)
change_from_previous
Optional[float]
Change compared to previous iteration value
change_from_best
Optional[float]
Change compared to best iteration value
148
@dataclass
class MonitoredValue:
"""Store a value and some indicators relative to its past iterations.
The value can be a metric/loss, and the iteration can be epochs/batch.
:param name: Name of the metric
:param greater_is_better: True, a greater value is considered better.
ex: (greater_is_better=True) For Accuracy 1 is greater and therefore better than 0.4
ex: (greater_is_better=False) For Loss 1 is greater and therefore worse than 0.4
None when unknown
:param current: Current value of the metric
:param previous: Value of the metric in previous iteration
:param best: Value of the metric in best iteration (best according to greater_is_better)
:param change_from_previous: Change compared to previous iteration value
:param change_from_best: Change compared to best iteration value
name: str
greater_is_better: Optional[bool] = None
current: Optional[float] = None
previous: Optional[float] = None
best: Optional[float] = None
change_from_previous: Optional[float] = None
change_from_best: Optional[float] = None
@property
def has_increased_from_previous(self) -> IncreaseType:
"""Type of increase compared to previous value, i.e. if the value is greater, smaller or the same."""
return self._get_increase_type(self.change_from_previous)
@property
def has_improved_from_previous(self) -> ImprovementType:
"""Type of improvement compared to previous value, i.e. if the value is better, worse or the same."""
return self._get_improvement_type(delta=self.change_from_previous)
@property
def has_increased_from_best(self) -> IncreaseType:
"""Type of increase compared to best value, i.e. if the value is greater, smaller or the same."""
return self._get_increase_type(self.change_from_best)
@property
def has_improved_from_best(self) -> ImprovementType:
"""Type of improvement compared to best value, i.e. if the value is better, worse or the same."""
return self._get_improvement_type(delta=self.change_from_best)
def _get_increase_type(self, delta: float) -> IncreaseType:
"""Type of increase, i.e. if the value is greater, smaller or the same."""
if self.change_from_best is None:
return IncreaseType.NONE
if delta > 0:
return IncreaseType.IS_GREATER
elif delta < 0:
return IncreaseType.IS_SMALLER
else:
return IncreaseType.IS_EQUAL
def _get_improvement_type(self, delta: float) -> ImprovementType:
"""Type of improvement, i.e. if value is better, worse or the same."""
if self.greater_is_better is None or self.change_from_best is None:
return ImprovementType.NONE
has_increased, has_decreased = delta > 0, delta < 0
if has_increased and self.greater_is_better or has_decreased and not self.greater_is_better:
return ImprovementType.IS_BETTER
elif has_increased and not self.greater_is_better or has_decreased and self.greater_is_better:
return ImprovementType.IS_WORSE
else:
return ImprovementType.IS_SAME
350
def add_log_to_file(filename, results_titles_list, results_values_list, epoch, max_epochs):
"""Add a message to the log file"""
# -Note: opening and closing the file every time is in-efficient. It is done for experimental purposes
with open(filename, "a") as f:
f.write("\nEpoch (%d/%d) - " % (epoch, max_epochs))
for result_title, result_value in zip(results_titles_list, results_values_list):
if isinstance(result_value, torch.Tensor):
result_value = result_value.item()
f.write(result_title + ": " + str(result_value) + "\t")
Dict[str, Dict[str, MonitoredValue]]
Dict of Dict. The first one represents the splut, and the second one a loss/metric.
required
254
def display_epoch_summary(epoch: int, n_digits: int, monitored_values_dict: Dict[str, Dict[str, MonitoredValue]]) -> None:
"""Display a summary of loss/metric of interest, for a given epoch.
:param epoch: the number of epoch.
:param n_digits: number of digits to display on screen for float values
:param monitored_values_dict: Dict of Dict. The first one represents the splut, and the second one a loss/metric.
def _format_to_str(val: Optional[float]) -> str:
return str(round(val, n_digits)) if val is not None else "None"
def _generate_tree(value_name: str, monitored_value: MonitoredValue) -> Tree:
"""Generate a tree that represents the stats of a given loss/metric."""
current = _format_to_str(monitored_value.current)
root_id = str(hash(f"{value_name} = {current}")) + str(random.random())
tree = Tree()
tree.create_node(tag=f"{value_name.capitalize()} = {current}", identifier=root_id)
if monitored_value.previous is not None:
previous = _format_to_str(monitored_value.previous)
best = _format_to_str(monitored_value.best)
change_from_previous = _format_to_str(monitored_value.change_from_previous)
change_from_best = _format_to_str(monitored_value.change_from_best)
diff_with_prev_colored = colored(
text=f"{monitored_value.has_increased_from_previous.to_symbol()} {change_from_previous}",
color=monitored_value.has_improved_from_previous.to_color(),
diff_with_best_colored = colored(
text=f"{monitored_value.has_increased_from_best.to_symbol()} {change_from_best}", color=monitored_value.has_improved_from_best.to_color()
tree.create_node(tag=f"Epoch N-1 = {previous:6} ({diff_with_prev_colored:8})", identifier=f"0_previous_{root_id}", parent=root_id)
tree.create_node(tag=f"Best until now = {best:6} ({diff_with_best_colored:8})", identifier=f"1_best_{root_id}", parent=root_id)
return tree
summary_tree = Tree()
summary_tree.create_node(f"SUMMARY OF EPOCH {epoch}", "Summary")
for split, monitored_values in monitored_values_dict.items():
if len(monitored_values):
split_tree = Tree()
split_tree.create_node(split, split)
for name, value in monitored_values.items():
split_tree.paste(split, new_tree=_generate_tree(name, monitored_value=value))
summary_tree.paste("Summary", split_tree)
print("===========================================================")
summary_tree.show(key=False)
print("===========================================================")
446
def get_callable_param_names(obj: callable) -> Tuple[str]:
"""Get the param names of a given callable (function, class, ...)
:param obj: Object to inspect
:return: Param names of that object
return tuple(inspect.signature(obj).parameters)
339
def init_summary_writer(tb_dir, checkpoint_loaded, user_prompt=False):
"""Remove previous tensorboard files from directory and launch a tensor board process"""
# If the training is from scratch, Walk through destination folder and delete existing tensorboard logs
user = ""
if not checkpoint_loaded:
for filename in os.listdir(tb_dir):
if "events" in filename:
if not user_prompt:
logger.debug('"{}" will not be deleted'.format(filename))
continue
while True:
# Verify with user before deleting old tensorboard files
user = (
input('\nOLDER TENSORBOARD FILES EXISTS IN EXPERIMENT FOLDER:\n"{}"\n' "DO YOU WANT TO DELETE THEM? [y/n]".format(filename))
if (user != "n" or user != "y")
else user
if user == "y":
os.remove("{}/{}".format(tb_dir, filename))
print("DELETED: {}!".format(filename))
break
elif user == "n":
print('"{}" will not be deleted'.format(filename))
break
print("Unknown answer...")
# Launch a tensorboard process
return SummaryWriter(tb_dir)
launch_tensorboard_process - Default behavior is to scan all free ports from 6006-6016 and try using them
unless port is defined by the user
:param checkpoints_dir_path:
:param sleep_postpone:
:param port:
:return: tuple of tb process, port
Source code in