2929from tzrec .protos import sampler_pb2
3030from tzrec .utils .env_util import use_hash_node_id
3131from tzrec .utils .load_class import get_register_class_meta
32+ from tzrec .utils .logging_util import logger
3233from tzrec .utils .misc_util import get_free_port
3334
3435
@@ -205,12 +206,25 @@ def __init__(
205206 self ._attr_types = []
206207 self ._attr_gl_types = []
207208 self ._attr_np_types = []
209+ self ._valid_attr_names = []
210+ self ._ignore_attr_names = set ()
208211 for field_name in config .attr_fields :
209- field = input_fields [field_name ]
212+ if field_name in input_fields :
213+ field = input_fields [field_name ]
214+ self ._valid_attr_names .append (field .name )
215+ else :
216+ field = pa .field (name = field_name , type = pa .string ())
217+ self ._ignore_attr_names .add (field_name )
210218 self ._attr_names .append (field .name )
211219 self ._attr_types .append (field .type )
212220 self ._attr_gl_types .append (_get_gl_type (field .type ))
213221 self ._attr_np_types .append (_get_np_type (field .type ))
222+ if len (self ._ignore_attr_names ) > 0 :
223+ logger .warning (
224+ f"Features { self ._ignore_attr_names } in "
225+ # pyre-ignore [16]
226+ f"{ self .__class__ .__name__ } will be ignored."
227+ )
214228
215229 if config .HasField ("field_delimiter" ):
216230 gl .set_field_delimiter (config .field_delimiter )
@@ -268,9 +282,12 @@ def _parse_nodes(self, nodes: gl.Nodes) -> List[pa.Array]:
268282 int_idx = 0
269283 float_idx = 0
270284 string_idx = 0
271- for attr_type , attr_gl_type , attr_np_type in zip (
272- self ._attr_types , self ._attr_gl_types , self ._attr_np_types
285+ for attr_name , attr_type , attr_gl_type , attr_np_type in zip (
286+ self ._attr_names , self . _attr_types , self ._attr_gl_types , self ._attr_np_types
273287 ):
288+ if attr_name in self ._ignore_attr_names :
289+ string_idx += 1
290+ continue
274291 if attr_gl_type == "int" :
275292 feature = nodes .int_attrs [:, :, int_idx ]
276293 int_idx += 1
@@ -295,9 +312,12 @@ def _parse_sparse_nodes(
295312 int_idx = 0
296313 float_idx = 0
297314 string_idx = 0
298- for attr_type , attr_gl_type , attr_np_type in zip (
299- self ._attr_types , self ._attr_gl_types , self ._attr_np_types
315+ for attr_name , attr_type , attr_gl_type , attr_np_type in zip (
316+ self ._attr_names , self . _attr_types , self ._attr_gl_types , self ._attr_np_types
300317 ):
318+ if attr_name in self ._ignore_attr_names :
319+ string_idx += 1
320+ continue
301321 if attr_gl_type == "int" :
302322 feature = nodes .int_attrs [:, int_idx ]
303323 int_idx += 1
@@ -379,7 +399,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
379399 ids = np .pad (ids , (0 , self ._batch_size - len (ids )), "edge" )
380400 nodes = self ._sampler .get (ids )
381401 features = self ._parse_nodes (nodes )
382- result_dict = dict (zip (self ._attr_names , features ))
402+ result_dict = dict (zip (self ._valid_attr_names , features ))
383403 return result_dict
384404
385405 @property
@@ -470,7 +490,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
470490 dst_ids = np .pad (dst_ids , (0 , self ._batch_size - len (dst_ids )), "edge" )
471491 nodes = self ._sampler .get (src_ids , dst_ids )
472492 features = self ._parse_nodes (nodes )
473- result_dict = dict (zip (self ._attr_names , features ))
493+ result_dict = dict (zip (self ._valid_attr_names , features ))
474494 return result_dict
475495
476496 @property
@@ -565,7 +585,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
565585 for i , v in enumerate (hard_neg_features ):
566586 results .append (pa .concat_arrays ([neg_features [i ], v ]))
567587
568- result_dict = dict (zip (self ._attr_names , results ))
588+ result_dict = dict (zip (self ._valid_attr_names , results ))
569589 result_dict ["hard_neg_indices" ] = pa .array (hard_neg_indices )
570590 return result_dict
571591
@@ -667,7 +687,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
667687 for i , v in enumerate (hard_neg_features ):
668688 results .append (pa .concat_arrays ([neg_features [i ], v ]))
669689
670- result_dict = dict (zip (self ._attr_names , results ))
690+ result_dict = dict (zip (self ._valid_attr_names , results ))
671691 result_dict ["hard_neg_indices" ] = pa .array (hard_neg_indices )
672692 return result_dict
673693
@@ -789,7 +809,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
789809 """
790810 ids = _pa_ids_to_npy (input_data [self ._item_id_field ]).reshape (- 1 , 1 )
791811 batch_size = len (ids )
792- num_fea = len (self ._attr_names [1 :])
812+ num_fea = len (self ._valid_attr_names [1 :])
793813
794814 # positive node.
795815 pos_nodes = self ._pos_sampler .get (ids ).layer_nodes (1 )
@@ -859,8 +879,8 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
859879 for i in range (num_fea )
860880 ]
861881
862- pos_result_dict = dict (zip (self ._attr_names [1 :], pos_fea_result ))
863- neg_result_dict = dict (zip (self ._attr_names [1 :], neg_fea_result ))
882+ pos_result_dict = dict (zip (self ._valid_attr_names [1 :], pos_fea_result ))
883+ neg_result_dict = dict (zip (self ._valid_attr_names [1 :], neg_fea_result ))
864884
865885 return pos_result_dict , neg_result_dict
866886
@@ -941,7 +961,7 @@ def get(self, input_data: Dict[str, pa.Array]) -> Dict[str, pa.Array]:
941961
942962 pos_nodes = self ._pos_sampler .get (ids ).layer_nodes (1 )
943963 pos_fea_result = self ._parse_nodes (pos_nodes )[1 :]
944- pos_result_dict = dict (zip (self ._attr_names [1 :], pos_fea_result ))
964+ pos_result_dict = dict (zip (self ._valid_attr_names [1 :], pos_fea_result ))
945965
946966 return pos_result_dict
947967
0 commit comments