Shortcuts

Source code for mmrotate.models.dense_heads.oriented_reppoints_head

# Copyright (c) OpenMMLab. All rights reserved.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from mmcv.ops import chamfer_distance, min_area_polygons
except ImportError:  # noqa: E722

    def chamfer_distance(*args, **kwargs):
        raise RuntimeError('chamfer_distance from mmcv.ops is not available. '
                           'Please install onedl-mmcv with ops support.')

    def min_area_polygons(*args, **kwargs):
        raise RuntimeError('min_area_polygons from mmcv.ops is not available. '
                           'Please install onedl-mmcv with ops support.')


from mmdet.models.utils import images_to_levels, multi_apply, unmap
from mmdet.utils import ConfigType, InstanceList, OptInstanceList
from mmengine.structures import InstanceData
from torch import Tensor
from typing import Dict, List, Optional, Tuple

from mmrotate.models.dense_heads.rotated_reppoints_head import \
    RotatedRepPointsHead
from mmrotate.registry import MODELS
from ..utils import levels_to_images


def ChamferDistance2D(point_set_1: Tensor,
                      point_set_2: Tensor,
                      distance_weight: float = 0.05,
                      eps: float = 1e-12):
    """Compute the Chamfer distance between two point sets.

    Args:
        point_set_1 (Tensor): point set 1 with shape
            (N_pointsets, N_points, 2)
        point_set_2 (Tensor): point set 2 with shape
            (N_pointsets, N_points, 2)
        distance_weight (float): weight of chamfer distance loss.
        eps (float): a value added to the denominator for numerical
            stability. Defaults to 1e-12.

    Returns:
        Tensor: chamfer distance between two point sets
        with shape (N_pointsets,)
    """
    assert point_set_1.dim() == point_set_2.dim()
    assert point_set_1.shape[-1] == point_set_2.shape[-1]
    assert point_set_1.dim() <= 3
    dist1, dist2, _, _ = chamfer_distance(point_set_1, point_set_2)
    dist1 = torch.sqrt(torch.clamp(dist1, eps))
    dist2 = torch.sqrt(torch.clamp(dist2, eps))
    dist = distance_weight * (dist1.mean(-1) + dist2.mean(-1)) / 2.0

    return dist


[docs] @MODELS.register_module() class OrientedRepPointsHead(RotatedRepPointsHead): """Oriented RepPoints head -<https://arxiv.org/pdf/2105.11111v4.pdf>. The head contains initial and refined stages based on RepPoints. The initial stage regresses coarse point sets, and the refine stage further regresses the fine point sets. The APAA scheme based on the quality of point set samples in the paper is employed in refined stage. Args: loss_spatial_init (:obj:`ConfigDict` or dict): Config of initial spatial loss. loss_spatial_refine (:obj:`ConfigDict` or dict): Config of refine spatial loss. top_ratio (float): Ratio of top high-quality point sets. Defaults to 0.4. init_qua_weight (float): Quality weight of initial stage. Defaults to 0.2. ori_qua_weight (float): Orientation quality weight. Defaults to 0.3. poc_qua_weight (float): Point-wise correlation quality weight. Defaults to 0.1. """ # noqa: W605 def __init__(self, *args, loss_spatial_init: ConfigType = dict( type='SpatialBorderLoss', loss_weight=0.05), loss_spatial_refine: ConfigType = dict( type='SpatialBorderLoss', loss_weight=0.1), top_ratio: float = 0.4, init_qua_weight: float = 0.2, ori_qua_weight: float = 0.3, poc_qua_weight: float = 0.1, **kwargs) -> None: super().__init__(*args, **kwargs) self.loss_spatial_init = MODELS.build(loss_spatial_init) self.loss_spatial_refine = MODELS.build(loss_spatial_refine) self.top_ratio = top_ratio self.init_qua_weight = init_qua_weight self.ori_qua_weight = ori_qua_weight self.poc_qua_weight = poc_qua_weight
[docs] def forward_single(self, x: Tensor) -> Tuple[Tensor]: """Forward feature map of a single FPN level.""" dcn_base_offset = self.dcn_base_offset.type_as(x) # If we use center_init, the initial reppoints is from center points. # If we use bounding bbox representation, the initial reppoints is # from regular grid placed on a pre-defined bbox. if self.use_grid_points or not self.center_init: scale = self.point_base_scale / 2 points_init = dcn_base_offset / dcn_base_offset.max() * scale bbox_init = x.new_tensor([-scale, -scale, scale, scale]).view(1, 4, 1, 1) else: points_init = 0 cls_feat = x pts_feat = x base_feat = x for cls_conv in self.cls_convs: cls_feat = cls_conv(cls_feat) for reg_conv in self.reg_convs: pts_feat = reg_conv(pts_feat) # initialize reppoints pts_out_init = self.reppoints_pts_init_out( self.relu(self.reppoints_pts_init_conv(pts_feat))) if self.use_grid_points: pts_out_init, bbox_out_init = self.gen_grid_from_reg( pts_out_init, bbox_init.detach()) else: pts_out_init = pts_out_init + points_init # refine and classify reppoints pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach( ) + self.gradient_mul * pts_out_init dcn_offset = pts_out_init_grad_mul - dcn_base_offset cls_out = self.reppoints_cls_out( self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset))) pts_out_refine = self.reppoints_pts_refine_out( self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset))) if self.use_grid_points: pts_out_refine, bbox_out_refine = self.gen_grid_from_reg( pts_out_refine, bbox_out_init.detach()) else: pts_out_refine = pts_out_refine + pts_out_init.detach() if self.training: return cls_out, pts_out_init, pts_out_refine, base_feat else: return cls_out, pts_out_refine
def _get_targets_single(self, flat_proposals: Tensor, valid_flags: Tensor, gt_instances: InstanceData, gt_instances_ignore: InstanceData, stage: str = 'init', unmap_outputs: bool = True) -> tuple: """Compute corresponding GT box and classification targets for proposals. Args: flat_proposals (Tensor): Multi level points of a image. valid_flags (Tensor): Multi level valid flags of a image. gt_instances (InstanceData): It usually includes ``bboxes`` and ``labels`` attributes. gt_instances_ignore (InstanceData): It includes ``bboxes`` attribute data that is ignored during training and testing. stage (str): 'init' or 'refine'. Generate target for init stage or refine stage. Defaults to 'init'. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Defaults to True. Returns: tuple: - labels (Tensor): Labels of each level. - label_weights (Tensor): Label weights of each level. - bbox_targets (Tensor): BBox targets of each level. - bbox_weights (Tensor): BBox weights of each level. - pos_inds (Tensor): positive samples indexes. - neg_inds (Tensor): negative samples indexes. - sampling_result (:obj:`SamplingResult`): Sampling results. """ inside_flags = valid_flags if not inside_flags.any(): raise ValueError( 'There is no valid proposal inside the image boundary. Please ' 'check the image size.') # assign gt and sample proposals proposals = flat_proposals[inside_flags, :] pred_instances = InstanceData(priors=proposals) if stage == 'init': assigner = self.init_assigner pos_weight = self.train_cfg.init.pos_weight else: assigner = self.refine_assigner pos_weight = self.train_cfg.refine.pos_weight assign_result = assigner.assign(pred_instances, gt_instances, gt_instances_ignore) sampling_result = self.sampler.sample(assign_result, pred_instances, gt_instances) gt_inds = assign_result.gt_inds num_valid_proposals = proposals.shape[0] bbox_gt = proposals.new_zeros([num_valid_proposals, 8]) pos_proposals = torch.zeros_like(proposals) proposals_weights = proposals.new_zeros(num_valid_proposals) labels = proposals.new_full((num_valid_proposals, ), self.num_classes, dtype=torch.long) label_weights = proposals.new_zeros( num_valid_proposals, dtype=torch.float) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds if len(pos_inds) > 0: bbox_gt[pos_inds, :] = sampling_result.pos_gt_bboxes pos_proposals[pos_inds, :] = proposals[pos_inds, :] proposals_weights[pos_inds] = 1.0 labels[pos_inds] = sampling_result.pos_gt_labels if pos_weight <= 0: label_weights[pos_inds] = 1.0 else: label_weights[pos_inds] = pos_weight if len(neg_inds) > 0: label_weights[neg_inds] = 1.0 # map up to original set of proposals if unmap_outputs: num_total_proposals = flat_proposals.size(0) labels = unmap( labels, num_total_proposals, inside_flags, fill=self.num_classes) # fill bg label label_weights = unmap(label_weights, num_total_proposals, inside_flags) bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags) pos_proposals = unmap(pos_proposals, num_total_proposals, inside_flags) proposals_weights = unmap(proposals_weights, num_total_proposals, inside_flags) gt_inds = unmap(gt_inds, num_total_proposals, inside_flags) return (labels, label_weights, bbox_gt, pos_proposals, proposals_weights, pos_inds, neg_inds, gt_inds, sampling_result)
[docs] def get_targets(self, proposals_list: List[Tensor], valid_flag_list: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None, stage: str = 'init', unmap_outputs: bool = True) -> tuple: """Compute corresponding GT box and classification targets for proposals. Args: proposals_list (list[Tensor]): Multi level points/bboxes of each image. valid_flag_list (list[Tensor]): Multi level valid flags of each image. batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. stage (str): 'init' or 'refine'. Generate target for init stage or refine stage. unmap_outputs (bool): Whether to map outputs back to the original set of anchors. Returns: tuple: - labels_list (list[Tensor]): Labels of each level. - label_weights_list (list[Tensor]): Label weights of each level. - bbox_gt_list (list[Tensor]): Ground truth bbox of each level. - proposals_list (list[Tensor]): Proposals(points/bboxes) of each level. - proposal_weights_list (list[Tensor]): Proposal weights of each level. - avg_factor (int): Average factor that is used to average the loss. When using sampling method, avg_factor is usually the sum of positive and negative priors. When using `PseudoSampler`, `avg_factor` is usually equal to the number of positive priors. """ assert stage in ['init', 'refine'] num_imgs = len(batch_img_metas) assert len(proposals_list) == len(valid_flag_list) == num_imgs # points number of multi levels num_level_proposals = [points.size(0) for points in proposals_list[0]] # concat all level points and flags to a single tensor for i in range(num_imgs): assert len(proposals_list[i]) == len(valid_flag_list[i]) proposals_list[i] = torch.cat(proposals_list[i]) valid_flag_list[i] = torch.cat(valid_flag_list[i]) if batch_gt_instances_ignore is None: batch_gt_instances_ignore = [None] * num_imgs (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds_list, neg_inds_list, all_gt_inds, sampling_results_list) = multi_apply( self._get_targets_single, proposals_list, valid_flag_list, batch_gt_instances, batch_gt_instances_ignore, stage=stage, unmap_outputs=unmap_outputs) if stage == 'init': # no valid points if any([labels is None for labels in all_labels]): return None # sampled points of all images num_total_pos = sum( [max(inds.numel(), 1) for inds in pos_inds_list]) num_total_neg = sum( [max(inds.numel(), 1) for inds in neg_inds_list]) # avg_refactor = sum( # [results.avg_factor for results in sampling_results_list]) labels_list = images_to_levels(all_labels, num_level_proposals) label_weights_list = images_to_levels(all_label_weights, num_level_proposals) bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals) proposals_list = images_to_levels(all_proposals, num_level_proposals) proposal_weights_list = images_to_levels(all_proposal_weights, num_level_proposals) res = (labels_list, label_weights_list, bbox_gt_list, proposals_list, proposal_weights_list, num_total_pos, num_total_neg) else: pos_inds = [] pos_gt_index = [] for i, single_labels in enumerate(all_labels): pos_mask = (0 <= single_labels) & ( single_labels < self.num_classes) pos_inds.append(pos_mask.nonzero(as_tuple=False).view(-1)) pos_gt_index.append( all_gt_inds[i][pos_mask.nonzero(as_tuple=False).view(-1)]) res = (all_labels, all_label_weights, all_bbox_gt, all_proposals, all_proposal_weights, pos_inds, pos_gt_index) return res
[docs] def loss_by_feat_single(self, cls_score: Tensor, pts_pred_init: Tensor, pts_pred_refine: Tensor, labels: Tensor, label_weights, bbox_gt_init: Tensor, bbox_weights_init: Tensor, bbox_gt_refine: Tensor, bbox_weights_refine: Tensor, stride: int, avg_factor_init: int, avg_factor_refine: int) -> Tuple[Tensor]: """Calculate the loss of a single scale level based on the features extracted by the detection head. Args: cls_score (Tensor): Box scores for each scale level Has shape (N, num_classes, h_i, w_i). pts_pred_init (Tensor): Points of shape (batch_size, h_i * w_i, num_points * 2). pts_pred_refine (Tensor): Points refined of shape (batch_size, h_i * w_i, num_points * 2). labels (Tensor): Ground truth class indices with shape (batch_size, h_i * w_i). label_weights (Tensor): Label weights of shape (batch_size, h_i * w_i). bbox_gt_init (Tensor): BBox regression targets in the init stage of shape (batch_size, h_i * w_i, 8). bbox_weights_init (Tensor): BBox regression loss weights in the init stage of shape (batch_size, h_i * w_i, 8). bbox_gt_refine (Tensor): BBox regression targets in the refine stage of shape (batch_size, h_i * w_i, 8). bbox_weights_refine (Tensor): BBox regression loss weights in the refine stage of shape (batch_size, h_i * w_i, 8). stride (int): Point stride. avg_factor_init (int): Average factor that is used to average the loss in the init stage. avg_factor_refine (int): Average factor that is used to average the loss in the refine stage. Returns: Tuple[Tensor]: loss components. """ # classification loss labels = labels.reshape(-1) label_weights = label_weights.reshape(-1) cls_score = cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels) cls_score = cls_score.contiguous() loss_cls = self.loss_cls( cls_score, labels, label_weights, avg_factor=avg_factor_refine) # init loss bbox_gt_init = bbox_gt_init.reshape(-1, 8) pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) bbox_weights_init = bbox_weights_init.reshape(-1) pos_ind_init = (bbox_weights_init > 0).nonzero(as_tuple=False).reshape(-1) pos_bbox_gt_init = bbox_gt_init[pos_ind_init] pos_pts_pred_init = pts_pred_init[pos_ind_init] pos_bbox_weights_init = bbox_weights_init[pos_ind_init] normalize_term = self.point_base_scale * stride loss_pts_init = self.loss_bbox_init( pos_pts_pred_init / normalize_term, pos_bbox_gt_init / normalize_term, pos_bbox_weights_init, avg_factor=avg_factor_init) # refine loss bbox_gt_refine = bbox_gt_refine.reshape(-1, 8) pts_pred_refine = pts_pred_refine.reshape(-1, 2 * self.num_points) bbox_weights_refine = bbox_weights_refine.reshape(-1) pos_ind_refine = (bbox_weights_refine > 0).nonzero(as_tuple=False).reshape(-1) pos_bbox_gt_refine = bbox_gt_refine[pos_ind_refine] pos_pts_pred_refine = pts_pred_refine[pos_ind_refine] pos_bbox_weights_refine = bbox_weights_refine[pos_ind_refine] loss_pts_refine = self.loss_bbox_refine( pos_pts_pred_refine / normalize_term, pos_bbox_gt_refine / normalize_term, pos_bbox_weights_refine, avg_factor=avg_factor_refine) return loss_cls, loss_pts_init, loss_pts_refine
[docs] def loss_by_feat( self, cls_scores: List[Tensor], pts_preds_init: List[Tensor], pts_preds_refine: List[Tensor], base_feat: List[Tensor], batch_gt_instances: InstanceList, batch_img_metas: List[dict], batch_gt_instances_ignore: OptInstanceList = None ) -> Dict[str, Tensor]: """Calculate the loss based on the features extracted by the detection head. Args: cls_scores (list[Tensor]): Box scores for each scale level, each is a 4D-tensor, of shape (batch_size, num_classes, h, w). pts_preds_init (list[Tensor]): Points for each scale level, each is a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). pts_preds_refine (list[Tensor]): Points refined for each scale level, each is a 3D-tensor, of shape (batch_size, h_i * w_i, num_points * 2). batch_gt_instances (list[:obj:`InstanceData`]): Batch of gt_instance. It usually includes ``bboxes`` and ``labels`` attributes. batch_img_metas (list[dict]): Meta information of each image, e.g., image size, scaling factor, etc. batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): Batch of gt_instances_ignore. It includes ``bboxes`` attribute data that is ignored during training and testing. Defaults to None. Returns: dict[str, Tensor]: A dictionary of loss components. """ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] device = cls_scores[0].device # target for initial stage center_list, valid_flag_list = self.get_points(featmap_sizes, batch_img_metas, device) pts_coordinate_preds_init = self.offset_to_pts(center_list, pts_preds_init) num_proposals_each_level = [(featmap.size(-1) * featmap.size(-2)) for featmap in cls_scores] num_level = len(featmap_sizes) assert num_level == len(pts_coordinate_preds_init) if self.train_cfg.init.assigner['type'] == 'ConvexAssigner': # Assign target for center list candidate_list = center_list else: raise NotImplementedError cls_reg_targets_init = self.get_targets( proposals_list=candidate_list, valid_flag_list=valid_flag_list, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore, stage='init') (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init, num_total_pos_init, num_total_neg_init) = cls_reg_targets_init # target for refinement stage center_list, valid_flag_list = self.get_points(featmap_sizes, batch_img_metas, device) pts_coordinate_preds_refine = self.offset_to_pts( center_list, pts_preds_refine) refine_points_features, = multi_apply(self.get_adaptive_points_feature, base_feat, pts_coordinate_preds_refine, self.point_strides) features_pts_refine = levels_to_images(refine_points_features) features_pts_refine = [ item.reshape(-1, self.num_points, item.shape[-1]) for item in features_pts_refine ] bbox_list = [] for i_img, center in enumerate(center_list): bbox = [] for i_lvl in range(len(pts_preds_refine)): points_preds_init_ = pts_preds_init[i_lvl].detach() points_preds_init_ = points_preds_init_.view( points_preds_init_.shape[0], -1, *points_preds_init_.shape[2:]) points_shift = points_preds_init_.permute( 0, 2, 3, 1) * self.point_strides[i_lvl] points_center = center[i_lvl][:, :2].repeat(1, self.num_points) bbox.append( points_center + points_shift[i_img].reshape(-1, 2 * self.num_points)) bbox_list.append(bbox) cls_reg_targets_refine = self.get_targets( proposals_list=bbox_list, valid_flag_list=valid_flag_list, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore, stage='refine') (labels_list, label_weights_list, bbox_gt_list_refine, candidate_list_refine, bbox_weights_list_refine, pos_inds_list_refine, pos_gt_index_list_refine) = cls_reg_targets_refine cls_scores = levels_to_images(cls_scores) cls_scores = [ item.reshape(-1, self.cls_out_channels) for item in cls_scores ] pts_coordinate_preds_init_img = levels_to_images( pts_coordinate_preds_init, flatten=True) pts_coordinate_preds_init_img = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_init_img ] pts_coordinate_preds_refine_img = levels_to_images( pts_coordinate_preds_refine, flatten=True) pts_coordinate_preds_refine_img = [ item.reshape(-1, 2 * self.num_points) for item in pts_coordinate_preds_refine_img ] with torch.no_grad(): quality_assess_list, = multi_apply( self.pointsets_quality_assessment, features_pts_refine, cls_scores, pts_coordinate_preds_init_img, pts_coordinate_preds_refine_img, labels_list, bbox_gt_list_refine, label_weights_list, bbox_weights_list_refine, pos_inds_list_refine) labels_list, label_weights_list, bbox_weights_list_refine, \ num_pos, pos_normalize_term = multi_apply( self.dynamic_pointset_samples_selection, quality_assess_list, labels_list, label_weights_list, bbox_weights_list_refine, pos_inds_list_refine, pos_gt_index_list_refine, num_proposals_each_level=num_proposals_each_level, num_level=num_level ) num_pos = sum(num_pos) # convert all tensor list to a flatten tensor cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) pts_preds_refine = torch.cat(pts_coordinate_preds_refine_img, 0).view( -1, pts_coordinate_preds_refine_img[0].size(-1)) labels = torch.cat(labels_list, 0).view(-1) labels_weight = torch.cat(label_weights_list, 0).view(-1) bbox_gt_refine = torch.cat(bbox_gt_list_refine, 0).view(-1, bbox_gt_list_refine[0].size(-1)) bbox_weights_refine = torch.cat(bbox_weights_list_refine, 0).view(-1) pos_normalize_term = torch.cat(pos_normalize_term, 0).reshape(-1) pos_inds_flatten = ((0 <= labels) & (labels < self.num_classes)).nonzero( as_tuple=False).reshape(-1) assert len(pos_normalize_term) == len(pos_inds_flatten) if num_pos: losses_cls = self.loss_cls( cls_scores, labels, labels_weight, avg_factor=num_pos) pos_pts_pred_refine = pts_preds_refine[pos_inds_flatten] pos_bbox_gt_refine = bbox_gt_refine[pos_inds_flatten] pos_bbox_weights_refine = bbox_weights_refine[pos_inds_flatten] losses_pts_refine = self.loss_bbox_refine( pos_pts_pred_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_weights_refine) loss_border_refine = self.loss_spatial_refine( pos_pts_pred_refine.reshape(-1, 2 * self.num_points) / pos_normalize_term.reshape(-1, 1), pos_bbox_gt_refine / pos_normalize_term.reshape(-1, 1), pos_bbox_weights_refine, avg_factor=None) else: losses_cls = cls_scores.sum() * 0 losses_pts_refine = pts_preds_refine.sum() * 0 loss_border_refine = pts_preds_refine.sum() * 0 losses_pts_init, loss_border_init = multi_apply( self.init_loss_single, pts_coordinate_preds_init, bbox_gt_list_init, bbox_weights_list_init, self.point_strides) loss_dict_all = { 'loss_cls': losses_cls, 'loss_pts_init': losses_pts_init, 'loss_pts_refine': losses_pts_refine, 'loss_spatial_init': loss_border_init, 'loss_spatial_refine': loss_border_refine } return loss_dict_all
[docs] def sampling_points(self, polygons: Tensor, points_num: int, device: str) -> Tensor: """Sample edge points for polygon. Args: polygons (Tensor): polygons with shape (N, 8) points_num (int): number of sampling points for each polygon edge. 10 by default. device (str): The device the tensor will be put on. Defaults to ``cuda``. Returns: sampling_points (Tensor): sampling points with shape (N, points_num*4, 2) """ polygons_xs, polygons_ys = polygons[:, 0::2], polygons[:, 1::2] ratio = torch.linspace(0, 1, points_num).to(device).repeat( polygons.shape[0], 1) edge_pts_x = [] edge_pts_y = [] for i in range(4): if i < 3: points_x = ratio * polygons_xs[:, i + 1:i + 2] + ( 1 - ratio) * polygons_xs[:, i:i + 1] points_y = ratio * polygons_ys[:, i + 1:i + 2] + ( 1 - ratio) * polygons_ys[:, i:i + 1] else: points_x = ratio * polygons_xs[:, 0].unsqueeze(1) + ( 1 - ratio) * polygons_xs[:, i].unsqueeze(1) points_y = ratio * polygons_ys[:, 0].unsqueeze(1) + ( 1 - ratio) * polygons_ys[:, i].unsqueeze(1) edge_pts_x.append(points_x) edge_pts_y.append(points_y) sampling_points_x = torch.cat(edge_pts_x, dim=1).unsqueeze(dim=2) sampling_points_y = torch.cat(edge_pts_y, dim=1).unsqueeze(dim=2) sampling_points = torch.cat([sampling_points_x, sampling_points_y], dim=2) return sampling_points
[docs] def get_adaptive_points_feature(self, features: Tensor, pt_locations: Tensor, stride: int) -> Tensor: """Get the points features from the locations of predicted points. Args: features (Tensor): base feature with shape (B,C,W,H) pt_locations (Tensor): locations of points in each point set with shape (B, N_points_set(number of point set), N_points(number of points in each point set) *2) stride (int): points strdie Returns: Tensor: sampling features with (B, C, N_points_set, N_points) """ h = features.shape[2] * stride w = features.shape[3] * stride pt_locations = pt_locations.view(pt_locations.shape[0], pt_locations.shape[1], -1, 2).clone() pt_locations[..., 0] = pt_locations[..., 0] / (w / 2.) - 1 pt_locations[..., 1] = pt_locations[..., 1] / (h / 2.) - 1 batch_size = features.size(0) sampled_features = torch.zeros([ pt_locations.shape[0], features.size(1), pt_locations.size(1), pt_locations.size(2) ]).to(pt_locations.device) for i in range(batch_size): feature = nn.functional.grid_sample(features[i:i + 1], pt_locations[i:i + 1])[0] sampled_features[i] = feature return sampled_features,
[docs] def feature_cosine_similarity(self, points_features: Tensor) -> Tensor: """Compute the points features similarity for points-wise correlation. Args: points_features (Tensor): sampling point feature with shape (N_pointsets, N_points, C) Returns: max_correlation (Tensor): max feature similarity in each point set with shape (N_points_set, N_points, C) """ mean_points_feats = torch.mean(points_features, dim=1, keepdim=True) norm_pts_feats = torch.norm( points_features, p=2, dim=2).unsqueeze(dim=2).clamp(min=1e-2) norm_mean_pts_feats = torch.norm( mean_points_feats, p=2, dim=2).unsqueeze(dim=2).clamp(min=1e-2) unity_points_features = points_features / norm_pts_feats unity_mean_points_feats = mean_points_feats / norm_mean_pts_feats feats_similarity = 1.0 - F.cosine_similarity( unity_points_features, unity_mean_points_feats, dim=2, eps=1e-6) max_correlation, _ = torch.max(feats_similarity, dim=1) return max_correlation
[docs] def pointsets_quality_assessment(self, pts_features: Tensor, cls_score: Tensor, pts_pred_init: Tensor, pts_pred_refine: Tensor, label: Tensor, bbox_gt: Tensor, label_weight: Tensor, bbox_weight: Tensor, pos_inds: Tensor) -> Tensor: """Assess the quality of each point set from the classification, localization, orientation, and point-wise correlation based on the assigned point sets samples. Args: pts_features (Tensor): points features with shape (N, 9, C) cls_score (Tensor): classification scores with shape (N, class_num) pts_pred_init (Tensor): initial point sets prediction with shape (N, 9*2) pts_pred_refine (Tensor): refined point sets prediction with shape (N, 9*2) label (Tensor): gt label with shape (N) bbox_gt(Tensor): gt bbox of polygon with shape (N, 8) label_weight (Tensor): label weight with shape (N) bbox_weight (Tensor): box weight with shape (N) pos_inds (Tensor): the inds of positive point set samples Returns: qua (Tensor) : weighted quality values for positive point set samples. """ device = cls_score.device # avoid no positive samplers if pos_inds.shape[0] == 0: pos_scores = cls_score pos_pts_pred_init = pts_pred_init pos_pts_pred_refine = pts_pred_refine pos_pts_refine_features = pts_features pos_bbox_gt = bbox_gt pos_label = label pos_label_weight = label_weight pos_bbox_weight = bbox_weight else: pos_scores = cls_score[pos_inds] pos_pts_pred_init = pts_pred_init[pos_inds] pos_pts_pred_refine = pts_pred_refine[pos_inds] pos_pts_refine_features = pts_features[pos_inds] pos_bbox_gt = bbox_gt[pos_inds] pos_label = label[pos_inds] pos_label_weight = label_weight[pos_inds] pos_bbox_weight = bbox_weight[pos_inds] # quality of point-wise correlation qua_poc = self.poc_qua_weight * self.feature_cosine_similarity( pos_pts_refine_features) qua_cls = self.loss_cls( pos_scores, pos_label, pos_label_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') polygons_pred_init = min_area_polygons(pos_pts_pred_init) polygons_pred_refine = min_area_polygons(pos_pts_pred_refine) sampling_pts_pred_init = self.sampling_points( polygons_pred_init, 10, device=device) sampling_pts_pred_refine = self.sampling_points( polygons_pred_refine, 10, device=device) sampling_pts_gt = self.sampling_points(pos_bbox_gt, 10, device=device) # quality of orientation qua_ori_init = self.ori_qua_weight * ChamferDistance2D( sampling_pts_gt, sampling_pts_pred_init) qua_ori_refine = self.ori_qua_weight * ChamferDistance2D( sampling_pts_gt, sampling_pts_pred_refine) # quality of localization qua_loc_init = self.loss_bbox_refine( pos_pts_pred_init, pos_bbox_gt, pos_bbox_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') qua_loc_refine = self.loss_bbox_refine( pos_pts_pred_refine, pos_bbox_gt, pos_bbox_weight, avg_factor=self.loss_cls.loss_weight, reduction_override='none') # quality of classification qua_cls = qua_cls.sum(-1) # weighted inti-stage and refine-stage qua = qua_cls + self.init_qua_weight * ( qua_loc_init + qua_ori_init) + (1.0 - self.init_qua_weight) * ( qua_loc_refine + qua_ori_refine) + qua_poc return qua,
[docs] def dynamic_pointset_samples_selection( self, quality: Tensor, label: Tensor, label_weight: Tensor, bbox_weight: Tensor, pos_inds: Tensor, pos_gt_inds: Tensor, num_proposals_each_level: Optional[List[int]] = None, num_level: Optional[int] = None) -> tuple: """The dynamic top k selection of point set samples based on the quality assessment values. Args: quality (Tensor): the quality values of positive point set samples label (Tensor): gt label with shape (N) label_weight (Tensor): label weight with shape (N) bbox_weight (Tensor): box weight with shape (N) pos_inds (Tensor): the inds of positive point set samples pos_gt_inds (Tensor): the inds of positive ground truth num_proposals_each_level (list[int]): proposals number of each level num_level (int): the level number Returns: tuple: - label: gt label with shape (N) - label_weight: label weight with shape (N) - bbox_weight: box weight with shape (N) - num_pos (int): the number of selected positive point samples with high-quality - pos_normalize_term (Tensor): the corresponding positive normalize term """ if len(pos_inds) == 0: return label, label_weight, bbox_weight, 0, Tensor( []).type_as(bbox_weight) num_gt = pos_gt_inds.max() num_proposals_each_level_ = num_proposals_each_level.copy() num_proposals_each_level_.insert(0, 0) inds_level_interval = np.cumsum(num_proposals_each_level_) pos_level_mask = [] for i in range(num_level): mask = (pos_inds >= inds_level_interval[i]) & ( pos_inds < inds_level_interval[i + 1]) pos_level_mask.append(mask) pos_inds_after_select = [] ignore_inds_after_select = [] for gt_ind in range(num_gt): pos_inds_select = [] pos_loss_select = [] gt_mask = pos_gt_inds == (gt_ind + 1) for level in range(num_level): level_mask = pos_level_mask[level] level_gt_mask = level_mask & gt_mask value, topk_inds = quality[level_gt_mask].topk( min(level_gt_mask.sum(), 6), largest=False) pos_inds_select.append(pos_inds[level_gt_mask][topk_inds]) pos_loss_select.append(value) pos_inds_select = torch.cat(pos_inds_select) pos_loss_select = torch.cat(pos_loss_select) if len(pos_inds_select) < 2: pos_inds_after_select.append(pos_inds_select) ignore_inds_after_select.append(pos_inds_select.new_tensor([])) else: pos_loss_select, sort_inds = pos_loss_select.sort( ) # small to large pos_inds_select = pos_inds_select[sort_inds] # dynamic top k topk = math.ceil(pos_loss_select.shape[0] * self.top_ratio) pos_inds_select_topk = pos_inds_select[:topk] pos_inds_after_select.append(pos_inds_select_topk) ignore_inds_after_select.append( pos_inds_select_topk.new_tensor([])) pos_inds_after_select = torch.cat(pos_inds_after_select) ignore_inds_after_select = torch.cat(ignore_inds_after_select) reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_select).all(1) reassign_ids = pos_inds[reassign_mask] label[reassign_ids] = self.num_classes label_weight[ignore_inds_after_select] = 0 bbox_weight[reassign_ids] = 0 num_pos = len(pos_inds_after_select) pos_level_mask_after_select = [] for i in range(num_level): mask = (pos_inds_after_select >= inds_level_interval[i]) & ( pos_inds_after_select < inds_level_interval[i + 1]) pos_level_mask_after_select.append(mask) pos_level_mask_after_select = torch.stack(pos_level_mask_after_select, 0).type_as(label) pos_normalize_term = pos_level_mask_after_select * ( self.point_base_scale * torch.as_tensor(self.point_strides).type_as(label)).reshape(-1, 1) pos_normalize_term = pos_normalize_term[pos_normalize_term > 0].type_as(bbox_weight) assert len(pos_normalize_term) == len(pos_inds_after_select) return label, label_weight, bbox_weight, num_pos, pos_normalize_term
[docs] def init_loss_single(self, pts_pred_init: Tensor, bbox_gt_init: Tensor, bbox_weights_init: Tensor, stride: int) -> Tuple[Tensor, Tensor]: """Single initial stage loss function. Args: pts_pred_init (Tensor): Initial point sets prediction with shape (N, 9*2) bbox_gt_init (Tensor): BBox regression targets in the init stage of shape (batch_size, h_i * w_i, 8). bbox_weights_init (Tensor): BBox regression loss weights in the init stage of shape (batch_size, h_i * w_i, 8). stride (int): Point stride. Returns: tuple: - loss_pts_init (Tensor): Initial bbox loss. - loss_border_init (Tensor): Initial spatial border loss. """ normalize_term = self.point_base_scale * stride bbox_gt_init = bbox_gt_init.reshape(-1, 8) bbox_weights_init = bbox_weights_init.reshape(-1) pts_pred_init = pts_pred_init.reshape(-1, 2 * self.num_points) pos_ind_init = (bbox_weights_init > 0).nonzero(as_tuple=False).reshape(-1) pts_pred_init_norm = pts_pred_init[pos_ind_init] bbox_gt_init_norm = bbox_gt_init[pos_ind_init] bbox_weights_pos_init = bbox_weights_init[pos_ind_init] loss_pts_init = self.loss_bbox_init( pts_pred_init_norm / normalize_term, bbox_gt_init_norm / normalize_term, bbox_weights_pos_init) loss_border_init = self.loss_spatial_init( pts_pred_init_norm.reshape(-1, 2 * self.num_points) / normalize_term, bbox_gt_init_norm / normalize_term, bbox_weights_pos_init, avg_factor=None) return loss_pts_init, loss_border_init