22import copy
33import logging
44import numpy as np
5+ from typing import List , Optional , Union
56import torch
67
8+ from detectron2 .config import configurable
9+
710from . import detection_utils as utils
811from . import transforms as T
912
@@ -31,38 +34,81 @@ class DatasetMapper:
3134 3. Prepare data and annotations to Tensor and :class:`Instances`
3235 """
3336
34- def __init__ (self , cfg , is_train = True ):
35- self .augmentation = utils .build_augmentation (cfg , is_train )
36- if cfg .INPUT .CROP .ENABLED and is_train :
37- self .augmentation .insert (0 , T .RandomCrop (cfg .INPUT .CROP .TYPE , cfg .INPUT .CROP .SIZE ))
38- logging .getLogger (__name__ ).info (
39- "Cropping used in training: " + str (self .augmentation [0 ])
40- )
41- self .compute_tight_boxes = True
42- else :
43- self .compute_tight_boxes = False
37+ @configurable
38+ def __init__ (
39+ self ,
40+ is_train : bool ,
41+ * ,
42+ augmentations : List [Union [T .Augmentation , T .Transform ]],
43+ image_format : str ,
44+ use_instance_mask : bool = False ,
45+ use_keypoint : bool = False ,
46+ instance_mask_format : str = "polygon" ,
47+ keypoint_hflip_indices : Optional [np .ndarray ] = None ,
48+ precomputed_proposal_topk : Optional [int ] = None ,
49+ recompute_boxes : bool = False
50+ ):
51+ """
52+ NOTE: this interface is experimental.
4453
54+ Args:
55+ is_train: whether it's used in training or inference
56+ augmentations: a list of augmentations or deterministic transforms to apply
57+ image_format: an image format supported by :func:`detection_utils.read_image`.
58+ use_instance_mask: whether to process instance segmentation annotations, if available
59+ use_keypoint: whether to process keypoint annotations if available
60+ instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
61+ masks into this format.
62+ keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
63+ precomputed_proposal_topk: if given, will load pre-computed
64+ proposals from dataset_dict and keep the top k proposals for each image.
65+ recompute_boxes: whether to overwrite bounding box annotations
66+ by computing tight bounding boxes from instance mask annotations.
67+ """
68+ if recompute_boxes :
69+ assert use_instance_mask , "recompute_boxes requires instance masks"
4570 # fmt: off
46- self .img_format = cfg .INPUT .FORMAT
47- self .mask_on = cfg .MODEL .MASK_ON
48- self .mask_format = cfg .INPUT .MASK_FORMAT
49- self .keypoint_on = cfg .MODEL .KEYPOINT_ON
50- self .load_proposals = cfg .MODEL .LOAD_PROPOSALS
71+ self .is_train = is_train
72+ self .augmentations = augmentations
73+ self .image_format = image_format
74+ self .use_instance_mask = use_instance_mask
75+ self .instance_mask_format = instance_mask_format
76+ self .use_keypoint = use_keypoint
77+ self .keypoint_hflip_indices = keypoint_hflip_indices
78+ self .proposal_topk = precomputed_proposal_topk
79+ self .recompute_boxes = recompute_boxes
5180 # fmt: on
52- if self .keypoint_on and is_train :
53- # Flip only makes sense in training
54- self .keypoint_hflip_indices = utils .create_keypoint_hflip_indices (cfg .DATASETS .TRAIN )
55- else :
56- self .keypoint_hflip_indices = None
81+ logger = logging .getLogger (__name__ )
82+ logger .info ("Augmentations used in training: " + str (augmentations ))
5783
58- if self .load_proposals :
59- self .proposal_min_box_size = cfg .MODEL .PROPOSAL_GENERATOR .MIN_SIZE
60- self .proposal_topk = (
84+ @classmethod
85+ def from_config (cls , cfg , is_train : bool = True ):
86+ augs = utils .build_augmentation (cfg , is_train )
87+ if cfg .INPUT .CROP .ENABLED and is_train :
88+ augs .insert (0 , T .RandomCrop (cfg .INPUT .CROP .TYPE , cfg .INPUT .CROP .SIZE ))
89+ recompute_boxes = cfg .MODEL .MASK_ON
90+ else :
91+ recompute_boxes = False
92+
93+ ret = {
94+ "is_train" : is_train ,
95+ "augmentations" : augs ,
96+ "image_format" : cfg .INPUT .FORMAT ,
97+ "use_instance_mask" : cfg .MODEL .MASK_ON ,
98+ "instance_mask_format" : cfg .INPUT .MASK_FORMAT ,
99+ "use_keypoint" : cfg .MODEL .KEYPOINT_ON ,
100+ "recompute_boxes" : recompute_boxes ,
101+ }
102+ if cfg .MODEL .KEYPOINT_ON :
103+ ret ["keypoint_hflip_indices" ] = utils .create_keypoint_hflip_indices (cfg .DATASETS .TRAIN )
104+
105+ if cfg .MODEL .LOAD_PROPOSALS :
106+ ret ["precomputed_proposal_topk" ] = (
61107 cfg .DATASETS .PRECOMPUTED_PROPOSAL_TOPK_TRAIN
62108 if is_train
63109 else cfg .DATASETS .PRECOMPUTED_PROPOSAL_TOPK_TEST
64110 )
65- self . is_train = is_train
111+ return ret
66112
67113 def __call__ (self , dataset_dict ):
68114 """
@@ -74,7 +120,7 @@ def __call__(self, dataset_dict):
74120 """
75121 dataset_dict = copy .deepcopy (dataset_dict ) # it will be modified by code below
76122 # USER: Write your own image loading if it's not from a file
77- image = utils .read_image (dataset_dict ["file_name" ], format = self .img_format )
123+ image = utils .read_image (dataset_dict ["file_name" ], format = self .image_format )
78124 utils .check_image_size (dataset_dict , image )
79125
80126 # USER: Remove if you don't do semantic/panoptic segmentation.
@@ -84,7 +130,7 @@ def __call__(self, dataset_dict):
84130 sem_seg_gt = None
85131
86132 aug_input = T .StandardAugInput (image , sem_seg = sem_seg_gt )
87- transforms = aug_input .apply_augmentations (self .augmentation )
133+ transforms = aug_input .apply_augmentations (self .augmentations )
88134 image , sem_seg_gt = aug_input .image , aug_input .sem_seg
89135
90136 image_shape = image .shape [:2 ] # h, w
@@ -97,13 +143,9 @@ def __call__(self, dataset_dict):
97143
98144 # USER: Remove if you don't use pre-computed proposals.
99145 # Most users would not need this feature.
100- if self .load_proposals :
146+ if self .proposal_topk is not None :
101147 utils .transform_proposals (
102- dataset_dict ,
103- image_shape ,
104- transforms ,
105- proposal_topk = self .proposal_topk ,
106- min_box_size = self .proposal_min_box_size ,
148+ dataset_dict , image_shape , transforms , proposal_topk = self .proposal_topk
107149 )
108150
109151 if not self .is_train :
@@ -115,9 +157,9 @@ def __call__(self, dataset_dict):
115157 if "annotations" in dataset_dict :
116158 # USER: Modify this if you want to keep them for some reason.
117159 for anno in dataset_dict ["annotations" ]:
118- if not self .mask_on :
160+ if not self .use_instance_mask :
119161 anno .pop ("segmentation" , None )
120- if not self .keypoint_on :
162+ if not self .use_keypoint :
121163 anno .pop ("keypoints" , None )
122164
123165 # USER: Implement additional transformations if you have other types of data
@@ -129,15 +171,15 @@ def __call__(self, dataset_dict):
129171 if obj .get ("iscrowd" , 0 ) == 0
130172 ]
131173 instances = utils .annotations_to_instances (
132- annos , image_shape , mask_format = self .mask_format
174+ annos , image_shape , mask_format = self .instance_mask_format
133175 )
134176
135177 # After transforms such as cropping are applied, the bounding box may no longer
136178 # tightly bound the object. As an example, imagine a triangle object
137179 # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
138180 # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
139181 # the intersection of original bounding box and the cropping box.
140- if self .compute_tight_boxes and instances . has ( "gt_masks" ) :
182+ if self .recompute_boxes :
141183 instances .gt_boxes = instances .gt_masks .get_bounding_boxes ()
142184 dataset_dict ["instances" ] = utils .filter_empty_instances (instances )
143185 return dataset_dict
0 commit comments