Source code for mmrotate.models.dense_heads.rotated_atss_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Scale
from mmdet.models.dense_heads.atss_head import ATSSHead
from mmdet.models.task_modules.prior_generators import anchor_inside_flags
from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.structures.bbox import cat_boxes, get_box_tensor
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.structures import InstanceData
from torch import Tensor
from typing import List, Optional
from mmrotate.registry import MODELS
from mmrotate.structures.bbox import RotatedBoxes
[docs]
@MODELS.register_module()
class RotatedATSSHead(ATSSHead):
"""Detection Head of `ATSS <https://arxiv.org/abs/1912.02424>`_.
ATSS head structure is similar with FCOS, however ATSS use anchor boxes
and assign label by Adaptive Training Sample Selection instead max-iou.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
pred_kernel_size (int): Kernel size of ``nn.Conv2d``
stacked_convs (int): Number of stacking convs of the head.
conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for
convolution layer. Defaults to None.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to ``dict(type='GN', num_groups=32,
requires_grad=True)``.
reg_decoded_bbox (bool): If true, the regression loss would be
applied directly on decoded bounding boxes, converting both
the predicted boxes and regression targets to absolute
coordinates format. Defaults to False. It should be `True` when
using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
loss_centerness (:obj:`ConfigDict` or dict): Config of centerness loss.
Defaults to ``dict(type='CrossEntropyLoss', use_sigmoid=True,
loss_weight=1.0)``.
init_cfg (:obj:`ConfigDict` or dict or list[dict] or
list[:obj:`ConfigDict`]): Initialization config dict.
"""
def _init_layers(self) -> None:
"""Initialize layers of the head."""
self.relu = nn.ReLU(inplace=True)
self.cls_convs = nn.ModuleList()
self.reg_convs = nn.ModuleList()
for i in range(self.stacked_convs):
chn = self.in_channels if i == 0 else self.feat_channels
self.cls_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.reg_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
pred_pad_size = self.pred_kernel_size // 2
self.atss_cls = nn.Conv2d(
self.feat_channels,
self.num_anchors * self.cls_out_channels,
self.pred_kernel_size,
padding=pred_pad_size)
reg_dim = self.bbox_coder.encode_size
self.atss_reg = nn.Conv2d(
self.feat_channels,
self.num_base_priors * reg_dim,
self.pred_kernel_size,
padding=pred_pad_size)
self.atss_centerness = nn.Conv2d(
self.feat_channels,
self.num_base_priors * 1,
self.pred_kernel_size,
padding=pred_pad_size)
self.scales = nn.ModuleList(
[Scale(1.0) for _ in self.prior_generator.strides])
[docs]
def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
bbox_pred: Tensor, centerness: Tensor,
labels: Tensor, label_weights: Tensor,
bbox_targets: Tensor, avg_factor: float) -> dict:
"""Calculate the loss of a single scale level based on the features
extracted by the detection head.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred (Tensor): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W).
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 4).
avg_factor (float): Average factor that is used to average
the loss. When using sampling method, avg_factor is usually
the sum of positive and negative priors. When using
`PseudoSampler`, `avg_factor` is usually equal to the number
of positive priors.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
dim = self.bbox_coder.encode_size
anchors = anchors.reshape(-1, dim)
cls_score = cls_score.permute(0, 2, 3, 1).reshape(
-1, self.cls_out_channels).contiguous()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, dim)
centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
bbox_targets = bbox_targets.reshape(-1, dim)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# classification loss
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_centerness = centerness[pos_inds]
centerness_targets = self.centerness_target(
pos_anchors, pos_bbox_targets)
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchors, pos_bbox_pred)
pos_decode_bbox_pred = get_box_tensor(pos_decode_bbox_pred)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_bbox_targets,
weight=centerness_targets,
avg_factor=1.0)
# centerness loss
loss_centerness = self.loss_centerness(
pos_centerness, centerness_targets, avg_factor=avg_factor)
else:
loss_bbox = bbox_pred.sum() * 0
loss_centerness = centerness.sum() * 0
centerness_targets = bbox_targets.new_tensor(0.)
return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
[docs]
def centerness_target(self, anchors: Tensor, gts: Tensor) -> Tensor:
"""Calculate the centerness between anchors and gts.
Only calculate pos centerness targets, otherwise there may be nan.
Args:
anchors (Tensor): Anchors with shape (N, 5),
<cx, cy, w, h, t> format.
gts (Tensor): Ground truth bboxes with shape (N, 5),
<cx, cy, w, h, t> format.
Returns:
Tensor: Centerness between anchors and gts.
"""
gts = RotatedBoxes(gts).convert_to('hbox').tensor
anchors_cx, anchors_cy = RotatedBoxes(anchors).centers.unbind(dim=-1)
l_ = anchors_cx - gts[:, 0]
t_ = anchors_cy - gts[:, 1]
r_ = gts[:, 2] - anchors_cx
b_ = gts[:, 3] - anchors_cy
left_right = torch.stack([l_, r_], dim=1)
top_bottom = torch.stack([t_, b_], dim=1)
centerness = torch.sqrt(
(left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
assert not torch.isnan(centerness).any()
return centerness
def _get_targets_single(self,
flat_anchors: Tensor,
valid_flags: Tensor,
num_level_anchors: List[int],
gt_instances: InstanceData,
img_meta: dict,
gt_instances_ignore: Optional[InstanceData] = None,
unmap_outputs: bool = True) -> tuple:
"""Compute regression, classification targets for anchors in a single
image.
Args:
flat_anchors (Tensor): Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags (Tensor): Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
num_level_anchors (List[int]): Number of anchors of each scale
level.
gt_instances (:obj:`InstanceData`): Ground truth of instance
annotations. It usually includes ``bboxes`` and ``labels``
attributes.
img_meta (dict): Meta information for current image.
gt_instances_ignore (:obj:`InstanceData`, optional): Instances
to be ignored during training. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors.
Returns:
tuple: N is the number of total anchors in the image.
labels (Tensor): Labels of all anchors in the image with shape
(N,).
label_weights (Tensor): Label weights of all anchor in the
image with shape (N,).
bbox_targets (Tensor): BBox targets of all anchors in the
image with shape (N, 4).
bbox_weights (Tensor): BBox weights of all anchors in the
image with shape (N, 4)
pos_inds (Tensor): Indices of positive anchor with shape
(num_pos,).
neg_inds (Tensor): Indices of negative anchor with shape
(num_neg,).
sampling_result (:obj:`SamplingResult`): Sampling results.
"""
inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
img_meta['img_shape'][:2],
self.train_cfg['allowed_border'])
if not inside_flags.any():
raise ValueError(
'There is no valid anchor inside the image boundary. Please '
'check the image size and anchor sizes, or set '
'``allowed_border`` to -1 to skip the condition.')
# assign gt and sample anchors
anchors = flat_anchors[inside_flags]
num_level_anchors_inside = self.get_num_level_anchors_inside(
num_level_anchors, inside_flags)
pred_instances = InstanceData(priors=anchors)
assign_result = self.assigner.assign(pred_instances,
num_level_anchors_inside,
gt_instances, gt_instances_ignore)
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_anchors = anchors.shape[0]
target_dim = gt_instances.bboxes.size(-1) if self.reg_decoded_bbox \
else self.bbox_coder.encode_size
bbox_targets = anchors.new_zeros(num_valid_anchors, target_dim)
bbox_weights = anchors.new_zeros(num_valid_anchors, target_dim)
labels = anchors.new_full((num_valid_anchors, ),
self.num_classes,
dtype=torch.long)
label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
if self.reg_decoded_bbox:
pos_bbox_targets = sampling_result.pos_gt_bboxes
pos_bbox_targets = get_box_tensor(pos_bbox_targets)
else:
pos_bbox_targets = self.bbox_coder.encode(
sampling_result.pos_priors, sampling_result.pos_gt_bboxes)
bbox_targets[pos_inds] = pos_bbox_targets
bbox_weights[pos_inds, :] = 1.0
labels[pos_inds] = sampling_result.pos_gt_labels
if self.train_cfg['pos_weight'] <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = self.train_cfg['pos_weight']
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# map up to original set of anchors
if unmap_outputs:
num_total_anchors = flat_anchors.size(0)
anchors = unmap(anchors.tensor, num_total_anchors, inside_flags)
labels = unmap(
labels, num_total_anchors, inside_flags, fill=self.num_classes)
label_weights = unmap(label_weights, num_total_anchors,
inside_flags)
bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
return (anchors, labels, label_weights, bbox_targets, bbox_weights,
pos_inds, neg_inds, sampling_result)
[docs]
def get_targets(self,
anchor_list: List[List[Tensor]],
valid_flag_list: List[List[Tensor]],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
unmap_outputs: bool = True) -> tuple:
"""Get targets for ATSS head.
This method is almost the same as `AnchorHead.get_targets()`. Besides
returning the targets as the parent method does, it also returns the
anchors as the first element of the returned tuple.
"""
num_imgs = len(batch_img_metas)
assert len(anchor_list) == len(valid_flag_list) == num_imgs
# anchor number of multi levels
num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
num_level_anchors_list = [num_level_anchors] * num_imgs
# concat all level anchors and flags to a single tensor
for i in range(num_imgs):
assert len(anchor_list[i]) == len(valid_flag_list[i])
anchor_list[i] = cat_boxes(anchor_list[i])
valid_flag_list[i] = cat_boxes(valid_flag_list[i])
# compute targets for each image
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
(all_anchors, all_labels, all_label_weights, all_bbox_targets,
all_bbox_weights, pos_inds_list, neg_inds_list,
sampling_results_list) = multi_apply(
self._get_targets_single,
anchor_list,
valid_flag_list,
num_level_anchors_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore,
unmap_outputs=unmap_outputs)
# Get `avg_factor` of all images, which calculate in `SamplingResult`.
# When using sampling method, avg_factor is usually the sum of
# positive and negative priors. When using `PseudoSampler`,
# `avg_factor` is usually equal to the number of positive priors.
avg_factor = sum(
[results.avg_factor for results in sampling_results_list])
# split targets to a list w.r.t. multiple levels
anchors_list = images_to_levels(all_anchors, num_level_anchors)
labels_list = images_to_levels(all_labels, num_level_anchors)
label_weights_list = images_to_levels(all_label_weights,
num_level_anchors)
bbox_targets_list = images_to_levels(all_bbox_targets,
num_level_anchors)
bbox_weights_list = images_to_levels(all_bbox_weights,
num_level_anchors)
return (anchors_list, labels_list, label_weights_list,
bbox_targets_list, bbox_weights_list, avg_factor)