Source code for mmrotate.models.utils.misc
# Copyright (c) OpenMMLab. All rights reserved.
import torch
try:
from mmcv.ops import convex_iou
except ImportError: # noqa: E722
def convex_iou(*args, **kwargs):
raise RuntimeError('convex_iou from mmcv.ops is not available. '
'Please install onedl-mmcv with ops support.')
[docs]
def points_center_pts(RPoints, y_first=True):
"""Compute center point of Pointsets.
Args:
RPoints (torch.Tensor): the lists of Pointsets, shape (k, 18).
y_first (bool, optional): if True, the sequence of Pointsets is (y,x).
Returns:
center_pts (torch.Tensor): the mean_center coordination of Pointsets,
shape (k, 18).
"""
RPoints = RPoints.reshape(-1, 9, 2)
if y_first:
pts_dy = RPoints[:, :, 0::2]
pts_dx = RPoints[:, :, 1::2]
else:
pts_dx = RPoints[:, :, 0::2]
pts_dy = RPoints[:, :, 1::2]
pts_dy_mean = pts_dy.mean(dim=1, keepdim=True).reshape(-1, 1)
pts_dx_mean = pts_dx.mean(dim=1, keepdim=True).reshape(-1, 1)
center_pts = torch.cat([pts_dx_mean, pts_dy_mean], dim=1).reshape(-1, 2)
return center_pts
[docs]
def convex_overlaps(gt_bboxes, points):
"""Compute overlaps between polygons and points.
Args:
gt_rbboxes (torch.Tensor): Groundtruth polygons, shape (k, 8).
points (torch.Tensor): Points to be assigned, shape(n, 18).
Returns:
overlaps (torch.Tensor): Overlaps between k gt_bboxes and n bboxes,
shape(k, n).
"""
overlaps = convex_iou(points, gt_bboxes)
overlaps = overlaps.transpose(1, 0)
return overlaps
[docs]
def levels_to_images(mlvl_tensor, flatten=False):
"""Concat multi-level feature maps by image.
[feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
(N, H*W , C), then split the element to N elements with shape (H*W, C), and
concat elements in same image of all level along first dimension.
Args:
mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
corresponding level. Each element is of shape (N, C, H, W)
flatten (bool, optional): if shape of mlvl_tensor is (N, C, H, W)
set False, if shape of mlvl_tensor is (N, H, W, C) set True.
Returns:
list[torch.Tensor]: A list that contains N tensors and each tensor is
of shape (num_elements, C)
"""
batch_size = mlvl_tensor[0].size(0)
batch_list = [[] for _ in range(batch_size)]
if flatten:
channels = mlvl_tensor[0].size(-1)
else:
channels = mlvl_tensor[0].size(1)
for t in mlvl_tensor:
if not flatten:
t = t.permute(0, 2, 3, 1)
t = t.view(batch_size, -1, channels).contiguous()
for img in range(batch_size):
batch_list[img].append(t[img])
return [torch.cat(item, 0) for item in batch_list]
[docs]
def get_num_level_anchors_inside(num_level_anchors, inside_flags):
"""Get number of every level anchors inside.
Args:
num_level_anchors (List[int]): List of number of every level's anchors.
inside_flags (torch.Tensor): Flags of all anchors.
Returns:
List[int]: List of number of inside anchors.
"""
split_inside_flags = torch.split(inside_flags, num_level_anchors)
num_level_anchors_inside = [
int(flags.sum()) for flags in split_inside_flags
]
return num_level_anchors_inside