Shortcuts

Source code for mmrotate.utils.patch.merge_results

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np

try:
    from mmcv.ops import batched_nms
except ImportError:  # noqa: E722

    def batched_nms(*args, **kwargs):
        raise RuntimeError('batched_nms from mmcv.ops is not available. '
                           'Please install onedl-mmcv with ops support.')


from mmdet.structures import DetDataSample, SampleList
from mmengine.structures import InstanceData
from torch import Tensor
from typing import Sequence, Tuple


def translate_bboxes(bboxes: Tensor, offset: Sequence[int]):
    """Translate bboxes w.r.t offset.

    The bboxes can be three types:

    - HorizontalBoxes: The boxes should be a tensor with shape of (n, 4),
      which means (x, y, x, y).
    - RotatedBoxes: The boxes should be a tensor with shape of (n, 5),
      which means (x, y, w, h, t).
    - QuariBoxes: The boxes should be a tensor with shape of (n, 8),
      which means (x1, y1, x2, y2, x3, y3, x4, y4).

    Args:
        bboxes (Tensor): The bboxes need to be translated. Its shape can
            be (n, 4), (n, 5), or (n, 8).
        offset (Sequence[int]): The translation offsets with shape of (2, ).

    Returns:
        Tensor: Translated bboxes.
    """
    if bboxes.shape[1] == 4:
        offset = bboxes.new_tensor(offset).tile(2)
        bboxes = bboxes + offset
    elif bboxes.shape[1] == 5:
        offset = bboxes.new_tensor(offset)
        bboxes[:, :2] = bboxes[:, :2] + offset
    elif bboxes.shape[1] == 8:
        offset = bboxes.new_tensor(offset).tile(4)
        bboxes = bboxes + offset
    else:
        raise TypeError('Require the shape of `bboxes` to be (n, 5), (n, 6)'
                        'or (n, 8), but get `bboxes` with shape being '
                        f'{bboxes.shape}.')
    return bboxes


def map_masks(masks: np.ndarray, offset: Sequence[int],
              new_shape: Sequence[int]) -> np.ndarray:
    """Map masks to the huge image.

    Args:
        masks (:obj:`np.ndarray`): masks need to be mapped.
        offset (Sequence[int]): The offset to translate with shape of (2, ).
        new_shape (Sequence[int]): A tuple of the huge image's width
            and height.

    Returns:
        :obj:`np.ndarray`: Mapped masks.
    """
    # empty masks
    if not masks:
        return masks

    new_width, new_height = new_shape
    x_start, y_start = offset
    mapped = []
    for mask in masks:
        ori_height, ori_width = mask.shape[:2]

        x_end = x_start + ori_width
        if x_end > new_width:
            ori_width -= x_end - new_width
            x_end = new_width

        y_end = y_start + ori_height
        if y_end > new_height:
            ori_height -= y_end - new_height
            y_end = new_height

        extended_mask = np.zeros((new_height, new_width), dtype=bool)
        extended_mask[y_start:y_end,
                      x_start:x_end] = mask[:ori_height, :ori_width]
        mapped.append(extended_mask)
    return np.stack(mapped, axis=0)


[docs] def merge_results_by_nms(results: SampleList, offsets: np.ndarray, img_shape: Tuple[int, int], nms_cfg: dict) -> DetDataSample: """Merge patch results by nms. Args: results (List[:obj:`DetDataSample`]): A list of patches results. offsets (:obj:`np.ndarray`): Positions of the left top points of patches. img_shape (Tuple[int, int]): A tuple of the huge image's width and height. nms_cfg (dict): it should specify nms type and other parameters like `iou_threshold`. Returns: :obj:`DetDataSample`: merged results. """ assert len(results) == offsets.shape[0], 'The `results` should has the ' \ 'same length with `offsets`.' pred_instances = [] for result, offset in zip(results, offsets): pred_inst = result.pred_instances pred_inst.bboxes = translate_bboxes(pred_inst.bboxes, offset) if 'masks' in result: pred_inst.masks = map_masks(pred_inst.masks, offset, img_shape) pred_instances.append(pred_inst) instances = InstanceData.cat(pred_instances) _, keeps = batched_nms( boxes=instances.bboxes, scores=instances.scores, idxs=instances.labels, nms_cfg=nms_cfg) merged_instances = instances[keeps] merged_result = DetDataSample() # update items like gt_instances, ignore_instances merged_result.update(results[0]) merged_result.pred_instances = merged_instances return merged_result