Shortcuts

Source code for mmrotate.models.dense_heads.oriented_rpn_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch

try:
    from mmcv.ops import batched_nms
except ImportError:  # noqa: E722

    def batched_nms(*args, **kwargs):
        raise RuntimeError('batched_nms from mmcv.ops is not available. '
                           'Please install onedl-mmcv with ops support.')


from mmdet.models.dense_heads import RPNHead
from mmdet.structures.bbox import (BaseBoxes, get_box_tensor, get_box_wh,
                                   scale_boxes)
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from typing import Optional

from mmrotate.registry import MODELS
from mmrotate.structures.bbox import rbox2hbox


[docs] @MODELS.register_module() class OrientedRPNHead(RPNHead): """Oriented RPN head for Oriented R-CNN.""" def _bbox_post_process(self, results: InstanceData, cfg: ConfigDict, rescale: bool = False, with_nms: bool = True, img_meta: Optional[dict] = None) -> InstanceData: """Bbox post-processing method, which use horizontal bboxes for NMS, but return the rotated bboxes result. Args: results (:obj:`InstaceData`): Detection instance results, each item has shape (num_bboxes, ). cfg (ConfigDict): Test / postprocessing configuration. rescale (bool): If True, return boxes in original image space. Defaults to False. with_nms (bool): If True, do nms before return boxes. Default to True. img_meta (dict, optional): Image meta info. Defaults to None. Returns: :obj:`InstanceData`: Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ assert with_nms, '`with_nms` must be True in RPNHead' if rescale: assert img_meta.get('scale_factor') is not None scale_factor = [1 / s for s in img_meta['scale_factor']] results.bboxes = scale_boxes(results.bboxes, scale_factor) # filter small size bboxes if cfg.get('min_bbox_size', -1) >= 0: w, h = get_box_wh(results.bboxes) valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) if not valid_mask.all(): results = results[valid_mask] if results.bboxes.numel() > 0: bboxes = get_box_tensor(results.bboxes) hbboxes = rbox2hbox(bboxes) det_bboxes, keep_idxs = batched_nms(hbboxes, results.scores, results.level_ids, cfg.nms) results = results[keep_idxs] # some nms would reweight the score, such as softnms results.scores = det_bboxes[:, -1] results = results[:cfg.max_per_img] # TODO: This would unreasonably show the 0th class label # in visualization results.labels = results.scores.new_zeros( len(results), dtype=torch.long) del results.level_ids else: # To avoid some potential error results_ = InstanceData() if isinstance(results.bboxes, BaseBoxes): results_.bboxes = results.bboxes.empty_boxes() else: results_.bboxes = results.scores.new_zeros(0, 4) results_.scores = results.scores.new_zeros(0) results_.labels = results.scores.new_zeros(0) results = results_ return results