Skip to content

Commit cc5518b

Browse files
committed
refactor(tuning): renaming and tests
1 parent 7d7e6a5 commit cc5518b

3 files changed

Lines changed: 528 additions & 55 deletions

File tree

examples/tutorials/try_optuna.ipynb

Lines changed: 20 additions & 20 deletions
Large diffs are not rendered by default.

packages/openstef-models/src/openstef_models/utils/tuning.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def model_selection_metric(self) -> tuple[QuantileOrGlobal, str, Any]:
208208
"""Metric used to select the best trial: (quantile, metric_name, direction)."""
209209
...
210210

211-
def get_tunable_hyperparams(self) -> list[ModelTuningInfo]:
211+
def get_model_tuning_info(self) -> list[ModelTuningInfo]:
212212
"""Return TunableField with model_hyperparams_field_name, hyperparams_instance and search_space for tuning.
213213
214214
Can be inherited from TuningConfigMixin.
@@ -326,7 +326,7 @@ def get_search_space(
326326
return result
327327

328328

329-
def suggest_hyperparams[HP: BaseModel](
329+
def apply_trial_suggestions[HP: BaseModel](
330330
trial: optuna.Trial,
331331
space: dict[str, TuningRange],
332332
current: HP,
@@ -343,18 +343,9 @@ def suggest_hyperparams[HP: BaseModel](
343343
"""
344344
updates: dict[str, Any] = {}
345345
for hyperparam_name, tuning_range in space.items():
346-
if isinstance(tuning_range, FloatRange):
347-
if tuning_range.low is not None and tuning_range.high is not None:
348-
updates[hyperparam_name] = trial.suggest_float(
349-
hyperparam_name, tuning_range.low, tuning_range.high, log=tuning_range.log
350-
)
351-
elif isinstance(tuning_range, IntRange):
352-
if tuning_range.low is not None and tuning_range.high is not None:
353-
updates[hyperparam_name] = trial.suggest_int(
354-
hyperparam_name, tuning_range.low, tuning_range.high, log=tuning_range.log
355-
)
356-
elif tuning_range.choices is not None:
357-
updates[hyperparam_name] = trial.suggest_categorical(hyperparam_name, list(tuning_range.choices))
346+
value = _suggest_hyperparam_value(trial, hyperparam_name, tuning_range)
347+
if value is not None:
348+
updates[hyperparam_name] = value
358349
return current.model_copy(update=updates)
359350

360351

@@ -389,13 +380,13 @@ def run_optuna_study(
389380

390381

391382
class TuningConfigMixin:
392-
"""Mixin for get_tunable_hyperparams for workflow configs.
383+
"""Mixin for get_model_tuning_info for workflow configs.
393384
394385
Discovers tunable fields by reflecting over model_fields and returning a TunableField for every field whose value
395386
is a TunableHyperParams instance with a non-empty search space.
396387
"""
397388

398-
def get_tunable_hyperparams(self) -> list[ModelTuningInfo]:
389+
def get_model_tuning_info(self) -> list[ModelTuningInfo]:
399390
"""Return one ModelTuningInfo per active tunable hyperparameter group for a model."""
400391
result: list[ModelTuningInfo] = []
401392
model_fields: dict[str, Any] = cast(dict[str, Any], getattr(type(self), "model_fields", {}))
@@ -454,7 +445,7 @@ class _TrialEntry(NamedTuple):
454445
tuning_range: TuningRange
455446

456447

457-
def _suggest_param(
448+
def _suggest_hyperparam_value(
458449
trial: optuna.Trial,
459450
trial_key: str,
460451
tuning_range: TuningRange,
@@ -476,6 +467,24 @@ def _suggest_param(
476467
return None
477468

478469

470+
def _build_hp_updates(
471+
model_tuning_info: list[ModelTuningInfo],
472+
per_field: dict[str, dict[str, Any]],
473+
) -> dict[str, Any]:
474+
"""Build a config-level update dict by applying *per_field* values to each HP group.
475+
476+
Returns:
477+
Mapping of config field name → updated :class:`TunableHyperParams` instance.
478+
"""
479+
return {
480+
tf.model_hyperparams_field_name: tf.tunable_hyperparams.model_copy(
481+
update=per_field[tf.model_hyperparams_field_name]
482+
)
483+
for tf in model_tuning_info
484+
if tf.model_hyperparams_field_name in per_field
485+
}
486+
487+
479488
class _TuningObjective:
480489
"""Callable Optuna objective that encapsulates the context for a tuning run."""
481490

@@ -506,17 +515,11 @@ def __call__(self, trial: optuna.Trial) -> float:
506515
"""
507516
per_field: dict[str, dict[str, Any]] = {}
508517
for trial_key, trial_entry in self._combined_space.items():
509-
value = _suggest_param(trial, trial_key, trial_entry.tuning_range)
518+
value = _suggest_hyperparam_value(trial, trial_key, trial_entry.tuning_range)
510519
if value is not None:
511520
per_field.setdefault(trial_entry.model_hyperparams_field_name, {})[trial_entry.hyperparam_name] = value
512521

513-
updates: dict[str, Any] = {}
514-
for tf in self._model_tuning_info:
515-
if tf.model_hyperparams_field_name in per_field:
516-
updates[tf.model_hyperparams_field_name] = tf.tunable_hyperparams.model_copy(
517-
update=per_field[tf.model_hyperparams_field_name]
518-
)
519-
tuned_config = self._config.model_copy(update=updates)
522+
tuned_config = self._config.model_copy(update=_build_hp_updates(self._model_tuning_info, per_field))
520523

521524
trial_workflow = self._create_workflow(tuned_config)
522525
trial_result = trial_workflow.fit(self._train_dataset)
@@ -545,7 +548,7 @@ def tune[ConfigT: TunableWorkflowConfig](
545548
Raises:
546549
ValueError: If no hyperparameter field has tune=True ranges.
547550
"""
548-
model_tuning_info = config.get_tunable_hyperparams()
551+
model_tuning_info = config.get_model_tuning_info()
549552
if not model_tuning_info:
550553
msg = (
551554
f"No tunable hyperparameters found on config '{config.model_id}'. "
@@ -613,14 +616,7 @@ def _reconstruct_best_config[ConfigT: TunableWorkflowConfig](
613616
hyperparam_name = trial_key
614617
per_field_best.setdefault(model_hyperparams_field_name, {})[hyperparam_name] = value
615618

616-
best_updates: dict[str, Any] = {}
617-
for tf in model_tuning_info_list:
618-
if tf.model_hyperparams_field_name in per_field_best:
619-
best_updates[tf.model_hyperparams_field_name] = tf.tunable_hyperparams.model_copy(
620-
update=per_field_best[tf.model_hyperparams_field_name]
621-
)
622-
623-
return config.model_copy(update=best_updates)
619+
return config.model_copy(update=_build_hp_updates(model_tuning_info_list, per_field_best))
624620

625621

626622
def fit_with_tuning[ConfigT: TunableWorkflowConfig](
@@ -652,9 +648,9 @@ def fit_with_tuning[ConfigT: TunableWorkflowConfig](
652648
"TuningConfigMixin",
653649
"TuningRange",
654650
"TuningResult",
651+
"apply_trial_suggestions",
655652
"fit_with_tuning",
656653
"get_search_space",
657654
"run_optuna_study",
658-
"suggest_hyperparams",
659655
"tune",
660656
]

0 commit comments

Comments
 (0)