Source code for mmrotate.models.detectors.h2rbox_v2
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import torch
from mmdet.models.detectors.single_stage import SingleStageDetector
from mmdet.models.utils import unpack_gt_instances
from mmdet.structures import DetDataSample, SampleList
from mmdet.structures.bbox import get_box_tensor
from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
from torch import Tensor
from torch.nn.functional import grid_sample
from torchvision import transforms
from typing import Tuple, Union
from mmrotate.registry import MODELS
from mmrotate.structures.bbox import RotatedBoxes
[docs]
@MODELS.register_module()
class H2RBoxV2Detector(SingleStageDetector):
"""Implementation of `H2RBox-v2 <https://arxiv.org/abs/2304.04403>`_"""
def __init__(self,
backbone: ConfigType,
neck: ConfigType,
bbox_head: ConfigType,
crop_size: Tuple[int, int] = (768, 768),
padding: str = 'reflection',
view_range: Tuple[float, float] = (0.25, 0.75),
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
self.crop_size = crop_size
self.padding = padding
self.view_range = view_range
[docs]
def rotate_crop(
self,
batch_inputs: Tensor,
rot: float = 0.,
size: Tuple[int, int] = (768, 768),
batch_gt_instances: InstanceList = None,
padding: str = 'reflection') -> Tuple[Tensor, InstanceList]:
"""
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
rot (float): Angle of view rotation. Defaults to 0.
size (tuple[int]): Crop size from image center.
Defaults to (768, 768).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
padding (str): Padding method of image black edge.
Defaults to 'reflection'.
Returns:
Processed batch_inputs (Tensor) and batch_gt_instances
(list[:obj:`InstanceData`])
"""
device = batch_inputs.device
n, c, h, w = batch_inputs.shape
size_h, size_w = size
crop_h = (h - size_h) // 2
crop_w = (w - size_w) // 2
if rot != 0:
cosa, sina = math.cos(rot), math.sin(rot)
tf = batch_inputs.new_tensor([[cosa, -sina], [sina, cosa]],
dtype=torch.float)
x_range = torch.linspace(-1, 1, w, device=device)
y_range = torch.linspace(-1, 1, h, device=device)
y, x = torch.meshgrid(y_range, x_range)
grid = torch.stack([x, y], -1).expand([n, -1, -1, -1])
grid = grid.reshape(-1, 2).matmul(tf).view(n, h, w, 2)
# rotate
batch_inputs = grid_sample(
batch_inputs, grid, 'bilinear', padding, align_corners=True)
if batch_gt_instances is not None:
for i, gt_instances in enumerate(batch_gt_instances):
gt_bboxes = get_box_tensor(gt_instances.bboxes)
xy, wh, a = gt_bboxes[..., :2], gt_bboxes[
..., 2:4], gt_bboxes[..., [4]]
ctr = tf.new_tensor([[w / 2, h / 2]])
xy = (xy - ctr).matmul(tf.T) + ctr
a = a + rot
rot_gt_bboxes = torch.cat([xy, wh, a], dim=-1)
batch_gt_instances[i].bboxes = RotatedBoxes(rot_gt_bboxes)
batch_inputs = batch_inputs[..., crop_h:crop_h + size_h,
crop_w:crop_w + size_w]
if batch_gt_instances is None:
return batch_inputs
else:
for i, gt_instances in enumerate(batch_gt_instances):
gt_bboxes = get_box_tensor(gt_instances.bboxes)
xy, wh, a = gt_bboxes[..., :2], gt_bboxes[...,
2:4], gt_bboxes[...,
[4]]
xy = xy - xy.new_tensor([[crop_w, crop_h]])
crop_gt_bboxes = torch.cat([xy, wh, a], dim=-1)
batch_gt_instances[i].bboxes = RotatedBoxes(crop_gt_bboxes)
return batch_inputs, batch_gt_instances
[docs]
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (N, C, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components.
"""
batch_gt_instances, _, _ = unpack_gt_instances(batch_data_samples)
# Crop original images and gts
batch_inputs, batch_gt_instances = self.rotate_crop(
batch_inputs, 0, self.crop_size, batch_gt_instances, self.padding)
offset = 1
for gt_instances in batch_gt_instances:
gt_instances.bid = torch.arange(
0,
len(gt_instances.bboxes),
1,
device=gt_instances.bboxes.device) + offset + 0.2
offset += len(gt_instances.bboxes)
# Generate rotated images and gts
rot = math.pi * (
torch.rand(1, device=batch_inputs.device) *
(self.view_range[1] - self.view_range[0]) + self.view_range[0])
batch_gt_rot = copy.deepcopy(batch_gt_instances)
batch_inputs_rot, batch_gt_rot = self.rotate_crop(
batch_inputs, rot, self.crop_size, batch_gt_rot, self.padding)
offset = 1
for gt_instances in batch_gt_rot:
gt_instances.bid = torch.arange(
0,
len(gt_instances.bboxes),
1,
device=gt_instances.bboxes.device) + offset + 0.4
offset += len(gt_instances.bboxes)
# Generate flipped images and gts
batch_inputs_flp = transforms.functional.vflip(batch_inputs)
batch_gt_flp = copy.deepcopy(batch_gt_instances)
offset = 1
for gt_instances in batch_gt_flp:
gt_instances.bboxes.flip_(batch_inputs.shape[2:4], 'vertical')
gt_instances.bid = torch.arange(
0,
len(gt_instances.bboxes),
1,
device=gt_instances.bboxes.device) + offset + 0.6
offset += len(gt_instances.bboxes)
# Concat original/rotated/flipped images and gts
batch_inputs_all = torch.cat(
(batch_inputs, batch_inputs_rot, batch_inputs_flp))
batch_data_samples_all = []
for gt_instances in batch_gt_instances + batch_gt_rot + batch_gt_flp:
data_sample = DetDataSample()
data_sample.gt_instances = gt_instances
batch_data_samples_all.append(data_sample)
feat = self.extract_feat(batch_inputs_all)
losses = self.bbox_head.loss(feat, batch_data_samples_all)
return losses