Shortcuts

Source code for mmrotate.models.dense_heads.cfa_head

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmdet.models.utils import multi_apply
from mmdet.utils import InstanceList, OptInstanceList
from mmengine.structures import InstanceData
from torch import Tensor
from typing import Dict, List, Optional, Tuple

from mmrotate.models.dense_heads.rotated_reppoints_head import \
    RotatedRepPointsHead
from mmrotate.registry import MODELS
from ..utils import convex_overlaps, levels_to_images


[docs] @MODELS.register_module() class CFAHead(RotatedRepPointsHead): """CFA head. Args: topk (int): Number of the highest topk points. Defaults to 6. anti_factor (float): Feature anti-aliasing coefficient. Defaults to 0.75. """ # noqa: W605 def __init__(self, *args, topk: int = 6, anti_factor: float = 0.75, **kwargs) -> None: super().__init__(*args, **kwargs) self.topk = topk self.anti_factor = anti_factor
[docs] def loss_by_feat_single(self, pts_pred_init: Tensor, bbox_gt_init: Tensor, bbox_weights_init: Tensor, stride: int, avg_factor_init: int) -> Tuple[Tensor]: """Calculate the loss of a single scale level based on the features extracted by the detection head. Args: pts_pred_init (Tensor): Points of shape (batch_size, h_i * w_i, num_points * 2). bbox_gt_init (Tensor): BBox regression targets in the init stage of shape (batch_size, h_i * w_i, 8). bbox_weights_init (Tensor): BBox regression loss weights in the init stage of shape (batch_size, h_i * w_i, 8). stride (int): Point stride. avg_factor_init (int): Average factor that is used to average the loss in the init stage. Returns: Tuple[Tensor]: loss components. """ bbox_gt_init = bbox_gt_init.reshape(-1, 8) pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) bbox_weights_init = bbox_weights_init.reshape(-1) pos_ind_init = (bbox_weights_init > 0).nonzero(as_tuple=False).reshape(-1) pos_bbox_gt_init = bbox_gt_init[pos_ind_init] pos_pts_pred_init = pts_pred_init[pos_ind_init] pos_bbox_weights_init = bbox_weights_init[pos_ind_init] normalize_term = self.point_base_scale * stride loss_pts_init = self.loss_bbox_init( pos_pts_pred_init / normalize_term, pos_bbox_gt_init / normalize_term, pos_bbox_weights_init, avg_factor=avg_factor_init) return loss_pts_init,
[docs] def loss_by_feat( self, cls_scores: List[Tensor], pts_preds_init: List[Tensor], pts_preds_refine: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> Dict[str, Tensor]: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, of shape (batch_size, num_classes, h, w). pts_preds_init (list[Tensor]): Points for each scale level, each is a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). pts_preds_refine (list[Tensor]): Points refined for each scale level, each is a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] device = cls_scores[0].device # target for initial stage center_list, valid_flag_list = self.get_points(featmap_sizes, batch_img_metas, device) pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init) num_proposals_each_level = [(featmap.size(-1) * featmap.size(-2)) for featmap in cls_scores] num_level = len(featmap_sizes) assert num_level == len(pts_coordinate_preds_init) if self.train_cfg.init.assigner['type'] == 'ConvexAssigner': # Assign target for center list candidate_list = center_list else: raise NotImplementedError cls_reg_targets_init = self.get_targets( proposals_list=candidate_list, valid_flag_list=valid_flag_list, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore, stage='init', return_sampling_results=False) (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, avg_factor_init) = cls_reg_targets_init # target for refinement stage center_list, valid_flag_list = self.get_points(featmap_sizes, batch_img_metas, device) pts_coordinate_preds_refine = self.offset_to_pts( center_list, pts_preds_refine) bbox_list = [] for i_img, center in enumerate(center_list): bbox = [] for i_lvl in range(len(pts_preds_refine)): points_preds_init_ = pts_preds_init[i_lvl].detach() points_preds_init_ = points_preds_init_.view( points_preds_init_.shape[0], -1, *points_preds_init_.shape[2:]) points_shift = points_preds_init_.permute( 0, 2, 3, 1) * self.point_strides[i_lvl] points_center = center[i_lvl][:, :2].repeat(1, self.num_points) bbox.append( points_center + points_shift[i_img].reshape(-1, 2 * self.num_points)) bbox_list.append(bbox) cls_reg_targets_refine = self.get_cfa_targets( bbox_list, valid_flag_list, batch_gt_instances, batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore, stage='refine', return_sampling_results=False) (labels_list, label_weights_list, bbox_gt_list_refine, _, bbox_weights_list_refine, pos_inds_list_refine, pos_gt_index_list_refine) = cls_reg_targets_refine cls_scores = levels_to_images(cls_scores) cls_scores = [ item.reshape(-1, self.cls_out_channels) for item in cls_scores ] pts_coordinate_preds_init_cfa = levels_to_images( pts_coordinate_preds_init, flatten=True) pts_coordinate_preds_init_cfa = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_init_cfa ] pts_coordinate_preds_refine = levels_to_images( pts_coordinate_preds_refine, flatten=True) pts_coordinate_preds_refine = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_refine ] with torch.no_grad(): pos_losses_list, = multi_apply(self.get_pos_loss, cls_scores, pts_coordinate_preds_init_cfa, labels_list, bbox_gt_list_refine, label_weights_list, bbox_weights_list_refine, pos_inds_list_refine) labels_list, label_weights_list, bbox_weights_list_refine, \ num_pos, pos_normalize_term = multi_apply( self.reassign, pos_losses_list, labels_list, label_weights_list, pts_coordinate_preds_init_cfa, bbox_weights_list_refine, batch_gt_instances, pos_inds_list_refine, pos_gt_index_list_refine, num_proposals_each_level=num_proposals_each_level, num_level=num_level ) num_pos = sum(num_pos) # convert all tensor list to a flatten tensor cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) pts_preds_refine = torch.cat(pts_coordinate_preds_refine, 0).view( -1, pts_coordinate_preds_refine[0].size(-1)) labels = torch.cat(labels_list, 0).view(-1) labels_weight = torch.cat(label_weights_list, 0).view(-1) rbbox_gt_refine = torch.cat(bbox_gt_list_refine, 0).view(-1, bbox_gt_list_refine[0].size(-1)) convex_weights_refine = torch.cat(bbox_weights_list_refine, 0).view(-1) pos_normalize_term = torch.cat(pos_normalize_term, 0).reshape(-1) pos_inds_flatten = ((0 <= labels) & (labels < self.num_classes)).nonzero( as_tuple=False).reshape(-1) assert len(pos_normalize_term) == len(pos_inds_flatten) if num_pos: losses_cls = self.loss_cls( cls_scores, labels, labels_weight, avg_factor=num_pos) pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten] pos_rbbox_gt_refine = rbbox_gt_refine[pos_inds_flatten] pos_convex_weights_refine = convex_weights_refine[pos_inds_flatten] losses_pts_refine = self.loss_bbox_refine( pos_pts_pred_refine / pos_normalize_term.reshape(-1, 1), pos_rbbox_gt_refine / pos_normalize_term.reshape(-1, 1), pos_convex_weights_refine) else: losses_cls = cls_scores.sum() * 0 losses_pts_refine = pts_preds_refine.sum() * 0 losses_pts_init, = multi_apply( self.loss_by_feat_single, pts_coordinate_preds_init, bbox_gt_list_init, bbox_weights_list_init, self.point_strides, avg_factor_init=avg_factor_init, ) loss_dict_all = { 'loss_cls': losses_cls, 'loss_pts_init': losses_pts_init, 'loss_pts_refine': losses_pts_refine } return loss_dict_all
[docs] def get_cfa_targets(self, proposals_list: List[Tensor], valid_flag_list: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None, stage: str = 'init', unmap_outputs: bool = True, return_sampling_results: bool = False) -> tuple: """Compute corresponding GT box and classification targets for proposals. Args: proposals_list (list[Tensor]): Multi level points/bboxes of each image. valid_flag_list (list[Tensor]): Multi level valid flags of each image. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. stage (str): 'init' or 'refine'. Generate target for init stage or refine stage. Defaults to 'init'. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Defaults to True. return_sampling_results (bool): Whether to return the sampling results. Defaults to False. Returns: tuple: - all_labels (list[Tensor]): Labels of each level. - all_label_weights (list[Tensor]): Label weights of each level. - all_bbox_gt (list[Tensor]): Ground truth bbox of each level. - all_proposals (list[Tensor]): Proposals(points/bboxes) of each level. - all_proposal_weights (list[Tensor]): Proposal weights of each level. - pos_inds (list[Tensor]): Index of positive samples in all images. - gt_inds (list[Tensor]): Index of ground truth bbox in all images. """ assert stage in ['init', 'refine'] num_imgs = len(batch_img_metas) assert len(proposals_list) == len(valid_flag_list) == num_imgs # concat all level points and flags to a single tensor for i in range(num_imgs): assert len(proposals_list[i]) == len(valid_flag_list[i]) proposals_list[i] = torch.cat(proposals_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i]) if batch_gt_instances_ignore is None: batch_gt_instances_ignore = [None] * num_imgs (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds_list, neg_inds_list, sampling_result) = multi_apply( self._get_targets_single, proposals_list, valid_flag_list, batch_gt_instances, batch_gt_instances_ignore, stage=stage, unmap_outputs=unmap_outputs) pos_inds = [] for i, single_labels in enumerate(all_labels): pos_mask = (0 <= single_labels) & ( single_labels < self.num_classes) pos_inds.append(pos_mask.nonzero(as_tuple=False).view(-1)) gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] return (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds, gt_inds)
[docs] def get_pos_loss(self, cls_score: Tensor, pts_pred: Tensor, label: Tensor, bbox_gt: Tensor, label_weight: Tensor, convex_weight: Tensor, pos_inds: Tensor) -> Tensor: """Calculate loss of all potential positive samples obtained from first match process. Args: cls_score (Tensor): Box scores of single image with shape (num_anchors, num_classes) pts_pred (Tensor): Box energies / deltas of single image with shape (num_anchors, 4) label (Tensor): classification target of each anchor with shape (num_anchors,) bbox_gt (Tensor): Ground truth box. label_weight (Tensor): Classification loss weight of each anchor with shape (num_anchors). convex_weight (Tensor): Bbox weight of each anchor with shape (num_anchors, 4). pos_inds (Tensor): Index of all positive samples got from first assign process. Returns: Tensor: Losses of all positive samples in single image. """ # avoid no positive samplers if pos_inds.shape[0] == 0: pos_scores = cls_score pos_pts_pred = pts_pred pos_bbox_gt = bbox_gt pos_label = label pos_label_weight = label_weight pos_convex_weight = convex_weight else: pos_scores = cls_score[pos_inds] pos_pts_pred = pts_pred[pos_inds] pos_bbox_gt = bbox_gt[pos_inds] pos_label = label[pos_inds] pos_label_weight = label_weight[pos_inds] pos_convex_weight = convex_weight[pos_inds] loss_cls = self.loss_cls( pos_scores, pos_label, pos_label_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') loss_bbox = self.loss_bbox_refine( pos_pts_pred, pos_bbox_gt, pos_convex_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') loss_cls = loss_cls.sum(-1) pos_loss = loss_bbox + loss_cls return pos_loss,
[docs] def reassign(self, pos_losses: Tensor, label: Tensor, label_weight: Tensor, pts_pred_init: Tensor, convex_weight: Tensor, gt_instances: InstanceData, pos_inds: Tensor, pos_gt_inds: Tensor, num_proposals_each_level: Optional[List] = None, num_level: Optional[int] = None) -> tuple: """CFA reassign process. Args: pos_losses (Tensor): Losses of all positive samples in single image. label (Tensor): classification target of each anchor with shape (num_anchors,) label_weight (Tensor): Classification loss weight of each anchor with shape (num_anchors). pts_pred_init (Tensor): convex_weight (Tensor): Bbox weight of each anchor with shape (num_anchors, 4). gt_instances (:obj:`InstanceData`): Ground truth of instance annotations. It usually includes ``bboxes`` and ``labels`` attributes. pos_inds (Tensor): Index of all positive samples got from first assign process. pos_gt_inds (Tensor): Gt_index of all positive samples got from first assign process. num_proposals_each_level (list, optional): Number of proposals of each level. num_level (int, optional): Number of level. Returns: tuple: Usually returns a tuple containing learning targets. - label (Tensor): classification target of each anchor after paa assign, with shape (num_anchors,) - label_weight (Tensor): Classification loss weight of each anchor after paa assign, with shape (num_anchors). - convex_weight (Tensor): Bbox weight of each anchor with shape (num_anchors, 4). - pos_normalize_term (list): pos normalize term for refine points losses. """ if len(pos_inds) == 0: return label, label_weight, convex_weight, 0, torch.tensor( []).type_as(convex_weight) num_gt = pos_gt_inds.max() + 1 num_proposals_each_level_ = num_proposals_each_level.copy() num_proposals_each_level_.insert(0, 0) inds_level_interval = np.cumsum(num_proposals_each_level_) pos_level_mask = [] for i in range(num_level): mask = (pos_inds >= inds_level_interval[i]) & ( pos_inds < inds_level_interval[i + 1]) pos_level_mask.append(mask) overlaps_matrix = convex_overlaps(gt_instances['bboxes'], pts_pred_init) pos_inds_after_cfa = [] ignore_inds_after_cfa = [] re_assign_weights_after_cfa = [] for gt_ind in range(num_gt): pos_inds_cfa = [] pos_loss_cfa = [] pos_overlaps_init_cfa = [] gt_mask = pos_gt_inds == gt_ind for level in range(num_level): level_mask = pos_level_mask[level] level_gt_mask = level_mask & gt_mask value, topk_inds = pos_losses[level_gt_mask].topk( min(level_gt_mask.sum(), self.topk), largest=False) pos_inds_cfa.append(pos_inds[level_gt_mask][topk_inds]) pos_loss_cfa.append(value) pos_overlaps_init_cfa.append( overlaps_matrix[:, pos_inds[level_gt_mask][topk_inds]]) pos_inds_cfa = torch.cat(pos_inds_cfa) pos_loss_cfa = torch.cat(pos_loss_cfa) pos_overlaps_init_cfa = torch.cat(pos_overlaps_init_cfa, 1) if len(pos_inds_cfa) < 2: pos_inds_after_cfa.append(pos_inds_cfa) ignore_inds_after_cfa.append(pos_inds_cfa.new_tensor([])) re_assign_weights_after_cfa.append( pos_loss_cfa.new_ones([len(pos_inds_cfa)])) else: pos_loss_cfa, sort_inds = pos_loss_cfa.sort() pos_inds_cfa = pos_inds_cfa[sort_inds] pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, sort_inds] \ .reshape(-1, len(pos_inds_cfa)) pos_loss_cfa = pos_loss_cfa.reshape(-1) loss_mean = pos_loss_cfa.mean() loss_var = pos_loss_cfa.var() gauss_prob_density = \ (-(pos_loss_cfa - loss_mean) ** 2 / loss_var) \ .exp() / loss_var.sqrt() index_inverted, _ = torch.arange( len(gauss_prob_density)).sort(descending=True) gauss_prob_inverted = torch.cumsum( gauss_prob_density[index_inverted], 0) gauss_prob = gauss_prob_inverted[index_inverted] gauss_prob_norm = (gauss_prob - gauss_prob.min()) / \ (gauss_prob.max() - gauss_prob.min()) # splitting by gradient consistency loss_curve = gauss_prob_norm * pos_loss_cfa _, max_thr = loss_curve.topk(1) reweights = gauss_prob_norm[:max_thr + 1] # feature anti-aliasing coefficient pos_overlaps_init_cfa = pos_overlaps_init_cfa[:, :max_thr + 1] overlaps_level = pos_overlaps_init_cfa[gt_ind] / ( pos_overlaps_init_cfa.sum(0) + 1e-6) reweights = \ self.anti_factor * overlaps_level * reweights + \ 1e-6 re_assign_weights = \ reweights.reshape(-1) / reweights.sum() * \ torch.ones(len(reweights)).type_as( gauss_prob_norm).sum() pos_inds_temp = pos_inds_cfa[:max_thr + 1] ignore_inds_temp = pos_inds_cfa.new_tensor([]) pos_inds_after_cfa.append(pos_inds_temp) ignore_inds_after_cfa.append(ignore_inds_temp) re_assign_weights_after_cfa.append(re_assign_weights) pos_inds_after_cfa = torch.cat(pos_inds_after_cfa) ignore_inds_after_cfa = torch.cat(ignore_inds_after_cfa) re_assign_weights_after_cfa = torch.cat(re_assign_weights_after_cfa) reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_cfa).all(1) reassign_ids = pos_inds[reassign_mask] label[reassign_ids] = self.num_classes label_weight[ignore_inds_after_cfa] = 0 convex_weight[reassign_ids] = 0 num_pos = len(pos_inds_after_cfa) re_assign_weights_mask = ( pos_inds.unsqueeze(1) == pos_inds_after_cfa).any(1) reweight_ids = pos_inds[re_assign_weights_mask] label_weight[reweight_ids] = re_assign_weights_after_cfa convex_weight[reweight_ids] = re_assign_weights_after_cfa pos_level_mask_after_cfa = [] for i in range(num_level): mask = (pos_inds_after_cfa >= inds_level_interval[i]) & ( pos_inds_after_cfa < inds_level_interval[i + 1]) pos_level_mask_after_cfa.append(mask) pos_level_mask_after_cfa = torch.stack(pos_level_mask_after_cfa, 0).type_as(label) pos_normalize_term = pos_level_mask_after_cfa * ( self.point_base_scale * torch.as_tensor(self.point_strides).type_as(label)).reshape(-1, 1) pos_normalize_term = pos_normalize_term[pos_normalize_term > 0].type_as(convex_weight) assert len(pos_normalize_term) == len(pos_inds_after_cfa) return label, label_weight, convex_weight, num_pos, pos_normalize_term