Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ env:
- TEST_DIR=/tmp/test_dir/
- MODULE=openml
matrix:
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" RUN_FLAKE8="true" SKIP_TESTS="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.21.2" COVERAGE="true" DOCPUSH="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.23.1" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.22.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.6" SKLEARN_VERSION="0.21.2" TEST_DIST="true"
- DISTRIB="conda" PYTHON_VERSION="3.7" SKLEARN_VERSION="0.20.2"
# Checks for older scikit-learn versions (which also don't nicely work with
# Python3.7)
Expand Down
18 changes: 11 additions & 7 deletions openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,12 +994,16 @@ def _get_fn_arguments_with_defaults(self, fn_name: Callable) -> Tuple[Dict, Set]
a set with all parameters that do not have a default value
"""
# parameters with defaults are optional, all others are required.
signature = inspect.getfullargspec(fn_name)
if signature.defaults:
optional_params = dict(zip(reversed(signature.args), reversed(signature.defaults)))
else:
optional_params = dict()
required_params = {arg for arg in signature.args if arg not in optional_params}
parameters = inspect.signature(fn_name).parameters
required_params = set()
optional_params = dict()
for param in parameters.keys():
parameter = parameters.get(param)
default_val = parameter.default # type: ignore
if default_val is inspect.Signature.empty:
required_params.add(param)
else:
optional_params[param] = default_val
return optional_params, required_params

def _deserialize_model(
Expand Down Expand Up @@ -1346,7 +1350,7 @@ def _can_measure_cputime(self, model: Any) -> bool:
# check the parameters for n_jobs
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(model.get_params(), "n_jobs")
for val in n_jobs_vals:
if val is not None and val != 1:
if val is not None and val != 1 and val != "deprecated":
return False
return True

Expand Down
Loading