Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 14 additions & 22 deletions tzrec/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ def __init__(
self._group_name_to_seq_encoder_configs = defaultdict(list)
self._grouped_features_keys = list()

seq_group_names = []
for feature_group in feature_groups:
group_name = feature_group.group_name
self._respect_and_supplement_feature_group(feature_group)
self._add_feature_group_sign_for_sequence_groups(feature_group)
self._inspect_and_supplement_feature_group(feature_group, seq_group_names)
# self._add_feature_group_sign_for_sequence_groups(feature_group)
features_data_group = defaultdict(list)
for feature_name in feature_group.feature_names:
feature = self._name_to_feature[feature_name]
Expand Down Expand Up @@ -241,10 +242,10 @@ def grouped_features_keys(self) -> List[str]:
"""grouped_features_keys."""
return self._grouped_features_keys

def _respect_and_supplement_feature_group(
self, feature_group: FeatureGroupConfig
def _inspect_and_supplement_feature_group(
self, feature_group: FeatureGroupConfig, seq_group_names: List[str]
) -> None:
"""Respect feature group sequence_groups and sequence_encoders."""
"""Inspect feature group sequence_groups and sequence_encoders."""
group_name = feature_group.group_name
sequence_groups = list(feature_group.sequence_groups)
sequence_encoders = list(feature_group.sequence_encoders)
Expand Down Expand Up @@ -273,6 +274,14 @@ def _respect_and_supplement_feature_group(
):
sequence_groups[0].group_name = group_name

for sequence_group in sequence_groups:
if sequence_group.group_name in seq_group_names:
raise ValueError(
f"has repeat sequences groups_name: {sequence_group.group_name}"
)
else:
seq_group_names.append(sequence_group.group_name)

group_has_encoder = {
sequence_group.group_name: False for sequence_group in sequence_groups
}
Expand Down Expand Up @@ -305,23 +314,6 @@ def _respect_and_supplement_feature_group(
f"sequence_groups and sequence_encoders must configured in DEEP"
)

def _add_feature_group_sign_for_sequence_groups(
self, feature_group: FeatureGroupConfig
) -> None:
"""Assign sequence_groups and sequence_encoder relation group name."""
group_name = feature_group.group_name
sequence_groups = list(feature_group.sequence_groups)
sequence_encoders = list(feature_group.sequence_encoders)
if len(sequence_groups) > 0:
for sequence_group in sequence_groups:
sequence_group.group_name = (
group_name + "___" + sequence_group.group_name
)
for sequence_encoder in sequence_encoders:
seq_type = sequence_encoder.WhichOneof("seq_module")
seq_config = getattr(sequence_encoder, seq_type)
seq_config.input = group_name + "___" + seq_config.input

def group_names(self) -> List[str]:
"""Feature group names."""
return list(self._name_to_feature_group.keys())
Expand Down
4 changes: 2 additions & 2 deletions tzrec/modules/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def test_embedding_group(self, graph_type) -> None:
group_type=model_pb2.FeatureGroupType.DEEP,
sequence_groups=[
model_pb2.SeqGroupConfig(
group_name="buy_seq",
group_name="only_buy_seq",
feature_names=[
"cat_a",
"int_a",
Expand All @@ -489,7 +489,7 @@ def test_embedding_group(self, graph_type) -> None:
sequence_encoders=[
seq_encoder_pb2.SeqEncoderConfig(
simple_attention=seq_encoder_pb2.SimpleAttention(
input="buy_seq"
input="only_buy_seq"
)
),
],
Expand Down