Shortcuts

Source code for mmrotate.models.losses.h2rbox_v2_consistency_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.utils import ConfigType
from torch import Tensor
from typing import Optional

from mmrotate.registry import MODELS


[docs] @MODELS.register_module() class H2RBoxV2ConsistencyLoss(torch.nn.Module): def __init__(self, loss_rot: ConfigType = dict( type='mmdet.SmoothL1Loss', loss_weight=1.0, beta=0.1), loss_flp: ConfigType = dict( type='mmdet.SmoothL1Loss', loss_weight=0.05, beta=0.1), use_snap_loss: bool = True, reduction: str = 'mean') -> None: super(H2RBoxV2ConsistencyLoss, self).__init__() self.loss_rot = MODELS.build(loss_rot) self.loss_flp = MODELS.build(loss_flp) self.use_snap_loss = use_snap_loss self.reduction = reduction
[docs] def forward(self, pred_ori: Tensor, pred_rot: Tensor, pred_flp: Tensor, target_ori: Tensor, target_rot: Tensor, agnostic_mask: Optional[Tensor] = None, avg_factor: Optional[int] = None, reduction_override: Optional[str] = None) -> Tensor: """Forward function. Args: pred (Tensor): Predicted boxes. target (Tensor): Corresponding gt boxes. weight (Tensor): The weight of loss for each prediction. avg_factor (int, optional): Average factor that is used to average the loss. Defaults to None. reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Defaults to None. Returns: Calculated loss (Tensor) """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) d_ang_rot = (pred_ori - pred_rot) - (target_ori - target_rot) d_ang_flp = pred_ori + pred_flp if self.use_snap_loss: d_ang_rot = (d_ang_rot + torch.pi / 2) % torch.pi - torch.pi / 2 d_ang_flp = (d_ang_flp + torch.pi / 2) % torch.pi - torch.pi / 2 if agnostic_mask is not None: d_ang_rot[agnostic_mask] = 0 d_ang_flp[agnostic_mask] = 0 loss_rot = self.loss_rot( d_ang_rot, torch.zeros_like(d_ang_rot), reduction_override=reduction, avg_factor=avg_factor) loss_flp = self.loss_flp( d_ang_flp, torch.zeros_like(d_ang_flp), reduction_override=reduction, avg_factor=avg_factor) return loss_rot + loss_flp