Source code for mmrotate.models.dense_heads.sam_reppoints_head
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.structures import InstanceData
from torch import Tensor
from typing import Dict, List, Tuple
from mmrotate.models.dense_heads.rotated_reppoints_head import \
RotatedRepPointsHead
from mmrotate.registry import MODELS
from mmrotate.structures.bbox import qbox2rbox
from ..utils import get_num_level_anchors_inside, points_center_pts
[docs]
@MODELS.register_module()
class SAMRepPointsHead(RotatedRepPointsHead):
"""SAM RepPoints head."""
def _get_targets_single(self,
flat_proposals: Tensor,
valid_flags: Tensor,
num_level_proposals: List[int],
gt_instances: InstanceData,
gt_instances_ignore: InstanceData,
stage: str = 'init',
unmap_outputs: bool = True) -> tuple:
"""Compute corresponding GT box and classification targets for
proposals.
Args:
flat_proposals (Tensor): Multi level points of a image.
valid_flags (Tensor): Multi level valid flags of a image.
num_level_proposals (List[int]): Number of anchors of each scale
level.
gt_instances (InstanceData): It usually includes ``bboxes`` and
``labels`` attributes.
gt_instances_ignore (InstanceData): It includes ``bboxes``
attribute data that is ignored during training and testing.
stage (str): 'init' or 'refine'. Generate target for
init stage or refine stage. Defaults to 'init'.
unmap_outputs (bool): Whether to map outputs back to
the original set of anchors. Defaults to True.
Returns:
tuple:
- labels (Tensor): Labels of each level.
- label_weights (Tensor): Label weights of each level.
- bbox_targets (Tensor): BBox targets of each level.
- bbox_weights (Tensor): BBox weights of each level.
- pos_inds (Tensor): positive samples indexes.
- neg_inds (Tensor): negative samples indexes.
- sampling_result (:obj:`SamplingResult`): Sampling results.
"""
inside_flags = valid_flags
if not inside_flags.any():
raise ValueError(
'There is no valid proposal inside the image boundary. Please '
'check the image size.')
# assign gt and sample proposals
proposals = flat_proposals[inside_flags, :]
num_level_proposals_inside = get_num_level_anchors_inside(
num_level_proposals, inside_flags)
pred_instances = InstanceData(priors=proposals)
if stage == 'init':
assigner = self.init_assigner
pos_weight = self.train_cfg.init.pos_weight
assign_result = assigner.assign(pred_instances, gt_instances,
gt_instances_ignore)
else:
assigner = self.refine_assigner
pos_weight = self.train_cfg.refine.pos_weight
if self.train_cfg.refine.assigner['type'] not in (
'ATSSAssigner', 'ATSSConvexAssigner', 'SASAssigner'):
assign_result = assigner.assign(pred_instances, gt_instances,
gt_instances_ignore)
else:
assign_result = assigner.assign(pred_instances,
num_level_proposals_inside,
gt_instances,
gt_instances_ignore)
sampling_result = self.sampler.sample(assign_result, pred_instances,
gt_instances)
num_valid_proposals = proposals.shape[0]
bbox_gt = proposals.new_zeros([num_valid_proposals, 8])
pos_proposals = torch.zeros_like(proposals)
proposals_weights = proposals.new_zeros(num_valid_proposals)
labels = proposals.new_full((num_valid_proposals, ),
self.num_classes,
dtype=torch.long)
label_weights = proposals.new_zeros(
num_valid_proposals, dtype=torch.float)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
if len(pos_inds) > 0:
bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes
pos_proposals[pos_inds, :] = proposals[pos_inds, :]
proposals_weights[pos_inds] = 1.0
labels[pos_inds] = sampling_result.pos_gt_labels
if pos_weight <= 0:
label_weights[pos_inds] = 1.0
else:
label_weights[pos_inds] = pos_weight
if len(neg_inds) > 0:
label_weights[neg_inds] = 1.0
# use la
rbboxes_center, width, height, angles = torch.split(
qbox2rbox(bbox_gt), [2, 1, 1, 1], dim=-1)
if stage == 'init':
points_xy = pos_proposals[:, :2]
else:
points_xy = points_center_pts(pos_proposals, y_first=True)
distances = torch.zeros_like(angles).reshape(-1)
angles_index_wh = ((width != 0) & (angles >= 0) &
(angles <= 1.57)).squeeze()
angles_index_hw = ((width != 0) & ((angles < 0) |
(angles > 1.57))).squeeze()
# 01_la:compution of distance
distances[angles_index_wh] = torch.sqrt(
(torch.pow(
rbboxes_center[angles_index_wh, 0] -
points_xy[angles_index_wh, 0], 2) /
width[angles_index_wh].squeeze()) +
(torch.pow(
rbboxes_center[angles_index_wh, 1] -
points_xy[angles_index_wh, 1], 2) /
height[angles_index_wh].squeeze()))
distances[angles_index_hw] = torch.sqrt(
(torch.pow(
rbboxes_center[angles_index_hw, 0] -
points_xy[angles_index_hw, 0], 2) /
height[angles_index_hw].squeeze()) +
(torch.pow(
rbboxes_center[angles_index_hw, 1] -
points_xy[angles_index_hw, 1], 2) /
width[angles_index_hw].squeeze()))
distances[distances == float('nan')] = 0.
sam_weights = label_weights * (torch.exp(1 / (distances + 1)))
sam_weights[sam_weights == float('inf')] = 0.
# map up to original set of proposals
if unmap_outputs:
num_total_proposals = flat_proposals.size(0)
labels = unmap(
labels,
num_total_proposals,
inside_flags,
fill=self.num_classes) # fill bg label
label_weights = unmap(label_weights, num_total_proposals,
inside_flags)
bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
pos_proposals = unmap(pos_proposals, num_total_proposals,
inside_flags)
proposals_weights = unmap(proposals_weights, num_total_proposals,
inside_flags)
sam_weights = unmap(sam_weights, num_total_proposals, inside_flags)
return (labels, label_weights, bbox_gt, pos_proposals,
proposals_weights, pos_inds, neg_inds, sampling_result,
sam_weights)
[docs]
def get_targets(self,
proposals_list: List[Tensor],
valid_flag_list: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None,
stage: str = 'init',
unmap_outputs: bool = True) -> tuple:
"""Compute corresponding GT box and classification targets for
proposals.
Args:
proposals_list (list[Tensor]): Multi level points/bboxes of each
image.
valid_flag_list (list[Tensor]): Multi level valid flags of each
image.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
stage (str): 'init' or 'refine'. Generate target for init stage or
refine stage. Defaults to 'init'.
unmap_outputs (bool): Whether to map outputs back to the original
set of anchors. Defaults to True.
Returns:
tuple:
- labels_list (list[Tensor]): Labels of each level.
- label_weights_list (list[Tensor]): Label weights of each
level.
- bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
- proposals_list (list[Tensor]): Proposals(points/bboxes) of
each level.
- proposal_weights_list (list[Tensor]): Proposal weights of
each level.
- avg_factor (int): Average factor that is used to average
the loss. When using sampling method, avg_factor is usually
the sum of positive and negative priors. When using
`PseudoSampler`, `avg_factor` is usually equal to the number
of positive priors.
"""
assert stage in ['init', 'refine']
num_imgs = len(batch_img_metas)
assert len(proposals_list) == len(valid_flag_list) == num_imgs
# points number of multi levels
num_level_proposals = [points.size(0) for points in proposals_list[0]]
num_level_proposals_list = [num_level_proposals] * num_imgs
# concat all level points and flags to a single tensor
for i in range(num_imgs):
assert len(proposals_list[i]) == len(valid_flag_list[i])
proposals_list[i] = torch.cat(proposals_list[i])
valid_flag_list[i] = torch.cat(valid_flag_list[i])
if batch_gt_instances_ignore is None:
batch_gt_instances_ignore = [None] * num_imgs
(all_labels, all_label_weights, all_bbox_gt, all_proposals,
all_proposal_weights, pos_inds_list, neg_inds_list,
sampling_results_list, all_sam_weights) = multi_apply(
self._get_targets_single,
proposals_list,
valid_flag_list,
num_level_proposals_list,
batch_gt_instances,
batch_gt_instances_ignore,
stage=stage,
unmap_outputs=unmap_outputs)
# sampled points of all images
avg_refactor = sum(
[results.avg_factor for results in sampling_results_list])
labels_list = images_to_levels(all_labels, num_level_proposals)
label_weights_list = images_to_levels(all_label_weights,
num_level_proposals)
bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
proposals_list = images_to_levels(all_proposals, num_level_proposals)
proposal_weights_list = images_to_levels(all_proposal_weights,
num_level_proposals)
sam_weights_list = images_to_levels(all_sam_weights,
num_level_proposals)
res = (labels_list, label_weights_list, bbox_gt_list, proposals_list,
proposal_weights_list, avg_refactor, sam_weights_list)
return res
[docs]
def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor,
pts_pred_refine: Tensor, labels: Tensor,
label_weights, bbox_gt_init: Tensor,
bbox_weights_init: Tensor,
sam_weights_init: Tensor, bbox_gt_refine: Tensor,
bbox_weights_refine: Tensor,
sam_weights_refine: Tensor, stride: int,
avg_factor_refine: int) -> Tuple[Tensor]:
"""Calculate the loss of a single scale level based on the features
extracted by the detection head.
Args:
cls_score (Tensor): Box scores for each scale level
Has shape (N, num_classes, h_i, w_i).
pts_pred_init (Tensor): Points of shape
(batch_size, h_i * w_i, num_points * 2).
pts_pred_refine (Tensor): Points refined of shape
(batch_size, h_i * w_i, num_points * 2).
labels (Tensor): Ground truth class indices with shape
(batch_size, h_i * w_i).
label_weights (Tensor): Label weights of shape
(batch_size, h_i * w_i).
bbox_gt_init (Tensor): BBox regression targets in the init stage
of shape (batch_size, h_i * w_i, 8).
bbox_weights_init (Tensor): BBox regression loss weights in the
init stage of shape (batch_size, h_i * w_i, 8).
sam_weights_init (Tensor):
bbox_gt_refine (Tensor): BBox regression targets in the refine
stage of shape (batch_size, h_i * w_i, 8).
bbox_weights_refine (Tensor): BBox regression loss weights in the
refine stage of shape (batch_size, h_i * w_i, 8).
sam_weights_refine (Tensor):
stride (int): Point stride.
avg_factor_refine (int): Average factor that is used to average
the loss in the refine stage.
Returns:
Tuple[Tensor]: loss components.
"""
# classification loss
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
cls_score = cls_score.contiguous()
loss_cls = self.loss_cls(
cls_score, labels, label_weights, avg_factor=avg_factor_refine)
# init loss
bbox_gt_init = bbox_gt_init.reshape(-1, 8)
pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points)
bbox_weights_init = bbox_weights_init.reshape(-1)
sam_weights_init = sam_weights_init.reshape(-1)
pos_ind_init = (bbox_weights_init
> 0).nonzero(as_tuple=False).reshape(-1)
pos_bbox_gt_init = bbox_gt_init[pos_ind_init]
pos_pts_pred_init = pts_pred_init[pos_ind_init]
pos_bbox_weights_init = bbox_weights_init[pos_ind_init]
sam_weights_pos_init = sam_weights_init[pos_ind_init]
normalize_term = self.point_base_scale * stride
loss_pts_init = self.loss_bbox_init(
pos_pts_pred_init / normalize_term,
pos_bbox_gt_init / normalize_term,
pos_bbox_weights_init * sam_weights_pos_init)
# refine loss
bbox_gt_refine = bbox_gt_refine.reshape(-1, 8)
pts_pred_refine = pts_pred_refine.reshape(-1, 2 * self.num_points)
bbox_weights_refine = bbox_weights_refine.reshape(-1)
sam_weights_refine = sam_weights_refine.reshape(-1)
pos_ind_refine = (bbox_weights_refine
> 0).nonzero(as_tuple=False).reshape(-1)
pos_bbox_gt_refine = bbox_gt_refine[pos_ind_refine]
pos_pts_pred_refine = pts_pred_refine[pos_ind_refine]
pos_bbox_weights_refine = bbox_weights_refine[pos_ind_refine]
sam_weights_pos_refine = sam_weights_refine[pos_ind_refine]
loss_pts_refine = self.loss_bbox_refine(
pos_pts_pred_refine / normalize_term,
pos_bbox_gt_refine / normalize_term,
pos_bbox_weights_refine * sam_weights_pos_refine)
return loss_cls, loss_pts_init, loss_pts_refine
[docs]
def loss_by_feat(
self,
cls_scores: List[Tensor],
pts_preds_init: List[Tensor],
pts_preds_refine: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
batch_gt_instances_ignore: OptInstanceList = None
) -> Dict[str, Tensor]:
"""Calculate the loss based on the features extracted by the detection
head.
Args:
cls_scores (list[Tensor]): Box scores for each scale level,
each is a 4D-tensor, of shape (batch_size, num_classes, h, w).
pts_preds_init (list[Tensor]): Points for each scale level, each is
a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2).
pts_preds_refine (list[Tensor]): Points refined for each scale
level, each is a 3D-tensor, of shape
(batch_size, h_i * w_i, num_points * 2).
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
device = cls_scores[0].device
# target for initial stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
batch_img_metas, device)
pts_coordinate_preds_init = self.offset_to_pts(center_list,
pts_preds_init)
if self.train_cfg.init.assigner['type'] == 'ConvexAssigner':
# Assign target for center list
candidate_list = center_list
else:
raise NotImplementedError
cls_reg_targets_init = self.get_targets(
proposals_list=candidate_list,
valid_flag_list=valid_flag_list,
batch_gt_instances=batch_gt_instances,
batch_img_metas=batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
stage='init')
(*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
avg_factor_init, sam_weights_list_init) = cls_reg_targets_init
# target for refinement stage
center_list, valid_flag_list = self.get_points(featmap_sizes,
batch_img_metas, device)
pts_coordinate_preds_refine = self.offset_to_pts(
center_list, pts_preds_refine)
bbox_list = []
for i_img, center in enumerate(center_list):
bbox = []
for i_lvl in range(len(pts_preds_refine)):
points_preds_init_ = pts_preds_init[i_lvl].detach()
points_preds_init_ = points_preds_init_.view(
points_preds_init_.shape[0], -1,
*points_preds_init_.shape[2:])
points_shift = points_preds_init_.permute(
0, 2, 3, 1) * self.point_strides[i_lvl]
points_center = center[i_lvl][:, :2].repeat(1, self.num_points)
bbox.append(
points_center +
points_shift[i_img].reshape(-1, 2 * self.num_points))
bbox_list.append(bbox)
cls_reg_targets_refine = self.get_targets(
proposals_list=bbox_list,
valid_flag_list=valid_flag_list,
batch_gt_instances=batch_gt_instances,
batch_img_metas=batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore,
stage='refine')
(labels_list, label_weights_list, bbox_gt_list_refine,
candidate_list_refine, bbox_weights_list_refine, avg_factor_refine,
sam_weights_list_refine) = cls_reg_targets_refine
# compute loss
losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
self.loss_by_feat_single,
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
sam_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
sam_weights_list_refine,
self.point_strides,
avg_factor_refine=avg_factor_refine)
loss_dict_all = {
'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
'loss_pts_refine': losses_pts_refine
}
return loss_dict_all