Source code for mmrotate.models.dense_heads.rotated_retina_head
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.dense_heads import RetinaHead
from mmdet.structures.bbox import get_box_tensor
from torch import Tensor
from mmrotate.registry import MODELS
[docs]
@MODELS.register_module()
class RotatedRetinaHead(RetinaHead):
"""Rotated retina head.
Args:
loss_bbox_type (str): Set the input type of ``loss_bbox``.
Defaults to 'normal'.
"""
def __init__(self,
*args,
loss_bbox_type: str = 'normal',
**kwargs) -> None:
super().__init__(*args, **kwargs)
self.loss_bbox_type = loss_bbox_type
[docs]
def loss_by_feat_single(self, cls_score: Tensor, bbox_pred: Tensor,
anchors: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
bbox_weights: Tensor, avg_factor: int) -> tuple:
"""Calculate the loss of a single scale level based on the features
extracted by the detection head.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 4).
bbox_weights (Tensor): BBox regression loss weights of each anchor
with shape (N, num_total_anchors, 4).
avg_factor (int): Average factor that is used to average the loss.
Returns:
tuple: loss components.
"""
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
# regression loss
target_dim = bbox_targets.size(-1)
bbox_targets = bbox_targets.reshape(-1, target_dim)
bbox_weights = bbox_weights.reshape(-1, target_dim)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1,
self.bbox_coder.encode_size)
if self.reg_decoded_bbox and (self.loss_bbox_type != 'kfiou'):
# When the regression loss (e.g. `IouLoss`, `GIouLoss`)
# is applied directly on the decoded bounding boxes, it
# decodes the already encoded coordinates to absolute format.
anchors = anchors.reshape(-1, anchors.size(-1))
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
bbox_pred = get_box_tensor(bbox_pred)
if self.loss_bbox_type == 'normal':
loss_bbox = self.loss_bbox(
bbox_pred, bbox_targets, bbox_weights, avg_factor=avg_factor)
elif self.loss_bbox_type == 'kfiou':
# When the regression loss (e.g. `KFLoss`)
# is applied on both the delta and decoded boxes.
anchors = anchors.reshape(-1, anchors.size(-1))
bbox_pred_decode = self.bbox_coder.decode(anchors, bbox_pred)
bbox_pred_decode = get_box_tensor(bbox_pred_decode)
bbox_targets_decode = self.bbox_coder.decode(anchors, bbox_targets)
bbox_targets_decode = get_box_tensor(bbox_targets_decode)
loss_bbox = self.loss_bbox(
bbox_pred,
bbox_targets,
bbox_weights,
pred_decode=bbox_pred_decode,
targets_decode=bbox_targets_decode,
avg_factor=avg_factor)
else:
raise NotImplementedError
return loss_cls, loss_bbox