Skip to content

Commit fc62a06

Browse files
committed
Guard optional rapidata metric import and tighten validation
1 parent aa2b198 commit fc62a06

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

src/pruna/evaluation/metrics/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222
from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric
2323
from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric
2424
from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore
25-
from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric
2625
from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric
2726
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
2827

28+
try:
29+
from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric
30+
except ModuleNotFoundError as e:
31+
if e.name != "rapidata":
32+
raise
33+
2934
__all__ = [
3035
"MetricRegistry",
3136
"TorchMetricWrapper",
@@ -44,5 +49,7 @@
4449
"DinoScore",
4550
"SharpnessMetric",
4651
"AestheticLAION",
47-
"RapidataMetric",
4852
]
53+
54+
if "RapidataMetric" in globals():
55+
__all__.append("RapidataMetric")

src/pruna/evaluation/metrics/metric_rapiddata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def compute(self) -> None:
299299
:meth:`retrieve_granular_results` once enough votes have been
300300
collected.
301301
"""
302+
self._require_benchmark()
302303
self._require_model()
303304
if not self.media_cache:
304305
raise ValueError("No data accumulated. Call update() before compute().")
@@ -348,7 +349,7 @@ def retrieve_results(self, *args, **kwargs) -> CompositeMetricResult | None:
348349
if "ValidationError" in type(e).__name__:
349350
pruna_logger.warning(
350351
"The benchmark hasn't finished yet.\n "
351-
"Please wait for more votes and try again."
352+
"Please wait for more votes and try again.\n "
352353
"Skipping."
353354
)
354355
return None

0 commit comments

Comments
 (0)