Shortcuts

Source code for mmrotate.models.losses.h2rbox_consistency_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.utils import ConfigType
from torch import Tensor
from typing import Optional

from mmrotate.registry import MODELS


[docs] @MODELS.register_module() class H2RBoxConsistencyLoss(torch.nn.Module): def __init__(self, center_loss_cfg: ConfigType = dict( type='mmdet.L1Loss', loss_weight=0.0), shape_loss_cfg: ConfigType = dict( type='mmdet.IoULoss', loss_weight=1.0), angle_loss_cfg: ConfigType = dict( type='mmdet.L1Loss', loss_weight=1.0), reduction: str = 'mean', loss_weight: float = 1.0) -> None: super(H2RBoxConsistencyLoss, self).__init__() self.center_loss = MODELS.build(center_loss_cfg) self.shape_loss = MODELS.build(shape_loss_cfg) self.angle_loss = MODELS.build(angle_loss_cfg) self.reduction = reduction self.loss_weight = loss_weight
[docs] def forward(self, pred: Tensor, target: Tensor, weight: Tensor, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None) -> Tensor: """Forward function. Args: pred (Tensor): Predicted boxes. target (Tensor): Corresponding gt boxes. weight (Tensor): The weight of loss for each prediction. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: Calculated loss (Tensor) """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) xy_pred = pred[..., :2] xy_target = target[..., :2] hbb_pred1 = torch.cat([-pred[..., 2:4], pred[..., 2:4]], dim=-1) hbb_pred2 = hbb_pred1[..., [1, 0, 3, 2]] hbb_target = torch.cat([-target[..., 2:4], target[..., 2:4]], dim=-1) d_a_pred = pred[..., 4] - target[..., 4] center_loss = self.center_loss( xy_pred, xy_target, weight=weight[:, None], reduction_override=reduction, avg_factor=avg_factor) shape_loss1 = self.shape_loss( hbb_pred1, hbb_target, weight=weight, reduction_override=reduction, avg_factor=avg_factor) + self.angle_loss( d_a_pred.sin(), torch.zeros_like(d_a_pred), weight=weight, reduction_override=reduction, avg_factor=avg_factor) shape_loss2 = self.shape_loss( hbb_pred2, hbb_target, weight=weight, reduction_override=reduction, avg_factor=avg_factor) + self.angle_loss( d_a_pred.cos(), torch.zeros_like(d_a_pred), weight=weight, reduction_override=reduction, avg_factor=avg_factor) loss_bbox = center_loss + torch.min(shape_loss1, shape_loss2) return self.loss_weight * loss_bbox