Skip to content

Commit 0061bae

Browse files
committed
Add Reduce operators to TFLite
1 parent 5629901 commit 0061bae

2 files changed

Lines changed: 118 additions & 1 deletion

File tree

python/tvm/relay/frontend/tflite.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def __init__(self, model, subgraph, exp_tab):
7373
'POW': self.convert_pow,
7474
'MAXIMUM': self.convert_maximum,
7575
'MINIMUM': self.convert_minimum,
76+
'REDUCE_MIN': self._convert_reduce_min,
77+
'REDUCE_MAX': self._convert_reduce_max,
78+
'MEAN': self._convert_reduce_mean,
79+
'REDUCE_PROD': self._convert_reduce_prod,
7680
'FULLY_CONNECTED': self.convert_fully_connected,
7781
'PAD': self.convert_pad,
7882
'LOGISTIC': self.convert_logistic,
@@ -427,6 +431,48 @@ def convert_maximum(self, op):
427431
def convert_minimum(self, op):
428432
return self._convert_elemwise(_op.minimum, op)
429433

434+
def _convert_reduce(self, relay_op, op):
435+
"""Generic method to Convert TFLite MEAN operators"""
436+
try:
437+
from tflite.BuiltinOptions import BuiltinOptions
438+
from tflite.Operator import Operator
439+
from tflite.ReducerOptions import ReducerOptions
440+
except ImportError:
441+
raise ImportError("The tflite package must be installed")
442+
443+
assert isinstance(op, Operator)
444+
input_tensors = self.get_input_tensors(op)
445+
assert len(input_tensors) == 2, "input tensors length should be 2"
446+
447+
# input_tensor
448+
input_tensor = input_tensors[0]
449+
in_expr = self.get_expr(input_tensor.tensor_idx)
450+
451+
# axis
452+
axis = tuple(self.get_tensor_value(input_tensors[1]))
453+
454+
# Options - keep_dims (bool)
455+
assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
456+
reduce_options = ReducerOptions()
457+
op_options = op.BuiltinOptions()
458+
reduce_options.Init(op_options.Bytes, op_options.Pos)
459+
keep_dims = reduce_options.KeepDims()
460+
461+
out = relay_op(in_expr, axis, keep_dims)
462+
return out
463+
464+
def _convert_reduce_min(self, op):
465+
return self._convert_reduce(_op.reduce.min, op)
466+
467+
def _convert_reduce_max(self, op):
468+
return self._convert_reduce(_op.reduce.max, op)
469+
470+
def _convert_reduce_mean(self, op):
471+
return self._convert_reduce(_op.reduce.mean, op)
472+
473+
def _convert_reduce_prod(self, op):
474+
return self._convert_reduce(_op.reduce.prod, op)
475+
430476
def convert_fully_connected(self, op):
431477
"""Convert TFLite fully connected"""
432478
try:

tests/python/frontend/tflite/test_forward.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def test_forward_concatenation():
360360
# ---
361361

362362
def _test_elemwise(math_op, data, fused_activation_function=None):
363-
""" One iteration of add """
363+
""" One iteration of elemwise """
364364

365365
assert len(data) == 2
366366

@@ -457,6 +457,74 @@ def test_all_elemwise():
457457
_test_forward_elemwise(_test_maximum)
458458
_test_forward_elemwise(_test_minimum)
459459

460+
#######################################################################
461+
# Reduce
462+
# ------
463+
464+
def _test_reduce(math_op, data, keep_dims=None):
465+
""" One iteration of reduce """
466+
467+
assert len(data) == 2
468+
469+
# Test with tensor and constant
470+
with tf.Graph().as_default():
471+
in_data = array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')
472+
out = math_op(in_data, data[1], keep_dims)
473+
compare_tflite_with_tvm([data[0]], ['in:0'], [in_data], [out])
474+
475+
476+
#######################################################################
477+
# Reduce_min
478+
# ----------
479+
480+
def _test_reduce_min(data, keep_dims=None):
481+
""" One iteration of reduce_min """
482+
return _test_reduce(math_ops.reduce_min, data, keep_dims)
483+
484+
#######################################################################
485+
# Reduce_max
486+
# ----------
487+
488+
def _test_reduce_max(data, keep_dims=None):
489+
""" One iteration of reduce_max """
490+
return _test_reduce(math_ops.reduce_max, data, keep_dims)
491+
492+
#######################################################################
493+
# Reduce_mean
494+
# -----------
495+
496+
def _test_reduce_mean(data, keep_dims=None):
497+
""" One iteration of reduce_mean """
498+
return _test_reduce(math_ops.reduce_mean, data, keep_dims)
499+
500+
#######################################################################
501+
# Reduce_prod
502+
# -----------
503+
504+
def _test_reduce_prod(data, keep_dims=None):
505+
""" One iteration of reduce_prod """
506+
return _test_reduce(math_ops.reduce_prod, data, keep_dims)
507+
508+
509+
def _test_forward_reduce(testop):
510+
""" Reduce """
511+
data0 = [np.random.rand(16, 16, 16, 16).astype("float32"), None]
512+
data1 = [np.random.rand(16, 16, 16, 16).astype("float32"), np.array([1, 2], dtype=np.int32)]
513+
testop(data0)
514+
testop(data0, keep_dims=False)
515+
testop(data0, keep_dims=True)
516+
testop(data1)
517+
testop(data1, keep_dims=False)
518+
testop(data1, keep_dims=True)
519+
520+
521+
def test_all_reduce():
522+
_test_forward_reduce(_test_reduce_min)
523+
_test_forward_reduce(_test_reduce_max)
524+
_test_forward_reduce(_test_reduce_mean)
525+
_test_forward_reduce(_test_reduce_prod)
526+
527+
460528
#######################################################################
461529
# Squeeze
462530
# -------
@@ -695,6 +763,9 @@ def test_forward_ssd_mobilenet_v1():
695763
# Elemwise
696764
test_all_elemwise()
697765

766+
# Reduce
767+
test_all_reduce()
768+
698769
# End to End
699770
test_forward_mobilenet_v1()
700771
test_forward_mobilenet_v2()

0 commit comments

Comments
 (0)