Shortcuts

Source code for mmrotate.models.utils.misc

# Copyright (c) OpenMMLab. All rights reserved.
import torch

try:
    from mmcv.ops import convex_iou
except ImportError:  # noqa: E722

    def convex_iou(*args, **kwargs):
        raise RuntimeError('convex_iou from mmcv.ops is not available. '
                           'Please install onedl-mmcv with ops support.')


[docs] def points_center_pts(RPoints, y_first=True): """Compute center point of Pointsets. Args: RPoints (torch.Tensor): the lists of Pointsets, shape (k, 18). y_first (bool, optional): if True, the sequence of Pointsets is (y,x). Returns: center_pts (torch.Tensor): the mean_center coordination of Pointsets, shape (k, 18). """ RPoints = RPoints.reshape(-1, 9, 2) if y_first: pts_dy = RPoints[:, :, 0::2] pts_dx = RPoints[:, :, 1::2] else: pts_dx = RPoints[:, :, 0::2] pts_dy = RPoints[:, :, 1::2] pts_dy_mean = pts_dy.mean(dim=1, keepdim=True).reshape(-1, 1) pts_dx_mean = pts_dx.mean(dim=1, keepdim=True).reshape(-1, 1) center_pts = torch.cat([pts_dx_mean, pts_dy_mean], dim=1).reshape(-1, 2) return center_pts
[docs] def convex_overlaps(gt_bboxes, points): """Compute overlaps between polygons and points. Args: gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8). points (torch.Tensor): Points to be assigned, shape(n, 18). Returns: overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes, shape(k, n). """ overlaps = convex_iou(points, gt_bboxes) overlaps = overlaps.transpose(1, 0) return overlaps
[docs] def levels_to_images(mlvl_tensor, flatten=False): """Concat multi-level feature maps by image. [feature_level0, feature_level1...] -> [feature_image0, feature_image1...] Convert the shape of each element in mlvl_tensor from (N, C, H, W) to (N, H*W , C), then split the element to N elements with shape (H*W, C), and concat elements in same image of all level along first dimension. Args: mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from corresponding level. Each element is of shape (N, C, H, W) flatten (bool, optional): if shape of mlvl_tensor is (N, C, H, W) set False, if shape of mlvl_tensor is (N, H, W, C) set True. Returns: list[torch.Tensor]: A list that contains N tensors and each tensor is of shape (num_elements, C) """ batch_size = mlvl_tensor[0].size(0) batch_list = [[] for _ in range(batch_size)] if flatten: channels = mlvl_tensor[0].size(-1) else: channels = mlvl_tensor[0].size(1) for t in mlvl_tensor: if not flatten: t = t.permute(0, 2, 3, 1) t = t.view(batch_size, -1, channels).contiguous() for img in range(batch_size): batch_list[img].append(t[img]) return [torch.cat(item, 0) for item in batch_list]
[docs] def get_num_level_anchors_inside(num_level_anchors, inside_flags): """Get number of every level anchors inside. Args: num_level_anchors (List[int]): List of number of every level's anchors. inside_flags (torch.Tensor): Flags of all anchors. Returns: List[int]: List of number of inside anchors. """ split_inside_flags = torch.split(inside_flags, num_level_anchors) num_level_anchors_inside = [ int(flags.sum()) for flags in split_inside_flags ] return num_level_anchors_inside