diff --git a/openml/datasets/dataset.py b/openml/datasets/dataset.py index fd13a8e8c..0c065b855 100644 --- a/openml/datasets/dataset.py +++ b/openml/datasets/dataset.py @@ -639,6 +639,11 @@ def _encode_if_category(column): @staticmethod def _unpack_categories(series, categories): + # nan-likes can not be explicitly specified as a category + def valid_category(cat): + return isinstance(cat, str) or (cat is not None and not np.isnan(cat)) + + filtered_categories = [c for c in categories if valid_category(c)] col = [] for x in series: try: @@ -647,7 +652,7 @@ def _unpack_categories(series, categories): col.append(np.nan) # We require two lines to create a series of categories as detailed here: # https://pandas.pydata.org/pandas-docs/version/0.24/user_guide/categorical.html#series-creation # noqa E501 - raw_cat = pd.Categorical(col, ordered=True, categories=categories) + raw_cat = pd.Categorical(col, ordered=True, categories=filtered_categories) return pd.Series(raw_cat, index=series.index, name=series.name) def get_data( diff --git a/tests/test_datasets/test_dataset.py b/tests/test_datasets/test_dataset.py index 14b1b02b7..416fce534 100644 --- a/tests/test_datasets/test_dataset.py +++ b/tests/test_datasets/test_dataset.py @@ -50,6 +50,17 @@ def test_init_string_validation(self): name="somename", description="a description", citation="Something by Müller" ) + def test__unpack_categories_with_nan_likes(self): + # unpack_categories decodes numeric categorical values according to the header + # Containing a 'non' category in the header shouldn't lead to failure. + categories = ["a", "b", None, float("nan"), np.nan] + series = pd.Series([0, 1, None, float("nan"), np.nan, 1, 0]) + clean_series = OpenMLDataset._unpack_categories(series, categories) + + expected_values = ["a", "b", np.nan, np.nan, np.nan, "b", "a"] + self.assertListEqual(list(clean_series.values), expected_values) + self.assertListEqual(list(clean_series.cat.categories.values), list("ab")) + def test_get_data_array(self): # Basic usage rval, _, categorical, attribute_names = self.dataset.get_data(dataset_format="array")