Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions flaml/tune/searcher/blendsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,24 @@ def __init__(
if global_search_alg is not None:
self._gs = global_search_alg
elif getattr(self, "__name__", None) != "CFO":
if space and self._ls.hierarchical:
# Use define-by-run for OptunaSearch when needed:
# - Hierarchical/conditional spaces are best supported via define-by-run.
# - Ray Tune domain/grid specs can trigger an "unresolved search space" warning
# unless we switch to define-by-run.
use_define_by_run = bool(getattr(self._ls, "hierarchical", False))
if (not use_define_by_run) and isinstance(space, dict) and space:
try:
from .variant_generator import parse_spec_vars

_, domain_vars, grid_vars = parse_spec_vars(space)
use_define_by_run = bool(domain_vars or grid_vars)
except Exception:
# Be conservative: if we can't determine whether the space is
# unresolved, fall back to the original behavior.
use_define_by_run = False

self._use_define_by_run = use_define_by_run
if use_define_by_run:
from functools import partial

gs_space = partial(define_by_run_func, space=space)
Expand Down Expand Up @@ -487,7 +504,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error:
self._ls_bound_max,
self._subspace.get(trial_id, self._ls.space),
)
if self._gs is not None and self._experimental and (not self._ls.hierarchical):
if self._gs is not None and self._experimental and (not getattr(self, "_use_define_by_run", False)):
self._gs.add_evaluated_point(flatten_dict(config), objective)
# TODO: recover when supported
# converted = convert_key(config, self._gs.space)
Expand Down
23 changes: 23 additions & 0 deletions test/tune/test_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,26 @@ def test_no_optuna():
import flaml.tune.searcher.suggestion

subprocess.check_call([sys.executable, "-m", "pip", "install", "optuna==2.8.0"])


def test_unresolved_search_space(caplog):
import logging

from flaml import tune
from flaml.tune.searcher.blendsearch import BlendSearch

if caplog is not None:
caplog.set_level(logging.INFO)

BlendSearch(metric="loss", mode="min", space={"lr": tune.uniform(0.001, 0.1), "depth": tune.randint(1, 10)})
try:
text = caplog.text
except AttributeError:
text = ""
assert (
"unresolved search space" not in text and text
), "BlendSearch should not produce warning about unresolved search space"


if __name__ == "__main__":
test_unresolved_search_space(None)
Loading