Shortcuts

Source code for mmrotate.datasets.dior

# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import os.path as osp
import xml.etree.ElementTree as ET
from mmengine.dataset import BaseDataset
from mmengine.fileio import get, get_local_path, list_from_file
from typing import List, Optional, Union

from mmrotate.registry import DATASETS


[docs] @DATASETS.register_module() class DIORDataset(BaseDataset): """DIOR dataset for detection. Args: ann_subdir (str): Subdir where annotations are. Defaults to 'Annotations/Oriented Bounding Boxes/'. file_client_args (dict): Arguments to instantiate the corresponding backend in mmdet <= 3.0.0rc6. Defaults to None. backend_args (dict, optional): Arguments to instantiate the corresponding backend. Defaults to None. ann_type (str): Choose obb or hbb as ground truth. Defaults to `obb`. """ METAINFO = { 'classes': ('airplane', 'airport', 'baseballfield', 'basketballcourt', 'bridge', 'chimney', 'expressway-service-area', 'expressway-toll-station', 'dam', 'golffield', 'groundtrackfield', 'harbor', 'overpass', 'ship', 'stadium', 'storagetank', 'tenniscourt', 'trainstation', 'vehicle', 'windmill'), # palette is a list of color tuples, which is used for visualization. 'palette': [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157)] } def __init__(self, ann_subdir: str = 'Annotations/Oriented Bounding Boxes/', file_client_args: dict = None, backend_args: dict = None, ann_type: str = 'obb', **kwargs) -> None: assert ann_type in ['hbb', 'obb'] self.ann_type = ann_type self.ann_subdir = ann_subdir self.backend_args = backend_args if file_client_args is not None: raise RuntimeError( 'The `file_client_args` is deprecated, ' 'please use `backend_args` instead, please refer to' 'https://github.com/vbti-development/onedl-mmdetection/blob/main/configs/_base_/datasets/coco_detection.py' # noqa: E501 ) super().__init__(**kwargs)
[docs] def load_data_list(self) -> List[dict]: """Load annotation from XML style ann_file. Returns: list[dict]: Annotation info from XML file. """ assert self._metainfo.get('classes', None) is not None, \ 'classes in `DIORDataset` can not be None.' self.cat2label = { cat: i for i, cat in enumerate(self.metainfo['classes']) } data_list = [] img_ids = list_from_file(self.ann_file, backend_args=self.backend_args) for img_id in img_ids: file_name = f'{img_id}.jpg' xml_path = osp.join(self.data_root, self.ann_subdir, f'{img_id}.xml') raw_img_info = {} raw_img_info['img_id'] = img_id raw_img_info['file_name'] = file_name raw_img_info['xml_path'] = xml_path parsed_data_info = self.parse_data_info(raw_img_info) data_list.append(parsed_data_info) return data_list
@property def bbox_min_size(self) -> Optional[str]: """Return the minimum size of bounding boxes in the images.""" if self.filter_cfg is not None: return self.filter_cfg.get('bbox_min_size', None) else: return None
[docs] def parse_data_info(self, img_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. Args: img_info (dict): Raw image information, usually it includes `img_id`, `file_name`, and `xml_path`. Returns: Union[dict, List[dict]]: Parsed annotation. """ data_info = {} img_path = osp.join(self.data_prefix['img_path'], img_info['file_name']) data_info['img_path'] = img_path data_info['img_id'] = img_info['img_id'] data_info['xml_path'] = img_info['xml_path'] # deal with xml file with get_local_path( img_info['xml_path'], backend_args=self.backend_args) as local_path: raw_ann_info = ET.parse(local_path) root = raw_ann_info.getroot() size = root.find('size') if size is not None: width = int(size.find('width').text) height = int(size.find('height').text) else: img_bytes = get(img_path, backend_args=self.backend_args) img = mmcv.imfrombytes(img_bytes, backend='cv2') width, height = img.shape[:2] del img, img_bytes data_info['height'] = height data_info['width'] = width instances = [] for obj in root.findall('object'): instance = {} cls = obj.find('name').text.lower() label = self.cat2label[cls] if label is None: continue if self.ann_type == 'obb': bnd_box = obj.find('robndbox') polygon = np.array([ float(bnd_box.find('x_left_top').text), float(bnd_box.find('y_left_top').text), float(bnd_box.find('x_right_top').text), float(bnd_box.find('y_right_top').text), float(bnd_box.find('x_right_bottom').text), float(bnd_box.find('y_right_bottom').text), float(bnd_box.find('x_left_bottom').text), float(bnd_box.find('y_left_bottom').text), ]).astype(np.float32) else: bnd_box = obj.find('bndbox') if bnd_box is None: continue polygon = np.array([ float(bnd_box.find('xmin').text), float(bnd_box.find('ymin').text), float(bnd_box.find('xmax').text), float(bnd_box.find('ymin').text), float(bnd_box.find('xmax').text), float(bnd_box.find('ymax').text), float(bnd_box.find('xmin').text), float(bnd_box.find('ymax').text) ]).astype(np.float32) ignore = False if self.bbox_min_size is not None: assert not self.test_mode if width < self.bbox_min_size or height < self.bbox_min_size: ignore = True if ignore: instance['ignore_flag'] = 1 else: instance['ignore_flag'] = 0 instance['bbox'] = polygon instance['bbox_label'] = label instances.append(instance) data_info['instances'] = instances return data_info
[docs] def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Returns: List[dict]: Filtered results. """ if self.test_mode: return self.data_list filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ if self.filter_cfg is not None else False min_size = self.filter_cfg.get('min_size', 0) \ if self.filter_cfg is not None else 0 valid_data_infos = [] for i, data_info in enumerate(self.data_list): width = data_info['width'] height = data_info['height'] if filter_empty_gt and len(data_info['instances']) == 0: continue if min(width, height) >= min_size: valid_data_infos.append(data_info) return valid_data_infos
[docs] def get_cat_ids(self, idx: int) -> List[int]: """Get DIOR category ids by index. Args: idx (int): Index of data. Returns: List[int]: All categories in the image of specified index. """ instances = self.get_data_info(idx)['instances'] return [instance['bbox_label'] for instance in instances]