Skip to content

Commit c590b3a

Browse files
authored
Make OpenMLTraceIteration a dataclass (#1201)
It provides a better repr and is less verbose.
1 parent beb598c commit c590b3a

2 files changed

Lines changed: 38 additions & 50 deletions

File tree

openml/runs/trace.py

Lines changed: 37 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# License: BSD 3-Clause
22

33
from collections import OrderedDict
4+
from dataclasses import dataclass
45
import json
56
import os
67
from typing import List, Tuple, Optional # noqa F401
@@ -331,12 +332,12 @@ def trace_from_xml(cls, xml):
331332
)
332333

333334
current = OpenMLTraceIteration(
334-
repeat,
335-
fold,
336-
iteration,
337-
setup_string,
338-
evaluation,
339-
selected,
335+
repeat=repeat,
336+
fold=fold,
337+
iteration=iteration,
338+
setup_string=setup_string,
339+
evaluation=evaluation,
340+
selected=selected,
340341
)
341342
trace[(repeat, fold, iteration)] = current
342343

@@ -386,8 +387,11 @@ def __iter__(self):
386387
yield val
387388

388389

389-
class OpenMLTraceIteration(object):
390-
"""OpenML Trace Iteration: parsed output from Run Trace call
390+
@dataclass
391+
class OpenMLTraceIteration:
392+
"""
393+
OpenML Trace Iteration: parsed output from Run Trace call
394+
Exactly one of `setup_string` or `parameters` must be provided.
391395
392396
Parameters
393397
----------
@@ -400,8 +404,9 @@ class OpenMLTraceIteration(object):
400404
iteration : int
401405
iteration number of optimization procedure
402406
403-
setup_string : str
407+
setup_string : str, optional
404408
json string representing the parameters
409+
If not provided, ``parameters`` should be set.
405410
406411
evaluation : double
407412
The evaluation that was awarded to this trace iteration.
@@ -412,42 +417,37 @@ class OpenMLTraceIteration(object):
412417
selected for making predictions. Per fold/repeat there
413418
should be only one iteration selected
414419
415-
parameters : OrderedDict
420+
parameters : OrderedDict, optional
421+
Dictionary specifying parameter names and their values.
422+
If not provided, ``setup_string`` should be set.
416423
"""
417424

418-
def __init__(
419-
self,
420-
repeat,
421-
fold,
422-
iteration,
423-
setup_string,
424-
evaluation,
425-
selected,
426-
parameters=None,
427-
):
428-
429-
if not isinstance(selected, bool):
430-
raise TypeError(type(selected))
431-
if setup_string and parameters:
425+
repeat: int
426+
fold: int
427+
iteration: int
428+
429+
evaluation: float
430+
selected: bool
431+
432+
setup_string: Optional[str] = None
433+
parameters: Optional[OrderedDict] = None
434+
435+
def __post_init__(self):
436+
# TODO: refactor into one argument of type <str | OrderedDict>
437+
if self.setup_string and self.parameters:
432438
raise ValueError(
433-
"Can only be instantiated with either " "setup_string or parameters argument."
439+
"Can only be instantiated with either `setup_string` or `parameters` argument."
434440
)
435-
elif not setup_string and not parameters:
436-
raise ValueError("Either setup_string or parameters needs to be passed as " "argument.")
437-
if parameters is not None and not isinstance(parameters, OrderedDict):
441+
elif not (self.setup_string or self.parameters):
442+
raise ValueError(
443+
"Either `setup_string` or `parameters` needs to be passed as argument."
444+
)
445+
if self.parameters is not None and not isinstance(self.parameters, OrderedDict):
438446
raise TypeError(
439447
"argument parameters is not an instance of OrderedDict, but %s"
440-
% str(type(parameters))
448+
% str(type(self.parameters))
441449
)
442450

443-
self.repeat = repeat
444-
self.fold = fold
445-
self.iteration = iteration
446-
self.setup_string = setup_string
447-
self.evaluation = evaluation
448-
self.selected = selected
449-
self.parameters = parameters
450-
451451
def get_parameters(self):
452452
result = {}
453453
# parameters have prefix 'parameter_'
@@ -461,15 +461,3 @@ def get_parameters(self):
461461
for param, value in self.parameters.items():
462462
result[param[len(PREFIX) :]] = value
463463
return result
464-
465-
def __repr__(self):
466-
"""
467-
tmp string representation, will be changed in the near future
468-
"""
469-
return "[(%d,%d,%d): %f (%r)]" % (
470-
self.repeat,
471-
self.fold,
472-
self.iteration,
473-
self.evaluation,
474-
self.selected,
475-
)

tests/test_runs/test_trace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_duplicate_name(self):
6363
]
6464
trace_content = [[0, 0, 0, 0.5, "true", 1], [0, 0, 0, 0.9, "false", 2]]
6565
with self.assertRaisesRegex(
66-
ValueError, "Either setup_string or parameters needs to be passed as argument."
66+
ValueError, "Either `setup_string` or `parameters` needs to be passed as argument."
6767
):
6868
OpenMLRunTrace.generate(trace_attributes, trace_content)
6969

0 commit comments

Comments
 (0)