Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@

from . import _ffi_api

# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
SCHEDULER = copy_constants


class OptimizeLUTs(ExprMutator):
"""A pass to merge an identity operator with a LUT based activation function with
Expand Down Expand Up @@ -356,12 +362,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), mod.functions.items())
}
mod = mod.with_attr("device_contexts", device_contexts)

# We are currently using copy_constants scheduler In the long run,
# this should be a single intelligent and a composite scheduler
# that can perform scheduling based on user inputs such as
# scratch memory size.
mod = LowerToTIR(copy_constants)(mod)
mod = LowerToTIR(SCHEDULER)(mod)

return mod

Expand Down
22 changes: 9 additions & 13 deletions python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the binary_elementwise operators in TIR."""
from typing import Dict, Tuple
from typing import Tuple
import tvm
from .utils import get_outer_loops, get_op_attrs
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialActivation, SerialBinaryElementwise
from .producers_consumers import ProducersConsumers


def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:
Expand All @@ -42,22 +43,17 @@ def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var:


def get_binary_elementwise_params(
stmt: tvm.tir.AttrStmt,
producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]:
"""Get the parameters necessary to construct a call_extern for a binary_elementwise.

Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a binary elementwise loop nest.
producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.

Returns
-------
Expand All @@ -84,10 +80,10 @@ def get_binary_elementwise_params(
input_pointer, input_pointer1 = input_pointer1, input_pointer
output_pointer = inner.buffer.data
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
serial_ifm, _ = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
output_pointer, producers_consumers, stmt
)
# Get activation info
serial_activation = SerialActivation(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):

mod = tvm.tir.transform.Simplify()(mod)
mod = ethosu_passes.RemoveConcatenates()(mod)
mod = tvm.tir.transform.InjectRollingBuffer()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.Simplify()(mod)
Expand Down
15 changes: 6 additions & 9 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution


def get_conv2d_params(stmt, producers, consumers):
def get_conv2d_params(stmt, producers_consumers):
"""Get the parameters necessary to construct a call_extern for a 2D convolution.

Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a convolution loop nest.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.

Returns
-------
Expand All @@ -62,9 +59,9 @@ def get_conv2d_params(stmt, producers, consumers):
input_pointer = loads[1].buffer.data
output_pointer = stores[0].buffer.data
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
output_pointer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
Expand Down
20 changes: 8 additions & 12 deletions python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the depthwise convolution operators in TIR."""
from typing import Dict, Tuple
from typing import Tuple
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
Expand All @@ -27,25 +27,21 @@
SerialActivation,
Serial2DDepthwise,
)
from .producers_consumers import ProducersConsumers


def get_depthwise_conv2d_params(
stmt: tvm.tir.AttrStmt,
producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt],
stmt: tvm.tir.AttrStmt, producers_consumers: ProducersConsumers
) -> Tuple[Serial2DDepthwise, tvm.tir.Var, tvm.tir.Var]:
"""Get the parameters necessary to construct a call_extern for a depthwise_conv2d.

Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a depthwise loop nest.
producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt]
A dictionary to associate pointers with the loop nest
that consumes their values.
producers_consumers: ProducersConsumers
It associates pointers with the loop nest that produces
their values and with the loop nest that consumes their values.

Returns
-------
Expand All @@ -71,9 +67,9 @@ def get_depthwise_conv2d_params(
input_pointer = loads[1].buffer.data
output_pointer = stores[0].buffer.data
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers_consumers, stmt)
serial_ofm, serial_block_config, replace_pointer, is_allocator = get_ofm_params(
output_pointer, consumers, producers
output_pointer, producers_consumers, stmt
)
# Get kernel info
serial_kernel = SerialKernel(
Expand Down
Loading