Skip to content

Commit a52dcfa

Browse files
[feat] support ignore unused features in negative sampler (#102)
1 parent dcfe077 commit a52dcfa

2 files changed

Lines changed: 66 additions & 13 deletions

File tree

tzrec/datasets/sampler.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tzrec.protos import sampler_pb2
3030
from tzrec.utils.env_util import use_hash_node_id
3131
from tzrec.utils.load_class import get_register_class_meta
32+
from tzrec.utils.logging_util import logger
3233
from tzrec.utils.misc_util import get_free_port
3334

3435

@@ -205,12 +206,25 @@ def __init__(
205206
self._attr_types = []
206207
self._attr_gl_types = []
207208
self._attr_np_types = []
209+
self._valid_attr_names = []
210+
self._ignore_attr_names = set()
208211
for field_name in config.attr_fields:
209-
field = input_fields[field_name]
212+
if field_name in input_fields:
213+
field = input_fields[field_name]
214+
self._valid_attr_names.append(field.name)
215+
else:
216+
field = pa.field(name=field_name, type=pa.string())
217+
self._ignore_attr_names.add(field_name)
210218
self._attr_names.append(field.name)
211219
self._attr_types.append(field.type)
212220
self._attr_gl_types.append(_get_gl_type(field.type))
213221
self._attr_np_types.append(_get_np_type(field.type))
222+
if len(self._ignore_attr_names) > 0:
223+
logger.warning(
224+
f"Features {self._ignore_attr_names} in "
225+
# pyre-ignore [16]
226+
f"{self.__class__.__name__} will be ignored."
227+
)
214228

215229
if config.HasField("field_delimiter"):
216230
gl.set_field_delimiter(config.field_delimiter)
@@ -268,9 +282,12 @@ def _parse_nodes(self, nodes: gl.Nodes) -> List[pa.Array]:
268282
int_idx = 0
269283
float_idx = 0
270284
string_idx = 0
271-
for attr_type, attr_gl_type, attr_np_type in zip(
272-
self._attr_types, self._attr_gl_types, self._attr_np_types
285+
for attr_name, attr_type, attr_gl_type, attr_np_type in zip(
286+
self._attr_names, self._attr_types, self._attr_gl_types, self._attr_np_types
273287
):
288+
if attr_name in self._ignore_attr_names:
289+
string_idx += 1
290+
continue
274291
if attr_gl_type == "int":
275292
feature = nodes.int_attrs[:, :, int_idx]
276293
int_idx += 1
@@ -295,9 +312,12 @@ def _parse_sparse_nodes(
295312
int_idx = 0
296313
float_idx = 0
297314
string_idx = 0
298-
for attr_type, attr_gl_type, attr_np_type in zip(
299-
self._attr_types, self._attr_gl_types, self._attr_np_types
315+
for attr_name, attr_type, attr_gl_type, attr_np_type in zip(
316+
self._attr_names, self._attr_types, self._attr_gl_types, self._attr_np_types
300317
):
318+
if attr_name in self._ignore_attr_names:
319+
string_idx += 1
320+
continue
301321
if attr_gl_type == "int":
302322
feature = nodes.int_attrs[:, int_idx]
303323
int_idx += 1
@@ -379,7 +399,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
379399
ids = np.pad(ids, (0, self._batch_size - len(ids)), "edge")
380400
nodes = self._sampler.get(ids)
381401
features = self._parse_nodes(nodes)
382-
result_dict = dict(zip(self._attr_names, features))
402+
result_dict = dict(zip(self._valid_attr_names, features))
383403
return result_dict
384404

385405
@property
@@ -470,7 +490,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
470490
dst_ids = np.pad(dst_ids, (0, self._batch_size - len(dst_ids)), "edge")
471491
nodes = self._sampler.get(src_ids, dst_ids)
472492
features = self._parse_nodes(nodes)
473-
result_dict = dict(zip(self._attr_names, features))
493+
result_dict = dict(zip(self._valid_attr_names, features))
474494
return result_dict
475495

476496
@property
@@ -565,7 +585,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
565585
for i, v in enumerate(hard_neg_features):
566586
results.append(pa.concat_arrays([neg_features[i], v]))
567587

568-
result_dict = dict(zip(self._attr_names, results))
588+
result_dict = dict(zip(self._valid_attr_names, results))
569589
result_dict["hard_neg_indices"] = pa.array(hard_neg_indices)
570590
return result_dict
571591

@@ -667,7 +687,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
667687
for i, v in enumerate(hard_neg_features):
668688
results.append(pa.concat_arrays([neg_features[i], v]))
669689

670-
result_dict = dict(zip(self._attr_names, results))
690+
result_dict = dict(zip(self._valid_attr_names, results))
671691
result_dict["hard_neg_indices"] = pa.array(hard_neg_indices)
672692
return result_dict
673693

@@ -789,7 +809,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
789809
"""
790810
ids = _pa_ids_to_npy(input_data[self._item_id_field]).reshape(-1, 1)
791811
batch_size = len(ids)
792-
num_fea = len(self._attr_names[1:])
812+
num_fea = len(self._valid_attr_names[1:])
793813

794814
# positive node.
795815
pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)
@@ -859,8 +879,8 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
859879
for i in range(num_fea)
860880
]
861881

862-
pos_result_dict = dict(zip(self._attr_names[1:], pos_fea_result))
863-
neg_result_dict = dict(zip(self._attr_names[1:], neg_fea_result))
882+
pos_result_dict = dict(zip(self._valid_attr_names[1:], pos_fea_result))
883+
neg_result_dict = dict(zip(self._valid_attr_names[1:], neg_fea_result))
864884

865885
return pos_result_dict, neg_result_dict
866886

@@ -941,7 +961,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
941961

942962
pos_nodes = self._pos_sampler.get(ids).layer_nodes(1)
943963
pos_fea_result = self._parse_nodes(pos_nodes)[1:]
944-
pos_result_dict = dict(zip(self._attr_names[1:], pos_fea_result))
964+
pos_result_dict = dict(zip(self._valid_attr_names[1:], pos_fea_result))
945965

946966
return pos_result_dict
947967

tzrec/datasets/sampler_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,39 @@ def _sampler_worker(res):
174174
self.assertEqual(len(res["float_b"]), 8)
175175
self.assertEqual(len(res["str_c"]), 8)
176176

177+
def test_negative_sampler_with_ignore_feature(self):
178+
f = self._create_item_gl_data()
179+
180+
def _sampler_worker(res):
181+
config = sampler_pb2.NegativeSampler(
182+
input_path=f.name,
183+
num_sample=8,
184+
attr_fields=["int_a", "float_b", "str_c"],
185+
item_id_field="item_id",
186+
)
187+
sampler = NegativeSampler(
188+
config=config,
189+
fields=[
190+
pa.field(name="int_a", type=pa.int64()),
191+
pa.field(name="str_c", type=pa.string()),
192+
],
193+
batch_size=4,
194+
)
195+
assert sampler.estimated_sample_num == 8
196+
sampler.init_cluster()
197+
sampler.launch_server()
198+
sampler.init()
199+
res.update(sampler.get({"item_id": pa.array([0, 1, 2, 3])}))
200+
201+
res = mp.Manager().dict()
202+
p = mp.Process(target=_sampler_worker, args=(res,))
203+
p.start()
204+
p.join()
205+
if p.exitcode != 0:
206+
raise RuntimeError("worker failed.")
207+
self.assertEqual(len(res["int_a"]), 8)
208+
self.assertEqual(len(res["str_c"]), 8)
209+
177210
@unittest.skip("accidental process defunct error")
178211
def test_negative_sampler_multi_node(self):
179212
f = self._create_item_gl_data()

0 commit comments

Comments
 (0)