Shortcuts

Source code for mmrotate.models.roi_heads.bbox_heads.convfc_rbbox_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.losses import accuracy
from mmdet.models.roi_heads.bbox_heads import Shared2FCBBoxHead
from mmdet.structures.bbox import get_box_tensor
from torch import Tensor
from typing import Optional

from mmrotate.registry import MODELS


[docs] @MODELS.register_module() class RotatedShared2FCBBoxHead(Shared2FCBBoxHead): """Rotated Shared2FC RBBox head. Args: loss_bbox_type (str): Set the input type of ``loss_bbox``. Defaults to 'normal'. """ def __init__(self, *args, loss_bbox_type: str = 'normal', **kwargs) -> None: super().__init__(*args, **kwargs) self.loss_bbox_type = loss_bbox_type
[docs] def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, labels: Tensor, label_weights: Tensor, bbox_targets: Tensor, bbox_weights: Tensor, reduction_override: Optional[str] = None) -> dict: """Calculate the loss based on the network predictions and targets. Args: cls_score (Tensor): Classification prediction results of all class, has shape (batch_size * num_proposals_single_image, num_classes) bbox_pred (Tensor): Regression prediction results, has shape (batch_size * num_proposals_single_image, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. rois (Tensor): RoIs with the shape (batch_size * num_proposals_single_image, 5) where the first column indicates batch id of each RoI. labels (Tensor): Gt_labels for all proposals in a batch, has shape (batch_size * num_proposals_single_image, ). label_weights (Tensor): Labels_weights for all proposals in a batch, has shape (batch_size * num_proposals_single_image, ). bbox_targets (Tensor): Regression target for all proposals in a batch, has shape (batch_size * num_proposals_single_image, 4), the last dimension 4 represents [tl_x, tl_y, br_x, br_y]. bbox_weights (Tensor): Regression weights for all proposals in a batch, has shape (batch_size * num_proposals_single_image, 4). reduction_override (str, optional): The reduction method used to override the original reduction method of the loss. Options are "none", "mean" and "sum". Defaults to None, Returns: dict: A dictionary of loss. """ losses = dict() if cls_score is not None: avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.) if cls_score.numel() > 0: loss_cls_ = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor, reduction_override=reduction_override) if isinstance(loss_cls_, dict): losses.update(loss_cls_) else: losses['loss_cls'] = loss_cls_ if self.custom_activation: acc_ = self.loss_cls.get_accuracy(cls_score, labels) losses.update(acc_) else: losses['acc'] = accuracy(cls_score, labels) if bbox_pred is not None: bg_class_ind = self.num_classes # 0~self.num_classes-1 are FG, self.num_classes is BG pos_inds = (labels >= 0) & (labels < bg_class_ind) # do not perform bounding box regression for BG anymore. if pos_inds.any(): if self.reg_decoded_bbox and (self.loss_bbox_type != 'kfiou'): # When the regression loss (e.g. `IouLoss`, # `GIouLoss`, `DIouLoss`) is applied directly on # the decoded bounding boxes, it decodes the # already encoded coordinates to absolute format. bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred) bbox_pred = get_box_tensor(bbox_pred) if self.reg_class_agnostic: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), -1)[pos_inds.type(torch.bool)] else: pos_bbox_pred = bbox_pred.view( bbox_pred.size(0), self.num_classes, -1)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] if self.loss_bbox_type == 'normal': losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], avg_factor=bbox_targets.size(0), reduction_override=reduction_override) elif self.loss_bbox_type == 'kfiou': # When the regression loss (e.g. `KFLoss`) # is applied on both the delta and decoded boxes. bbox_pred_decode = self.bbox_coder.decode( rois[:, 1:], bbox_pred) bbox_pred_decode = get_box_tensor(bbox_pred_decode) bbox_targets_decode = self.bbox_coder.decode( rois[:, 1:], bbox_targets) bbox_targets_decode = get_box_tensor(bbox_targets_decode) if self.reg_class_agnostic: pos_bbox_pred_decode = bbox_pred_decode.view( bbox_pred_decode.size(0), 5)[pos_inds.type(torch.bool)] else: pos_bbox_pred_decode = bbox_pred_decode.view( bbox_pred_decode.size(0), -1, 5)[pos_inds.type(torch.bool), labels[pos_inds.type(torch.bool)]] losses['loss_bbox'] = self.loss_bbox( pos_bbox_pred, bbox_targets[pos_inds.type(torch.bool)], bbox_weights[pos_inds.type(torch.bool)], pred_decode=pos_bbox_pred_decode, targets_decode=bbox_targets_decode[pos_inds.type( torch.bool)], avg_factor=bbox_targets.size(0), reduction_override=reduction_override) else: raise NotImplementedError else: losses['loss_bbox'] = bbox_pred[pos_inds].sum() return losses