Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions tzrec/modules/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class DINEncoder(SequenceEncoder):
query_dim (int): query tensor channel dimension.
input(str): input feature group name.
attn_mlp (dict): target attention MLP module parameters.
max_seq_length (int): maximum sequence length.
"""

def __init__(
Expand All @@ -69,6 +70,7 @@ def __init__(
query_dim: int,
input: str,
attn_mlp: Dict[str, Any],
max_seq_length: int = 0,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
Expand All @@ -81,6 +83,7 @@ def __init__(
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
self._max_seq_length = max_seq_length

def output_dim(self) -> int:
"""Output dimension of the module."""
Expand All @@ -91,6 +94,9 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
if self._max_seq_length > 0:
sequence_length = torch.clamp_max(sequence_length, self._max_seq_length)
sequence = sequence[:, : self._max_seq_length, :]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(
max_seq_length, device=sequence_length.device
Expand Down Expand Up @@ -121,6 +127,7 @@ def __init__(
sequence_dim: int,
query_dim: int,
input: str,
max_seq_length: int = 0,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
Expand All @@ -129,6 +136,7 @@ def __init__(
self._query_name = f"{input}.query"
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
self._max_seq_length = max_seq_length

def output_dim(self) -> int:
"""Output dimension of the module."""
Expand All @@ -139,6 +147,9 @@ def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
query = sequence_embedded[self._query_name]
sequence = sequence_embedded[self._sequence_name]
sequence_length = sequence_embedded[self._sequence_length_name]
if self._max_seq_length > 0:
sequence_length = torch.clamp_max(sequence_length, self._max_seq_length)
sequence = sequence[:, : self._max_seq_length, :]
max_seq_length = sequence.size(1)
sequence_mask = fx_arange(max_seq_length, sequence_length.device).unsqueeze(
0
Expand All @@ -165,6 +176,7 @@ def __init__(
sequence_dim: int,
input: str,
pooling_type: str = "mean",
max_seq_length: int = 0,
**kwargs: Optional[Dict[str, Any]],
) -> None:
super().__init__(input)
Expand All @@ -176,6 +188,7 @@ def __init__(
], "only sum|mean pooling type supported now."
self._sequence_name = f"{input}.sequence"
self._sequence_length_name = f"{input}.sequence_length"
self._max_seq_length = max_seq_length

def output_dim(self) -> int:
"""Output dimension of the module."""
Expand All @@ -184,12 +197,14 @@ def output_dim(self) -> int:
def forward(self, sequence_embedded: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the module."""
sequence = sequence_embedded[self._sequence_name]
if self._max_seq_length > 0:
sequence = sequence[:, : self._max_seq_length, :]
feature = torch.sum(sequence, dim=1)
if self._pooling_type == "mean":
sequence_length = sequence_embedded[self._sequence_length_name]
sequence_length = torch.max(
sequence_length, torch.ones_like(sequence_length)
)
if self._max_seq_length > 0:
sequence_length = torch.clamp_max(sequence_length, self._max_seq_length)
sequence_length = torch.clamp_min(sequence_length, 1)
feature = feature / sequence_length.unsqueeze(1)
return feature

Expand Down
94 changes: 64 additions & 30 deletions tzrec/modules/sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@
class DINEncoderTest(unittest.TestCase):
@parameterized.expand(
[
[TestGraphType.NORMAL, False, False],
[TestGraphType.FX_TRACE, False, False],
[TestGraphType.JIT_SCRIPT, False, False],
[TestGraphType.NORMAL, True, False],
[TestGraphType.FX_TRACE, True, False],
[TestGraphType.JIT_SCRIPT, True, False],
[TestGraphType.NORMAL, False, True],
[TestGraphType.FX_TRACE, False, True],
[TestGraphType.JIT_SCRIPT, False, True],
[TestGraphType.NORMAL, False, False, 0],
[TestGraphType.FX_TRACE, False, False, 0],
[TestGraphType.JIT_SCRIPT, False, False, 0],
[TestGraphType.NORMAL, True, False, 0],
[TestGraphType.FX_TRACE, True, False, 0],
[TestGraphType.JIT_SCRIPT, True, False, 0],
[TestGraphType.NORMAL, False, True, 0],
[TestGraphType.FX_TRACE, False, True, 0],
[TestGraphType.JIT_SCRIPT, False, True, 0],
[TestGraphType.NORMAL, False, True, 3],
[TestGraphType.FX_TRACE, False, True, 3],
[TestGraphType.JIT_SCRIPT, False, True, 3],
]
)
def test_din_encoder(self, graph_type, use_bn, use_dice) -> None:
def test_din_encoder(self, graph_type, use_bn, use_dice, max_seq_length) -> None:
din = DINEncoder(
query_dim=16,
sequence_dim=16,
Expand All @@ -51,14 +54,24 @@ def test_din_encoder(self, graph_type, use_bn, use_dice) -> None:
use_bn=use_bn,
dropout_ratio=0.9,
),
max_seq_length=max_seq_length,
)
self.assertEqual(din.output_dim(), 16)
din = create_test_module(din, graph_type)
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, 10, 16),
"click_seq.sequence_length": torch.tensor([2, 3, 4, 5]),
}
if max_seq_length > 0:
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, max_seq_length, 16),
"click_seq.sequence_length": torch.clamp_max(
torch.tensor([2, 3, 4, 5]), max_seq_length
),
}
else:
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, 10, 16),
"click_seq.sequence_length": torch.tensor([2, 3, 4, 5]),
}
result = din(embedded)
self.assertEqual(result.size(), (4, 16))

Expand Down Expand Up @@ -138,31 +151,52 @@ def test_hstu_encoder_padding(self, graph_type) -> None:

class SimpleAttentionTest(unittest.TestCase):
@parameterized.expand(
[[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]]
[
[TestGraphType.NORMAL, 0],
[TestGraphType.FX_TRACE, 0],
[TestGraphType.JIT_SCRIPT, 0],
[TestGraphType.NORMAL, 3],
[TestGraphType.FX_TRACE, 3],
[TestGraphType.JIT_SCRIPT, 3],
]
)
def test_simple_attention(self, graph_type) -> None:
attn = SimpleAttention(
16,
16,
input="click_seq",
)
def test_simple_attention(self, graph_type, max_seq_length) -> None:
attn = SimpleAttention(16, 16, input="click_seq", max_seq_length=max_seq_length)
self.assertEqual(attn.output_dim(), 16)
attn = create_test_module(attn, graph_type)
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, 10, 16),
"click_seq.sequence_length": torch.tensor([2, 3, 4, 5]),
}
if max_seq_length > 0:
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, max_seq_length, 16),
"click_seq.sequence_length": torch.clamp_max(
torch.tensor([2, 3, 4, 5]), max_seq_length
),
}
else:
embedded = {
"click_seq.query": torch.randn(4, 16),
"click_seq.sequence": torch.randn(4, 10, 16),
"click_seq.sequence_length": torch.tensor([2, 3, 4, 5]),
}
result = attn(embedded)
self.assertEqual(result.size(), (4, 16))


class PoolingEncoderTest(unittest.TestCase):
@parameterized.expand(
[[TestGraphType.NORMAL], [TestGraphType.FX_TRACE], [TestGraphType.JIT_SCRIPT]]
[
[TestGraphType.NORMAL, 0],
[TestGraphType.FX_TRACE, 0],
[TestGraphType.JIT_SCRIPT, 0],
[TestGraphType.NORMAL, 3],
[TestGraphType.FX_TRACE, 3],
[TestGraphType.JIT_SCRIPT, 3],
]
)
def test_mean_pooling(self, graph_type) -> None:
attn = PoolingEncoder(16, input="click_seq", pooling_type="mean")
def test_mean_pooling(self, graph_type, max_seq_length) -> None:
attn = PoolingEncoder(
16, input="click_seq", pooling_type="mean", max_seq_length=max_seq_length
)
self.assertEqual(attn.output_dim(), 16)
attn = create_test_module(attn, graph_type)
sequence_length = torch.tensor([2, 3, 4, 5])
Expand Down
6 changes: 6 additions & 0 deletions tzrec/protos/seq_encoder.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ message DINEncoder {
required string input = 2;
// mlp config for target attention score
required MLP attn_mlp = 3;
// maximum sequence length
optional int32 max_seq_length = 6 [default = 0];
}

message SimpleAttention {
// seq encoder name
optional string name = 1;
// sequence feature name
required string input = 2;
// maximum sequence length
optional int32 max_seq_length = 6 [default = 0];
}

message PoolingEncoder {
Expand All @@ -26,6 +30,8 @@ message PoolingEncoder {
required string input = 2;
// pooling type, sum or mean
optional string pooling_type = 3 [default = 'mean'];
// maximum sequence length
optional int32 max_seq_length = 6 [default = 0];
}

message MultiWindowDINEncoder {
Expand Down
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.7.11"
__version__ = "0.7.12"