Skip to content

Commit b3c8a2c

Browse files
committed
indent
1 parent 040ad1f commit b3c8a2c

File tree

1 file changed

+43
-44
lines changed

1 file changed

+43
-44
lines changed

tests/python_package_test/test_dask.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,52 +1118,51 @@ def test_machines_should_be_used_if_provided(task, output):
11181118
if task == 'ranking' and output == 'scipy_csr_matrix':
11191119
pytest.skip('LGBMRanker is not currently tested on sparse matrices')
11201120

1121-
with LocalCluster(n_workers=2) as cluster:
1122-
with Client(cluster) as client:
1123-
if task == 'ranking':
1124-
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
1125-
output=output,
1126-
group=None,
1127-
chunk_size=10,
1128-
)
1129-
dask_model_factory = lgb.DaskLGBMRanker
1130-
else:
1131-
_, _, _, dX, dy, _ = _create_data(
1132-
objective=task,
1133-
output=output,
1134-
chunk_size=10,
1135-
)
1136-
dg = None
1137-
if task == 'classification':
1138-
dask_model_factory = lgb.DaskLGBMClassifier
1139-
elif task == 'regression':
1140-
dask_model_factory = lgb.DaskLGBMRegressor
1141-
1142-
# rebalance data to be sure that each worker has a piece of the data
1143-
if output == 'array':
1144-
client.rebalance()
1145-
1146-
n_workers = len(client.scheduler_info()['workers'])
1147-
open_ports = [_find_random_open_port() for _ in range(n_workers)]
1148-
dask_model = dask_model_factory(
1149-
n_estimators=5,
1150-
num_leaves=5,
1151-
machines=",".join([
1152-
"127.0.0.1:" + str(port)
1153-
for port in open_ports
1154-
]),
1121+
with LocalCluster(n_workers=2) as cluster, Client(cluster) as client:
1122+
if task == 'ranking':
1123+
_, _, _, _, dX, dy, _, dg = _create_ranking_data(
1124+
output=output,
1125+
group=None,
1126+
chunk_size=10,
11551127
)
1128+
dask_model_factory = lgb.DaskLGBMRanker
1129+
else:
1130+
_, _, _, dX, dy, _ = _create_data(
1131+
objective=task,
1132+
output=output,
1133+
chunk_size=10,
1134+
)
1135+
dg = None
1136+
if task == 'classification':
1137+
dask_model_factory = lgb.DaskLGBMClassifier
1138+
elif task == 'regression':
1139+
dask_model_factory = lgb.DaskLGBMRegressor
1140+
1141+
# rebalance data to be sure that each worker has a piece of the data
1142+
if output == 'array':
1143+
client.rebalance()
1144+
1145+
n_workers = len(client.scheduler_info()['workers'])
1146+
open_ports = [_find_random_open_port() for _ in range(n_workers)]
1147+
dask_model = dask_model_factory(
1148+
n_estimators=5,
1149+
num_leaves=5,
1150+
machines=",".join([
1151+
"127.0.0.1:" + str(port)
1152+
for port in open_ports
1153+
]),
1154+
)
11561155

1157-
# test that "machines" is actually respected by creating a socket that uses
1158-
# one of the ports mentioned in "machines"
1159-
error_msg = "Binding port %s failed" % open_ports[0]
1160-
with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
1161-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1162-
s.bind(('127.0.0.1', open_ports[0]))
1163-
if task == 'ranking':
1164-
dask_model.fit(dX, dy, group=dg)
1165-
else:
1166-
dask_model.fit(dX, dy)
1156+
# test that "machines" is actually respected by creating a socket that uses
1157+
# one of the ports mentioned in "machines"
1158+
error_msg = "Binding port %s failed" % open_ports[0]
1159+
with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
1160+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1161+
s.bind(('127.0.0.1', open_ports[0]))
1162+
if task == 'ranking':
1163+
dask_model.fit(dX, dy, group=dg)
1164+
else:
1165+
dask_model.fit(dX, dy)
11671166

11681167

11691168
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)