Skip to content

Commit f00b3a7

Browse files
committed
Merge branch 'master' into feat/network-params
2 parents 52e0c39 + eb5f471 commit f00b3a7

1 file changed

Lines changed: 34 additions & 0 deletions

File tree

tests/python_package_test/test_dask.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import joblib
2525
import numpy as np
2626
import pandas as pd
27+
import sklearn.utils.estimator_checks as sklearn_checks
2728
from dask.array.utils import assert_eq
2829
from dask.distributed import Client, LocalCluster, default_client, wait
2930
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
@@ -1156,3 +1157,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods):
11561157
for param in dask_spec.args:
11571158
error_msg = f"param '{param}' has different default values in the methods"
11581159
assert dask_params[param].default == sklearn_params[param].default, error_msg
1160+
1161+
1162+
def sklearn_checks_to_run():
1163+
check_names = [
1164+
"check_estimator_get_tags_default_keys",
1165+
"check_get_params_invariance",
1166+
"check_set_params"
1167+
]
1168+
for check_name in check_names:
1169+
check_func = getattr(sklearn_checks, check_name, None)
1170+
if check_func:
1171+
yield check_func
1172+
1173+
1174+
def _tested_estimators():
1175+
for Estimator in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRegressor]:
1176+
yield Estimator()
1177+
1178+
1179+
@pytest.mark.parametrize("estimator", _tested_estimators())
1180+
@pytest.mark.parametrize("check", sklearn_checks_to_run())
1181+
def test_sklearn_integration(estimator, check, client):
1182+
estimator.set_params(local_listen_port=18000, time_out=5)
1183+
name = type(estimator).__name__
1184+
check(name, estimator)
1185+
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
1186+
1187+
1188+
# this test is separate because it takes a not-yet-constructed estimator
1189+
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
1190+
def test_parameters_default_constructible(estimator):
1191+
name, Estimator = estimator.__class__.__name__, estimator.__class__
1192+
sklearn_checks.check_parameters_default_constructible(name, Estimator)

0 commit comments

Comments
 (0)