99# See the License for the specific language governing permissions and
1010# limitations under the License.
1111
12+ import inspect
1213import os
1314from typing import Any , Dict , List , Optional , Sequence
1415
@@ -142,9 +143,9 @@ def forward(
142143 Return:
143144 predictions (dict): a dict of predicted result.
144145 """
145- grouped_features = self .embedding_group (data , device )
146- y = self .dense (grouped_features )
147- return y
146+ emb_ebc , _ = self .embedding_group (data , device )
147+ outputs = self .dense (emb_ebc )
148+ return outputs
148149
149150
150151def get_trt_max_seq_len () -> int :
@@ -157,29 +158,36 @@ def get_trt_max_seq_len() -> int:
157158
158159
159160def export_model_trt (
160- model : nn .Module , data : Dict [str , torch .Tensor ], save_dir : str
161+ sparse_model : nn .Module ,
162+ dense_model : nn .Module ,
163+ data : Dict [str , torch .Tensor ],
164+ save_dir : str ,
161165) -> None :
162166 """Export trt model.
163167
164168 Args:
165- model (nn.Module): the model
169+ sparse_model (nn.Module): the sparse part
170+ dense_model (nn.Module): the dense part
166171 data (Dict[str, torch.Tensor]): the test data
167172 save_dir (str): model save dir
168173 """
169- # ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
170- emb_trace_gpu = ScriptWrapperList (model .model .embedding_group )
171- emb_res = emb_trace_gpu (data , "cuda:0" )
172- emb_trace_gpu = symbolic_trace (emb_trace_gpu )
173- emb_trace_gpu = torch .jit .script (emb_trace_gpu )
174+ emb_ebc , _ = sparse_model (data , "cuda:0" )
175+ sparse_model_traced = symbolic_trace (sparse_model )
176+
177+ with open (os .path .join (save_dir , "gm_sparse.code" ), "w" ) as f :
178+ f .write (sparse_model_traced .code )
179+
180+ sparse_model_scripted = torch .jit .script (sparse_model_traced )
174181
175182 # dynamic shapes
176183 max_batch_size = get_max_export_batch_size ()
177184 max_seq_len = get_trt_max_seq_len ()
178185 batch = torch .export .Dim ("batch" , min = 1 , max = max_batch_size )
179186 dynamic_shapes_list = []
180187 values_list_cuda = []
181- for i , value in enumerate (emb_res ):
182- v = value .detach ().to ("cuda:0" )
188+ key_list = []
189+ for i , k in enumerate (emb_ebc .keys ()):
190+ v = emb_ebc [k ].detach ().to ("cuda:0" )
183191 dict_dy = {0 : batch }
184192 if v .dim () == 3 :
185193 # workaround -> 0/1 specialization
@@ -191,46 +199,64 @@ def export_model_trt(
191199 v = torch .zeros ((2 ,) + v .size ()[1 :], device = "cuda:0" , dtype = v .dtype )
192200 values_list_cuda .append (v )
193201 dynamic_shapes_list .append (dict_dy )
194-
202+ key_list . append ( k )
195203 # convert dense
196- dense = model .model .dense
197- logger .info ("dense res: %s" , dense (values_list_cuda ))
198- dense_layer = symbolic_trace (dense )
199- dynamic_shapes = {"args" : dynamic_shapes_list }
204+ # logger.info("dense res: %s", dense_model(emb_ebc))
205+
206+ dense_layer = symbolic_trace (dense_model )
207+ dense_signature = inspect .signature (dense_model .forward )
208+ dense_arg_name = list (dense_signature .parameters .keys ())[0 ]
209+ dynamic_shapes = {}
210+ dynamic_shapes [dense_arg_name ] = {}
211+ for i , k in enumerate (key_list ):
212+ dynamic_shapes [dense_arg_name ].update ({k : dynamic_shapes_list [i ]})
213+
200214 exp_program = torch .export .export (
201- dense_layer , (values_list_cuda ,), dynamic_shapes = dynamic_shapes
215+ dense_layer , (emb_ebc ,), dynamic_shapes = dynamic_shapes
202216 )
203- dense_layer_trt = trt_convert (exp_program , values_list_cuda )
204- dict_res = dense_layer_trt (values_list_cuda )
205- logger .info ("dense trt res: %s" , dict_res )
217+ dense_layer_trt = trt_convert (exp_program , (emb_ebc ,))
218+ # logger.info("dense trt res: %s", dense_layer_trt(emb_ebc))
206219
220+ dense_layer_trt_traced = torch .jit .trace (
221+ dense_layer_trt , example_inputs = (emb_ebc ,), strict = False
222+ )
223+ with open (os .path .join (save_dir , "gm_dense.code" ), "w" ) as f :
224+ f .write (dense_layer_trt_traced .code )
225+
226+ dense_layer_trt_scripted = torch .jit .script (dense_layer_trt_traced )
207227 # save combined_model
208- combined_model = ScriptWrapperTRT (emb_trace_gpu , dense_layer_trt )
228+ combined_model = ScriptWrapperTRT (
229+ embedding_group = sparse_model_scripted ,
230+ dense = dense_layer_trt_scripted ,
231+ )
209232 result = combined_model (data , "cuda:0" )
210233 logger .info ("combined model result: %s" , result )
211234 # combined_model = symbolic_trace(combined_model)
212- combined_model = torch .jit .trace (
213- combined_model , example_inputs = ( data ,) , strict = False
214- )
235+ # combined_model = torch.jit.trace(
236+ # combined_model, example_inputs=data, strict=False
237+ # )
215238 scripted_model = torch .jit .script (combined_model )
216239 # pyre-ignore [16]
217240 scripted_model .save (os .path .join (save_dir , "scripted_model.pt" ))
218241
242+ with open (os .path .join (save_dir , "gm.code" ), "w" ) as f :
243+ f .write (scripted_model .code )
244+
219245 if is_debug_trt ():
220246 with profile (
221247 activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ],
222248 record_shapes = True ,
223249 ) as prof :
224250 with record_function ("model_inference_dense" ):
225- dict_res = dense ( values_list_cuda )
251+ _ = dense_model ( emb_ebc )
226252 logger .info (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 100 ))
227253
228254 with profile (
229255 activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ],
230256 record_shapes = True ,
231257 ) as prof :
232258 with record_function ("model_inference_dense_trt" ):
233- dict_res = dense_layer_trt (values_list_cuda )
259+ _ = dense_layer_trt (emb_ebc )
234260 logger .info (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 100 ))
235261
236262 model_gpu_combined = torch .jit .load (
@@ -243,7 +269,7 @@ def export_model_trt(
243269 record_shapes = True ,
244270 ) as prof :
245271 with record_function ("model_inference_combined_trt" ):
246- dict_res = model_gpu_combined (data )
272+ _ = model_gpu_combined (data )
247273 logger .info (prof .key_averages ().table (sort_by = "cuda_time_total" , row_limit = 100 ))
248274
249275 logger .info ("trt convert success" )
0 commit comments