@@ -1118,52 +1118,51 @@ def test_machines_should_be_used_if_provided(task, output):
11181118 if task == 'ranking' and output == 'scipy_csr_matrix' :
11191119 pytest .skip ('LGBMRanker is not currently tested on sparse matrices' )
11201120
1121- with LocalCluster (n_workers = 2 ) as cluster :
1122- with Client (cluster ) as client :
1123- if task == 'ranking' :
1124- _ , _ , _ , _ , dX , dy , _ , dg = _create_ranking_data (
1125- output = output ,
1126- group = None ,
1127- chunk_size = 10 ,
1128- )
1129- dask_model_factory = lgb .DaskLGBMRanker
1130- else :
1131- _ , _ , _ , dX , dy , _ = _create_data (
1132- objective = task ,
1133- output = output ,
1134- chunk_size = 10 ,
1135- )
1136- dg = None
1137- if task == 'classification' :
1138- dask_model_factory = lgb .DaskLGBMClassifier
1139- elif task == 'regression' :
1140- dask_model_factory = lgb .DaskLGBMRegressor
1141-
1142- # rebalance data to be sure that each worker has a piece of the data
1143- if output == 'array' :
1144- client .rebalance ()
1145-
1146- n_workers = len (client .scheduler_info ()['workers' ])
1147- open_ports = [_find_random_open_port () for _ in range (n_workers )]
1148- dask_model = dask_model_factory (
1149- n_estimators = 5 ,
1150- num_leaves = 5 ,
1151- machines = "," .join ([
1152- "127.0.0.1:" + str (port )
1153- for port in open_ports
1154- ]),
1121+ with LocalCluster (n_workers = 2 ) as cluster , Client (cluster ) as client :
1122+ if task == 'ranking' :
1123+ _ , _ , _ , _ , dX , dy , _ , dg = _create_ranking_data (
1124+ output = output ,
1125+ group = None ,
1126+ chunk_size = 10 ,
11551127 )
1128+ dask_model_factory = lgb .DaskLGBMRanker
1129+ else :
1130+ _ , _ , _ , dX , dy , _ = _create_data (
1131+ objective = task ,
1132+ output = output ,
1133+ chunk_size = 10 ,
1134+ )
1135+ dg = None
1136+ if task == 'classification' :
1137+ dask_model_factory = lgb .DaskLGBMClassifier
1138+ elif task == 'regression' :
1139+ dask_model_factory = lgb .DaskLGBMRegressor
1140+
1141+ # rebalance data to be sure that each worker has a piece of the data
1142+ if output == 'array' :
1143+ client .rebalance ()
1144+
1145+ n_workers = len (client .scheduler_info ()['workers' ])
1146+ open_ports = [_find_random_open_port () for _ in range (n_workers )]
1147+ dask_model = dask_model_factory (
1148+ n_estimators = 5 ,
1149+ num_leaves = 5 ,
1150+ machines = "," .join ([
1151+ "127.0.0.1:" + str (port )
1152+ for port in open_ports
1153+ ]),
1154+ )
11561155
1157- # test that "machines" is actually respected by creating a socket that uses
1158- # one of the ports mentioned in "machines"
1159- error_msg = "Binding port %s failed" % open_ports [0 ]
1160- with pytest .raises (lgb .basic .LightGBMError , match = error_msg ):
1161- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
1162- s .bind (('127.0.0.1' , open_ports [0 ]))
1163- if task == 'ranking' :
1164- dask_model .fit (dX , dy , group = dg )
1165- else :
1166- dask_model .fit (dX , dy )
1156+ # test that "machines" is actually respected by creating a socket that uses
1157+ # one of the ports mentioned in "machines"
1158+ error_msg = "Binding port %s failed" % open_ports [0 ]
1159+ with pytest .raises (lgb .basic .LightGBMError , match = error_msg ):
1160+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
1161+ s .bind (('127.0.0.1' , open_ports [0 ]))
1162+ if task == 'ranking' :
1163+ dask_model .fit (dX , dy , group = dg )
1164+ else :
1165+ dask_model .fit (dX , dy )
11671166
11681167
11691168@pytest .mark .parametrize (
0 commit comments