Source code for mmrotate.models.roi_heads.gv_ratio_roi_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
from mmdet.models.task_modules.samplers import SamplingResult
from mmdet.models.utils import empty_instances
from mmdet.structures.bbox import bbox2roi
from mmdet.utils import ConfigType, InstanceList
from torch import Tensor
from typing import List, Tuple
from mmrotate.registry import MODELS
[docs]
@MODELS.register_module()
class GVRatioRoIHead(StandardRoIHead):
"""Gliding vertex roi head including one bbox head and one mask head."""
# TODO: Need to refactor later
[docs]
def forward(self, x: Tuple[Tensor],
rpn_results_list: InstanceList) -> tuple:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
x (List[Tensor]): Multi-level features that may have different
resolutions.
rpn_results_list (list[:obj:`InstanceData`]): List of region
proposals.
Returns
tuple: A tuple of features from ``bbox_head`` and ``mask_head``
forward.
"""
results = ()
proposals = [rpn_results.bboxes for rpn_results in rpn_results_list]
rois = bbox2roi(proposals)
# bbox head
if self.with_bbox:
bbox_results = self._bbox_forward(x, rois)
results = results + (
bbox_results['cls_score'], bbox_results['bbox_pred'],
bbox_results['fix_pred'], bbox_results['ratio_pred'])
# mask head
if self.with_mask:
mask_rois = rois[:100]
mask_results = self._mask_forward(x, mask_rois)
results = results + (mask_results['mask_preds'], )
return results
def _bbox_forward(self, x: Tuple[Tensor], rois: Tensor) -> dict:
"""Box head forward function used in both training and testing.
Args:
x (tuple[Tensor]): List of multi-level img features.
rois (Tensor): RoIs with the shape (n, 5) where the first
column indicates batch id of each RoI.
Returns:
dict[str, Tensor]: Usually returns a dictionary with keys:
- `cls_score` (Tensor): Classification scores.
- `bbox_pred` (Tensor): Box energies / deltas.
- `fix_pred` (Tensor): fix / deltas.
- `ratio_pred` (Tensor): ratio / deltas.
- `bbox_feats` (Tensor): Extract bbox RoI features.
"""
# TODO: a more flexible way to decide which feature maps to use
bbox_feats = self.bbox_roi_extractor(
x[:self.bbox_roi_extractor.num_inputs], rois)
if self.with_shared_head:
bbox_feats = self.shared_head(bbox_feats)
cls_score, bbox_pred, fix_pred, ratio_pred = self.bbox_head(bbox_feats)
bbox_results = dict(
cls_score=cls_score,
bbox_pred=bbox_pred,
fix_pred=fix_pred,
ratio_pred=ratio_pred,
bbox_feats=bbox_feats)
return bbox_results
[docs]
def bbox_loss(self, x: Tuple[Tensor],
sampling_results: List[SamplingResult]) -> dict:
"""Perform forward propagation and loss calculation of the bbox head on
the features of the upstream network.
Args:
x (tuple[Tensor]): List of multi-level img features.
sampling_results (list["obj:`SamplingResult`]): Sampling results.
Returns:
dict[str, Tensor]: Usually returns a dictionary with keys:
- `cls_score` (Tensor): Classification scores.
- `bbox_pred` (Tensor): Box energies / deltas.
- `fix_pred` (Tensor): fix / deltas.
- `ratio_pred` (Tensor): ratio / deltas.
- `bbox_feats` (Tensor): Extract bbox RoI features.
- `loss_bbox` (dict): A dictionary of bbox loss components.
"""
rois = bbox2roi([res.priors for res in sampling_results])
bbox_results = self._bbox_forward(x, rois)
bbox_loss_and_target = self.bbox_head.loss_and_target(
cls_score=bbox_results['cls_score'],
bbox_pred=bbox_results['bbox_pred'],
fix_pred=bbox_results['fix_pred'],
ratio_pred=bbox_results['ratio_pred'],
rois=rois,
sampling_results=sampling_results,
rcnn_train_cfg=self.train_cfg)
bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox'])
return bbox_results
[docs]
def predict_bbox(self,
x: Tuple[Tensor],
batch_img_metas: List[dict],
rpn_results_list: InstanceList,
rcnn_test_cfg: ConfigType,
rescale: bool = False) -> InstanceList:
"""Perform forward propagation of the bbox head and predict detection
results on the features of the upstream network.
Args:
x (tuple[Tensor]): Feature maps of all scale level.
batch_img_metas (list[dict]): List of image information.
rpn_results_list (list[:obj:`InstanceData`]): List of region
proposals.
rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
Returns:
list[:obj:`InstanceData`]: Detection results of each image after
the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
proposals = [res.bboxes for res in rpn_results_list]
rois = bbox2roi(proposals)
if rois.shape[0] == 0:
return empty_instances(
batch_img_metas,
rois.device,
task_type='bbox',
box_type=self.bbox_head.predict_box_type)
bbox_results = self._bbox_forward(x, rois)
# split batch bbox prediction back to each image
cls_scores = bbox_results['cls_score']
bbox_preds = bbox_results['bbox_pred']
fix_preds = bbox_results['fix_pred']
ratio_preds = bbox_results['ratio_pred']
num_proposals_per_img = tuple(len(p) for p in proposals)
rois = rois.split(num_proposals_per_img, 0)
cls_scores = cls_scores.split(num_proposals_per_img, 0)
# some detector with_reg is False, bbox_preds will be None
if bbox_preds is not None:
# TODO move this to a sabl_roi_head
# the bbox prediction of some detectors like SABL is not Tensor
if isinstance(bbox_preds, torch.Tensor):
bbox_preds = bbox_preds.split(num_proposals_per_img, 0)
fix_preds = fix_preds.split(num_proposals_per_img, 0)
ratio_preds = ratio_preds.split(num_proposals_per_img, 0)
else:
bbox_preds = self.bbox_head.bbox_pred_split(
bbox_preds, num_proposals_per_img)
else:
bbox_preds = (None, ) * len(proposals)
fix_preds = (None, ) * len(proposals)
ratio_preds = (None, ) * len(proposals)
result_list = self.bbox_head.predict_by_feat(
rois=rois,
cls_scores=cls_scores,
bbox_preds=bbox_preds,
fix_preds=fix_preds,
ratio_preds=ratio_preds,
batch_img_metas=batch_img_metas,
rcnn_test_cfg=rcnn_test_cfg,
rescale=rescale)
return result_list