|
24 | 24 | import joblib |
25 | 25 | import numpy as np |
26 | 26 | import pandas as pd |
| 27 | +import sklearn.utils.estimator_checks as sklearn_checks |
27 | 28 | from dask.array.utils import assert_eq |
28 | 29 | from dask.distributed import Client, LocalCluster, default_client, wait |
29 | 30 | from distributed.utils_test import client, cluster_fixture, gen_cluster, loop |
@@ -1156,3 +1157,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods): |
1156 | 1157 | for param in dask_spec.args: |
1157 | 1158 | error_msg = f"param '{param}' has different default values in the methods" |
1158 | 1159 | assert dask_params[param].default == sklearn_params[param].default, error_msg |
| 1160 | + |
| 1161 | + |
| 1162 | +def sklearn_checks_to_run(): |
| 1163 | + check_names = [ |
| 1164 | + "check_estimator_get_tags_default_keys", |
| 1165 | + "check_get_params_invariance", |
| 1166 | + "check_set_params" |
| 1167 | + ] |
| 1168 | + for check_name in check_names: |
| 1169 | + check_func = getattr(sklearn_checks, check_name, None) |
| 1170 | + if check_func: |
| 1171 | + yield check_func |
| 1172 | + |
| 1173 | + |
| 1174 | +def _tested_estimators(): |
| 1175 | + for Estimator in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRegressor]: |
| 1176 | + yield Estimator() |
| 1177 | + |
| 1178 | + |
| 1179 | +@pytest.mark.parametrize("estimator", _tested_estimators()) |
| 1180 | +@pytest.mark.parametrize("check", sklearn_checks_to_run()) |
| 1181 | +def test_sklearn_integration(estimator, check, client): |
| 1182 | + estimator.set_params(local_listen_port=18000, time_out=5) |
| 1183 | + name = type(estimator).__name__ |
| 1184 | + check(name, estimator) |
| 1185 | + client.close(timeout=CLIENT_CLOSE_TIMEOUT) |
| 1186 | + |
| 1187 | + |
| 1188 | +# this test is separate because it takes a not-yet-constructed estimator |
| 1189 | +@pytest.mark.parametrize("estimator", list(_tested_estimators())) |
| 1190 | +def test_parameters_default_constructible(estimator): |
| 1191 | + name, Estimator = estimator.__class__.__name__, estimator.__class__ |
| 1192 | + sklearn_checks.check_parameters_default_constructible(name, Estimator) |
0 commit comments