Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a7276fb
Refactor functional module layout into category packages
loliverhennigh Mar 18, 2026
6206787
Merge branch 'main' into pr1-functional-layout
loliverhennigh Mar 19, 2026
0e40da0
Fix functional exports to match files in geometry split
loliverhennigh Mar 19, 2026
4220e78
Fix functional exports to match files in geometry split
loliverhennigh Mar 19, 2026
7ec27f5
Fix interpolation exports for pr1 functional layout
loliverhennigh Mar 19, 2026
7999210
Fix interpolation compat alias target in pr1
loliverhennigh Mar 19, 2026
f947e26
Merge remote-tracking branch 'upstream/main' into pr1-merge-sync
loliverhennigh Mar 19, 2026
7026e31
Update legacy functional tests for neighbors package layout
loliverhennigh Mar 19, 2026
ebb206a
Apply ruff formatting fixes for functional namespace
loliverhennigh Mar 19, 2026
0433212
Document functional deep-import compatibility impact
loliverhennigh Mar 19, 2026
9259cd5
Add jaxtyping hints to radius_search forward implementations
loliverhennigh Mar 19, 2026
097a364
Add jaxtyping annotations across FunctionSpec functionals
loliverhennigh Mar 19, 2026
ef435cc
Merge branch 'main' into pr1-functional-layout
loliverhennigh Mar 20, 2026
728b50e
Merge upstream/main into pr1-functional-layout
loliverhennigh Mar 20, 2026
b55d416
Refactor functional layout, interpolation split, and parity/benchmark…
loliverhennigh Mar 20, 2026
7cdfa4e
Merge origin/pr1-functional-layout into pr1-functional-layout
loliverhennigh Mar 20, 2026
846507b
Revert "Merge origin/pr1-functional-layout into pr1-functional-layout"
loliverhennigh Mar 20, 2026
8bd1d20
Drop unintended stormcast changes from pr1-functional-layout
loliverhennigh Mar 20, 2026
c37b6c4
Align mesh sampling sample_data.py with upstream/main
loliverhennigh Mar 23, 2026
ba1b9c2
Fix mesh sampling knn import path for functional layout split
loliverhennigh Mar 23, 2026
f60bbd7
Restore UV-sphere SDF benchmark input generation
loliverhennigh Mar 23, 2026
28a36ba
Align import-linter module layers with upstream main
loliverhennigh Mar 23, 2026
c6299b8
Restore nearest-cells changelog entry
loliverhennigh Mar 23, 2026
23272c7
Fix legacy nn.functional deep-import paths after layout split
loliverhennigh Mar 23, 2026
2a2d2de
Merge upstream/main into pr1-functional-layout
loliverhennigh Mar 23, 2026
9444a9d
Port functional runtime, FunctionSpec, and tests from split branch
loliverhennigh Mar 23, 2026
2d38494
Drop out-of-scope derivative and ball-pivoting files from runtime split
loliverhennigh Mar 23, 2026
17fe540
Retarget branch to docs/media and coding standards only
loliverhennigh Mar 24, 2026
4027226
Remove docs media generation scripts from PR
loliverhennigh Mar 24, 2026
c19d23e
Merge remote-tracking branch 'upstream/main' into pr2-functionals-run…
loliverhennigh Mar 24, 2026
69350f2
Port benchmark/runtime API refactor and migrate functional tests
loliverhennigh Mar 24, 2026
fad450a
Remove docs for out-of-scope functionals
loliverhennigh Mar 24, 2026
952e90a
Trim docs media assets from PR2 scope
loliverhennigh Mar 24, 2026
074adeb
Align functional tests with naming/layout conventions
loliverhennigh Mar 25, 2026
5d29cae
Add functional spec-contract compare tests and guideline
loliverhennigh Mar 25, 2026
a800d6e
Require compare hooks only for multi-backend functionals
loliverhennigh Mar 25, 2026
18e2505
Normalize functional test ordering and dispatch guidance
loliverhennigh Mar 25, 2026
b0d93ea
Merge branch 'main' into pr2-functionals-runtime-tests
loliverhennigh Mar 25, 2026
750d15b
Update CODING_STANDARDS/FUNCTIONAL_APIS.md
loliverhennigh Mar 25, 2026
687a897
Update CODING_STANDARDS/FUNCTIONAL_APIS.md
loliverhennigh Mar 25, 2026
539f24e
Keep functional benchmark plot output backward-compatible
loliverhennigh Mar 25, 2026
711b301
Use docs/img as canonical functional benchmark output root
loliverhennigh Mar 25, 2026
7d48f8d
Use existing benchmark.png assets in functional docs
loliverhennigh Mar 25, 2026
640c7cc
Revert unintended interpolation docs page rename
loliverhennigh Mar 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 162 additions & 56 deletions CODING_STANDARDS/FUNCTIONAL_APIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ This document is structured in two main sections:
| [`FNC-002`](#fnc-002-file-layout-for-functionals) | File layout for functionals | Adding or refactoring functional files |
| [`FNC-003`](#fnc-003-registration-and-dispatch-rules) | Registration and dispatch rules | Registering implementations |
| [`FNC-004`](#fnc-004-optional-dependency-handling) | Optional dependency handling | Using optional backends |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs`/`compare` |
| [`FNC-005`](#fnc-005-benchmarking-hooks) | Benchmarking hooks | Implementing `make_inputs_forward`/`make_inputs_backward`/`compare_forward` |
| [`FNC-006`](#fnc-006-testing-functionals) | Testing functionals | Adding functional tests |
| [`FNC-007`](#fnc-007-benchmark-registry) | Benchmark registry | Adding a functional to ASV |

Expand All @@ -72,7 +72,8 @@ This document is structured in two main sections:

All functionals must be implemented with `FunctionSpec`, even if only a single
implementation exists. This ensures the operation participates in validation
and benchmarking via `make_inputs` and `compare`.
and benchmarking through input generators and `compare_forward` (and
`compare_backward` where needed).

**Rationale:**

Expand All @@ -82,52 +83,41 @@ selection, benchmarking and verification across the codebase.
**Example:**

```python
import importlib
import torch
import warp as wp

from physicsnemo.core.function_spec import FunctionSpec
from physicsnemo.core.version_check import check_version_spec

WARP_AVAILABLE = check_version_spec("warp", "0.6.0", hard_fail=False)

if WARP_AVAILABLE:
wp = importlib.import_module("warp")
wp.init()
wp.config.quiet = True

@wp.kernel
def _identity_kernel(
x: wp.array(dtype=wp.float32),
y: wp.array(dtype=wp.float32),
):
i = wp.tid()
y[i] = x[i]

@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
def identity_impl(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
device, stream = FunctionSpec.warp_launch_context(x)
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
with wp.ScopedStream(stream):
wp.launch(
kernel=_identity_kernel,
dim=x.numel(),
inputs=[wp_x, wp_y],
device=device,
stream=stream,
)
return out

@identity_impl.register_fake
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)
else:

def identity_impl(*args, **kwargs) -> torch.Tensor:
raise ImportError(
"warp>=0.6.0 is required for the Warp identity implementation"
wp.init()
wp.config.quiet = True

@wp.kernel
def _identity_kernel(
x: wp.array(dtype=wp.float32),
y: wp.array(dtype=wp.float32),
):
i = wp.tid()
y[i] = x[i]

@torch.library.custom_op("physicsnemo::identity_warp", mutates_args=())
def identity_impl(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
device, stream = FunctionSpec.warp_launch_context(x)
wp_x = wp.from_torch(x, dtype=wp.float32, return_ctype=True)
wp_y = wp.from_torch(out, dtype=wp.float32, return_ctype=True)
with wp.ScopedStream(stream):
wp.launch(
kernel=_identity_kernel,
dim=x.numel(),
inputs=[wp_x, wp_y],
device=device,
stream=stream,
)
return out

@identity_impl.register_fake
def identity_impl_fake(x: torch.Tensor) -> torch.Tensor:
return torch.empty_like(x)

def identity_torch(x: torch.Tensor) -> torch.Tensor:
return x.clone()
Expand All @@ -148,14 +138,41 @@ class Identity(FunctionSpec):
return identity_torch(x)

@classmethod
def make_inputs(cls, device: torch.device | str = "cpu"):
def make_inputs_forward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def compare(cls, output: torch.Tensor, reference: torch.Tensor) -> None:
def make_inputs_backward(cls, device: torch.device | str = "cpu"):
device = torch.device(device)
yield (
"small-bwd",
(torch.randn(1024, device=device, requires_grad=True),),
{},
)
yield (
"medium-bwd",
(torch.randn(4096, device=device, requires_grad=True),),
{},
)
yield (
"large-bwd",
(torch.randn(16384, device=device, requires_grad=True),),
{},
)

@classmethod
def compare_forward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(
cls, output: torch.Tensor, reference: torch.Tensor
) -> None:
torch.testing.assert_close(output, reference)

identity = Identity.make_function("identity")
Expand Down Expand Up @@ -210,6 +227,14 @@ __all__ = ["knn"]
`physicsnemo/nn/functional/<name>/`.
- Keep each backend in its own module (e.g., `_torch_impl.py`).
- Keep shared helpers in `utils.py`.
- For complex Warp backends, prefer a dedicated `_warp_impl/` package with:
- `op.py` for torch custom-op registration and validation
- `launch_forward.py` for forward launch dispatch
- `launch_backward.py` for backward launch dispatch
- `_kernels/` with one kernel per file
- `utils.py` for shared Warp constants/functions
- Keep `launch_forward.py` and `launch_backward.py` as the only launch
surfaces; avoid extra launch helper modules unless there is a strong reason.

**Rationale:**

Expand All @@ -228,6 +253,21 @@ physicsnemo/nn/functional/knn/
utils.py
```

```text
physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/
grid_to_point_interpolation.py
_torch_impl.py
_warp_impl/
__init__.py
op.py
launch_forward.py
launch_backward.py
_kernels/
forward_3d_stride2.py
backward_3d_stride2.py
utils.py
```

**Anti-pattern:**

```text
Expand Down Expand Up @@ -308,11 +348,20 @@ import missing_dep # raises at import time

**Description:**

Implement `make_inputs` and `compare` for every functional. `make_inputs` should
yield labeled inputs ordered from smaller to larger cases. Labels do not have to
be exactly "small/medium/large", and you can provide more than three cases.
`compare` should validate output consistency. Labels are used for benchmark
plots and summaries.
Implement `make_inputs_forward` for every functional so it can be benchmarked.
Implement `compare_forward` when a functional has multiple implementations and
needs cross-backend parity checks in tests.

Implement `make_inputs_backward` only for functionals with a meaningful
backward pass (for example differentiable functionals). Implement
`compare_backward` when a functional has backward support and multiple
implementations that need backward parity checks.

Input generators should yield labeled inputs ordered from smaller to larger
cases. Labels do not have to be exactly "small/medium/large", and you can
provide more than three cases. Compare hooks should validate output
consistency where implemented. Labels are used for benchmark plots and
summaries.

**Rationale:**

Expand All @@ -323,17 +372,30 @@ backends.

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
yield ("small", (torch.randn(1024, device=device),), {})
yield ("medium", (torch.randn(4096, device=device),), {})
yield ("large", (torch.randn(16384, device=device),), {})

@classmethod
def make_inputs_backward(cls, device="cpu"):
x = torch.randn(4096, device=device, requires_grad=True)
yield ("medium", (x,), {})

@classmethod
def compare_forward(cls, output, reference):
torch.testing.assert_close(output, reference)

@classmethod
def compare_backward(cls, output, reference):
torch.testing.assert_close(output, reference)
```

**Anti-pattern:**

```python
@classmethod
def make_inputs(cls, device="cpu"):
def make_inputs_forward(cls, device="cpu"):
pass
```

Expand All @@ -346,6 +408,37 @@ def make_inputs(cls, device="cpu"):
Add tests under `test/nn/functional/` to validate selection, optional
dependencies, and output correctness.

Use a consistent test layout when possible. This is **highly recommended** for
readability and review speed, but it is **not strictly required** when a
functional needs a different shape.

Baseline spec-contract tests (expected for every functional):

1. Backend and reference correctness:
- `test_<functional_name>_<implementation_name>`
2. Dispatch behavior (only when custom dispatch behavior exists):
- `test_<functional_name>_dispatch_*`
3. Benchmark-input contract:
- `test_<functional_name>_make_inputs_forward`
- `test_<functional_name>_make_inputs_backward` (only when backward is meaningful)
4. Validation/deprecation path coverage:
- `test_<functional_name>_error_handling` (when validation branches exist)

Cross-backend parity tests and compare-hook tests
(required only when multiple implementations exist):

1. Forward parity:
- `test_<functional_name>_backend_forward_parity`
- `test_<functional_name>_compare_forward_contract`
2. Backward parity:
- `test_<functional_name>_backend_backward_parity` (only for differentiable ops)
- `test_<functional_name>_compare_backward_contract` (only when backward is meaningful)

Where possible, keep all backend parity checks in one functional test file and
use the functional's `compare_forward`/`compare_backward` hooks for consistency.
For single-implementation functionals, `compare_forward`/`compare_backward`
overrides and compare-hook contract tests are not required.

**Rationale:**

Functional APIs are public entry points and need coverage for both the API and
Expand All @@ -354,8 +447,20 @@ backend behavior.
**Example:**

```python
def test_knn_cpu():
indices, distances = knn(points, queries, k=4)
def test_grid_to_point_interpolation_torch():
...

def test_grid_to_point_interpolation_warp():
...

def test_grid_to_point_interpolation_backend_forward_parity():
...

def test_grid_to_point_interpolation_backend_backward_parity():
...

def test_grid_to_point_interpolation_error_handling():
...
```

**Anti-pattern:**
Expand All @@ -372,7 +477,8 @@ def test_knn_cpu():

Functionals that should be benchmarked must be added to
`benchmarks/physicsnemo/nn/functional/registry.py`. Only add a functional once
its `make_inputs` implementation yields labeled inputs.
its input generators (`make_inputs_forward`, and optionally
`make_inputs_backward`) yield labeled inputs.

**Rationale:**

Expand All @@ -392,6 +498,6 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
**Anti-pattern:**

```python
# Adding a functional before make_inputs is implemented.
# Adding a functional before input generators are implemented.
FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
```
Loading
Loading