-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Labels
bugSomething isn't workingSomething isn't working
Description
When using the pure python padding with torch.compile, there seems to be some implicit limitation on only using the compiled module for a single nside value.
Here is a minimal reproducer for the failures with multiple nsides:
import torch
from earth2grid.healpix import HEALPIX_PAD_XY, pad_with_dim
pad_compiled = torch.compile(pad_with_dim)
def test_pad_compile(batch_size, timesteps, nside, nchannels):
x = torch.rand([batch_size, timesteps, 12*nside*nside, nchannels], dtype=torch.bfloat16, device="cuda")
x_pad = pad_with_dim(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
x_pad_c = pad_compiled(x, 1, dim=-2, pixel_order=HEALPIX_PAD_XY)
return torch.abs(x_pad - x_pad_c).max()
#nsides = [64,] # This works
nsides = [64, 32,] # This fails
for nside in nsides:
print(f"compiled pad error = {test_pad_compile(16, 1, nside, 128)}")Running this prints this error stemming from torch.compile:
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'OpaqueUnaryFn_sqrt' is not definedReactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working