Shortcuts

Source code for mmrotate.models.roi_heads.gv_ratio_roi_head

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import empty_instances
from mmdet.structures.bbox import bbox2roi
from mmdet.utils import ConfigType, InstanceList
from torch import Tensor
from typing import List, Tuple

from mmrotate.registry import MODELS


[docs] @MODELS.register_module() class GVRatioRoIHead(StandardRoIHead): """Gliding vertex roi head including one bbox head and one mask head.""" # TODO: Need to refactor later
[docs] def forward(self, x: Tuple[Tensor], rpn_results_list: InstanceList) -> tuple: """Network forward process. Usually includes backbone, neck and head forward without any post-processing. Args: x (List[Tensor]): Multi-level features that may have different resolutions. rpn_results_list (list[:obj:`InstanceData`]): List of region proposals. Returns tuple: A tuple of features from ``bbox_head`` and ``mask_head`` forward. """ results = () proposals = [rpn_results.bboxes for rpn_results in rpn_results_list] rois = bbox2roi(proposals) # bbox head if self.with_bbox: bbox_results = self._bbox_forward(x, rois) results = results + ( bbox_results['cls_score'], bbox_results['bbox_pred'], bbox_results['fix_pred'], bbox_results['ratio_pred']) # mask head if self.with_mask: mask_rois = rois[:100] mask_results = self._mask_forward(x, mask_rois) results = results + (mask_results['mask_preds'], ) return results
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict: """Box head forward function used in both training and testing. Args: x (tuple[Tensor]): List of multi-level img features. rois (Tensor): RoIs with the shape (n, 5) where the first column indicates batch id of each RoI. Returns: dict[str, Tensor]: Usually returns a dictionary with keys: - `cls_score` (Tensor): Classification scores. - `bbox_pred` (Tensor): Box energies / deltas. - `fix_pred` (Tensor): fix / deltas. - `ratio_pred` (Tensor): ratio / deltas. - `bbox_feats` (Tensor): Extract bbox RoI features. """ # TODO: a more flexible way to decide which feature maps to use bbox_feats = self.bbox_roi_extractor( x[:self.bbox_roi_extractor.num_inputs], rois) if self.with_shared_head: bbox_feats = self.shared_head(bbox_feats) cls_score, bbox_pred, fix_pred, ratio_pred = self.bbox_head(bbox_feats) bbox_results = dict( cls_score=cls_score, bbox_pred=bbox_pred, fix_pred=fix_pred, ratio_pred=ratio_pred, bbox_feats=bbox_feats) return bbox_results
[docs] def bbox_loss(self, x: Tuple[Tensor], sampling_results: List[SamplingResult]) -> dict: """Perform forward propagation and loss calculation of the bbox head on the features of the upstream network. Args: x (tuple[Tensor]): List of multi-level img features. sampling_results (list["obj:`SamplingResult`]): Sampling results. Returns: dict[str, Tensor]: Usually returns a dictionary with keys: - `cls_score` (Tensor): Classification scores. - `bbox_pred` (Tensor): Box energies / deltas. - `fix_pred` (Tensor): fix / deltas. - `ratio_pred` (Tensor): ratio / deltas. - `bbox_feats` (Tensor): Extract bbox RoI features. - `loss_bbox` (dict): A dictionary of bbox loss components. """ rois = bbox2roi([res.priors for res in sampling_results]) bbox_results = self._bbox_forward(x, rois) bbox_loss_and_target = self.bbox_head.loss_and_target( cls_score=bbox_results['cls_score'], bbox_pred=bbox_results['bbox_pred'], fix_pred=bbox_results['fix_pred'], ratio_pred=bbox_results['ratio_pred'], rois=rois, sampling_results=sampling_results, rcnn_train_cfg=self.train_cfg) bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) return bbox_results
[docs] def predict_bbox(self, x: Tuple[Tensor], batch_img_metas: List[dict], rpn_results_list: InstanceList, rcnn_test_cfg: ConfigType, rescale: bool = False) -> InstanceList: """Perform forward propagation of the bbox head and predict detection results on the features of the upstream network. Args: x (tuple[Tensor]): Feature maps of all scale level. batch_img_metas (list[dict]): List of image information. rpn_results_list (list[:obj:`InstanceData`]): List of region proposals. rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. rescale (bool): If True, return boxes in original image space. Defaults to False. Returns: list[:obj:`InstanceData`]: Detection results of each image after the post process. Each item usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). """ proposals = [res.bboxes for res in rpn_results_list] rois = bbox2roi(proposals) if rois.shape[0] == 0: return empty_instances( batch_img_metas, rois.device, task_type='bbox', box_type=self.bbox_head.predict_box_type) bbox_results = self._bbox_forward(x, rois) # split batch bbox prediction back to each image cls_scores = bbox_results['cls_score'] bbox_preds = bbox_results['bbox_pred'] fix_preds = bbox_results['fix_pred'] ratio_preds = bbox_results['ratio_pred'] num_proposals_per_img = tuple(len(p) for p in proposals) rois = rois.split(num_proposals_per_img, 0) cls_scores = cls_scores.split(num_proposals_per_img, 0) # some detector with_reg is False, bbox_preds will be None if bbox_preds is not None: # TODO move this to a sabl_roi_head # the bbox prediction of some detectors like SABL is not Tensor if isinstance(bbox_preds, torch.Tensor): bbox_preds = bbox_preds.split(num_proposals_per_img, 0) fix_preds = fix_preds.split(num_proposals_per_img, 0) ratio_preds = ratio_preds.split(num_proposals_per_img, 0) else: bbox_preds = self.bbox_head.bbox_pred_split( bbox_preds, num_proposals_per_img) else: bbox_preds = (None, ) * len(proposals) fix_preds = (None, ) * len(proposals) ratio_preds = (None, ) * len(proposals) result_list = self.bbox_head.predict_by_feat( rois=rois, cls_scores=cls_scores, bbox_preds=bbox_preds, fix_preds=fix_preds, ratio_preds=ratio_preds, batch_img_metas=batch_img_metas, rcnn_test_cfg=rcnn_test_cfg, rescale=rescale) return result_list