@@ -81,6 +81,36 @@ def transform_function(
8181 return RemoveDropout ().visit (func )
8282
8383
84+ class BroadcastInputs (ExprMutator ):
85+ """
86+ Binary operators need broadcasting for CLML.
87+ """
88+
89+ def visit_call (self , call ):
90+ if call .op .name in ["add" , "subtract" , "multiply" , "divide" , "maximum" , "minimum" ]:
91+ new_fn = self .visit (call .op )
92+ call_shape = call .checked_type .shape
93+ lhs = call .args [0 ]
94+ rhs = call .args [1 ]
95+ lhs_shape = lhs .checked_type .shape
96+ rhs_shape = rhs .checked_type .shape
97+ if list (call_shape ) != list (lhs_shape ):
98+ lhs = relay .broadcast_to (self .visit (lhs ), call_shape )
99+ if list (call_shape ) != list (rhs_shape ):
100+ rhs = relay .broadcast_to (self .visit (rhs ), call_shape )
101+ args = [lhs , rhs ]
102+ return Call (new_fn , args , call .attrs )
103+ return super ().visit_call (call )
104+
105+
106+ @transform .function_pass (opt_level = 0 )
107+ class BinaryOpBroadcaster :
108+ def transform_function (
109+ self , func : relay .function .Function , mod : tvm .IRModule , _ : tvm .transform .PassContext
110+ ) -> relay .function .Function :
111+ return BroadcastInputs ().visit (func )
112+
113+
84114def partition_for_clml (mod , params = None , ** opts ):
85115 """Partition the graph greedily offloading supported
86116 operators to CLML Library.
@@ -104,6 +134,7 @@ def partition_for_clml(mod, params=None, **opts):
104134 [
105135 transform .InferType (),
106136 RemoveDropoutPass (),
137+ BinaryOpBroadcaster (),
107138 transform .FoldConstant (),
108139 transform .MergeComposite (clml_pattern_table ()),
109140 transform .AnnotateTarget ("clml" , False ),
@@ -261,8 +292,6 @@ def concat_pattern():
261292 def dense_pattern ():
262293 """Create a dense pattern."""
263294 pattern = is_op ("nn.dense" )(wildcard (), is_constant ())
264- pattern = pattern .optional (lambda x : is_op ("add" )(x , is_constant ()))
265- pattern = pattern .optional (lambda x : is_op ("nn.bias_add" )(x , is_constant ()))
266295 return pattern
267296
268297 def pad_pattern ():
@@ -344,9 +373,19 @@ def check_conv_transpose(extract):
344373
345374 def check_binary_op (extract ):
346375 call = extract
347- if len (call .args [1 ].checked_type .shape ) > 0 :
348- return True
349- return False
376+ # Scalers are not supported
377+ if len (call .args [1 ].checked_type .shape ) == 0 :
378+ return False
379+
380+ for arg in call .args :
381+ # Avoid any operators with dtype Int64
382+ if arg .checked_type .dtype == "int64" :
383+ return False
384+ # No support for batch> 1
385+ if arg .checked_type .shape [0 ] > 1 :
386+ return False
387+
388+ return True
350389
351390 def check_pad_op (extract ):
352391 call = extract
@@ -377,6 +416,20 @@ def check_concat_op(extract):
377416 return True
378417
379418 def check_default_op (extract ):
419+ call = extract
420+ # Avoid any operators with dtype Int64
421+ for arg in call .args :
422+ if arg .checked_type .dtype == "int64" :
423+ return False
424+ return True
425+
426+ def check_batch_matmul_op (extract ):
427+ call = extract
428+ # Only support single Matmul
429+ if call .args [0 ].checked_type .shape [0 ] > 1 :
430+ return False
431+ if call .args [1 ].checked_type .shape [0 ] > 1 :
432+ return False
380433 return True
381434
382435 return [
@@ -394,7 +447,7 @@ def check_default_op(extract):
394447 ("clml.minimum" , is_op ("minimum" )(wildcard (), wildcard ()), check_binary_op ),
395448 ("clml.maximum" , is_op ("maximum" )(wildcard (), wildcard ()), check_binary_op ),
396449 ("clml.softmax" , is_op ("nn.softmax" )(wildcard ()), check_softmax_op ),
397- ("clml.reshape" , is_op ("reshape" )(wildcard ()), check_default_op ),
450+ # ("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
398451 ("clml.avg_pool2d" , is_op ("nn.avg_pool2d" )(wildcard ()), check_default_op ),
399452 ("clml.max_pool2d" , is_op ("nn.max_pool2d" )(wildcard ()), check_default_op ),
400453 ("clml.global_avg_pool2d" , is_op ("nn.global_avg_pool2d" )(wildcard ()), check_default_op ),
@@ -404,6 +457,11 @@ def check_default_op(extract):
404457 ("clml.batch_flatten" , is_op ("nn.batch_flatten" )(wildcard ()), check_default_op ),
405458 ("clml.depth_to_space" , is_op ("nn.depth_to_space" )(wildcard ()), check_default_op ),
406459 ("clml.upsampling" , is_op ("nn.upsampling" )(wildcard ()), check_upsampling_op ),
460+ (
461+ "clml.batch_matmul" ,
462+ is_op ("nn.batch_matmul" )(wildcard (), wildcard ()),
463+ check_batch_matmul_op ,
464+ ),
407465 ]
408466
409467
@@ -570,7 +628,9 @@ def __init__(self, cmod):
570628 runner.MakeDense($input_tensor,
571629 $weight_tensor,
572630 $output_tensor,
573- $bias_tensor, "$dtype");"""
631+ std::vector<cl_uint> ({$in_shape}),
632+ std::vector<cl_uint> ({$wt_shape}),
633+ "$dtype");"""
574634 )
575635 self .MakeSoftMax = Template (
576636 """
@@ -641,13 +701,12 @@ def __init__(self, cmod):
641701 " Output Count : $output_count\\ n"
642702 ' Input MetaInfo\\ n$input_meta\\ n Output MetaInfo\\ n$output_meta");'
643703 )
644-
645704 self .MakeInputMetaInfo = Template (
646- " Input: $in_name\\ n Dtype : $dtype\\ n Shape : [$shape]"
705+ " Input: $in_name\\ n Dtype : $dtype\\ n Shape : [$shape]\\ n "
647706 )
648707
649708 self .MakeOutputMetaInfo = Template (
650- " Output: $out_name\\ n Dtype : $dtype\\ n Shape : [$shape]"
709+ " Output: $out_name\\ n Dtype : $dtype\\ n Shape : [$shape]\\ n "
651710 )
652711
653712 def get_src (self ):
@@ -666,23 +725,40 @@ def get_tensor_from_map(
666725 else :
667726 node = self .nodes [node_seq ]
668727 dtype = str (node ["attrs" ]["dtype" ][0 ][0 ])
728+ if node ["op" ] == "input" :
729+ self .clml_code .append ("// Input Node" )
730+ node_out_name = self .sub_module_name + "_" + "input_" + str (node_seq )
731+ else :
732+ node_out_name = node ["name" ]
669733 if shape is None :
670734 shape = str (tuple (node ["attrs" ]["shape" ][0 ][0 ]))[1 :- 1 ]
671735
672736 self .clml_code .append (
673737 self .MakeCLMLTensor .substitute (
674- name = node [ "name" ] , shape = shape , dtype = dtype , layout = layout
738+ name = node_out_name , shape = shape , dtype = dtype , layout = layout
675739 )
676740 )
677741 self .clml_code .append (
678- self .MapInsert .substitute (nid = node [ "name" ] , tensor_desc = node [ "name" ] )
742+ self .MapInsert .substitute (nid = node_out_name , tensor_desc = node_out_name )
679743 )
744+ if node ["op" ] == "input" :
745+ self .clml_code .append (
746+ Template ("runner.inputs.push_back($clml_input);" ).substitute (
747+ clml_input = node_out_name
748+ )
749+ )
750+ self .input_meta .append (
751+ self .MakeInputMetaInfo .substitute (
752+ in_name = node_out_name , dtype = dtype , shape = shape
753+ )
754+ )
755+
680756 if self .nodes [node_seq ]["op" ] == "const" :
681757 self .clml_code .append (
682758 Template ('runner.consts.push_back("$nid");' ).substitute (nid = node ["name" ])
683759 )
684- self .node_map [node_seq ] = node [ "name" ]
685- return node [ "name" ]
760+ self .node_map [node_seq ] = node_out_name
761+ return node_out_name
686762
687763 def make_output_tensor (
688764 node , node_seq , shape = None , layout = "CL_TENSOR_LAYOUT_OPTIMAL_QCOM" , dtype = "float32"
@@ -697,40 +773,13 @@ def make_output_tensor(
697773 name = node_out_name ,
698774 shape = shape ,
699775 dtype = dtype ,
700- layout = "CL_TENSOR_LAYOUT_OPTIMAL_QCOM" ,
776+ layout = layout ,
701777 )
702778 )
703779 return node_out_name
704780
705781 for node_seq , node in enumerate (self .nodes ):
706- if node ["op" ] == "input" :
707- self .clml_code .append ("// Input Node" )
708- dtype = str (node ["attrs" ]["dtype" ][0 ][0 ])
709- shape = str (tuple (node ["attrs" ]["shape" ][0 ][0 ]))[1 :- 1 ]
710- node_out_name = self .sub_module_name + "_" + "input_" + str (node_seq )
711- self .clml_code .append (
712- self .MakeCLMLTensor .substitute (
713- name = node_out_name ,
714- shape = shape ,
715- dtype = dtype ,
716- layout = "CL_TENSOR_LAYOUT_OPTIMAL_QCOM" ,
717- )
718- )
719- self .clml_code .append (
720- self .MapInsert .substitute (nid = node_out_name , tensor_desc = node_out_name )
721- )
722- self .clml_code .append (
723- Template ("runner.inputs.push_back($clml_input);" ).substitute (
724- clml_input = node_out_name
725- )
726- )
727- self .node_map [node_seq ] = node_out_name
728- self .input_meta .append (
729- self .MakeInputMetaInfo .substitute (
730- in_name = node_out_name , dtype = dtype , shape = shape
731- )
732- )
733- elif node ["op" ] == "kernel" :
782+ if node ["op" ] == "kernel" :
734783 self .clml_code .append ("// Kernel Node : " + node ["name" ])
735784 if node ["name" ] == "nn.conv2d" or node ["name" ] == "nn.depthwise_conv2d" :
736785 if "padding" in node ["attrs" ]:
@@ -791,6 +840,7 @@ def make_output_tensor(
791840 bn_shape = [1 , 1 , 1 , 1 ]
792841 bn_node = self .nodes [node ["inputs" ][bn_index ][0 ]]
793842 bn_shape [axis ] = bn_node ["attrs" ]["shape" ][0 ][0 ]
843+ dtype = bn_node ["attrs" ]["dtype" ][0 ][0 ]
794844
795845 bn_scale_tensor = get_tensor_from_map (
796846 node ["inputs" ][bn_index ][0 ],
@@ -858,6 +908,7 @@ def make_output_tensor(
858908 bn_shape = [1 , 1 , 1 , 1 ]
859909 bn_node = self .nodes [node ["inputs" ][0 ][0 ]]
860910 bn_shape [axis ] = bn_node ["attrs" ]["shape" ][0 ][0 ]
911+ dtype = bn_node ["attrs" ]["dtype" ][0 ][0 ]
861912 bn_scale_tensor = get_tensor_from_map (
862913 node ["inputs" ][0 ][0 ], shape = str (tuple (bn_shape ))[1 :- 1 ], dtype = dtype
863914 )
@@ -947,26 +998,26 @@ def make_output_tensor(
947998 in_shape = tuple (in_node ["attrs" ]["shape" ][0 ][0 ])
948999 wt_shape = tuple (in_node ["attrs" ]["shape" ][0 ][0 ])
9491000 input_tensor = get_tensor_from_map (
950- node ["inputs" ][0 ][0 ], shape = str ( tuple ([ 1 , in_shape [ 1 ], 1 , 1 ]))[ 1 : - 1 ]
1001+ node ["inputs" ][0 ][0 ], layout = "CL_TENSOR_LAYOUT_NCHW_QCOM"
9511002 )
9521003 weight_tensor = get_tensor_from_map (
9531004 node ["inputs" ][1 ][0 ],
954- shape = str (tuple ([wt_shape [0 ], wt_shape [1 ], 1 , 1 ]))[1 :- 1 ],
1005+ shape = str (tuple ([1 , 1 , wt_shape [0 ], wt_shape [1 ]]))[1 :- 1 ],
1006+ layout = "CL_TENSOR_LAYOUT_NCHW_QCOM" ,
9551007 )
956- if len (node ["inputs" ]) == 3 :
957- bias_tensor = "runner.unusedTensor"
958- else :
959- bias_tensor = get_tensor_from_map (node ["inputs" ][2 ][0 ])
960-
9611008 node_out_name = make_output_tensor (
962- node , node_seq , shape = str (tuple ([1 , wt_shape [0 ], 1 , 1 ]))[1 :- 1 ]
1009+ node ,
1010+ node_seq ,
1011+ shape = str (tuple ([in_shape [0 ], wt_shape [0 ], 1 , 1 ]))[1 :- 1 ],
1012+ layout = "CL_TENSOR_LAYOUT_NCHW_QCOM" ,
9631013 )
9641014 self .clml_code .append (
9651015 self .MakeDense .substitute (
9661016 input_tensor = input_tensor ,
9671017 weight_tensor = weight_tensor ,
9681018 output_tensor = node_out_name ,
969- bias_tensor = bias_tensor ,
1019+ in_shape = str (in_shape )[1 :- 1 ],
1020+ wt_shape = str (wt_shape )[1 :- 1 ],
9701021 dtype = node ["attrs" ]["dtype" ][0 ][0 ],
9711022 )
9721023 )
@@ -1045,7 +1096,7 @@ def make_output_tensor(
10451096 )
10461097 self .node_map [node_seq ] = node_out_name
10471098
1048- elif node ["op" ] != "const" :
1099+ elif node ["op" ] not in [ "const" , "input" ] :
10491100 print ("Unknown Node type:" , node ["op" ])
10501101
10511102 # Populate outputs
@@ -1086,8 +1137,8 @@ def make_output_tensor(
10861137 name = self .sub_module_name ,
10871138 input_count = len (self .input_meta ),
10881139 output_count = len (self .output_meta ),
1089- input_meta = "\n " .join (self .input_meta ),
1090- output_meta = "\n " .join (self .output_meta ),
1140+ input_meta = "\\ \ n " .join (self .input_meta ),
1141+ output_meta = "\\ \ n " .join (self .output_meta ),
10911142 )
10921143 )
10931144
0 commit comments