Shortcuts

Source code for mmrotate.models.detectors.h2rbox

# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
from mmdet.models.detectors.single_stage import SingleStageDetector
from mmdet.models.utils import unpack_gt_instances
from mmdet.structures import SampleList
from mmdet.structures.bbox import get_box_tensor
from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
from torch import Tensor
from torch.nn.functional import grid_sample
from typing import Tuple, Union

from mmrotate.registry import MODELS
from mmrotate.structures.bbox import RotatedBoxes


[docs] @MODELS.register_module() class H2RBoxDetector(SingleStageDetector): """Implementation of `H2RBox <https://arxiv.org/abs/2210.06742>`_""" def __init__(self, backbone: ConfigType, neck: ConfigType, bbox_head: ConfigType, crop_size: Tuple[int, int] = (768, 768), padding: str = 'reflection', train_cfg: OptConfigType = None, test_cfg: OptConfigType = None, data_preprocessor: OptConfigType = None, init_cfg: OptMultiConfig = None) -> None: super().__init__( backbone=backbone, neck=neck, bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.crop_size = crop_size self.padding = padding
[docs] def rotate_crop( self, batch_inputs: Tensor, rot: float = 0., size: Tuple[int, int] = (768, 768), batch_gt_instances: InstanceList = None, padding: str = 'reflection') -> Tuple[Tensor, InstanceList]: """ Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. rot (float): Angle of view rotation. Defaults to 0. size (tuple[int]): Crop size from image center. Defaults to (768, 768). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. padding (str): Padding method of image black edge. Defaults to 'reflection'. Returns: Processed batch_inputs (Tensor) and batch_gt_instances (list[:obj:`InstanceData`]) """ device = batch_inputs.device n, c, h, w = batch_inputs.shape size_h, size_w = size crop_h = (h - size_h) // 2 crop_w = (w - size_w) // 2 if rot != 0: cosa, sina = math.cos(rot), math.sin(rot) tf = batch_inputs.new_tensor([[cosa, -sina], [sina, cosa]], dtype=torch.float) x_range = torch.linspace(-1, 1, w, device=device) y_range = torch.linspace(-1, 1, h, device=device) y, x = torch.meshgrid(y_range, x_range) grid = torch.stack([x, y], -1).expand([n, -1, -1, -1]) grid = grid.reshape(-1, 2).matmul(tf).view(n, h, w, 2) # rotate batch_inputs = grid_sample( batch_inputs, grid, 'bilinear', padding, align_corners=True) if batch_gt_instances is not None: for i, gt_instances in enumerate(batch_gt_instances): gt_bboxes = get_box_tensor(gt_instances.bboxes) xy, wh, a = gt_bboxes[..., :2], gt_bboxes[ ..., 2:4], gt_bboxes[..., [4]] ctr = tf.new_tensor([[w / 2, h / 2]]) xy = (xy - ctr).matmul(tf.T) + ctr a = a + rot rot_gt_bboxes = torch.cat([xy, wh, a], dim=-1) batch_gt_instances[i].bboxes = RotatedBoxes(rot_gt_bboxes) batch_inputs = batch_inputs[..., crop_h:crop_h + size_h, crop_w:crop_w + size_w] if batch_gt_instances is None: return batch_inputs else: for i, gt_instances in enumerate(batch_gt_instances): gt_bboxes = get_box_tensor(gt_instances.bboxes) xy, wh, a = gt_bboxes[..., :2], gt_bboxes[..., 2:4], gt_bboxes[..., [4]] xy = xy - xy.new_tensor([[crop_w, crop_h]]) crop_gt_bboxes = torch.cat([xy, wh, a], dim=-1) batch_gt_instances[i].bboxes = RotatedBoxes(crop_gt_bboxes) return batch_inputs, batch_gt_instances
[docs] def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> Union[dict, list]: """Calculate losses from a batch of inputs and data samples. Args: batch_inputs (Tensor): Input images of shape (N, C, H, W). These should usually be mean centered and std scaled. batch_data_samples (list[:obj:`DetDataSample`]): The batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Returns: dict: A dictionary of loss components. """ outputs = unpack_gt_instances(batch_data_samples) (batch_gt_instances, batch_gt_instances_ignore, batch_img_metas) = outputs rot = (torch.rand(1, device=batch_inputs.device) * 2 - 1) * math.pi batch_inputs_ws, batch_gt_instances = \ self.rotate_crop(batch_inputs, 0, self.crop_size, batch_gt_instances, self.padding) feat_ws = self.extract_feat(batch_inputs_ws) batch_inputs_ss = self.rotate_crop( batch_inputs, rot, self.crop_size, padding=self.padding) feat_ss = self.extract_feat(batch_inputs_ss) losses = self.bbox_head.loss(feat_ws, feat_ss, rot, batch_gt_instances, batch_gt_instances_ignore, batch_img_metas) return losses