Source code for mmrotate.models.dense_heads.h2rbox_head
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import torch
from mmcv.cnn import Scale
from mmdet.models.utils import (filter_scores_and_topk, multi_apply,
select_single_mlvl)
from mmdet.structures.bbox import cat_boxes, get_box_tensor
from mmdet.utils import (ConfigType, InstanceList, OptConfigType,
OptInstanceList, reduce_mean)
from mmengine import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor
from typing import Dict, List, Optional, Tuple
from mmrotate.models.dense_heads.rotated_fcos_head import RotatedFCOSHead
from mmrotate.registry import MODELS
from mmrotate.structures import RotatedBoxes, hbox2rbox, rbox2hbox
INF = 1e8
[docs]
@MODELS.register_module()
class H2RBoxHead(RotatedFCOSHead):
"""Anchor-free head used in `H2RBox <https://arxiv.org/abs/2210.06742>`_.
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
angle_version (str): Angle representations. Defaults to 'le90'.
use_hbbox_loss (bool): If true, use horizontal bbox loss and
loss_angle should not be None. Defaults to False.
scale_angle (bool): If true, add scale to angle pred branch.
Defaults to True.
angle_coder (:obj:`ConfigDict` or dict): Config of angle coder.
h_bbox_coder (dict): Config of horzional bbox coder,
only used when use_hbbox_loss is True.
bbox_coder (:obj:`ConfigDict` or dict): Config of bbox coder. Defaults
to 'DistanceAnglePointCoder'.
loss_cls (:obj:`ConfigDict` or dict): Config of classification loss.
loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss.
loss_centerness (:obj:`ConfigDict`, or dict): Config of centerness loss.
loss_angle (:obj:`ConfigDict` or dict, Optional): Config of angle loss.
loss_bbox_ss (:obj:`ConfigDict` or dict): Config of consistency loss.
rotation_agnostic_classes (list): Ids of rotation agnostic category.
weak_supervised (bool): If true, horizontal gtbox is input.
Defaults to True.
square_classes (list): Ids of the square category.
crop_size (tuple[int]): Crop size from image center.
Defaults to (768, 768).
Example:
>>> self = H2RBoxHead(11, 7)
>>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
>>> cls_score, bbox_pred, angle_pred, centerness = self.forward(feats)
>>> assert len(cls_score) == len(self.scales)
""" # noqa: E501
def __init__(self,
num_classes: int,
in_channels: int,
angle_version: str = 'le90',
use_hbbox_loss: bool = False,
scale_angle: bool = True,
angle_coder: ConfigType = dict(type='PseudoAngleCoder'),
h_bbox_coder: ConfigType = dict(
type='mmdet.DistancePointBBoxCoder'),
bbox_coder: ConfigType = dict(type='DistanceAnglePointCoder'),
loss_cls: ConfigType = dict(
type='mmdet.FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox: ConfigType = dict(
type='RotatedIoULoss', loss_weight=1.0),
loss_centerness: ConfigType = dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
loss_angle: OptConfigType = None,
loss_bbox_ss: ConfigType = dict(
type='mmdet.IoULoss', loss_weight=1.0),
rotation_agnostic_classes: list = None,
weak_supervised: bool = True,
square_classes: list = None,
crop_size: Tuple[int, int] = (768, 768),
**kwargs):
super().__init__(
num_classes=num_classes,
in_channels=in_channels,
angle_version=angle_version,
use_hbbox_loss=use_hbbox_loss,
scale_angle=scale_angle,
angle_coder=angle_coder,
h_bbox_coder=h_bbox_coder,
bbox_coder=bbox_coder,
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness,
loss_angle=loss_angle,
**kwargs)
self.loss_bbox_ss = MODELS.build(loss_bbox_ss)
self.rotation_agnostic_classes = rotation_agnostic_classes
self.weak_supervised = weak_supervised
self.square_classes = square_classes
self.crop_size = crop_size
def obb2xyxy(self, rbboxes):
w = rbboxes[:, 2::5]
h = rbboxes[:, 3::5]
a = rbboxes[:, 4::5]
cosa = torch.cos(a).abs()
sina = torch.sin(a).abs()
hbbox_w = cosa * w + sina * h
hbbox_h = sina * w + cosa * h
dx = rbboxes[..., 0]
dy = rbboxes[..., 1]
dw = hbbox_w.reshape(-1)
dh = hbbox_h.reshape(-1)
x1 = dx - dw / 2
y1 = dy - dh / 2
x2 = dx + dw / 2
y2 = dy + dh / 2
return torch.stack((x1, y1, x2, y2), -1)
def _process_rotation_agnostic(self, tensor, cls, dim=4):
_rot_agnostic_mask = torch.ones_like(tensor)
for c in self.rotation_agnostic_classes:
if dim is None:
_rot_agnostic_mask[cls == c] = 0
else:
_rot_agnostic_mask[cls == c, dim] = 0
return tensor * _rot_agnostic_mask
[docs]
def forward_ss_single(self, feats: Tensor, scale: Scale,
stride: int) -> Tuple[Tensor, Tensor]:
"""Forward features of a single scale level in SS branch.
Args:
feats (Tensor): FPN feature maps of the specified stride.
scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
the bbox prediction.
stride (int): The corresponding stride for feature maps, only
used to normalize the bbox prediction when self.norm_on_bbox
is True.
Returns:
tuple: bbox predictions and angle predictions of input
feature maps.
"""
reg_feat = feats
for reg_layer in self.reg_convs:
reg_feat = reg_layer(reg_feat)
bbox_pred = self.conv_reg(reg_feat)
bbox_pred = scale(bbox_pred).float()
if self.norm_on_bbox:
bbox_pred = bbox_pred.clamp(min=0)
if not self.training:
bbox_pred *= stride
else:
bbox_pred = bbox_pred.exp()
angle_pred = self.conv_angle(reg_feat)
if self.is_scale_angle:
angle_pred = self.scale_angle(angle_pred).float()
return bbox_pred, angle_pred
[docs]
def forward_ss(self,
feats: Tuple[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
"""Forward features from the upstream network.
Args:
feats (tuple[Tensor]): Features from the upstream network, each is
a 4D-tensor.
Returns:
tuple: A tuple of each level outputs.
- bbox_pred (list[Tensor]): Box energies / deltas for each \
scale level, each is a 4D-tensor, the channel number is \
num_points * 4.
- angle_pred (list[Tensor]): Box angle for each scale level, \
each is a 4D-tensor, the channel number is num_points * 1.
"""
return multi_apply(self.forward_ss_single, feats, self.scales,
self.strides)
[docs]
def loss(self, x_ws: Tuple[Tensor], x_ss: Tuple[Tensor], rot: float,
batch_gt_instances: InstanceData,
batch_gt_instances_ignore: InstanceData,
batch_img_metas: List[dict]) -> dict:
"""Perform forward propagation and loss calculation of the detection
head on the features of the upstream network.
Args:
x_ws (tuple[Tensor]): Features from the weakly supervised network,
each is a 4D-tensor.
x_ss (tuple[Tensor]): Features from the self-supervised network,
each is a 4D-tensor.
rot (float): Angle of view rotation.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_gt_instances_ignore (list[:obj:`batch_gt_instances_ignore`]):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
batch_img_metas (list[dict]): Meta information of each image,
e.g., image size, scaling factor, etc.
Returns:
dict: A dictionary of loss components.
"""
cls_scores_ws, bbox_preds_ws, angle_preds_ws, centernesses_ws = self(
x_ws)
bbox_preds_ss, angle_preds_ss = self.forward_ss(x_ss)
losses = self.loss_by_feat(cls_scores_ws, bbox_preds_ws,
angle_preds_ws, centernesses_ws,
bbox_preds_ss, angle_preds_ss, rot,
batch_gt_instances, batch_img_metas,
batch_gt_instances_ignore)
return losses
[docs]
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
centernesses: List[Tensor],
bbox_preds_ss: List[Tensor],
angle_preds_ss: List[Tensor],
rot: float,
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None
) -> Dict[str, Tensor]:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level in
weakly supervised barch, each is a 4D-tensor, the channel
number is num_points * num_classes.
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level in weakly supervised barch, each is a 4D-tensor, the
channel number is num_points * 4.
angle_preds (list[Tensor]): Box angle for each scale level in
weakly supervised barch, each is a 4D-tensor, the channel
number is num_points * encode_size.
centernesses (list[Tensor]): centerness for each scale level in
weakly supervised barch, each is a 4D-tensor, the channel
number is num_points * 1.
bbox_preds_ss (list[Tensor]): Box energies / deltas for each scale
level in self-supervised barch, each is a 4D-tensor, the
channel number is num_points * 4.
angle_preds_ss (list[Tensor]): Box angle for each scale level in
self-supervised barch, each is a 4D-tensor, the channel number
is num_points * encode_size.
rot (float): Angle of view rotation.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
assert len(cls_scores) == len(bbox_preds) \
== len(angle_preds) == len(centernesses)
assert len(bbox_preds_ss) == len(angle_preds_ss)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
all_level_points = self.prior_generator.grid_priors(
featmap_sizes,
dtype=bbox_preds[0].dtype,
device=bbox_preds[0].device)
# bbox_targets here is in format t,b,l,r
# angle_targets is not coded here
labels, bbox_targets, angle_targets = self.get_targets(
all_level_points, batch_gt_instances)
num_imgs = cls_scores[0].size(0)
# flatten cls_scores, bbox_preds, angle_preds and centerness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds
]
angle_dim = self.angle_coder.encode_size
flatten_angle_preds = [
angle_pred.permute(0, 2, 3, 1).reshape(-1, angle_dim)
for angle_pred in angle_preds
]
flatten_centerness = [
centerness.permute(0, 2, 3, 1).reshape(-1)
for centerness in centernesses
]
flatten_cls_scores = torch.cat(flatten_cls_scores)
flatten_bbox_preds = torch.cat(flatten_bbox_preds)
flatten_angle_preds = torch.cat(flatten_angle_preds)
flatten_centerness = torch.cat(flatten_centerness)
flatten_labels = torch.cat(labels)
flatten_bbox_targets = torch.cat(bbox_targets)
flatten_angle_targets = torch.cat(angle_targets)
# repeat points to align with bbox_preds
flatten_points = torch.cat(
[points.repeat(num_imgs, 1) for points in all_level_points])
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((flatten_labels >= 0)
& (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
num_pos = torch.tensor(
len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
num_pos = max(reduce_mean(num_pos), 1.0)
loss_cls = self.loss_cls(
flatten_cls_scores, flatten_labels, avg_factor=num_pos)
pos_bbox_preds = flatten_bbox_preds[pos_inds]
pos_angle_preds = flatten_angle_preds[pos_inds]
pos_centerness = flatten_centerness[pos_inds]
pos_bbox_targets = flatten_bbox_targets[pos_inds]
pos_angle_targets = flatten_angle_targets[pos_inds]
pos_centerness_targets = self.centerness_target(pos_bbox_targets)
# centerness weighted iou loss
centerness_denorm = max(
reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
if len(pos_inds) > 0:
cosa, sina = math.cos(rot), math.sin(rot)
tf = flatten_cls_scores.new_tensor([[cosa, -sina], [sina, cosa]])
pos_inds_ss = []
pos_inds_ss_b = []
pos_inds_ss_v = torch.empty_like(pos_inds, dtype=torch.bool)
offset = 0
for h, w in featmap_sizes:
level_mask = (offset
<= pos_inds).logical_and(pos_inds < offset +
num_imgs * h * w)
pos_ind = pos_inds[level_mask] - offset
xy = torch.stack((pos_ind % w, (pos_ind // w) % h), dim=-1)
b = pos_ind // (w * h)
ctr = tf.new_tensor([[(w - 1) / 2, (h - 1) / 2]])
xy_ss = ((xy - ctr).matmul(tf.T) + ctr).round().long()
x_ss = xy_ss[..., 0]
y_ss = xy_ss[..., 1]
xy_valid_ss = ((x_ss >= 0) & (x_ss < w) & (y_ss >= 0) &
(y_ss < h))
pos_ind_ss = (b * h + y_ss) * w + x_ss
pos_inds_ss_v[level_mask] = xy_valid_ss
pos_inds_ss.append(pos_ind_ss[xy_valid_ss] + offset)
pos_inds_ss_b.append(b[xy_valid_ss])
offset += num_imgs * h * w
has_valid_ss = pos_inds_ss_v.any()
pos_points = flatten_points[pos_inds]
pos_labels = flatten_labels[pos_inds]
if has_valid_ss:
pos_inds_ss = torch.cat(pos_inds_ss)
# pos_inds_ss_b = torch.cat(pos_inds_ss_b)
flatten_bbox_preds_ss = [
bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
for bbox_pred in bbox_preds_ss
]
flatten_angle_preds_ss = [
angle_pred.permute(0, 2, 3, 1).reshape(-1, 1)
for angle_pred in angle_preds_ss
]
flatten_bbox_preds_ss = torch.cat(flatten_bbox_preds_ss)
flatten_angle_preds_ss = torch.cat(flatten_angle_preds_ss)
pos_bbox_preds_ss = flatten_bbox_preds_ss[pos_inds_ss]
pos_angle_preds_ss = flatten_angle_preds_ss[pos_inds_ss]
pos_points_ss = flatten_points[pos_inds_ss]
bbox_coder = self.bbox_coder
pos_decoded_angle_preds = self.angle_coder.decode(
pos_angle_preds, keepdim=True)
pos_bbox_preds = torch.cat(
[pos_bbox_preds, pos_decoded_angle_preds], dim=-1)
pos_bbox_targets = torch.cat([pos_bbox_targets, pos_angle_targets],
dim=-1)
pos_decoded_bbox_preds = bbox_coder.decode(pos_points,
pos_bbox_preds)
pos_decoded_target_preds = bbox_coder.decode(
pos_points, pos_bbox_targets)
if self.weak_supervised:
loss_bbox = self.loss_bbox(
self.obb2xyxy(pos_decoded_bbox_preds),
self.obb2xyxy(pos_decoded_target_preds),
weight=pos_centerness_targets,
avg_factor=centerness_denorm)
else:
loss_bbox = self.loss_bbox(
pos_decoded_bbox_preds,
pos_decoded_target_preds,
weight=pos_centerness_targets,
avg_factor=centerness_denorm)
loss_centerness = self.loss_centerness(
pos_centerness, pos_centerness_targets, avg_factor=num_pos)
if has_valid_ss:
pos_bbox_preds_ss = torch.cat(
[pos_bbox_preds_ss, pos_angle_preds_ss], dim=-1)
pos_decoded_bbox_preds_ss = bbox_coder.decode(
pos_points_ss, pos_bbox_preds_ss)
_h, _w = self.crop_size
_ctr = tf.new_tensor([[(_w - 1) / 2, (_h - 1) / 2]])
_xy = pos_decoded_bbox_preds[pos_inds_ss_v, :2]
_wh = pos_decoded_bbox_preds[pos_inds_ss_v, 2:4]
pos_angle_targets_ss = pos_decoded_bbox_preds[pos_inds_ss_v,
4:] + rot
_xy = (_xy - _ctr).matmul(tf.T) + _ctr
if self.rotation_agnostic_classes:
pos_labels_ss = pos_labels[pos_inds_ss_v]
pos_angle_targets_ss = self._process_rotation_agnostic(
pos_angle_targets_ss, pos_labels_ss, dim=None)
pos_decoded_target_preds_ss = torch.cat(
[_xy, _wh, pos_angle_targets_ss], dim=-1)
pos_centerness_targets_ss = pos_centerness_targets[
pos_inds_ss_v]
centerness_denorm_ss = max(
pos_centerness_targets_ss.sum().detach(), 1)
loss_bbox_ss = self.loss_bbox_ss(
pos_decoded_bbox_preds_ss,
pos_decoded_target_preds_ss,
weight=pos_centerness_targets_ss,
avg_factor=centerness_denorm_ss)
else:
loss_bbox_ss = pos_bbox_preds[[]].sum()
else:
loss_bbox = pos_bbox_preds.sum()
loss_bbox_ss = pos_bbox_preds.sum()
loss_centerness = pos_centerness.sum()
return dict(
loss_cls=loss_cls,
loss_bbox=loss_bbox,
loss_centerness=loss_centerness,
loss_bbox_ss=loss_bbox_ss)
[docs]
def predict_by_feat(self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
angle_preds: List[Tensor],
score_factors: Optional[List[Tensor]] = None,
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> InstanceList:
"""Transform a batch of output features extracted from the head into
bbox results.
Note: When score_factors is not None, the cls_scores are
usually multiplied by it then obtain the real score used in NMS,
such as CenterNess in FCOS, IoU branch in ATSS.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
angle_preds (list[Tensor]): Box angle for each scale level
with shape (N, num_points * encode_size, H, W)
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
list[:obj:`InstanceData`]: Object detection results of each image
after the post process. Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 5),
the last dimension 5 arrange as (x, y, w, h, t).
"""
assert len(cls_scores) == len(bbox_preds)
if score_factors is None:
# e.g. Retina, FreeAnchor, Foveabox, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, AutoAssign, etc.
with_score_factors = True
assert len(cls_scores) == len(score_factors)
num_levels = len(cls_scores)
featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes,
dtype=cls_scores[0].dtype,
device=cls_scores[0].device)
result_list = []
for img_id in range(len(batch_img_metas)):
img_meta = batch_img_metas[img_id]
cls_score_list = select_single_mlvl(
cls_scores, img_id, detach=True)
bbox_pred_list = select_single_mlvl(
bbox_preds, img_id, detach=True)
angle_pred_list = select_single_mlvl(
angle_preds, img_id, detach=True)
if with_score_factors:
score_factor_list = select_single_mlvl(
score_factors, img_id, detach=True)
else:
score_factor_list = [None for _ in range(num_levels)]
results = self._predict_by_feat_single(
cls_score_list=cls_score_list,
bbox_pred_list=bbox_pred_list,
angle_pred_list=angle_pred_list,
score_factor_list=score_factor_list,
mlvl_priors=mlvl_priors,
img_meta=img_meta,
cfg=cfg,
rescale=rescale,
with_nms=with_nms)
result_list.append(results)
return result_list
def _predict_by_feat_single(self,
cls_score_list: List[Tensor],
bbox_pred_list: List[Tensor],
angle_pred_list: List[Tensor],
score_factor_list: List[Tensor],
mlvl_priors: List[Tensor],
img_meta: dict,
cfg: ConfigDict,
rescale: bool = False,
with_nms: bool = True) -> InstanceData:
"""Transform a single image's features extracted from the head into
bbox results.
Args:
cls_score_list (list[Tensor]): Box scores from all scale
levels of a single image, each item has shape
(num_priors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas from
all scale levels of a single image, each item has shape
(num_priors * 4, H, W).
angle_pred_list (list[Tensor]): Box angle for a single scale
level with shape (N, num_points * encode_size, H, W).
score_factor_list (list[Tensor]): Score factor from all scale
levels of a single image, each item has shape
(num_priors * 1, H, W).
mlvl_priors (list[Tensor]): Each element in the list is
the priors of a single level in feature pyramid. In all
anchor-based methods, it has shape (num_priors, 4). In
all anchor-free methods, it has shape (num_priors, 2)
when `with_stride=True`, otherwise it still has shape
(num_priors, 4).
img_meta (dict): Image meta info.
cfg (mmengine.Config): Test / postprocessing configuration,
if None, test_cfg would be used.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
:obj:`InstanceData`: Detection results of each image
after the post process.
Each item usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 5),
the last dimension 5 arrange as (x, y, w, h, t).
"""
if score_factor_list[0] is None:
# e.g. Retina, FreeAnchor, etc.
with_score_factors = False
else:
# e.g. FCOS, PAA, ATSS, etc.
with_score_factors = True
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)
img_shape = img_meta['img_shape']
nms_pre = cfg.get('nms_pre', -1)
mlvl_bbox_preds = []
mlvl_valid_priors = []
mlvl_scores = []
mlvl_labels = []
if with_score_factors:
mlvl_score_factors = []
else:
mlvl_score_factors = None
for level_idx, (
cls_score, bbox_pred, angle_pred, score_factor, priors) in \
enumerate(zip(cls_score_list, bbox_pred_list, angle_pred_list,
score_factor_list, mlvl_priors)):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
# dim = self.bbox_coder.encode_size
bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
angle_pred = angle_pred.permute(1, 2, 0).reshape(
-1, self.angle_coder.encode_size)
if with_score_factors:
score_factor = score_factor.permute(1, 2,
0).reshape(-1).sigmoid()
cls_score = cls_score.permute(1, 2,
0).reshape(-1, self.cls_out_channels)
if self.use_sigmoid_cls:
scores = cls_score.sigmoid()
else:
# remind that we set FG labels to [0, num_class-1]
# since mmdet v2.0
# BG cat_id: num_class
scores = cls_score.softmax(-1)[:, :-1]
# After https://github.com/open-mmlab/mmdetection/pull/6268/,
# this operation keeps fewer bboxes under the same `nms_pre`.
# There is no difference in performance for most models. If you
# find a slight drop in performance, you can set a larger
# `nms_pre` than before.
score_thr = cfg.get('score_thr', 0)
results = filter_scores_and_topk(
scores, score_thr, nms_pre,
dict(
bbox_pred=bbox_pred, angle_pred=angle_pred, priors=priors))
scores, labels, keep_idxs, filtered_results = results
bbox_pred = filtered_results['bbox_pred']
angle_pred = filtered_results['angle_pred']
priors = filtered_results['priors']
decoded_angle = self.angle_coder.decode(angle_pred, keepdim=True)
bbox_pred = torch.cat([bbox_pred, decoded_angle], dim=-1)
if with_score_factors:
score_factor = score_factor[keep_idxs]
mlvl_bbox_preds.append(bbox_pred)
mlvl_valid_priors.append(priors)
mlvl_scores.append(scores)
mlvl_labels.append(labels)
if with_score_factors:
mlvl_score_factors.append(score_factor)
bbox_pred = torch.cat(mlvl_bbox_preds)
priors = cat_boxes(mlvl_valid_priors)
bboxes = self.bbox_coder.decode(priors, bbox_pred, max_shape=img_shape)
results = InstanceData()
results.bboxes = RotatedBoxes(bboxes)
results.scores = torch.cat(mlvl_scores)
results.labels = torch.cat(mlvl_labels)
if with_score_factors:
results.score_factors = torch.cat(mlvl_score_factors)
results = self._bbox_post_process(
results=results,
cfg=cfg,
rescale=rescale,
with_nms=with_nms,
img_meta=img_meta)
if self.square_classes:
bboxes = get_box_tensor(results.bboxes)
for id in self.square_classes:
inds = results.labels == id
bboxes[inds, :] = hbox2rbox(rbox2hbox(bboxes[inds, :]))
results.bboxes = RotatedBoxes(bboxes)
return results