Shortcuts

Source code for mmrotate.models.dense_heads.r3_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.utils import select_single_mlvl
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.config import ConfigDict
from torch import Tensor
from typing import List, Optional, Tuple, Union

from mmrotate.registry import MODELS
from mmrotate.structures.bbox import RotatedBoxes
from .rotated_retina_head import RotatedRetinaHead


[docs] @MODELS.register_module() class R3Head(RotatedRetinaHead): r"""An anchor-based head used in `R3Det <https://arxiv.org/pdf/1908.05612.pdf>`_. """ # noqa: W605
[docs] def filter_bboxes(self, cls_scores: List[Tensor], bbox_preds: List[Tensor]) -> List[List[Tensor]]: """Filter predicted bounding boxes at each position of the feature maps. Only one bounding boxes with highest score will be left at each position. This filter will be used in R3Det prior to the first feature refinement stage. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_anchors * num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 5, H, W) Returns: list[list[Tensor]]: best or refined rbboxes of each level of each image. """ num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) device = cls_scores[0].device featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] mlvl_anchors = self.prior_generator.grid_priors( featmap_sizes, device=device) bboxes_list = [[] for _ in range(num_imgs)] for lvl in range(num_levels): cls_score = cls_scores[lvl] bbox_pred = bbox_preds[lvl] anchors = mlvl_anchors[lvl] cls_score = cls_score.permute(0, 2, 3, 1) cls_score = cls_score.reshape(num_imgs, -1, self.num_anchors, self.cls_out_channels) cls_score, _ = cls_score.max(dim=-1, keepdim=True) best_ind = cls_score.argmax(dim=-2, keepdim=True) best_ind = best_ind.expand(-1, -1, -1, 5) bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(num_imgs, -1, self.num_anchors, 5) best_pred = bbox_pred.gather( dim=-2, index=best_ind).squeeze(dim=-2) anchors = anchors.reshape(-1, self.num_anchors, 5).tensor for img_id in range(num_imgs): best_ind_i = best_ind[img_id] best_pred_i = best_pred[img_id] best_anchor_i = anchors.gather( dim=-2, index=best_ind_i).squeeze(dim=-2) best_bbox_i = self.bbox_coder.decode( RotatedBoxes(best_anchor_i), best_pred_i) bboxes_list[img_id].append(best_bbox_i.detach()) return bboxes_list
[docs] @MODELS.register_module() class R3RefineHead(RotatedRetinaHead): r"""An anchor-based head used in `R3Det <https://arxiv.org/pdf/1908.05612.pdf>`_. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. frm_cfg (dict): Config of the feature refine module. """ # noqa: W605 def __init__(self, num_classes: int, in_channels: int, frm_cfg: dict = None, **kwargs) -> None: super().__init__( num_classes=num_classes, in_channels=in_channels, **kwargs) self.feat_refine_module = MODELS.build(frm_cfg) self.bboxes_as_anchors = None
[docs] def loss_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None, rois: List[Tensor] = None) -> dict: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (list[Tensor]): Box scores for each scale level has shape (N, num_anchors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, num_anchors * 4, H, W). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. rois (list[Tensor]) Returns: dict: A dictionary of loss components. """ assert rois is not None self.bboxes_as_anchors = rois return super(RotatedRetinaHead, self).loss_by_feat( cls_scores=cls_scores, bbox_preds=bbox_preds, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore)
[docs] def get_anchors(self, featmap_sizes: List[tuple], batch_img_metas: List[dict], device: Union[torch.device, str] = 'cuda') \ -> Tuple[List[List[Tensor]], List[List[Tensor]]]: """Get anchors according to feature map sizes. Args: featmap_sizes (list[tuple]): Multi-level feature map sizes. batch_img_metas (list[dict]): Image meta info. device (torch.device | str): Device for returned tensors. Defaults to cuda. Returns: tuple: - anchor_list (list[list[Tensor]]): Anchors of each image. - valid_flag_list (list[list[Tensor]]): Valid flags of each image. """ anchor_list = [[ RotatedBoxes(bboxes_img_lvl).detach() for bboxes_img_lvl in bboxes_img ] for bboxes_img in self.bboxes_as_anchors] # for each image, we compute valid flags of multi level anchors valid_flag_list = [] for img_id, img_meta in enumerate(batch_img_metas): multi_level_flags = self.prior_generator.valid_flags( featmap_sizes, img_meta['pad_shape'], device) valid_flag_list.append(multi_level_flags) return anchor_list, valid_flag_list
[docs] def predict_by_feat(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], score_factors: Optional[List[Tensor]] = None, rois: List[Tensor] = None, batch_img_metas: Optional[List[dict]] = None, cfg: Optional[ConfigDict] = None, rescale: bool = False, with_nms: bool = True) -> InstanceList: """Transform a batch of output features extracted from the head into bbox results. Note: When score_factors is not None, the cls_scores are usually multiplied by it then obtain the real score used in NMS, such as CenterNess in FCOS, IoU branch in ATSS. Args: cls_scores (list[Tensor]): Classification scores for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * num_classes, H, W). bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). score_factors (list[Tensor], optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, num_priors * 1, H, W). Defaults to None. rois (list[Tensor]): batch_img_metas (list[dict], Optional): Batch image meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. Defaults to None. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Defaults to True. Returns: list[:obj:`InstanceData`]: Object detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert len(cls_scores) == len(bbox_preds) assert rois is not None if score_factors is None: # e.g. Retina, FreeAnchor, Foveabox, etc. with_score_factors = False else: # e.g. FCOS, PAA, ATSS, AutoAssign, etc. with_score_factors = True assert len(cls_scores) == len(score_factors) num_levels = len(cls_scores) result_list = [] for img_id in range(len(batch_img_metas)): img_meta = batch_img_metas[img_id] cls_score_list = select_single_mlvl( cls_scores, img_id, detach=True) bbox_pred_list = select_single_mlvl( bbox_preds, img_id, detach=True) if with_score_factors: score_factor_list = select_single_mlvl( score_factors, img_id, detach=True) else: score_factor_list = [None for _ in range(num_levels)] results = self._predict_by_feat_single( cls_score_list=cls_score_list, bbox_pred_list=bbox_pred_list, score_factor_list=score_factor_list, mlvl_priors=rois[img_id], img_meta=img_meta, cfg=cfg, rescale=rescale, with_nms=with_nms) result_list.append(results) return result_list
[docs] def feature_refine(self, x: List[Tensor], rois: List[List[Tensor]]) -> List[Tensor]: """Refine the input feature use feature refine module. Args: x (list[Tensor]): feature maps of multiple scales. rois (list[list[Tensor]]): input rbboxes of multiple scales of multiple images, output by former stages and are to be refined. Returns: list[Tensor]: refined feature maps of multiple scales. """ return self.feat_refine_module(x, rois)
[docs] def refine_bboxes(self, cls_scores: List[Tensor], bbox_preds: List[Tensor], rois: List[List[Tensor]]) -> List[List[Tensor]]: """Refine predicted bounding boxes at each position of the feature maps. This method will be used in R3Det in refinement stages. Args: cls_scores (list[Tensor]): Box scores for each scale level Has shape (N, num_classes, H, W) bbox_preds (list[Tensor]): Box energies / deltas for each scale level with shape (N, 5, H, W) rois (list[list[Tensor]]): input rbboxes of each level of each image. rois output by former stages and are to be refined Returns: list[list[Tensor]]: best or refined rbboxes of each level of each image. """ num_levels = len(cls_scores) assert num_levels == len(bbox_preds) num_imgs = cls_scores[0].size(0) for i in range(num_levels): assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0) bboxes_list = [[] for _ in range(num_imgs)] assert rois is not None mlvl_rois = [torch.cat(r) for r in zip(*rois)] for lvl in range(num_levels): bbox_pred = bbox_preds[lvl] rois = mlvl_rois[lvl] assert bbox_pred.size(1) == 5 bbox_pred = bbox_pred.permute(0, 2, 3, 1) bbox_pred = bbox_pred.reshape(-1, 5) refined_bbox = self.bbox_coder.decode(rois, bbox_pred) refined_bbox = refined_bbox.reshape(num_imgs, -1, 5) for img_id in range(num_imgs): bboxes_list[img_id].append(refined_bbox[img_id].detach()) return bboxes_list