Skip to content

Commit 525ce95

Browse files
[feat] add hstu rank model (#227)
1 parent 44f6c48 commit 525ce95

15 files changed

Lines changed: 874 additions & 44 deletions

tzrec/models/dlrm_hstu.py

Lines changed: 467 additions & 0 deletions
Large diffs are not rendered by default.

tzrec/models/dlrm_hstu_test.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) 2025, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
from torchrec import JaggedTensor, KeyedJaggedTensor
17+
18+
from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
19+
from tzrec.features.feature import create_features
20+
from tzrec.models.dlrm_hstu import DlrmHSTU
21+
from tzrec.ops import Kernel
22+
from tzrec.protos import (
23+
feature_pb2,
24+
loss_pb2,
25+
model_pb2,
26+
module_pb2,
27+
tower_pb2,
28+
)
29+
from tzrec.protos.models import multi_task_rank_pb2
30+
from tzrec.utils.state_dict_util import init_parameters
31+
from tzrec.utils.test_util import TestGraphType, create_test_model
32+
33+
34+
class DlrmHSTUTest(unittest.TestCase):
35+
@parameterized.expand(
36+
[[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]]
37+
)
38+
def test_dlrm_hstu(self, graph_type) -> None:
39+
feature_cfgs = [
40+
feature_pb2.FeatureConfig(
41+
id_feature=feature_pb2.IdFeature(
42+
feature_name="user_id", embedding_dim=16, num_buckets=100
43+
)
44+
),
45+
feature_pb2.FeatureConfig(
46+
id_feature=feature_pb2.IdFeature(
47+
feature_name="user_active_degree",
48+
embedding_dim=16,
49+
num_buckets=1000,
50+
)
51+
),
52+
feature_pb2.FeatureConfig(
53+
sequence_id_feature=feature_pb2.SequenceIdFeature(
54+
feature_name="video_id",
55+
embedding_dim=16,
56+
embedding_name="video_id_emb",
57+
num_buckets=1000,
58+
)
59+
),
60+
feature_pb2.FeatureConfig(
61+
sequence_id_feature=feature_pb2.SequenceIdFeature(
62+
feature_name="item_video_id",
63+
embedding_dim=16,
64+
embedding_name="video_id_emb",
65+
num_buckets=1000,
66+
)
67+
),
68+
feature_pb2.FeatureConfig(
69+
sequence_id_feature=feature_pb2.SequenceIdFeature(
70+
feature_name="action_timestamp"
71+
)
72+
),
73+
feature_pb2.FeatureConfig(
74+
sequence_id_feature=feature_pb2.SequenceIdFeature(
75+
feature_name="item_query_time"
76+
)
77+
),
78+
feature_pb2.FeatureConfig(
79+
sequence_id_feature=feature_pb2.SequenceIdFeature(
80+
feature_name="action_weight",
81+
num_buckets=1000,
82+
)
83+
),
84+
feature_pb2.FeatureConfig(
85+
sequence_id_feature=feature_pb2.SequenceIdFeature(
86+
feature_name="item_action_weight",
87+
num_buckets=1000,
88+
)
89+
),
90+
feature_pb2.FeatureConfig(
91+
sequence_raw_feature=feature_pb2.SequenceRawFeature(
92+
feature_name="watch_time"
93+
)
94+
),
95+
feature_pb2.FeatureConfig(
96+
sequence_raw_feature=feature_pb2.SequenceRawFeature(
97+
feature_name="item_target_watchtime"
98+
)
99+
),
100+
]
101+
features = create_features(feature_cfgs)
102+
feature_groups = [
103+
model_pb2.FeatureGroupConfig(
104+
group_name="contextual",
105+
feature_names=["user_id", "user_active_degree"],
106+
group_type=model_pb2.FeatureGroupType.SEQUENCE,
107+
),
108+
model_pb2.FeatureGroupConfig(
109+
group_name="uih",
110+
feature_names=[
111+
"video_id",
112+
],
113+
group_type=model_pb2.FeatureGroupType.SEQUENCE,
114+
),
115+
model_pb2.FeatureGroupConfig(
116+
group_name="candidate",
117+
feature_names=[
118+
"item_video_id",
119+
],
120+
group_type=model_pb2.FeatureGroupType.SEQUENCE,
121+
),
122+
]
123+
124+
model_config = model_pb2.ModelConfig(
125+
feature_groups=feature_groups,
126+
dlrm_hstu=multi_task_rank_pb2.DlrmHSTU(
127+
uih_id_feature_name="video_id",
128+
uih_action_time_feature_name="action_timestamp",
129+
uih_action_weight_feature_name="action_weight",
130+
uih_watchtime_feature_name="watch_time",
131+
candidates_id_feature_name="item_video_id",
132+
candidates_query_time_feature_name="item_query_time",
133+
candidates_action_weight_feature_name="item_action_weight",
134+
candidates_watchtime_feature_name="item_target_watchtime",
135+
hstu=module_pb2.HSTU(
136+
stu=module_pb2.STU(
137+
embedding_dim=512,
138+
num_heads=4,
139+
hidden_dim=128,
140+
attention_dim=128,
141+
output_dropout_ratio=0.2,
142+
),
143+
positional_encoder=module_pb2.GRPositionalEncoder(
144+
num_position_buckets=8192,
145+
num_time_buckets=2048,
146+
use_time_encoding=True,
147+
),
148+
input_preprocessor=module_pb2.GRInputPreprocessor(
149+
contextual_preprocessor=module_pb2.GRContextualPreprocessor(
150+
action_encoder=module_pb2.GRActionEncoder(
151+
action_embedding_dim=8,
152+
action_feature_name="action_weight",
153+
action_weights=[1, 2, 4],
154+
),
155+
action_mlp=module_pb2.GRContextualizedMLP(
156+
simple_mlp=module_pb2.GRSimpleContextualizedMLP(
157+
hidden_dim=256
158+
)
159+
),
160+
content_mlp=module_pb2.GRContextualizedMLP(
161+
simple_mlp=module_pb2.GRSimpleContextualizedMLP(
162+
hidden_dim=256
163+
)
164+
),
165+
)
166+
),
167+
output_postprocessor=module_pb2.GROutputPostprocessor(
168+
layernorm_postprocessor=module_pb2.GRLayerNormPostprocessor()
169+
),
170+
),
171+
fusion_mtl_tower=tower_pb2.FusionMTLTower(
172+
mlp=module_pb2.MLP(hidden_units=[512], activation="nn.SiLU"),
173+
task_configs=[
174+
tower_pb2.FusionSubTaskConfig(
175+
task_name="is_click",
176+
label_name="item_action_weight",
177+
task_bitmask=1,
178+
losses=[
179+
loss_pb2.LossConfig(
180+
binary_cross_entropy=loss_pb2.BinaryCrossEntropy()
181+
)
182+
],
183+
),
184+
tower_pb2.FusionSubTaskConfig(
185+
task_name="is_like",
186+
label_name="item_action_weight",
187+
task_bitmask=2,
188+
losses=[
189+
loss_pb2.LossConfig(
190+
binary_cross_entropy=loss_pb2.BinaryCrossEntropy()
191+
)
192+
],
193+
),
194+
tower_pb2.FusionSubTaskConfig(
195+
task_name="is_comment",
196+
label_name="item_action_weight",
197+
task_bitmask=4,
198+
losses=[
199+
loss_pb2.LossConfig(
200+
binary_cross_entropy=loss_pb2.BinaryCrossEntropy()
201+
)
202+
],
203+
),
204+
tower_pb2.FusionSubTaskConfig(
205+
task_name="watchtime",
206+
label_name="item_target_watchtime",
207+
losses=[loss_pb2.LossConfig(l2_loss=loss_pb2.L2Loss())],
208+
),
209+
],
210+
),
211+
max_seq_len=100,
212+
),
213+
)
214+
dlrm_hstu = DlrmHSTU(
215+
model_config=model_config,
216+
features=features,
217+
labels=["item_action_weight", "item_target_watchtime"],
218+
)
219+
dlrm_hstu.set_kernel(Kernel.PYTORCH)
220+
init_parameters(dlrm_hstu, device=torch.device("cpu"))
221+
dlrm_hstu = create_test_model(dlrm_hstu, graph_type)
222+
223+
sparse_feature = KeyedJaggedTensor.from_lengths_sync(
224+
keys=[
225+
"user_id",
226+
"user_active_degree",
227+
"video_id",
228+
"item_video_id",
229+
"action_weight",
230+
"item_action_weight",
231+
"action_timestamp",
232+
"item_query_time",
233+
],
234+
values=torch.tensor(list(range(37))),
235+
lengths=torch.tensor([1, 1, 1, 1, 2, 3, 2, 4, 2, 3, 2, 4, 2, 3, 2, 4]),
236+
)
237+
sequence_dense_features = {
238+
"watch_time": JaggedTensor(
239+
values=torch.tensor([[0.1], [0.2], [0.3], [0.4], [0.5]]),
240+
lengths=torch.tensor([2, 3]),
241+
),
242+
"item_target_watchtime": JaggedTensor(
243+
values=torch.tensor([[0.1], [0.2], [0.3], [0.4], [0.5], [0.6]]),
244+
lengths=torch.tensor([2, 4]),
245+
),
246+
}
247+
batch = Batch(
248+
sequence_dense_features=sequence_dense_features,
249+
sparse_features={BASE_DATA_GROUP: sparse_feature},
250+
labels={},
251+
)
252+
if graph_type == TestGraphType.JIT_SCRIPT:
253+
predictions = dlrm_hstu(batch.to_dict())
254+
else:
255+
predictions = dlrm_hstu(batch)
256+
self.assertEqual(predictions["logits_is_click"].size(), (6,))
257+
self.assertEqual(predictions["probs_is_click"].size(), (6,))
258+
self.assertEqual(predictions["logits_is_like"].size(), (6,))
259+
self.assertEqual(predictions["probs_is_like"].size(), (6,))
260+
self.assertEqual(predictions["logits_is_comment"].size(), (6,))
261+
self.assertEqual(predictions["probs_is_comment"].size(), (6,))
262+
263+
264+
if __name__ == "__main__":
265+
unittest.main()

tzrec/models/rank_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ def _update_tensor_dict(
3737
tensor_dict[key] = new_tensor
3838

3939

40+
def _is_classification_loss(loss_cfg: LossConfig) -> bool:
41+
loss_type = loss_cfg.WhichOneof("loss")
42+
return loss_type in [
43+
"binary_cross_entropy",
44+
"softmax_cross_entropy",
45+
"jrc_loss",
46+
"binary_focal_loss",
47+
]
48+
49+
4050
class RankModel(BaseModel):
4151
"""Base model for ranking.
4252

tzrec/modules/norm.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
import torch
1616

1717
from tzrec.modules.utils import BaseModule
18-
from tzrec.ops import Kernel
19-
from tzrec.ops.layer_norm import (
20-
layer_norm,
21-
swish_layer_norm,
22-
)
23-
from tzrec.ops.triton.triton_layer_norm import triton_rms_norm
18+
from tzrec.ops.layer_norm import layer_norm, rms_norm, swish_layer_norm
2419

2520

2621
class LayerNorm(BaseModule):
@@ -78,16 +73,9 @@ def __init__(
7873
self._eps = eps
7974
self._weight = torch.nn.Parameter(torch.ones(dim))
8075

81-
def _norm(self, x: torch.Tensor) -> torch.Tensor:
82-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self._eps)
83-
8476
def forward(self, x: torch.Tensor) -> torch.Tensor:
8577
"""Forward the module."""
86-
if self.kernel() == Kernel.TRITON:
87-
return triton_rms_norm(x, self._weight, self._eps)
88-
else:
89-
output = self._norm(x.float()).type_as(x)
90-
return output * self._weight
78+
return rms_norm(x=x, weight=self._weight, eps=self._eps, kernel=self.kernel())
9179

9280

9381
class SwishLayerNorm(BaseModule):

tzrec/ops/hstu_attention.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@
1717

1818
import torch
1919
from torch.fx._symbolic_trace import is_fx_tracing
20+
from torch.utils._triton import has_triton
2021

2122
from tzrec.ops import Kernel
2223
from tzrec.ops.pytorch.pt_hstu_attention import (
2324
pytorch_cached_hstu_mha,
2425
pytorch_hstu_mha,
2526
)
26-
from tzrec.ops.triton.triton_hstu_attention import (
27-
triton_cached_hstu_mha,
28-
triton_hstu_mha,
29-
)
27+
28+
if has_triton():
29+
from tzrec.ops.triton.triton_hstu_attention import (
30+
triton_cached_hstu_mha,
31+
triton_hstu_mha,
32+
)
33+
else:
34+
triton_cached_hstu_mha = pytorch_cached_hstu_mha
35+
triton_hstu_mha = pytorch_hstu_mha
3036
from tzrec.ops.utils import switch_to_contiguous_if_needed
3137

3238

tzrec/ops/hstu_compute.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.nn.functional as F
2020
from torch.fx._symbolic_trace import is_fx_tracing
21+
from torch.utils._triton import has_triton
2122

2223
from tzrec.ops import Kernel
2324
from tzrec.ops.hstu_attention import hstu_mha
@@ -26,12 +27,13 @@
2627
from tzrec.ops.pytorch.pt_hstu_linear import (
2728
pytorch_hstu_compute_output,
2829
)
29-
from tzrec.ops.triton.triton_hstu_linear import (
30-
triton_hstu_compute_output,
31-
)
32-
from tzrec.ops.triton.triton_hstu_preprocess_and_attention import (
33-
triton_hstu_preprocess_and_attention,
34-
)
30+
31+
if has_triton():
32+
from tzrec.ops.triton.triton_hstu_linear import (
33+
triton_hstu_compute_output,
34+
)
35+
else:
36+
triton_hstu_compute_output = pytorch_hstu_compute_output
3537

3638

3739
def hstu_compute_uqvk(
@@ -164,6 +166,10 @@ def hstu_preprocess_and_attention(
164166
"uvqk_weight.shape[1] must equal 2 * num_heads * (hidden_dim + attn_dim)",
165167
)
166168
if kernel == Kernel.TRITON and prefill is False:
169+
from tzrec.ops.triton.triton_hstu_preprocess_and_attention import (
170+
triton_hstu_preprocess_and_attention,
171+
)
172+
167173
u, attn_output = triton_hstu_preprocess_and_attention(
168174
x=x,
169175
norm_weight=norm_weight,

0 commit comments

Comments
 (0)