Source code for mmrotate.models.dense_heads.r3_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.utils import select_single_mlvl
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.config import ConfigDict
from torch import Tensor
from typing import List, Optional, Tuple, Union
from mmrotate.registry import MODELS
from mmrotate.structures.bbox import RotatedBoxes
from .rotated_retina_head import RotatedRetinaHead
[docs]
@MODELS.register_module()
class R3Head(RotatedRetinaHead):
r"""An anchor-based head used in `R3Det
<https://arxiv.org/pdf/1908.05612.pdf>`_.
""" # noqa: W605
[docs]
def filter_bboxes(self, cls_scores: List[Tensor],
bbox_preds: List[Tensor]) -> List[List[Tensor]]:
"""Filter predicted bounding boxes at each position of the feature
maps. Only one bounding boxes with highest score will be left at each
position. This filter will be used in R3Det prior to the first feature
refinement stage.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 5, H, W)
Returns:
list[list[Tensor]]: best or refined rbboxes of each level
of each image.
"""
num_levels = len(cls_scores)
assert num_levels == len(bbox_preds)
num_imgs = cls_scores[0].size(0)
for i in range(num_levels):
assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0)
device = cls_scores[0].device
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_anchors = self.prior_generator.grid_priors(
featmap_sizes, device=device)
bboxes_list = [[] for _ in range(num_imgs)]
for lvl in range(num_levels):
cls_score = cls_scores[lvl]
bbox_pred = bbox_preds[lvl]
anchors = mlvl_anchors[lvl]
cls_score = cls_score.permute(0, 2, 3, 1)
cls_score = cls_score.reshape(num_imgs, -1, self.num_anchors,
self.cls_out_channels)
cls_score, _ = cls_score.max(dim=-1, keepdim=True)
best_ind = cls_score.argmax(dim=-2, keepdim=True)
best_ind = best_ind.expand(-1, -1, -1, 5)
bbox_pred = bbox_pred.permute(0, 2, 3, 1)
bbox_pred = bbox_pred.reshape(num_imgs, -1, self.num_anchors, 5)
best_pred = bbox_pred.gather(
dim=-2, index=best_ind).squeeze(dim=-2)
anchors = anchors.reshape(-1, self.num_anchors, 5).tensor
for img_id in range(num_imgs):
best_ind_i = best_ind[img_id]
best_pred_i = best_pred[img_id]
best_anchor_i = anchors.gather(
dim=-2, index=best_ind_i).squeeze(dim=-2)
best_bbox_i = self.bbox_coder.decode(
RotatedBoxes(best_anchor_i), best_pred_i)
bboxes_list[img_id].append(best_bbox_i.detach())
return bboxes_list
[docs]
@MODELS.register_module()
class R3RefineHead(RotatedRetinaHead):
r"""An anchor-based head used in `R3Det
<https://arxiv.org/pdf/1908.05612.pdf>`_.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
frm_cfg (dict): Config of the feature refine module.
""" # noqa: W605
def __init__(self,
num_classes: int,
in_channels: int,
frm_cfg: dict = None,
**kwargs) -> None:
super().__init__(
num_classes=num_classes, in_channels=in_channels, **kwargs)
self.feat_refine_module = MODELS.build(frm_cfg)
self.bboxes_as_anchors = None
[docs]
def loss_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
rois: List[Tensor] = None) -> dict:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
has shape (N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
rois (list[Tensor])
Returns:
dict: A dictionary of loss components.
"""
assert rois is not None
self.bboxes_as_anchors = rois
return super(RotatedRetinaHead, self).loss_by_feat(
cls_scores=cls_scores,
bbox_preds=bbox_preds,
batch_gt_instances=batch_gt_instances,
batch_img_metas=batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
[docs]
def get_anchors(self,
featmap_sizes: List[tuple],
batch_img_metas: List[dict],
device: Union[torch.device, str] = 'cuda') \
-> Tuple[List[List[Tensor]], List[List[Tensor]]]:
"""Get anchors according to feature map sizes.
Args:
featmap_sizes (list[tuple]): Multi-level feature map sizes.
batch_img_metas (list[dict]): Image meta info.
device (torch.device | str): Device for returned tensors.
Defaults to cuda.
Returns:
tuple:
- anchor_list (list[list[Tensor]]): Anchors of each image.
- valid_flag_list (list[list[Tensor]]): Valid flags of each
image.
"""
anchor_list = [[
RotatedBoxes(bboxes_img_lvl).detach()
for bboxes_img_lvl in bboxes_img
] for bboxes_img in self.bboxes_as_anchors]
# for each image, we compute valid flags of multi level anchors
valid_flag_list = []
for img_id, img_meta in enumerate(batch_img_metas):
multi_level_flags = self.prior_generator.valid_flags(
featmap_sizes, img_meta['pad_shape'], device)
valid_flag_list.append(multi_level_flags)
return anchor_list, valid_flag_list
[docs]
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
rois: List[Tensor] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
rois (list[Tensor]):
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
assert len(cls_scores) == len(bbox_preds)
assert rois is not None
if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)
num_levels = len(cls_scores)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(
cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
if with_score_factors:
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]
results = self._predict_by_feat_single(
cls_score_list=cls_score_list,
bbox_pred_list=bbox_pred_list,
score_factor_list=score_factor_list,
mlvl_priors=rois[img_id],
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
[docs]
def feature_refine(self, x: List[Tensor],
rois: List[List[Tensor]]) -> List[Tensor]:
"""Refine the input feature use feature refine module.
Args:
x (list[Tensor]): feature maps of multiple scales.
rois (list[list[Tensor]]): input rbboxes of multiple
scales of multiple images, output by former stages
and are to be refined.
Returns:
list[Tensor]: refined feature maps of multiple scales.
"""
return self.feat_refine_module(x, rois)
[docs]
def refine_bboxes(self, cls_scores: List[Tensor], bbox_preds: List[Tensor],
rois: List[List[Tensor]]) -> List[List[Tensor]]:
"""Refine predicted bounding boxes at each position of the feature
maps. This method will be used in R3Det in refinement stages.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, 5, H, W)
rois (list[list[Tensor]]): input rbboxes of each level of each
image. rois output by former stages and are to be refined
Returns:
list[list[Tensor]]: best or refined rbboxes of each level of each
image.
"""
num_levels = len(cls_scores)
assert num_levels == len(bbox_preds)
num_imgs = cls_scores[0].size(0)
for i in range(num_levels):
assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0)
bboxes_list = [[] for _ in range(num_imgs)]
assert rois is not None
mlvl_rois = [torch.cat(r) for r in zip(*rois)]
for lvl in range(num_levels):
bbox_pred = bbox_preds[lvl]
rois = mlvl_rois[lvl]
assert bbox_pred.size(1) == 5
bbox_pred = bbox_pred.permute(0, 2, 3, 1)
bbox_pred = bbox_pred.reshape(-1, 5)
refined_bbox = self.bbox_coder.decode(rois, bbox_pred)
refined_bbox = refined_bbox.reshape(num_imgs, -1, 5)
for img_id in range(num_imgs):
bboxes_list[img_id].append(refined_bbox[img_id].detach())
return bboxes_list