Skip to content

Commit 44ccfdd

Browse files
Pr2 functionals runtime tests (#1529)
* Refactor functional module layout into category packages * Fix functional exports to match files in geometry split * Fix functional exports to match files in geometry split * Fix interpolation exports for pr1 functional layout * Fix interpolation compat alias target in pr1 * Update legacy functional tests for neighbors package layout * Apply ruff formatting fixes for functional namespace * Document functional deep-import compatibility impact * Add jaxtyping hints to radius_search forward implementations * Add jaxtyping annotations across FunctionSpec functionals * Refactor functional layout, interpolation split, and parity/benchmark/docs updates * Revert "Merge origin/pr1-functional-layout into pr1-functional-layout" This reverts commit 7cdfa4e, reversing changes made to ef435cc. * Drop unintended stormcast changes from pr1-functional-layout * Align mesh sampling sample_data.py with upstream/main * Fix mesh sampling knn import path for functional layout split * Restore UV-sphere SDF benchmark input generation * Align import-linter module layers with upstream main * Restore nearest-cells changelog entry * Fix legacy nn.functional deep-import paths after layout split * Port functional runtime, FunctionSpec, and tests from split branch * Drop out-of-scope derivative and ball-pivoting files from runtime split * Retarget branch to docs/media and coding standards only * Remove docs media generation scripts from PR * Port benchmark/runtime API refactor and migrate functional tests * Remove docs for out-of-scope functionals * Trim docs media assets from PR2 scope * Align functional tests with naming/layout conventions * Add functional spec-contract compare tests and guideline * Require compare hooks only for multi-backend functionals * Normalize functional test ordering and dispatch guidance * Update CODING_STANDARDS/FUNCTIONAL_APIS.md Co-authored-by: megnvidia <mmiranda@nvidia.com> * Update CODING_STANDARDS/FUNCTIONAL_APIS.md Co-authored-by: megnvidia <mmiranda@nvidia.com> * Keep functional benchmark plot output backward-compatible * Use docs/img as canonical functional benchmark output root * Use existing benchmark.png assets in functional docs * Revert unintended interpolation docs page rename --------- Co-authored-by: megnvidia <mmiranda@nvidia.com>
1 parent c667944 commit 44ccfdd

21 files changed

Lines changed: 2291 additions & 1187 deletions

File tree

CODING_STANDARDS/FUNCTIONAL_APIS.md

Lines changed: 162 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ This document is structured in two main sections:
5858
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
5959
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
6060
| [`FNC-004`](#fnc-004-optional-dependency-handling) | Optional dependency handling | Using optional backends |
61-
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs`/`compare` |
61+
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` |
6262
| [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests |
6363
| [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV |
6464

@@ -72,7 +72,8 @@ This document is structured in two main sections:
7272

7373
All functionals must be implemented with `FunctionSpec`, even if only a single
7474
implementation exists. This ensures the operation participates in validation
75-
and benchmarking via `make_inputs` and `compare`.
75+
and benchmarking through input generators and `compare_forward` (and
76+
`compare_backward` where needed).
7677

7778
**Rationale:**
7879

@@ -82,52 +83,41 @@ selection, benchmarking and verification across the codebase.
8283
**Example:**
8384

8485
```python
85-
import importlib
8686
import torch
87+
import warp as wp
8788

8889
from physicsnemo.core.function_spec import FunctionSpec
89-
from physicsnemo.core.version_check import check_version_spec
90-
91-
WARP_AVAILABLE = check_version_spec("warp", "0.6.0", hard_fail=False)
92-
93-
if WARP_AVAILABLE:
94-
wp = importlib.import_module("warp")
95-
wp.init()
96-
wp.config.quiet = True
97-
98-
@wp.kernel
99-
def _identity_kernel(
100-
x: wp.array(dtype=wp.float32),
101-
y: wp.array(dtype=wp.float32),
102-
):
103-
i = wp.tid()
104-
y[i] = x[i]
105-
106-
@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
107-
def identity_impl(x: torch.Tensor) -> torch.Tensor:
108-
out = torch.empty_like(x)
109-
device, stream = FunctionSpec.warp_launch_context(x)
110-
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
111-
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
112-
with wp.ScopedStream(stream):
113-
wp.launch(
114-
kernel=_identity_kernel,
115-
dim=x.numel(),
116-
inputs=[wp_x, wp_y],
117-
device=device,
118-
stream=stream,
119-
)
120-
return out
121-
122-
@identity_impl.register_fake
123-
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
124-
return torch.empty_like(x)
125-
else:
12690

127-
def identity_impl(*args, **kwargs) -> torch.Tensor:
128-
raise ImportError(
129-
"warp>=0.6.0 is required for the Warp identity implementation"
91+
wp.init()
92+
wp.config.quiet = True
93+
94+
@wp.kernel
95+
def _identity_kernel(
96+
x: wp.array(dtype=wp.float32),
97+
y: wp.array(dtype=wp.float32),
98+
):
99+
i = wp.tid()
100+
y[i] = x[i]
101+
102+
@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
103+
def identity_impl(x: torch.Tensor) -> torch.Tensor:
104+
out = torch.empty_like(x)
105+
device, stream = FunctionSpec.warp_launch_context(x)
106+
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
107+
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
108+
with wp.ScopedStream(stream):
109+
wp.launch(
110+
kernel=_identity_kernel,
111+
dim=x.numel(),
112+
inputs=[wp_x, wp_y],
113+
device=device,
114+
stream=stream,
130115
)
116+
return out
117+
118+
@identity_impl.register_fake
119+
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
120+
return torch.empty_like(x)
131121

132122
def identity_torch(x: torch.Tensor) -> torch.Tensor:
133123
return x.clone()
@@ -148,14 +138,41 @@ class Identity(FunctionSpec):
148138
return identity_torch(x)
149139

150140
@classmethod
151-
def make_inputs(cls, device: torch.device | str = "cpu"):
141+
def make_inputs_forward(cls, device: torch.device | str = "cpu"):
152142
device = torch.device(device)
153143
yield ("small", (torch.randn(1024, device=device),), {})
154144
yield ("medium", (torch.randn(4096, device=device),), {})
155145
yield ("large", (torch.randn(16384, device=device),), {})
156146

157147
@classmethod
158-
def compare(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
148+
def make_inputs_backward(cls, device: torch.device | str = "cpu"):
149+
device = torch.device(device)
150+
yield (
151+
"small-bwd",
152+
(torch.randn(1024, device=device, requires_grad=True),),
153+
{},
154+
)
155+
yield (
156+
"medium-bwd",
157+
(torch.randn(4096, device=device, requires_grad=True),),
158+
{},
159+
)
160+
yield (
161+
"large-bwd",
162+
(torch.randn(16384, device=device, requires_grad=True),),
163+
{},
164+
)
165+
166+
@classmethod
167+
def compare_forward(
168+
cls, output: torch.Tensor, reference: torch.Tensor
169+
) -> None:
170+
torch.testing.assert_close(output, reference)
171+
172+
@classmethod
173+
def compare_backward(
174+
cls, output: torch.Tensor, reference: torch.Tensor
175+
) -> None:
159176
torch.testing.assert_close(output, reference)
160177

161178
identity = Identity.make_function("identity")
@@ -210,6 +227,14 @@ __all__ = ["knn"]
210227
`physicsnemo/nn/functional/<name>/`.
211228
- Keep each backend in its own module (e.g., `_torch_impl.py`).
212229
- Keep shared helpers in `utils.py`.
230+
- For complex Warp backends, prefer a dedicated `_warp_impl/` package with:
231+
- `op.py` for torch custom-op registration and validation
232+
- `launch_forward.py` for forward launch dispatch
233+
- `launch_backward.py` for backward launch dispatch
234+
- `_kernels/` with one kernel per file
235+
- `utils.py` for shared Warp constants/functions
236+
- Keep `launch_forward.py` and `launch_backward.py` as the only launch
237+
surfaces; avoid extra launch helper modules unless there is a strong reason.
213238

214239
**Rationale:**
215240

@@ -228,6 +253,21 @@ physicsnemo/nn/functional/knn/
228253
utils.py
229254
```
230255

256+
```text
257+
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/
258+
grid_to_point_interpolation.py
259+
_torch_impl.py
260+
_warp_impl/
261+
__init__.py
262+
op.py
263+
launch_forward.py
264+
launch_backward.py
265+
_kernels/
266+
forward_3d_stride2.py
267+
backward_3d_stride2.py
268+
utils.py
269+
```
270+
231271
**Anti-pattern:**
232272

233273
```text
@@ -308,11 +348,20 @@ import missing_dep # raises at import time
308348

309349
**Description:**
310350

311-
Implement `make_inputs` and `compare` for every functional. `make_inputs` should
312-
yield labeled inputs ordered from smaller to larger cases. Labels do not have to
313-
be exactly "small/medium/large", and you can provide more than three cases.
314-
`compare` should validate output consistency. Labels are used for benchmark
315-
plots and summaries.
351+
Implement `make_inputs_forward` for every functional so it can be benchmarked.
352+
Implement `compare_forward` when a functional has multiple implementations and
353+
needs cross-backend parity checks in tests.
354+
355+
Implement `make_inputs_backward` only for functionals with a meaningful
356+
backward pass (for example differentiable functionals). Implement
357+
`compare_backward` when a functional has backward support and multiple
358+
implementations that need backward parity checks.
359+
360+
Input generators should yield labeled inputs ordered from smaller to larger
361+
cases. Labels do not have to be exactly "small/medium/large", and you can
362+
provide more than three cases. Compare hooks should validate output
363+
consistency where implemented. Labels are used for benchmark plots and
364+
summaries.
316365

317366
**Rationale:**
318367

@@ -323,17 +372,30 @@ backends.
323372

324373
```python
325374
@classmethod
326-
def make_inputs(cls, device="cpu"):
375+
def make_inputs_forward(cls, device="cpu"):
327376
yield ("small", (torch.randn(1024, device=device),), {})
328377
yield ("medium", (torch.randn(4096, device=device),), {})
329378
yield ("large", (torch.randn(16384, device=device),), {})
379+
380+
@classmethod
381+
def make_inputs_backward(cls, device="cpu"):
382+
x = torch.randn(4096, device=device, requires_grad=True)
383+
yield ("medium", (x,), {})
384+
385+
@classmethod
386+
def compare_forward(cls, output, reference):
387+
torch.testing.assert_close(output, reference)
388+
389+
@classmethod
390+
def compare_backward(cls, output, reference):
391+
torch.testing.assert_close(output, reference)
330392
```
331393

332394
**Anti-pattern:**
333395

334396
```python
335397
@classmethod
336-
def make_inputs(cls, device="cpu"):
398+
def make_inputs_forward(cls, device="cpu"):
337399
pass
338400
```
339401

@@ -346,6 +408,37 @@ def make_inputs(cls, device="cpu"):
346408
Add tests under `test/nn/functional/` to validate selection, optional
347409
dependencies, and output correctness.
348410

411+
Use a consistent test layout when possible. This is **highly recommended** for
412+
readability and review speed, but it is **not strictly required** when a
413+
functional needs a different shape.
414+
415+
Baseline spec-contract tests (expected for every functional):
416+
417+
1. Backend and reference correctness:
418+
- `test_<functional_name>_<implementation_name>`
419+
2. Dispatch behavior (only when custom dispatch behavior exists):
420+
- `test_<functional_name>_dispatch_*`
421+
3. Benchmark-input contract:
422+
- `test_<functional_name>_make_inputs_forward`
423+
- `test_<functional_name>_make_inputs_backward` (only when backward is meaningful)
424+
4. Validation/deprecation path coverage:
425+
- `test_<functional_name>_error_handling` (when validation branches exist)
426+
427+
Cross-backend parity tests and compare-hook tests
428+
(required only when multiple implementations exist):
429+
430+
1. Forward parity:
431+
- `test_<functional_name>_backend_forward_parity`
432+
- `test_<functional_name>_compare_forward_contract`
433+
2. Backward parity:
434+
- `test_<functional_name>_backend_backward_parity` (only for differentiable ops)
435+
- `test_<functional_name>_compare_backward_contract` (only when backward is meaningful)
436+
437+
Where possible, keep all backend parity checks in one functional test file and
438+
use the functional's `compare_forward`/`compare_backward` hooks for consistency.
439+
For single-implementation functionals, `compare_forward`/`compare_backward`
440+
overrides and compare-hook contract tests are not required.
441+
349442
**Rationale:**
350443

351444
Functional APIs are public entry points and need coverage for both the API and
@@ -354,8 +447,20 @@ backend behavior.
354447
**Example:**
355448

356449
```python
357-
def test_knn_cpu():
358-
indices, distances = knn(points, queries, k=4)
450+
def test_grid_to_point_interpolation_torch():
451+
...
452+
453+
def test_grid_to_point_interpolation_warp():
454+
...
455+
456+
def test_grid_to_point_interpolation_backend_forward_parity():
457+
...
458+
459+
def test_grid_to_point_interpolation_backend_backward_parity():
460+
...
461+
462+
def test_grid_to_point_interpolation_error_handling():
463+
...
359464
```
360465

361466
**Anti-pattern:**
@@ -372,7 +477,8 @@ def test_knn_cpu():
372477

373478
Functionals that should be benchmarked must be added to
374479
`benchmarks/physicsnemo/nn/functional/registry.py`. Only add a functional once
375-
its `make_inputs` implementation yields labeled inputs.
480+
its input generators (`make_inputs_forward`, and optionally
481+
`make_inputs_backward`) yield labeled inputs.
376482

377483
**Rationale:**
378484

@@ -392,6 +498,6 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
392498
**Anti-pattern:**
393499

394500
```python
395-
# Adding a functional before make_inputs is implemented.
501+
# Adding a functional before input generators are implemented.
396502
FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
397503
```

0 commit comments

Comments
 (0)