Source code for mmrotate.apis.inference
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
try:
from mmcv.ops import RoIPool
except ImportError: # noqa: E722
class RoIPool:
def __init__(self, *args, **kwargs):
raise RuntimeError('RoIPool from mmcv.ops is not available. '
'Please install onedl-mmcv with ops support.')
from mmcv.transforms import Compose
from mmdet.structures import DetDataSample, SampleList
from torch import nn
from typing import List, Optional, Sequence, Union
from mmrotate.utils import (get_multiscale_patch, get_test_pipeline_cfg,
merge_results_by_nms, slide_window)
ImagesType = Union[str, np.ndarray, Sequence[str], Sequence[np.ndarray]]
[docs]
def inference_detector_by_patches(
model: nn.Module,
imgs: ImagesType,
sizes: List[int],
steps: List[int],
ratios: List[float],
nms_cfg: dict,
test_pipeline: Optional[Compose] = None,
bs: int = 1) -> Union[DetDataSample, SampleList]:
"""Inference patches with the detector.
Split huge image(s) into patches and inference them with the detector.
Finally, merge patch results on one huge image by nms.
Args:
model (nn.Module): The loaded detector.
imgs (str, ndarray, Sequence[str/ndarray]): Either image files or
loaded images.
sizes (list[int]): The sizes of patches.
steps (list[int]): The steps between two patches.
ratios (list[float]): Image resizing ratios for multi-scale detecting.
nms_cfg (dict): nms config.
bs (int): Batch size, must greater than or equal to 1.
Returns:
list[np.ndarray]: Detection results.
"""
assert bs >= 1, 'The batch size must greater than or equal to 1'
if isinstance(imgs, (list, tuple)):
is_batch = True
else:
imgs = [imgs]
is_batch = False
cfg = model.cfg
if test_pipeline is None:
cfg = cfg.copy()
test_pipeline = get_test_pipeline_cfg(cfg)
new_test_pipeline = []
for pipeline in test_pipeline:
if pipeline['type'] != 'LoadAnnotations' and pipeline[
'type'] != 'LoadPanopticAnnotations':
new_test_pipeline.append(pipeline)
# set loading pipeline type
test_pipeline[0].type = 'LoadPatchFromNDArray'
test_pipeline = Compose(new_test_pipeline)
if model.data_preprocessor.device.type == 'cpu':
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
result_list = []
for img in imgs:
if not isinstance(img, np.ndarray):
img = mmcv.imread(img)
height, width = img.shape[:2]
sizes, steps = get_multiscale_patch(sizes, steps, ratios)
patches = slide_window(width, height, sizes, steps)
results = []
start = 0
while True:
# prepare patch data
patch_datas = dict(inputs=[], data_samples=[])
end = min(start + bs, len(patches))
for patch in patches[start:end]:
data_ = dict(
img=img, img_id=0, img_path=None, patch=patch.tolist())
data = test_pipeline(data_)
patch_datas['inputs'].append(data['inputs'])
patch_datas['data_samples'].append(data['data_samples'])
# forward the model
with torch.no_grad():
results.extend(model.test_step(patch_datas))
if end >= len(patches):
break
start += bs
result_list.append(
merge_results_by_nms(
results,
patches[:, :2],
img_shape=(width, height),
nms_cfg=nms_cfg,
))
if is_batch:
return result_list
else:
return result_list[0]