Skip to content

Commit ac0b090

Browse files
dcherianclaude
andauthored
More direct use of affine transform (#83)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 8c5a3b1 commit ac0b090

4 files changed

Lines changed: 92 additions & 30 deletions

File tree

src/rasterix/rasterize/core.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,25 @@
88
import numpy as np
99
import xarray as xr
1010

11+
from ..raster_index import RasterIndex
1112
from ..utils import get_affine, get_grid_mapping_var
1213
from .utils import XAXIS, YAXIS, clip_to_bbox, is_in_memory, prepare_for_dask
1314

1415
if TYPE_CHECKING:
1516
import dask_geopandas
17+
from affine import Affine
1618

1719
__all__ = ["rasterize", "geometry_mask", "geometry_clip"]
1820

21+
22+
def _get_affine(obj: xr.Dataset | xr.DataArray, *, x_dim: str, y_dim: str) -> Affine:
23+
"""Get affine transform, preferring RasterIndex if available."""
24+
idx = obj.xindexes.get(x_dim)
25+
if isinstance(idx, RasterIndex):
26+
return idx.transform()
27+
return get_affine(obj, x_dim=x_dim, y_dim=y_dim)
28+
29+
1930
Engine = Literal["rasterio", "rusterize", "exactextract"]
2031

2132

@@ -222,7 +233,7 @@ def rasterize(
222233
if clip:
223234
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
224235

225-
affine = get_affine(obj, x_dim=xdim, y_dim=ydim)
236+
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
226237
engine_merge_alg = _normalize_merge_alg(merge_alg, resolved_engine)
227238

228239
rasterize_geometries, dask_rasterize_wrapper = _get_rasterize_funcs(resolved_engine)
@@ -370,7 +381,7 @@ def geometry_mask(
370381
if clip:
371382
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
372383

373-
affine = get_affine(obj, x_dim=xdim, y_dim=ydim)
384+
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
374385

375386
np_geometry_mask, dask_mask_wrapper = _get_mask_funcs(resolved_engine)
376387

src/rasterix/rasterize/exact.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from exactextract import exact_extract
1111
from exactextract.raster import NumPyRasterSource
1212

13+
from ..rasterize.core import _get_affine
1314
from ..utils import get_grid_mapping_var
1415
from .utils import clip_to_bbox, geometries_as_dask_array, is_in_memory
1516

@@ -103,9 +104,9 @@ def get_dtype(coverage_weight: CoverageWeights, geometries):
103104

104105

105106
def np_coverage(
106-
x: np.ndarray,
107-
y: np.ndarray,
107+
affine,
108108
*,
109+
shape: tuple[int, int],
109110
geometries: gpd.GeoDataFrame,
110111
strategy: Strategy = DEFAULT_STRATEGY,
111112
coverage_weight: CoverageWeights = "fraction",
@@ -115,8 +116,7 @@ def np_coverage(
115116
if len(geometries.columns) > 1:
116117
raise ValueError("Require a single geometries column or a GeoSeries.")
117118

118-
shape = (y.size, x.size)
119-
raster = xy_to_raster_source(x, y, srs_wkt=geometries.crs.to_wkt())
119+
raster = affine_to_raster_source(affine, shape, srs_wkt=geometries.crs.to_wkt())
120120
result = exact_extract(
121121
rast=raster,
122122
vec=geometries,
@@ -161,43 +161,59 @@ def np_coverage(
161161

162162
def coverage_np_dask_wrapper(
163163
geom_array: np.ndarray,
164-
x: np.ndarray,
165-
y: np.ndarray,
164+
x_offsets: np.ndarray,
165+
y_offsets: np.ndarray,
166+
x_sizes: np.ndarray,
167+
y_sizes: np.ndarray,
168+
affine,
166169
coverage_weight: CoverageWeights,
167170
strategy: Strategy,
168171
crs,
169172
) -> np.ndarray:
173+
shape = (y_sizes.item(), x_sizes.item())
174+
chunk_affine = affine * affine.translation(x_offsets.item(), y_offsets.item())
170175
return np_coverage(
171-
x=x.squeeze(axis=(GEOM_AXIS, Y_AXIS)),
172-
y=y.squeeze(axis=(GEOM_AXIS, X_AXIS)),
176+
chunk_affine,
177+
shape=shape,
173178
geometries=gpd.GeoDataFrame(geometry=geom_array.squeeze(axis=(X_AXIS, Y_AXIS)), crs=crs),
174179
coverage_weight=coverage_weight,
175180
)
176181

177182

178183
def dask_coverage(
179-
x: dask.array.Array,
180-
y: dask.array.Array,
184+
affine,
181185
*,
186+
chunks: tuple[tuple[int, ...], tuple[int, ...]],
182187
geom_array: dask.array.Array,
183188
coverage_weight: CoverageWeights = "fraction",
184189
strategy: Strategy = DEFAULT_STRATEGY,
185190
crs: Any,
186191
) -> dask.array.Array:
187192
import dask.array
193+
from dask.array import from_array
188194

189-
if any(c == 1 for c in x.chunks[0]) or any(c == 1 for c in y.chunks[0]):
195+
y_chunks, x_chunks = chunks
196+
197+
if any(c == 1 for c in x_chunks) or any(c == 1 for c in y_chunks):
190198
raise ValueError("exactextract does not support a chunksize of 1. Please rechunk to avoid this")
191199

200+
x_sizes = from_array(x_chunks, chunks=1)
201+
y_sizes = from_array(y_chunks, chunks=1)
202+
x_offsets = from_array(np.cumulative_sum(x_chunks[:-1], include_initial=True), chunks=1)
203+
y_offsets = from_array(np.cumulative_sum(y_chunks[:-1], include_initial=True), chunks=1)
204+
192205
out = dask.array.map_blocks(
193206
coverage_np_dask_wrapper,
194207
geom_array[:, np.newaxis, np.newaxis],
195-
x[np.newaxis, np.newaxis, :],
196-
y[np.newaxis, :, np.newaxis],
208+
x_offsets[np.newaxis, np.newaxis, :],
209+
y_offsets[np.newaxis, :, np.newaxis],
210+
x_sizes[np.newaxis, np.newaxis, :],
211+
y_sizes[np.newaxis, :, np.newaxis],
212+
affine=affine,
197213
crs=crs,
198214
coverage_weight=coverage_weight,
199215
strategy=strategy,
200-
chunks=(*geom_array.chunks, *y.chunks, *x.chunks),
216+
chunks=(*geom_array.chunks, tuple(y_chunks), tuple(x_chunks)),
201217
meta=sparse.COO(
202218
[], data=np.array([], dtype=get_dtype(coverage_weight, geom_array)), shape=(0, 0, 0), fill_value=0
203219
),
@@ -290,32 +306,30 @@ def coverage(
290306

291307
if clip:
292308
obj = clip_to_bbox(obj, geometries, xdim=xdim, ydim=ydim)
309+
310+
affine = _get_affine(obj, x_dim=xdim, y_dim=ydim)
311+
shape = (obj.sizes[ydim], obj.sizes[xdim])
312+
293313
if is_in_memory(obj=obj, geometries=geometries):
294314
out = np_coverage(
295-
x=obj[xdim].data,
296-
y=obj[ydim].data,
315+
affine,
316+
shape=shape,
297317
geometries=geometries,
298318
coverage_weight=coverage_weight,
299319
strategy=strategy,
300320
)
301321
geom_array = geometries.to_numpy().squeeze(axis=1)
302322
else:
303-
from dask.array import Array, from_array
304-
305323
geom_dask_array = geometries_as_dask_array(geometries)
306-
if not isinstance(obj[xdim].data, Array):
307-
dask_x = from_array(obj[xdim].data, chunks=obj.chunksizes.get(xdim, -1))
308-
else:
309-
dask_x = obj[xdim].data
310324

311-
if not isinstance(obj[ydim].data, Array):
312-
dask_y = from_array(obj[ydim].data, chunks=obj.chunksizes.get(ydim, -1))
313-
else:
314-
dask_y = obj[ydim].data
325+
chunks = (
326+
obj.chunksizes.get(ydim, (obj.sizes[ydim],)),
327+
obj.chunksizes.get(xdim, (obj.sizes[xdim],)),
328+
)
315329

316330
out = dask_coverage(
317-
x=dask_x,
318-
y=dask_y,
331+
affine,
332+
chunks=chunks,
319333
geom_array=geom_dask_array,
320334
crs=geometries.crs,
321335
coverage_weight=coverage_weight,

tests/test_exact.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ def test_coverage_weights(
120120
ds = ds.isel(x=xslicer, y=yslicer)
121121
if xchunks is not None or ychunks is not None:
122122
ds = ds.chunk({"x": xchunks, "y": ychunks})
123+
if not indexed:
124+
# Ensure coordinate arrays are in memory so get_affine doesn't trigger dask compute
125+
ds.coords["x"].load()
126+
ds.coords["y"].load()
123127

124128
with raise_if_dask_computes():
125129
actual = coverage(ds, geometries[["geometry"]], coverage_weight=coverage_weight)

tests/test_rasterize.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import xproj # noqa
77
from xarray.tests import raise_if_dask_computes
88

9+
from rasterix import RasterIndex, assign_index
910
from rasterix.rasterize import geometry_clip, geometry_mask, rasterize
1011

1112
pytestmark = pytest.mark.filterwarnings("ignore:variable '.*' has non-conforming '_FillValue'")
@@ -121,3 +122,35 @@ def test_geometry_clip(engine, dataset):
121122
assert clipped is not None
122123
# Basic check that clipping worked - masked values outside geometries
123124
assert clipped["u"].isnull().any()
125+
126+
127+
@pytest.fixture
128+
def raster_index_dataset(dataset):
129+
"""Same grid as ``dataset`` but with a RasterIndex on the spatial dims."""
130+
ds = assign_index(dataset, x_dim="longitude", y_dim="latitude")
131+
assert isinstance(ds.xindexes["longitude"], RasterIndex)
132+
return ds
133+
134+
135+
def test_rasterize_with_raster_index(engine, raster_index_dataset, dataset):
136+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
137+
kwargs = dict(xdim="longitude", ydim="latitude", engine=engine)
138+
139+
expected = rasterize(dataset, world[["geometry"]], **kwargs)
140+
result = rasterize(raster_index_dataset, world[["geometry"]], **kwargs)
141+
142+
xr.testing.assert_equal(result, expected)
143+
assert isinstance(result.xindexes["longitude"], RasterIndex)
144+
assert isinstance(result.xindexes["latitude"], RasterIndex)
145+
146+
147+
def test_geometry_mask_with_raster_index(engine, raster_index_dataset, dataset):
148+
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
149+
kwargs = dict(xdim="longitude", ydim="latitude", engine=engine)
150+
151+
expected = geometry_mask(dataset, world[["geometry"]], **kwargs)
152+
result = geometry_mask(raster_index_dataset, world[["geometry"]], **kwargs)
153+
154+
xr.testing.assert_equal(result, expected)
155+
assert isinstance(result.xindexes["longitude"], RasterIndex)
156+
assert isinstance(result.xindexes["latitude"], RasterIndex)

0 commit comments

Comments
 (0)