Skip to content

Commit a0e6735

Browse files
authored
Gluon data 2.0: c++ dataloader and built-in image/bbox transforms (apache#17841)
* c++ dataloader and built-in image/bbox * update * fix error * fix import error * fix ci build * fix vs openmp loop type * fix warning as error with sign/unsign comp * sign/unsign comp * update to pytest * remove nose * fix tear_down * address comments * thread safe dataset * address comments * address comments * fix * serial pytest for data download
1 parent 23df47f commit a0e6735

52 files changed

Lines changed: 6863 additions & 286 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/mxnet/c_api.h

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ typedef void *ExecutorHandle;
8181
typedef void *DataIterCreator;
8282
/*! \brief handle to a DataIterator */
8383
typedef void *DataIterHandle;
84+
/*! \brief handle a dataset creator */
85+
typedef void *DatasetCreator;
86+
/*! \brief handle to a Dataset */
87+
typedef void *DatasetHandle;
88+
/*! \brief handle to a BatchifyFunction creator*/
89+
typedef void *BatchifyFunctionCreator;
90+
/*! \brief handle to a BatchifyFunction */
91+
typedef void *BatchifyFunctionHandle;
8492
/*! \brief handle to KVStore */
8593
typedef void *KVStoreHandle;
8694
/*! \brief handle to RecordIO */
@@ -2670,6 +2678,13 @@ MXNET_DLL int MXDataIterNext(DataIterHandle handle,
26702678
*/
26712679
MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
26722680

2681+
/*!
2682+
* \brief Call iterator.GetLenHint. Note that some iterators don't provide length.
2683+
* \param handle the handle to iterator
2684+
* \return 0 when success, -1 when failure happens
2685+
*/
2686+
MXNET_DLL int MXDataIterGetLenHint(DataIterHandle handle,
2687+
int64_t *len);
26732688
/*!
26742689
* \brief Get the handle to the NDArray of underlying data
26752690
* \param handle the handle pointer to the data iterator
@@ -2705,6 +2720,147 @@ MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle,
27052720
*/
27062721
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
27072722
NDArrayHandle *out);
2723+
/*!
2724+
* \brief Get the handles to specified underlying ndarrays of index
2725+
* \param handle the handle pointer to the data iterator
2726+
* \param num_outputs the length of outputs
2727+
* \param out the handle to an array of NDArrays that stores pointers to handles
2728+
* \return 0 when success, -1 when failure happens
2729+
*/
2730+
MXNET_DLL int MXDataIterGetItems(DataIterHandle handle,
2731+
int* num_outputs,
2732+
NDArrayHandle **outputs);
2733+
2734+
/*!
2735+
* \brief List all the available dataset entries
2736+
* \param out_size the size of returned datasets
2737+
* \param out_array the output dataset entries
2738+
* \return 0 when success, -1 when failure happens
2739+
*/
2740+
MXNET_DLL int MXListDatasets(uint32_t *out_size,
2741+
DatasetCreator **out_array);
2742+
/*!
2743+
* \brief Init an dataset, init with parameters
2744+
* the array size of passed in arguments
2745+
* \param handle of the dataset creator
2746+
* \param num_param number of parameter
2747+
* \param keys parameter keys
2748+
* \param vals parameter values
2749+
* \param out resulting dataset
2750+
* \return 0 when success, -1 when failure happens
2751+
*/
2752+
MXNET_DLL int MXDatasetCreateDataset(DatasetCreator handle,
2753+
uint32_t num_param,
2754+
const char **keys,
2755+
const char **vals,
2756+
DatasetHandle *out);
2757+
/*!
2758+
* \brief Get the detailed information about dataset.
2759+
* \param creator the DatasetCreator.
2760+
* \param name The returned name of the creator.
2761+
* \param description The returned description of the symbol.
2762+
* \param num_args Number of arguments.
2763+
* \param arg_names Name of the arguments.
2764+
* \param arg_type_infos Type informations about the arguments.
2765+
* \param arg_descriptions Description information about the arguments.
2766+
* \return 0 when success, -1 when failure happens
2767+
*/
2768+
MXNET_DLL int MXDatasetGetDatasetInfo(DatasetCreator creator,
2769+
const char **name,
2770+
const char **description,
2771+
uint32_t *num_args,
2772+
const char ***arg_names,
2773+
const char ***arg_type_infos,
2774+
const char ***arg_descriptions);
2775+
/*!
2776+
* \brief Free the handle to the IO module
2777+
* \param handle the handle pointer to the dataset
2778+
* \return 0 when success, -1 when failure happens
2779+
*/
2780+
MXNET_DLL int MXDatasetFree(DatasetHandle handle);
2781+
/*!
2782+
* \brief Get dataset overal length(size)
2783+
* \param handle the handle to dataset
2784+
* \param out return value of GetLen
2785+
* \return 0 when success, -1 when failure happens
2786+
*/
2787+
MXNET_DLL int MXDatasetGetLen(DatasetHandle handle,
2788+
uint64_t *out);
2789+
/*!
2790+
* \brief Get Output NDArray given specified indices
2791+
* \param handle the handle to dataset
2792+
* \param index the index of the dataset item to be retrieved
2793+
* \param num_outputs the number of output ndarrays
2794+
* \param outputs the pointers to handles of ndarrays
2795+
* \param is_scalar if not zeros then output should be casted to scalars
2796+
* \return 0 when success, -1 when failure happens
2797+
*/
2798+
MXNET_DLL int MXDatasetGetItems(DatasetHandle handle,
2799+
uint64_t index,
2800+
int* num_outputs,
2801+
NDArrayHandle **outputs);
2802+
2803+
/*!
2804+
* \brief List all the available batchify function entries
2805+
* \param out_size the size of returned batchify functions
2806+
* \param out_array the output batchify function entries
2807+
* \return 0 when success, -1 when failure happens
2808+
*/
2809+
MXNET_DLL int MXListBatchifyFunctions(uint32_t *out_size,
2810+
BatchifyFunctionCreator **out_array);
2811+
/*!
2812+
* \brief Init an batchify function, init with parameters
2813+
* the array size of passed in arguments
2814+
* \param handle of the batchify function creator
2815+
* \param num_param number of parameter
2816+
* \param keys parameter keys
2817+
* \param vals parameter values
2818+
* \param out resulting batchify function
2819+
* \return 0 when success, -1 when failure happens
2820+
*/
2821+
MXNET_DLL int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle,
2822+
uint32_t num_param,
2823+
const char **keys,
2824+
const char **vals,
2825+
BatchifyFunctionHandle *out);
2826+
/*!
2827+
* \brief Get the detailed information about batchify function.
2828+
* \param creator the batchifyFunctionCreator.
2829+
* \param name The returned name of the creator.
2830+
* \param description The returned description of the symbol.
2831+
* \param num_args Number of arguments.
2832+
* \param arg_names Name of the arguments.
2833+
* \param arg_type_infos Type informations about the arguments.
2834+
* \param arg_descriptions Description information about the arguments.
2835+
* \return 0 when success, -1 when failure happens
2836+
*/
2837+
MXNET_DLL int MXBatchifyFunctionGetFunctionInfo(BatchifyFunctionCreator creator,
2838+
const char **name,
2839+
const char **description,
2840+
uint32_t *num_args,
2841+
const char ***arg_names,
2842+
const char ***arg_type_infos,
2843+
const char ***arg_descriptions);
2844+
/*!
2845+
* \brief Invoke the Batchify Function
2846+
* \param handle the handle pointer to the batchify function
2847+
* \param batch_size the batch size
2848+
* \param num_output the number of ndarrays for output
2849+
* \param inputs the pointers to input ndarrays
2850+
* \param ouptuts the pointers to output ndarrays
2851+
* \return 0 when success, -1 when failure happens
2852+
*/
2853+
MXNET_DLL int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle,
2854+
int batch_size,
2855+
int num_output,
2856+
NDArrayHandle *inputs,
2857+
NDArrayHandle **outputs);
2858+
/*!
2859+
* \brief Free the handle to the IO module
2860+
* \param handle the handle pointer to the batchify function
2861+
* \return 0 when success, -1 when failure happens
2862+
*/
2863+
MXNET_DLL int MXBatchifyFunctionFree(BatchifyFunctionHandle handle);
27082864
//--------------------------------------------
27092865
// Part 6: basic KVStore interface
27102866
//--------------------------------------------

include/mxnet/io.h

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class IIterator : public dmlc::DataIter<DType> {
6161
inline void SetDataName(const std::string data_name) {
6262
data_names.push_back(data_name);
6363
}
64+
/*! \brief request iterator length hint for current epoch.
65+
* Note that the returned value can be < 0, indicating
66+
* that the length of iterator is unknown unless you went through all data.
67+
*/
68+
virtual int64_t GetLenHint(void) const {
69+
return -1;
70+
}
6471
}; // class IIterator
6572

6673
/*! \brief a single data instance */
@@ -104,7 +111,7 @@ struct DataIteratorReg
104111
*
105112
* \code
106113
* // example of registering a mnist iterator
107-
* REGISTER_IO_ITE(MNISTIter)
114+
* REGISTER_IO_ITER(MNISTIter)
108115
* .describe("Mnist data iterator")
109116
* .set_body([]() {
110117
* return new PrefetcherIter(new MNISTIter());
@@ -113,5 +120,94 @@ struct DataIteratorReg
113120
*/
114121
#define MXNET_REGISTER_IO_ITER(name) \
115122
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name)
123+
124+
/*!
125+
* \brief A random accessable dataset which provides GetLen() and GetItem().
126+
* Unlike DataIter, it's a static lookup storage which is friendly to random access.
127+
* The dataset itself should NOT contain data processing, which should be applied during
128+
* data augmentation or transformation processes.
129+
*/
130+
class Dataset {
131+
public:
132+
/*!
133+
* \brief Get the size of the dataset
134+
*/
135+
virtual uint64_t GetLen(void) const = 0;
136+
/*!
137+
* \brief Get the ndarray items given index in dataset
138+
* \param idx the integer index for required data
139+
* \param ret the returned ndarray items
140+
*/
141+
virtual bool GetItem(uint64_t idx, std::vector<NDArray>* ret) = 0;
142+
// virtual destructor
143+
virtual ~Dataset(void) {}
144+
}; // class Dataset
145+
146+
/*! \brief typedef the factory function of dataset */
147+
typedef std::function<Dataset *(
148+
const std::vector<std::pair<std::string, std::string> >&)> DatasetFactory;
149+
/*!
150+
* \brief Registry entry for Dataset factory functions.
151+
*/
152+
struct DatasetReg
153+
: public dmlc::FunctionRegEntryBase<DatasetReg,
154+
DatasetFactory> {
155+
};
156+
//--------------------------------------------------------------
157+
// The following part are API Registration of Datasets
158+
//--------------------------------------------------------------
159+
/*!
160+
* \brief Macro to register Datasets
161+
*
162+
* \code
163+
* // example of registering an image sequence dataset
164+
* REGISTER_IO_ITE(ImageSequenceDataset)
165+
* .describe("image sequence dataset")
166+
* .set_body([]() {
167+
* return new ImageSequenceDataset();
168+
* });
169+
* \endcode
170+
*/
171+
#define MXNET_REGISTER_IO_DATASET(name) \
172+
DMLC_REGISTRY_REGISTER(::mxnet::DatasetReg, DatasetReg, name)
173+
174+
class BatchifyFunction {
175+
public:
176+
/*! \brief Destructor */
177+
virtual ~BatchifyFunction(void) {}
178+
/*! \brief The batchify logic */
179+
virtual bool Batchify(const std::vector<std::vector<NDArray> >& inputs,
180+
std::vector<NDArray>* outputs) = 0;
181+
}; // class BatchifyFunction
182+
183+
using BatchifyFunctionPtr = std::shared_ptr<BatchifyFunction>;
184+
185+
/*! \brief typedef the factory function of data sampler */
186+
typedef std::function<BatchifyFunction *(
187+
const std::vector<std::pair<std::string, std::string> >&)> BatchifyFunctionFactory;
188+
/*!
189+
* \brief Registry entry for DataSampler factory functions.
190+
*/
191+
struct BatchifyFunctionReg
192+
: public dmlc::FunctionRegEntryBase<BatchifyFunctionReg,
193+
BatchifyFunctionFactory> {
194+
};
195+
//--------------------------------------------------------------
196+
// The following part are API Registration of Batchify Function
197+
//--------------------------------------------------------------
198+
/*!
199+
* \brief Macro to register Batchify Functions
200+
*
201+
* \code
202+
* // example of registering a Batchify Function
203+
* MXNET_REGISTER_IO_BATCHIFY_FUNCTION(StackBatchify)
204+
* .describe("Stack Batchify Function")
205+
* .set_body([]() {
206+
* return new StackBatchify();
207+
* });
208+
* \endcode
209+
*/
210+
#define MXNET_REGISTER_IO_BATCHIFY_FUNCTION(name) \
211+
DMLC_REGISTRY_REGISTER(::mxnet::BatchifyFunctionReg, BatchifyFunctionReg, name)
116212
} // namespace mxnet
117213
#endif // MXNET_IO_H_

python/mxnet/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ def _load_lib():
365365
ExecutorHandle = ctypes.c_void_p
366366
DataIterCreatorHandle = ctypes.c_void_p
367367
DataIterHandle = ctypes.c_void_p
368+
DatasetHandle = ctypes.c_void_p
369+
BatchifyFunctionhandle = ctypes.c_void_p
368370
KVStoreHandle = ctypes.c_void_p
369371
RecordIOHandle = ctypes.c_void_p
370372
RtcHandle = ctypes.c_void_p

python/mxnet/gluon/contrib/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
"""Contrib datasets."""
2121

2222
from . import text
23+
from . import vision
2324

2425
from .sampler import *
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# coding: utf-8
19+
# pylint: disable=wildcard-import
20+
"""Contrib vision utilities."""
21+
from .transforms import *
22+
from .dataloader import *

0 commit comments

Comments
 (0)