Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion openml/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,11 @@ def _encode_if_category(column):

@staticmethod
def _unpack_categories(series, categories):
# nan-likes can not be explicitly specified as a category
def valid_category(cat):
return isinstance(cat, str) or (cat is not None and not np.isnan(cat))

filtered_categories = [c for c in categories if valid_category(c)]
col = []
for x in series:
try:
Expand All @@ -647,7 +652,7 @@ def _unpack_categories(series, categories):
col.append(np.nan)
# We require two lines to create a series of categories as detailed here:
# https://pandas.pydata.org/pandas-docs/version/0.24/user_guide/categorical.html#series-creation # noqa E501
raw_cat = pd.Categorical(col, ordered=True, categories=categories)
raw_cat = pd.Categorical(col, ordered=True, categories=filtered_categories)
return pd.Series(raw_cat, index=series.index, name=series.name)

def get_data(
Expand Down
11 changes: 11 additions & 0 deletions tests/test_datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_init_string_validation(self):
name="somename", description="a description", citation="Something by Müller"
)

def test__unpack_categories_with_nan_likes(self):
# unpack_categories decodes numeric categorical values according to the header
# Containing a 'non' category in the header shouldn't lead to failure.
categories = ["a", "b", None, float("nan"), np.nan]
series = pd.Series([0, 1, None, float("nan"), np.nan, 1, 0])
clean_series = OpenMLDataset._unpack_categories(series, categories)

expected_values = ["a", "b", np.nan, np.nan, np.nan, "b", "a"]
self.assertListEqual(list(clean_series.values), expected_values)
self.assertListEqual(list(clean_series.cat.categories.values), list("ab"))

def test_get_data_array(self):
# Basic usage
rval, _, categorical, attribute_names = self.dataset.get_data(dataset_format="array")
Expand Down