2626from tvm .contrib import cublas
2727
2828from ...dataflow_pattern import is_op , wildcard
29+ from .te_target import lower_composite , relay_to_runtime
2930from .register import register_pattern_table
3031
3132
33+ tvm ._ffi .register_func ("relay.ext.cublas" , relay_to_runtime (tvm .target .cuda ()))
34+
35+
3236def partition_for_cublas (
3337 mod : tvm .IRModule , params : Optional [Dict [str , tvm .runtime .NDArray ]] = None
3438) -> tvm .IRModule :
@@ -111,51 +115,7 @@ def check_matmul_like(matched: relay.Call) -> bool:
111115 ]
112116
113117
114- _LowerFunc = Callable [[relay .Call , List [te .Tensor ]], te .Tensor ]
115- _LOWER_MAP : Dict [str , _LowerFunc ] = {}
116-
117-
118- def _lower_composite (comp_name : str ) -> Callable [[_LowerFunc ], _LowerFunc ]:
119- """Register a lowering function for a given composite function name."""
120-
121- def _register (f : _LowerFunc ) -> _LowerFunc :
122- _LOWER_MAP [comp_name ] = f
123- return f
124-
125- return _register
126-
127-
128- @tvm ._ffi .register_func ("relay.ext.cublas" )
129- def relay_to_runtime (partition : relay .Function ) -> tvm .runtime .Module :
130- """Compile cuBLAS Relay functions to a runtime module."""
131- assert isinstance (partition , relay .Function )
132- assert isinstance (partition .body , relay .Call )
133- assert isinstance (partition .body .op , relay .Function )
134-
135- global_name = str (partition .attrs .global_symbol )
136- target = tvm .target .cuda ()
137- comp_func = partition .body .op
138- comp_name = comp_func .attrs ["Composite" ]
139- assert comp_name in _LOWER_MAP
140- assert isinstance (comp_func .body , relay .Call )
141-
142- op = comp_func .body
143- inputs = []
144- for i , param in enumerate (comp_func .params ):
145- inputs .append (
146- te .placeholder (
147- param .checked_type .shape ,
148- name = f"input_{ i } " ,
149- dtype = param .checked_type .dtype ,
150- )
151- )
152-
153- output = _LOWER_MAP [comp_name ](op , inputs )
154- prim_func = te .create_prim_func (inputs + [output ])
155- return tvm .build (prim_func , target = target , name = global_name )
156-
157-
158- @_lower_composite ("cublas.matmul" )
118+ @lower_composite ("cublas.matmul" )
159119def _lower_matmul (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
160120 """Lower a matmul using cuBLAS."""
161121 return cublas .matmul (
@@ -167,7 +127,7 @@ def _lower_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
167127 )
168128
169129
170- @_lower_composite ("cublas.batch_matmul" )
130+ @lower_composite ("cublas.batch_matmul" )
171131def _lower_batch_matmul (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
172132 """Lower a batch_matmul using cuBLAS."""
173133 return cublas .batch_matmul (
@@ -179,7 +139,7 @@ def _lower_batch_matmul(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
179139 )
180140
181141
182- @_lower_composite ("cublas.dense" )
142+ @lower_composite ("cublas.dense" )
183143def _lower_dense (op : relay .Call , inputs : List [te .Tensor ]) -> te .Tensor :
184144 """Lower a dense using cuBLAS."""
185145 return cublas .matmul (
0 commit comments