Skip to content

Commit ef577b2

Browse files
committed
fix(sample) get_global in the scalar case now returns the scalar with the original type
1 parent e494a85 commit ef577b2

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323

2424
### Fixes
2525

26+
- (sample) get_global in the scalar case now returns the scalar with the original type.
2627
- (datasets) fix missing location use in get_field_names, and improve corresponding tests.
2728
- (cgns_helpers) update_features_for_CGNS_compatibility: fix behavior where input variable was modified by the function.
2829
- (storage/common/preprocessor) make constant_features split-dependant.

src/plaid/containers/features.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,12 +751,19 @@ def get_global(
751751
Optional[Array]: The global array if found, otherwise None. Returns a scalar if the array has size 1.
752752
"""
753753
time = self.resolve_time(time)
754-
if self.has_globals(time):
755-
global_ = CGU.getValueByPath(self.data[time], "Global/" + name)
756-
return global_.item() if getattr(global_, "size", None) == 1 else global_
757-
else:
754+
755+
if not self.has_globals(time):
756+
return None
757+
758+
global_ = CGU.getValueByPath(self.data[time], "Global/" + name)
759+
if global_ is None:
758760
return None
759761

762+
if getattr(global_, "size", None) == 1:
763+
return global_[0]
764+
765+
return global_
766+
760767
def add_global(
761768
self,
762769
name: str,

tests/containers/test_sample.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def test_get_scalar_empty(self, sample):
629629
def test_get_scalar(self, sample_with_scalar):
630630
assert sample_with_scalar.get_scalar("missing_scalar_name") is None
631631
assert sample_with_scalar.get_scalar("test_scalar_1") is not None
632+
assert isinstance(sample_with_scalar.get_scalar("test_scalar_1"), np.float64)
632633

633634
def test_scalars_add_empty(self, sample_with_scalar):
634635
assert isinstance(sample_with_scalar.get_scalar("test_scalar_1"), float)

0 commit comments

Comments
 (0)