Skip to content

Commit a16a85a

Browse files
authored
[feat] TensorRT export (#318)
1 parent 3e025e1 commit a16a85a

10 files changed

Lines changed: 247 additions & 527 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,5 @@ docs/source/intro.md
4040
docs/source/proto.html
4141

4242
.vscode/
43+
44+
tmp/

tzrec/acc/trt_utils.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import inspect
1213
import os
1314
from typing import Any, Dict, List, Optional, Sequence
1415

@@ -142,9 +143,9 @@ def forward(
142143
Return:
143144
predictions (dict): a dict of predicted result.
144145
"""
145-
grouped_features = self.embedding_group(data, device)
146-
y = self.dense(grouped_features)
147-
return y
146+
emb_ebc, _ = self.embedding_group(data, device)
147+
outputs = self.dense(emb_ebc)
148+
return outputs
148149

149150

150151
def get_trt_max_seq_len() -> int:
@@ -157,29 +158,36 @@ def get_trt_max_seq_len() -> int:
157158

158159

159160
def export_model_trt(
160-
model: nn.Module, data: Dict[str, torch.Tensor], save_dir: str
161+
sparse_model: nn.Module,
162+
dense_model: nn.Module,
163+
data: Dict[str, torch.Tensor],
164+
save_dir: str,
161165
) -> None:
162166
"""Export trt model.
163167
164168
Args:
165-
model (nn.Module): the model
169+
sparse_model (nn.Module): the sparse part
170+
dense_model (nn.Module): the dense part
166171
data (Dict[str, torch.Tensor]): the test data
167172
save_dir (str): model save dir
168173
"""
169-
# ScriptWrapperList for trace the ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
170-
emb_trace_gpu = ScriptWrapperList(model.model.embedding_group)
171-
emb_res = emb_trace_gpu(data, "cuda:0")
172-
emb_trace_gpu = symbolic_trace(emb_trace_gpu)
173-
emb_trace_gpu = torch.jit.script(emb_trace_gpu)
174+
emb_ebc, _ = sparse_model(data, "cuda:0")
175+
sparse_model_traced = symbolic_trace(sparse_model)
176+
177+
with open(os.path.join(save_dir, "gm_sparse.code"), "w") as f:
178+
f.write(sparse_model_traced.code)
179+
180+
sparse_model_scripted = torch.jit.script(sparse_model_traced)
174181

175182
# dynamic shapes
176183
max_batch_size = get_max_export_batch_size()
177184
max_seq_len = get_trt_max_seq_len()
178185
batch = torch.export.Dim("batch", min=1, max=max_batch_size)
179186
dynamic_shapes_list = []
180187
values_list_cuda = []
181-
for i, value in enumerate(emb_res):
182-
v = value.detach().to("cuda:0")
188+
key_list = []
189+
for i, k in enumerate(emb_ebc.keys()):
190+
v = emb_ebc[k].detach().to("cuda:0")
183191
dict_dy = {0: batch}
184192
if v.dim() == 3:
185193
# workaround -> 0/1 specialization
@@ -191,46 +199,64 @@ def export_model_trt(
191199
v = torch.zeros((2,) + v.size()[1:], device="cuda:0", dtype=v.dtype)
192200
values_list_cuda.append(v)
193201
dynamic_shapes_list.append(dict_dy)
194-
202+
key_list.append(k)
195203
# convert dense
196-
dense = model.model.dense
197-
logger.info("dense res: %s", dense(values_list_cuda))
198-
dense_layer = symbolic_trace(dense)
199-
dynamic_shapes = {"args": dynamic_shapes_list}
204+
# logger.info("dense res: %s", dense_model(emb_ebc))
205+
206+
dense_layer = symbolic_trace(dense_model)
207+
dense_signature = inspect.signature(dense_model.forward)
208+
dense_arg_name = list(dense_signature.parameters.keys())[0]
209+
dynamic_shapes = {}
210+
dynamic_shapes[dense_arg_name] = {}
211+
for i, k in enumerate(key_list):
212+
dynamic_shapes[dense_arg_name].update({k: dynamic_shapes_list[i]})
213+
200214
exp_program = torch.export.export(
201-
dense_layer, (values_list_cuda,), dynamic_shapes=dynamic_shapes
215+
dense_layer, (emb_ebc,), dynamic_shapes=dynamic_shapes
202216
)
203-
dense_layer_trt = trt_convert(exp_program, values_list_cuda)
204-
dict_res = dense_layer_trt(values_list_cuda)
205-
logger.info("dense trt res: %s", dict_res)
217+
dense_layer_trt = trt_convert(exp_program, (emb_ebc,))
218+
# logger.info("dense trt res: %s", dense_layer_trt(emb_ebc))
206219

220+
dense_layer_trt_traced = torch.jit.trace(
221+
dense_layer_trt, example_inputs=(emb_ebc,), strict=False
222+
)
223+
with open(os.path.join(save_dir, "gm_dense.code"), "w") as f:
224+
f.write(dense_layer_trt_traced.code)
225+
226+
dense_layer_trt_scripted = torch.jit.script(dense_layer_trt_traced)
207227
# save combined_model
208-
combined_model = ScriptWrapperTRT(emb_trace_gpu, dense_layer_trt)
228+
combined_model = ScriptWrapperTRT(
229+
embedding_group=sparse_model_scripted,
230+
dense=dense_layer_trt_scripted,
231+
)
209232
result = combined_model(data, "cuda:0")
210233
logger.info("combined model result: %s", result)
211234
# combined_model = symbolic_trace(combined_model)
212-
combined_model = torch.jit.trace(
213-
combined_model, example_inputs=(data,), strict=False
214-
)
235+
# combined_model = torch.jit.trace(
236+
# combined_model, example_inputs=data, strict=False
237+
# )
215238
scripted_model = torch.jit.script(combined_model)
216239
# pyre-ignore [16]
217240
scripted_model.save(os.path.join(save_dir, "scripted_model.pt"))
218241

242+
with open(os.path.join(save_dir, "gm.code"), "w") as f:
243+
f.write(scripted_model.code)
244+
219245
if is_debug_trt():
220246
with profile(
221247
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
222248
record_shapes=True,
223249
) as prof:
224250
with record_function("model_inference_dense"):
225-
dict_res = dense(values_list_cuda)
251+
_ = dense_model(emb_ebc)
226252
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
227253

228254
with profile(
229255
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
230256
record_shapes=True,
231257
) as prof:
232258
with record_function("model_inference_dense_trt"):
233-
dict_res = dense_layer_trt(values_list_cuda)
259+
_ = dense_layer_trt(emb_ebc)
234260
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
235261

236262
model_gpu_combined = torch.jit.load(
@@ -243,7 +269,7 @@ def export_model_trt(
243269
record_shapes=True,
244270
) as prof:
245271
with record_function("model_inference_combined_trt"):
246-
dict_res = model_gpu_combined(data)
272+
_ = model_gpu_combined(data)
247273
logger.info(prof.key_averages().table(sort_by="cuda_time_total", row_limit=100))
248274

249275
logger.info("trt convert success")

tzrec/models/multi_tower_din_trt.py

Lines changed: 0 additions & 159 deletions
This file was deleted.

tzrec/protos/model.proto

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ message ModelConfig {
6767

6868
TDM tdm = 400;
6969

70-
MultiTowerDINTRT multi_tower_din_trt =500;
71-
72-
RocketLaunching rocket_launching = 600;
70+
RocketLaunching rocket_launching = 500;
7371
}
7472

7573
optional uint32 num_class = 2 [default = 1];

tzrec/protos/models/rank_model.proto

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ message MultiTowerDIN {
3232
required MLP final = 3;
3333
}
3434

35-
message MultiTowerDINTRT {
36-
repeated Tower towers = 1;
37-
repeated DINTower din_towers = 2;
38-
required MLP final = 3;
39-
}
40-
4135

4236
message DLRM {
4337
required MLP dense_mlp = 1;

0 commit comments

Comments
 (0)