Skip to content

Commit c30cd14

Browse files
feat: support for loose init model from run (#1371)
1 parent 26ae499 commit c30cd14

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

openml/runs/functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def get_run_trace(run_id: int) -> OpenMLRunTrace:
364364
return OpenMLRunTrace.trace_from_xml(trace_xml)
365365

366366

367-
def initialize_model_from_run(run_id: int) -> Any:
367+
def initialize_model_from_run(run_id: int, *, strict_version: bool = True) -> Any:
368368
"""
369369
Initialized a model based on a run_id (i.e., using the exact
370370
same parameter settings)
@@ -373,6 +373,8 @@ def initialize_model_from_run(run_id: int) -> Any:
373373
----------
374374
run_id : int
375375
The Openml run_id
376+
strict_version: bool (default=True)
377+
See `flow_to_model` strict_version.
376378
377379
Returns
378380
-------
@@ -382,7 +384,7 @@ def initialize_model_from_run(run_id: int) -> Any:
382384
# TODO(eddiebergman): I imagine this is None if it's not published,
383385
# might need to raise an explicit error for that
384386
assert run.setup_id is not None
385-
return initialize_model(run.setup_id)
387+
return initialize_model(setup_id=run.setup_id, strict_version=strict_version)
386388

387389

388390
def initialize_model_from_trace(

openml/setups/functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def __list_setups(
265265
return setups
266266

267267

268-
def initialize_model(setup_id: int) -> Any:
268+
def initialize_model(setup_id: int, *, strict_version: bool = True) -> Any:
269269
"""
270270
Initialized a model based on a setup_id (i.e., using the exact
271271
same parameter settings)
@@ -274,6 +274,8 @@ def initialize_model(setup_id: int) -> Any:
274274
----------
275275
setup_id : int
276276
The Openml setup_id
277+
strict_version: bool (default=True)
278+
See `flow_to_model` strict_version.
277279
278280
Returns
279281
-------
@@ -294,7 +296,7 @@ def initialize_model(setup_id: int) -> Any:
294296
subflow = flow
295297
subflow.parameters[hyperparameter.parameter_name] = hyperparameter.value
296298

297-
return flow.extension.flow_to_model(flow)
299+
return flow.extension.flow_to_model(flow, strict_version=strict_version)
298300

299301

300302
def _to_dict(

tests/test_runs/test_run_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,16 @@ def test_delete_run(self):
19051905
_run_id = run.run_id
19061906
assert delete_run(_run_id)
19071907

1908+
@unittest.skipIf(
1909+
Version(sklearn.__version__) < Version("0.20"),
1910+
reason="SimpleImputer doesn't handle mixed type DataFrame as input",
1911+
)
1912+
def test_initialize_model_from_run_nonstrict(self):
1913+
# We cannot guarantee that a run with an older version exists on the server.
1914+
# Thus, we test it simply with a run that we know exists that might not be loose.
1915+
# This tests all lines of code for OpenML but not the initialization, which we do not want to guarantee anyhow.
1916+
_ = openml.runs.initialize_model_from_run(run_id=1, strict_version=False)
1917+
19081918

19091919
@mock.patch.object(requests.Session, "delete")
19101920
def test_delete_run_not_owned(mock_delete, test_files_directory, test_api_key):

0 commit comments

Comments
 (0)