@@ -208,7 +208,7 @@ def model_selection_metric(self) -> tuple[QuantileOrGlobal, str, Any]:
208208 """Metric used to select the best trial: (quantile, metric_name, direction)."""
209209 ...
210210
211- def get_tunable_hyperparams (self ) -> list [ModelTuningInfo ]:
211+ def get_model_tuning_info (self ) -> list [ModelTuningInfo ]:
212212 """Return TunableField with model_hyperparams_field_name, hyperparams_instance and search_space for tuning.
213213
214214 Can be inherited from TuningConfigMixin.
@@ -326,7 +326,7 @@ def get_search_space(
326326 return result
327327
328328
329- def suggest_hyperparams [HP : BaseModel ](
329+ def apply_trial_suggestions [HP : BaseModel ](
330330 trial : optuna .Trial ,
331331 space : dict [str , TuningRange ],
332332 current : HP ,
@@ -343,18 +343,9 @@ def suggest_hyperparams[HP: BaseModel](
343343 """
344344 updates : dict [str , Any ] = {}
345345 for hyperparam_name , tuning_range in space .items ():
346- if isinstance (tuning_range , FloatRange ):
347- if tuning_range .low is not None and tuning_range .high is not None :
348- updates [hyperparam_name ] = trial .suggest_float (
349- hyperparam_name , tuning_range .low , tuning_range .high , log = tuning_range .log
350- )
351- elif isinstance (tuning_range , IntRange ):
352- if tuning_range .low is not None and tuning_range .high is not None :
353- updates [hyperparam_name ] = trial .suggest_int (
354- hyperparam_name , tuning_range .low , tuning_range .high , log = tuning_range .log
355- )
356- elif tuning_range .choices is not None :
357- updates [hyperparam_name ] = trial .suggest_categorical (hyperparam_name , list (tuning_range .choices ))
346+ value = _suggest_hyperparam_value (trial , hyperparam_name , tuning_range )
347+ if value is not None :
348+ updates [hyperparam_name ] = value
358349 return current .model_copy (update = updates )
359350
360351
@@ -389,13 +380,13 @@ def run_optuna_study(
389380
390381
391382class TuningConfigMixin :
392- """Mixin for get_tunable_hyperparams for workflow configs.
383+ """Mixin for get_model_tuning_info for workflow configs.
393384
394385 Discovers tunable fields by reflecting over model_fields and returning a TunableField for every field whose value
395386 is a TunableHyperParams instance with a non-empty search space.
396387 """
397388
398- def get_tunable_hyperparams (self ) -> list [ModelTuningInfo ]:
389+ def get_model_tuning_info (self ) -> list [ModelTuningInfo ]:
399390 """Return one ModelTuningInfo per active tunable hyperparameter group for a model."""
400391 result : list [ModelTuningInfo ] = []
401392 model_fields : dict [str , Any ] = cast (dict [str , Any ], getattr (type (self ), "model_fields" , {}))
@@ -454,7 +445,7 @@ class _TrialEntry(NamedTuple):
454445 tuning_range : TuningRange
455446
456447
457- def _suggest_param (
448+ def _suggest_hyperparam_value (
458449 trial : optuna .Trial ,
459450 trial_key : str ,
460451 tuning_range : TuningRange ,
@@ -476,6 +467,24 @@ def _suggest_param(
476467 return None
477468
478469
470+ def _build_hp_updates (
471+ model_tuning_info : list [ModelTuningInfo ],
472+ per_field : dict [str , dict [str , Any ]],
473+ ) -> dict [str , Any ]:
474+ """Build a config-level update dict by applying *per_field* values to each HP group.
475+
476+ Returns:
477+ Mapping of config field name → updated :class:`TunableHyperParams` instance.
478+ """
479+ return {
480+ tf .model_hyperparams_field_name : tf .tunable_hyperparams .model_copy (
481+ update = per_field [tf .model_hyperparams_field_name ]
482+ )
483+ for tf in model_tuning_info
484+ if tf .model_hyperparams_field_name in per_field
485+ }
486+
487+
479488class _TuningObjective :
480489 """Callable Optuna objective that encapsulates the context for a tuning run."""
481490
@@ -506,17 +515,11 @@ def __call__(self, trial: optuna.Trial) -> float:
506515 """
507516 per_field : dict [str , dict [str , Any ]] = {}
508517 for trial_key , trial_entry in self ._combined_space .items ():
509- value = _suggest_param (trial , trial_key , trial_entry .tuning_range )
518+ value = _suggest_hyperparam_value (trial , trial_key , trial_entry .tuning_range )
510519 if value is not None :
511520 per_field .setdefault (trial_entry .model_hyperparams_field_name , {})[trial_entry .hyperparam_name ] = value
512521
513- updates : dict [str , Any ] = {}
514- for tf in self ._model_tuning_info :
515- if tf .model_hyperparams_field_name in per_field :
516- updates [tf .model_hyperparams_field_name ] = tf .tunable_hyperparams .model_copy (
517- update = per_field [tf .model_hyperparams_field_name ]
518- )
519- tuned_config = self ._config .model_copy (update = updates )
522+ tuned_config = self ._config .model_copy (update = _build_hp_updates (self ._model_tuning_info , per_field ))
520523
521524 trial_workflow = self ._create_workflow (tuned_config )
522525 trial_result = trial_workflow .fit (self ._train_dataset )
@@ -545,7 +548,7 @@ def tune[ConfigT: TunableWorkflowConfig](
545548 Raises:
546549 ValueError: If no hyperparameter field has tune=True ranges.
547550 """
548- model_tuning_info = config .get_tunable_hyperparams ()
551+ model_tuning_info = config .get_model_tuning_info ()
549552 if not model_tuning_info :
550553 msg = (
551554 f"No tunable hyperparameters found on config '{ config .model_id } '. "
@@ -613,14 +616,7 @@ def _reconstruct_best_config[ConfigT: TunableWorkflowConfig](
613616 hyperparam_name = trial_key
614617 per_field_best .setdefault (model_hyperparams_field_name , {})[hyperparam_name ] = value
615618
616- best_updates : dict [str , Any ] = {}
617- for tf in model_tuning_info_list :
618- if tf .model_hyperparams_field_name in per_field_best :
619- best_updates [tf .model_hyperparams_field_name ] = tf .tunable_hyperparams .model_copy (
620- update = per_field_best [tf .model_hyperparams_field_name ]
621- )
622-
623- return config .model_copy (update = best_updates )
619+ return config .model_copy (update = _build_hp_updates (model_tuning_info_list , per_field_best ))
624620
625621
626622def fit_with_tuning [ConfigT : TunableWorkflowConfig ](
@@ -652,9 +648,9 @@ def fit_with_tuning[ConfigT: TunableWorkflowConfig](
652648 "TuningConfigMixin" ,
653649 "TuningRange" ,
654650 "TuningResult" ,
651+ "apply_trial_suggestions" ,
655652 "fit_with_tuning" ,
656653 "get_search_space" ,
657654 "run_optuna_study" ,
658- "suggest_hyperparams" ,
659655 "tune" ,
660656]
0 commit comments