-
Notifications
You must be signed in to change notification settings - Fork 62
Expand file tree
/
Copy pathbase.py
More file actions
executable file
·915 lines (735 loc) · 36.4 KB
/
base.py
File metadata and controls
executable file
·915 lines (735 loc) · 36.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
"""
Base Benchmark Class for MLPerf Storage.
This module provides the abstract base class for all benchmark implementations.
The Benchmark class implements BenchmarkInterface and provides common
functionality including:
- Cluster information collection via MPI
- Result directory management
- Metadata generation and persistence
- Verification/validation integration
- Command execution with signal handling
Classes:
Benchmark: Abstract base class implementing BenchmarkInterface.
Subclassing:
To create a new benchmark type:
1. Inherit from Benchmark
2. Set BENCHMARK_TYPE class attribute
3. Implement _run() method
4. Optionally override generate_command(), validate_args(), etc.
Example:
class MyBenchmark(Benchmark):
BENCHMARK_TYPE = BENCHMARK_TYPES.my_benchmark
def _run(self):
cmd = self.generate_my_command()
stdout, stderr, rc = self._execute_command(cmd)
return rc
"""
import abc
import json
import os
import pprint
import signal
import sys
import time
import types
from argparse import Namespace
from typing import Tuple, Dict, Any, List, Optional, Callable, Set, TYPE_CHECKING
from functools import wraps
from pyarrow.ipc import open_stream
from mlpstorage_py.config import PARAM_VALIDATION, DATETIME_STR, MLPS_DEBUG, EXEC_TYPE
from mlpstorage_py.debug import debug_tryer_wrapper
from mlpstorage_py.interfaces import BenchmarkInterface, BenchmarkConfig, BenchmarkCommand
from mlpstorage_py.mlps_logging import setup_logging, apply_logging_options
from mlpstorage_py.rules import BenchmarkVerifier, generate_output_location, ClusterInformation
from mlpstorage_py.rules.models import ClusterSnapshots, TimeSeriesData, TimeSeriesSample
from mlpstorage_py.utils import CommandExecutor, MLPSJsonEncoder
from mlpstorage_py.cluster_collector import (
collect_cluster_info,
SSHClusterCollector,
TimeSeriesCollector,
MultiHostTimeSeriesCollector,
)
from mlpstorage_py.progress import create_stage_progress, progress_context
if TYPE_CHECKING:
import logging
class Benchmark(BenchmarkInterface, abc.ABC):
"""Base class for all MLPerf Storage benchmarks.
This abstract class implements BenchmarkInterface and provides common
functionality for all benchmark types. Subclasses must implement:
- _run(): The actual benchmark execution logic
- BENCHMARK_TYPE: Class attribute defining the benchmark type
The class supports dependency injection for cluster collectors and validators
to enable easier testing and flexibility.
Attributes:
BENCHMARK_TYPE: Class attribute defining the benchmark type enum value.
args: Parsed command-line arguments.
logger: Logger instance for output.
run_datetime: Timestamp string for the run.
cluster_information: Collected cluster system information.
"""
BENCHMARK_TYPE = None
def __init__(
self,
args: Namespace,
logger: Optional['logging.Logger'] = None,
run_datetime: Optional[str] = None,
run_number: int = 0,
cluster_collector: Optional[Any] = None,
validator: Optional[Any] = None
) -> None:
"""Initialize the benchmark.
Args:
args: Parsed command-line arguments (argparse.Namespace).
logger: Optional logger instance. If not provided, one will be created.
run_datetime: Optional datetime string in YYYYMMDD_HHMMSS format.
Defaults to current time.
run_number: Run number for this benchmark execution (for loops).
cluster_collector: Optional cluster collector for dependency injection.
Used for testing without MPI.
validator: Optional validator for dependency injection.
Used for testing validation logic.
"""
self.args = args
self.debug = self.args.debug or MLPS_DEBUG
if logger:
self.logger = logger
else:
# Ensure there is always a logger available
self.logger = setup_logging(name=f"{self.BENCHMARK_TYPE}_benchmark", stream_log_level=args.stream_log_level)
self.logger.warning(f'Benchmark did not get a logger passed. Using default logger.')
apply_logging_options(self.logger, args)
if not run_datetime:
self.logger.warning('No run datetime provided. Using current datetime.')
self.run_datetime = run_datetime if run_datetime else DATETIME_STR
self.run_number = run_number
self.runtime = 0
# Dependency injection for testability
self._cluster_collector = cluster_collector
self._validator = validator
self.benchmark_run_verifier = None
self.verification = None
self.cmd_executor = CommandExecutor(logger=self.logger, debug=args.debug)
self.command_output_files = list()
self.run_result_output = self.generate_output_location()
os.makedirs(self.run_result_output, exist_ok=True)
self.metadata_filename = f"{self.BENCHMARK_TYPE.value}_{self.run_datetime}_metadata.json"
self.metadata_file_path = os.path.join(self.run_result_output, self.metadata_filename)
# Time-series collection (HOST-04, HOST-05)
self._timeseries_collector = None
self._timeseries_data = None
self.timeseries_filename = f"{self.BENCHMARK_TYPE.value}_{self.run_datetime}_timeseries.json"
self.timeseries_file_path = os.path.join(self.run_result_output, self.timeseries_filename)
self.logger.status(f'Benchmark results directory: {self.run_result_output}')
# =========================================================================
# BenchmarkInterface Implementation
# =========================================================================
@property
def config(self) -> BenchmarkConfig:
"""Return benchmark configuration.
Subclasses can override this to provide more specific configuration.
"""
return BenchmarkConfig(
name=self.BENCHMARK_TYPE.value if self.BENCHMARK_TYPE else "unknown",
benchmark_type=self.BENCHMARK_TYPE.name if self.BENCHMARK_TYPE else "unknown",
supported_commands=self._get_supported_commands(),
requires_cluster_info=True,
requires_mpi=getattr(self.args, 'exec_type', None) == EXEC_TYPE.MPI,
)
def _get_supported_commands(self) -> List[BenchmarkCommand]:
"""Get list of supported commands. Override in subclass."""
return [BenchmarkCommand.RUN]
def validate_args(self, args) -> List[str]:
"""Validate command-line arguments.
Args:
args: Parsed command-line arguments.
Returns:
List of error messages. Empty list indicates valid arguments.
"""
errors = []
# Subclasses should override to add specific validation
return errors
def get_command_handler(self, command: str) -> Optional[Callable]:
"""Return handler function for the given command.
Args:
command: Command string (e.g., 'run', 'datagen').
Returns:
Callable that handles the command, or None if not supported.
"""
# Default implementation - subclasses should override
handlers = {
'run': self._run,
}
return handlers.get(command)
def generate_command(self, command: str) -> str:
"""Generate the shell command to execute.
Args:
command: Command string (e.g., 'run', 'datagen').
Returns:
Shell command string ready for execution.
"""
# Default implementation - subclasses must override for actual command generation
raise NotImplementedError("Subclasses must implement generate_command()")
def collect_results(self) -> Dict[str, Any]:
"""Collect and return benchmark results.
Returns:
Dictionary containing benchmark results and metadata.
"""
return {
'benchmark_type': self.BENCHMARK_TYPE.name if self.BENCHMARK_TYPE else None,
'run_datetime': self.run_datetime,
'runtime': self.runtime,
'verification': self.verification.name if self.verification else None,
'result_dir': self.run_result_output,
}
def get_metadata(self) -> Dict[str, Any]:
"""Get benchmark metadata for recording.
Returns:
Dictionary containing benchmark configuration and parameters.
"""
return self.metadata
# =========================================================================
# Original Benchmark Methods
# =========================================================================
def _execute_command(
self,
command: str,
output_file_prefix: Optional[str] = None,
print_stdout: bool = True,
print_stderr: bool = True
) -> Tuple[str, str, int]:
"""Execute the given command and return stdout, stderr, and return code.
Handles what-if mode, signal watching for graceful termination,
and optionally saves output to log files.
Args:
command: Shell command string to execute.
output_file_prefix: If provided, stdout/stderr are saved to
{prefix}.stdout.log and {prefix}.stderr.log
print_stdout: Whether to print stdout to console in real-time.
print_stderr: Whether to print stderr to console in real-time.
Returns:
Tuple of (stdout_content, stderr_content, return_code).
In what-if mode, returns ("", "", 0) without execution.
"""
self.__dict__.update({'executed_command': command})
if self.args.what_if:
self.logger.debug(f'Executing command in --what-if mode means no execution will be performed.')
log_message = f'What-if mode: \nCommand: {command}'
if self.debug:
log_message += f'\n\nParameters: \n{pprint.pformat(vars(self.args))}'
self.logger.info(log_message)
return "", "", 0
else:
watch_signals = {signal.SIGINT, signal.SIGTERM}
stdout, stderr, return_code = self.cmd_executor.execute(command, watch_signals=watch_signals,
print_stdout=print_stdout,
print_stderr=print_stderr)
if output_file_prefix:
stdout_filename = f"{output_file_prefix}.stdout.log"
stderr_filename = f"{output_file_prefix}.stderr.log"
stdout_file = os.path.join(self.run_result_output, stdout_filename)
stderr_file = os.path.join(self.run_result_output, stderr_filename)
with open(stdout_file, 'w+') as fd:
self.logger.verbose(f'Command stdout saved to: {stdout_filename}')
fd.write(stdout)
with open(stderr_file, 'w+') as fd:
self.logger.verbose(f'Command stderr saved to: {stderr_filename}')
fd.write(stderr)
self.command_output_files.append(dict(command=command, stdout=stdout_file, stderr=stderr_file))
return stdout, stderr, return_code
@property
def metadata(self) -> Dict[str, Any]:
"""Generate metadata dict capturing the benchmark run configuration.
This metadata is designed to be complete enough that BenchmarkRunData
can be reconstructed from it without needing tool-specific result files.
The metadata includes:
- benchmark_type, model, command, run_datetime
- parameters and override_parameters
- system_info (cluster configuration)
- runtime, verification status
- executed_command and output files
Returns:
Dictionary containing all benchmark metadata.
"""
# Core fields required by BenchmarkRunData
metadata = {
'benchmark_type': self.BENCHMARK_TYPE.name,
'model': getattr(self.args, 'model', None),
'command': getattr(self.args, 'command', None),
'run_datetime': self.run_datetime,
'num_processes': getattr(self.args, 'num_processes', None),
'accelerator': getattr(self.args, 'accelerator_type', None),
'result_dir': self.run_result_output,
}
# Parameters - prefer combined_params if available (includes YAML + overrides)
if hasattr(self, 'combined_params'):
metadata['parameters'] = self.combined_params
else:
metadata['parameters'] = {}
# Override parameters - user-specified overrides only
if hasattr(self, 'params_dict'):
metadata['override_parameters'] = self.params_dict
else:
metadata['override_parameters'] = {}
# System info - serialize ClusterInformation if available
if hasattr(self, 'cluster_information') and self.cluster_information:
metadata['system_info'] = self.cluster_information.as_dict()
else:
metadata['system_info'] = None
# Include cluster snapshots if available (start and end collection)
if hasattr(self, 'cluster_snapshots') and self.cluster_snapshots:
metadata['cluster_snapshots'] = self.cluster_snapshots.as_dict()
# Include time-series data reference if available (HOST-04)
if hasattr(self, '_timeseries_data') and self._timeseries_data:
metadata['timeseries_data'] = {
'file': self.timeseries_filename,
'num_samples': self._timeseries_data.num_samples,
'interval_seconds': self._timeseries_data.collection_interval_seconds,
'hosts_collected': self._timeseries_data.hosts_collected,
}
# Additional context (not part of BenchmarkRunData but useful)
metadata['runtime'] = self.runtime
metadata['verification'] = self.verification.name if self.verification else None
metadata['executed_command'] = getattr(self, 'executed_command', None)
metadata['command_output_files'] = self.command_output_files
# Include full args for debugging/auditing (skip non-serializable)
try:
metadata['args'] = vars(self.args)
except Exception:
metadata['args'] = str(self.args)
return metadata
def write_metadata(self) -> None:
"""Write benchmark metadata to JSON file.
Writes metadata to {metadata_file_path}. In verbose/debug mode,
also prints metadata to stdout.
"""
with open(self.metadata_file_path, 'w+') as fd:
json.dump(self.metadata, fd, indent=2, cls=MLPSJsonEncoder)
if self.args.verbose or self.args.debug or self.debug:
json.dump(self.metadata, sys.stdout, indent=2, cls=MLPSJsonEncoder)
def write_cluster_info(self):
"""Write detailed cluster information to a separate JSON file."""
if not hasattr(self, 'cluster_information') or not self.cluster_information:
return
cluster_info_filename = f"{self.BENCHMARK_TYPE.value}_cluster_info.json"
cluster_info_path = os.path.join(self.run_result_output, cluster_info_filename)
try:
with open(cluster_info_path, 'w') as fd:
json.dump(self.cluster_information.to_detailed_dict(), fd, indent=2)
self.logger.verbose(f'Cluster information saved to: {cluster_info_filename}')
except Exception as e:
self.logger.warning(f'Failed to write cluster info: {e}')
def _should_collect_cluster_info(self) -> bool:
"""Determine if we should collect cluster information via MPI.
Returns True if:
- hosts argument is provided and not empty
- command is not 'datagen' or 'configview' (data generation doesn't need cluster info)
- skip_cluster_collection is not set
"""
# Check if hosts are specified
if not hasattr(self.args, 'hosts') or not self.args.hosts:
return False
# Skip for certain commands that don't need cluster info
if hasattr(self.args, 'command') and self.args.command in ('datagen', 'configview'):
return False
# Check if user explicitly disabled collection
if hasattr(self.args, 'skip_cluster_collection') and self.args.skip_cluster_collection:
return False
return True
def _collect_cluster_information(self) -> 'ClusterInformation':
"""Collect cluster information using MPI if available, otherwise return None.
This method attempts to collect detailed system information from all hosts
using MPI. If MPI collection fails or is not available, it returns None
and the subclass should fall back to CLI args-based collection.
Returns:
ClusterInformation instance if collection succeeds, None otherwise.
"""
if not self._should_collect_cluster_info():
self.logger.debug('Skipping cluster info collection (conditions not met)')
return None
# Only attempt MPI collection if exec_type is MPI
if not hasattr(self.args, 'exec_type') or self.args.exec_type != EXEC_TYPE.MPI:
self.logger.debug('Skipping MPI cluster collection (exec_type is not MPI)')
return None
try:
self.logger.debug('Collecting cluster information via MPI...')
# Get collection parameters
mpi_bin = getattr(self.args, 'mpi_bin', 'mpirun')
allow_run_as_root = getattr(self.args, 'allow_run_as_root', False)
timeout = getattr(self.args, 'cluster_collection_timeout', 60)
ssh_username = getattr(self.args, 'ssh_username', None)
shared_staging_dir = getattr(self.args, 'shared_staging_dir', None)
# Collect cluster info. ``results_dir`` is required by
# ``collect_cluster_info`` for staging the helper script under
# ``<results_dir>/collector-staging/`` (see issue #363).
collected_data = collect_cluster_info(
hosts=self.args.hosts,
mpi_bin=mpi_bin,
logger=self.logger,
results_dir=self.run_result_output,
allow_run_as_root=allow_run_as_root,
timeout_seconds=timeout,
fallback_to_local=True,
shared_staging_dir=shared_staging_dir,
ssh_username=ssh_username,
)
# Create ClusterInformation from collected data
cluster_info = ClusterInformation.from_mpi_collection(collected_data, self.logger)
# Log collection results
collection_method = collected_data.get('_metadata', {}).get('collection_method', 'unknown')
self.logger.debug(
f'Cluster info collected via {collection_method}: '
f'{cluster_info.num_hosts} hosts, '
f'{cluster_info.total_memory_bytes / (1024**3):.1f}GiB total memory, '
f'{cluster_info.total_cores} total cores'
)
# Log any consistency warnings
if cluster_info.host_consistency_issues:
for issue in cluster_info.host_consistency_issues:
self.logger.warning(f'Cluster consistency: {issue}')
return cluster_info
except Exception as e:
self.logger.warning(f'MPI cluster info collection failed: {e}')
return None
def _should_use_ssh_collection(self) -> bool:
"""Determine if SSH-based collection should be used.
SSH collection is used when:
- hosts are specified
- exec_type is NOT MPI (or exec_type is not set)
- command is 'run' (not datagen/configview)
Returns:
True if SSH collection should be used, False otherwise.
"""
if not hasattr(self.args, 'hosts') or not self.args.hosts:
return False
if hasattr(self.args, 'command') and self.args.command in ('datagen', 'configview'):
return False
if hasattr(self.args, 'skip_cluster_collection') and self.args.skip_cluster_collection:
return False
# Use SSH for non-MPI execution
if not hasattr(self.args, 'exec_type') or self.args.exec_type != EXEC_TYPE.MPI:
return True
return False
def _collect_via_ssh(self) -> Optional['ClusterInformation']:
"""Collect cluster information using SSH.
Returns:
ClusterInformation instance if collection succeeds, None otherwise.
"""
try:
self.logger.debug('Collecting cluster information via SSH...')
ssh_username = getattr(self.args, 'ssh_username', None)
timeout = getattr(self.args, 'cluster_collection_timeout', 60)
collector = SSHClusterCollector(
hosts=self.args.hosts,
logger=self.logger,
ssh_username=ssh_username,
timeout_seconds=timeout
)
if not collector.is_available():
self.logger.warning('SSH not available for cluster collection')
return None
result = collector.collect(self.args.hosts, timeout)
if not result.success:
self.logger.warning(f'SSH collection had errors: {result.errors}')
# Create ClusterInformation from collected data
cluster_info = ClusterInformation.from_mpi_collection(
{**result.data, '_metadata': {
'collection_method': 'ssh',
'collection_timestamp': result.timestamp
}},
self.logger
)
self.logger.debug(
f'Cluster info collected via SSH: '
f'{cluster_info.num_hosts} hosts, '
f'{cluster_info.total_memory_bytes / (1024**3):.1f}GiB total memory'
)
return cluster_info
except Exception as e:
self.logger.warning(f'SSH cluster info collection failed: {e}')
return None
def _collect_cluster_start(self) -> None:
"""Collect cluster information at benchmark start.
Stores the result in self._cluster_info_start for later use.
Called at the beginning of run().
"""
if not self._should_collect_cluster_info() and not self._should_use_ssh_collection():
self.logger.debug('Skipping start cluster collection (conditions not met)')
return
hosts = self.args.hosts if hasattr(self.args, 'hosts') else []
host_count = len(hosts) if hosts else 1
self.logger.debug(f"Collecting cluster info ({host_count} host{'s' if host_count != 1 else ''})...")
with progress_context("Collecting cluster info...", total=None) as (_, set_desc):
if self._should_use_ssh_collection():
set_desc("Collecting via SSH...")
self._cluster_info_start = self._collect_via_ssh()
self._collection_method = 'ssh'
else:
set_desc("Collecting via MPI...")
self._cluster_info_start = self._collect_cluster_information()
self._collection_method = 'mpi'
if self._cluster_info_start:
self.logger.debug(f'Collected start cluster info via {self._collection_method}')
def _collect_cluster_end(self) -> None:
"""Collect cluster information at benchmark end.
Only collects if start collection was performed.
Creates ClusterSnapshots with both start and end data.
"""
if not hasattr(self, '_cluster_info_start') or self._cluster_info_start is None:
self.logger.debug('Skipping end cluster collection (no start collection)')
return
self.logger.debug("Collecting end cluster info...")
with progress_context("Collecting cluster info...", total=None) as (_, set_desc):
if self._collection_method == 'ssh':
set_desc("Collecting via SSH...")
self._cluster_info_end = self._collect_via_ssh()
else:
set_desc("Collecting via MPI...")
self._cluster_info_end = self._collect_cluster_information()
if self._cluster_info_end:
self.logger.debug(f'Collected end cluster info via {self._collection_method}')
# Create ClusterSnapshots
self.cluster_snapshots = ClusterSnapshots(
start=self._cluster_info_start,
end=self._cluster_info_end,
collection_method=getattr(self, '_collection_method', 'unknown')
)
# Also set cluster_information to the start snapshot for backward compatibility
self.cluster_information = self._cluster_info_start
def _should_collect_timeseries(self) -> bool:
"""Determine if time-series collection should be performed.
Returns:
True if time-series collection should be performed.
"""
# Check if user explicitly disabled
if hasattr(self.args, 'skip_timeseries') and self.args.skip_timeseries:
return False
# Only collect for 'run' command
if hasattr(self.args, 'command') and self.args.command not in ('run',):
return False
# Skip in what-if mode
if hasattr(self.args, 'what_if') and self.args.what_if:
return False
return True
def _start_timeseries_collection(self) -> None:
"""Start time-series collection in background.
Uses MultiHostTimeSeriesCollector if hosts specified,
otherwise uses single-host TimeSeriesCollector.
Collection runs in a background thread to minimize performance impact
on benchmark execution (HOST-05 requirement).
"""
if not self._should_collect_timeseries():
self.logger.debug('Skipping time-series collection (disabled or not applicable)')
return
interval = getattr(self.args, 'timeseries_interval', 10.0)
max_samples = getattr(self.args, 'max_timeseries_samples', 3600)
try:
if hasattr(self.args, 'hosts') and self.args.hosts:
# Multi-host collection
ssh_username = getattr(self.args, 'ssh_username', None)
ssh_timeout = getattr(self.args, 'cluster_collection_timeout', 30)
self._timeseries_collector = MultiHostTimeSeriesCollector(
hosts=self.args.hosts,
interval_seconds=interval,
max_samples=max_samples,
ssh_username=ssh_username,
ssh_timeout=ssh_timeout,
logger=self.logger
)
self.logger.debug(
f'Starting multi-host time-series collection ({len(self.args.hosts)} hosts, '
f'interval={interval}s)'
)
else:
# Single-host collection (localhost only)
self._timeseries_collector = TimeSeriesCollector(
interval_seconds=interval,
max_samples=max_samples,
logger=self.logger
)
self.logger.debug(
f'Starting single-host time-series collection (interval={interval}s)'
)
self._timeseries_collector.start()
except Exception as e:
self.logger.warning(f'Failed to start time-series collection: {e}')
self._timeseries_collector = None
def _stop_timeseries_collection(self) -> None:
"""Stop time-series collection and store results."""
if self._timeseries_collector is None:
return
try:
if isinstance(self._timeseries_collector, MultiHostTimeSeriesCollector):
samples_by_host = self._timeseries_collector.stop()
hosts_collected = self._timeseries_collector.get_hosts_with_data()
# Convert to TimeSeriesSample dataclasses
samples_by_host_typed = {}
total_samples = 0
for host, samples in samples_by_host.items():
samples_by_host_typed[host] = [
TimeSeriesSample.from_dict(s) for s in samples
]
total_samples += len(samples)
self._timeseries_data = TimeSeriesData(
collection_interval_seconds=self._timeseries_collector.interval_seconds,
start_time=self._timeseries_collector.start_time or '',
end_time=self._timeseries_collector.end_time or '',
num_samples=total_samples,
samples_by_host=samples_by_host_typed,
collection_method='ssh' if len(hosts_collected) > 1 else 'local',
hosts_requested=list(self._timeseries_collector.hosts),
hosts_collected=hosts_collected,
)
else:
# Single-host TimeSeriesCollector
samples = self._timeseries_collector.stop()
hostname = samples[0]['hostname'] if samples else 'localhost'
samples_typed = [TimeSeriesSample.from_dict(s) for s in samples]
self._timeseries_data = TimeSeriesData(
collection_interval_seconds=self._timeseries_collector.interval_seconds,
start_time=self._timeseries_collector.start_time or '',
end_time=self._timeseries_collector.end_time or '',
num_samples=len(samples),
samples_by_host={hostname: samples_typed},
collection_method='local',
hosts_requested=[hostname],
hosts_collected=[hostname] if samples else [],
)
self.logger.debug(
f'Time-series collection complete ({self._timeseries_data.num_samples} samples)'
)
except Exception as e:
self.logger.warning(f'Failed to stop time-series collection: {e}')
self._timeseries_data = None
def write_timeseries_data(self) -> None:
"""Write time-series data to JSON file.
Output file follows naming convention: {benchmark_type}_{datetime}_timeseries.json
This ensures the file is discoverable alongside other benchmark output files
(HOST-04 requirement).
"""
if self._timeseries_data is None:
return
try:
with open(self.timeseries_file_path, 'w') as f:
json.dump(self._timeseries_data.to_dict(), f, indent=2, cls=MLPSJsonEncoder)
self.logger.verbose(f'Time-series data saved to: {self.timeseries_filename}')
except Exception as e:
self.logger.warning(f'Failed to write time-series data: {e}')
def generate_output_location(self) -> str:
"""Generate the output directory path for this benchmark run.
Creates a path based on BENCHMARK_TYPE, model, command, and datetime.
Returns:
Absolute path string for the result directory.
Raises:
ValueError: If BENCHMARK_TYPE is not set.
"""
if not self.BENCHMARK_TYPE:
raise ValueError('No benchmark specified. Unable to generate output location')
return generate_output_location(self, self.run_datetime)
def verify_benchmark(self) -> bool:
"""Verify benchmark parameters meet OPEN or CLOSED requirements.
Uses BenchmarkVerifier to check if the current configuration
meets the requirements for closed or open submission.
Returns:
True if verification passes, False otherwise.
May call sys.exit(1) if invalid and --allow-invalid-params not set.
"""
self.logger.verboser(f'Verifying benchmark parameters: {self.args}')
if not self.benchmark_run_verifier:
self.benchmark_run_verifier = BenchmarkVerifier(self, logger=self.logger)
self.verification = self.benchmark_run_verifier.verify()
self.logger.verboser(f'Benchmark verification result: {self.verification}')
if not self.args.closed and not hasattr(self.args, "open"):
self.logger.warning(f'Running the benchmark without verification for open or closed configurations. These results are not valid for submission. Use --open or --closed to specify a configuration.')
return True
if not self.BENCHMARK_TYPE:
raise ValueError(f'No benchmark specified. Unable to verify benchmark')
if not self.verification:
self.logger.error(f'Verification did not return a result. Contact the developer')
sys.exit(1)
if self.verification == PARAM_VALIDATION.CLOSED:
return True
elif self.verification == PARAM_VALIDATION.INVALID:
if self.args.allow_invalid_params:
self.logger.warning(f'Invalid configuration found. Allowing the benchmark to proceed.')
return True
else:
self.logger.error(f'Invalid configuration found. Aborting benchmark run.')
sys.exit(1)
if self.verification == PARAM_VALIDATION.OPEN:
if self.args.closed == False:
# "--open" was passed
self.logger.status(f'Running as allowed open configuration')
return True
else:
self.logger.warning(f'Parameters allowed for open but not closed. Use --open and rerun the benchmark.')
sys.exit(1)
@abc.abstractmethod
def _run(self) -> int:
"""Run the actual benchmark execution.
Subclasses must implement this method to define the benchmark
execution logic. The method should:
1. Generate and execute the benchmark command
2. Collect and process results
3. Write metadata and output files
4. Return the exit code
Returns:
Exit code (0 for success, non-zero for failure).
"""
raise NotImplementedError
def _validate_environment(self) -> None:
"""Validate environment before benchmark execution.
Called early in run() to catch configuration issues before
any work is done. Subclasses can override to add benchmark-
specific validation.
Note: Primary environment validation is done in main.py via
validate_benchmark_environment() BEFORE benchmark instantiation.
This hook is for benchmark-specific validation that requires
the benchmark instance to exist.
Raises:
DependencyError: If required dependencies are missing.
ConfigurationError: If configuration is invalid.
"""
# Environment validation is primarily done in main.py before
# benchmark instantiation. This hook allows subclasses to add
# benchmark-specific validation if needed.
pass
def run(self) -> int:
"""Execute the benchmark and track runtime.
Wraps _run() with timing measurement, cluster collection, and
time-series collection. Shows stage indicators during execution.
Collects cluster information at start and end of benchmark
(HOST-03 requirement).
Collects time-series data during benchmark execution using a
background thread to minimize performance impact (HOST-04, HOST-05).
Returns:
Exit code from _run().
"""
stages = [
"Validating environment...",
"Collecting cluster info...",
"Running benchmark...",
"Processing results...",
]
with create_stage_progress(stages, logger=self.logger) as advance_stage:
# Stage 1: Validation
self._validate_environment()
advance_stage()
# Stage 2: Cluster collection
self._collect_cluster_start()
self._start_timeseries_collection()
advance_stage()
# Stage 3: Benchmark execution
# Note: Stage progress remains visible showing elapsed time
# during this phase. DLIO output flows through directly.
start_time = time.time()
try:
result = self._run()
finally:
self.runtime = time.time() - start_time
advance_stage()
# Stage 4: Cleanup/Processing
self._stop_timeseries_collection()
self._collect_cluster_end()
self.write_timeseries_data()
advance_stage()
return result