@@ -58,7 +58,7 @@ This document is structured in two main sections:
5858| [ ` FNC-002 ` ] ( #fnc-002-file-layout-for-functionals ) | File layout for functionals | Adding or refactoring functional files |
5959| [ ` FNC-003 ` ] ( #fnc-003-registration-and-dispatch-rules ) | Registration and dispatch rules | Registering implementations |
6060| [ ` FNC-004 ` ] ( #fnc-004-optional-dependency-handling ) | Optional dependency handling | Using optional backends |
61- | [ ` FNC-005 ` ] ( #fnc-005-benchmarking-hooks ) | Benchmarking hooks | Implementing ` make_inputs ` /` compare ` |
61+ | [ ` FNC-005 ` ] ( #fnc-005-benchmarking-hooks ) | Benchmarking hooks | Implementing ` make_inputs_forward ` /` make_inputs_backward ` / ` compare_forward ` |
6262| [ ` FNC-006 ` ] ( #fnc-006-testing-functionals ) | Testing functionals | Adding functional tests |
6363| [ ` FNC-007 ` ] ( #fnc-007-benchmark-registry ) | Benchmark registry | Adding a functional to ASV |
6464
@@ -72,7 +72,8 @@ This document is structured in two main sections:
7272
7373All functionals must be implemented with ` FunctionSpec ` , even if only a single
7474implementation exists. This ensures the operation participates in validation
75- and benchmarking via ` make_inputs ` and ` compare ` .
75+ and benchmarking through input generators and ` compare_forward ` (and
76+ ` compare_backward ` where needed).
7677
7778** Rationale:**
7879
@@ -82,52 +83,41 @@ selection, benchmarking and verification across the codebase.
8283** Example:**
8384
8485``` python
85- import importlib
8686import torch
87+ import warp as wp
8788
8889from physicsnemo.core.function_spec import FunctionSpec
89- from physicsnemo.core.version_check import check_version_spec
90-
91- WARP_AVAILABLE = check_version_spec(" warp" , " 0.6.0" , hard_fail = False )
92-
93- if WARP_AVAILABLE :
94- wp = importlib.import_module(" warp" )
95- wp.init()
96- wp.config.quiet = True
97-
98- @wp.kernel
99- def _identity_kernel (
100- x : wp.array(dtype = wp.float32),
101- y : wp.array(dtype = wp.float32),
102- ):
103- i = wp.tid()
104- y[i] = x[i]
105-
106- @torch.library.custom_op (" physicsnemo::identity_warp" , mutates_args = ())
107- def identity_impl (x : torch.Tensor) -> torch.Tensor:
108- out = torch.empty_like(x)
109- device, stream = FunctionSpec.warp_launch_context(x)
110- wp_x = wp.from_torch(x, dtype = wp.float32, return_ctype = True )
111- wp_y = wp.from_torch(out, dtype = wp.float32, return_ctype = True )
112- with wp.ScopedStream(stream):
113- wp.launch(
114- kernel = _identity_kernel,
115- dim = x.numel(),
116- inputs = [wp_x, wp_y],
117- device = device,
118- stream = stream,
119- )
120- return out
121-
122- @identity_impl.register_fake
123- def identity_impl_fake (x : torch.Tensor) -> torch.Tensor:
124- return torch.empty_like(x)
125- else :
12690
127- def identity_impl (* args , ** kwargs ) -> torch.Tensor:
128- raise ImportError (
129- " warp>=0.6.0 is required for the Warp identity implementation"
91+ wp.init()
92+ wp.config.quiet = True
93+
94+ @wp.kernel
95+ def _identity_kernel (
96+ x : wp.array(dtype = wp.float32),
97+ y : wp.array(dtype = wp.float32),
98+ ):
99+ i = wp.tid()
100+ y[i] = x[i]
101+
102+ @torch.library.custom_op (" physicsnemo::identity_warp" , mutates_args = ())
103+ def identity_impl (x : torch.Tensor) -> torch.Tensor:
104+ out = torch.empty_like(x)
105+ device, stream = FunctionSpec.warp_launch_context(x)
106+ wp_x = wp.from_torch(x, dtype = wp.float32, return_ctype = True )
107+ wp_y = wp.from_torch(out, dtype = wp.float32, return_ctype = True )
108+ with wp.ScopedStream(stream):
109+ wp.launch(
110+ kernel = _identity_kernel,
111+ dim = x.numel(),
112+ inputs = [wp_x, wp_y],
113+ device = device,
114+ stream = stream,
130115 )
116+ return out
117+
118+ @identity_impl.register_fake
119+ def identity_impl_fake (x : torch.Tensor) -> torch.Tensor:
120+ return torch.empty_like(x)
131121
132122def identity_torch (x : torch.Tensor) -> torch.Tensor:
133123 return x.clone()
@@ -148,14 +138,41 @@ class Identity(FunctionSpec):
148138 return identity_torch(x)
149139
150140 @ classmethod
151- def make_inputs (cls , device : torch.device | str = " cpu" ):
141+ def make_inputs_forward (cls , device : torch.device | str = " cpu" ):
152142 device = torch.device(device)
153143 yield (" small" , (torch.randn(1024 , device = device),), {})
154144 yield (" medium" , (torch.randn(4096 , device = device),), {})
155145 yield (" large" , (torch.randn(16384 , device = device),), {})
156146
157147 @ classmethod
158- def compare (cls , output : torch.Tensor, reference : torch.Tensor) -> None :
148+ def make_inputs_backward (cls , device : torch.device | str = " cpu" ):
149+ device = torch.device(device)
150+ yield (
151+ " small-bwd" ,
152+ (torch.randn(1024 , device = device, requires_grad = True ),),
153+ {},
154+ )
155+ yield (
156+ " medium-bwd" ,
157+ (torch.randn(4096 , device = device, requires_grad = True ),),
158+ {},
159+ )
160+ yield (
161+ " large-bwd" ,
162+ (torch.randn(16384 , device = device, requires_grad = True ),),
163+ {},
164+ )
165+
166+ @ classmethod
167+ def compare_forward (
168+ cls , output : torch.Tensor, reference : torch.Tensor
169+ ) -> None :
170+ torch.testing.assert_close(output, reference)
171+
172+ @ classmethod
173+ def compare_backward (
174+ cls , output : torch.Tensor, reference : torch.Tensor
175+ ) -> None :
159176 torch.testing.assert_close(output, reference)
160177
161178identity = Identity.make_function(" identity" )
@@ -210,6 +227,14 @@ __all__ = ["knn"]
210227 ` physicsnemo/nn/functional/<name>/ ` .
211228 - Keep each backend in its own module (e.g., ` _torch_impl.py ` ).
212229 - Keep shared helpers in ` utils.py ` .
230+ - For complex Warp backends, prefer a dedicated ` _warp_impl/ ` package with:
231+ - ` op.py ` for torch custom-op registration and validation
232+ - ` launch_forward.py ` for forward launch dispatch
233+ - ` launch_backward.py ` for backward launch dispatch
234+ - ` _kernels/ ` with one kernel per file
235+ - ` utils.py ` for shared Warp constants/functions
236+ - Keep ` launch_forward.py ` and ` launch_backward.py ` as the only launch
237+ surfaces; avoid extra launch helper modules unless there is a strong reason.
213238
214239** Rationale:**
215240
@@ -228,6 +253,21 @@ physicsnemo/nn/functional/knn/
228253 utils.py
229254```
230255
256+ ``` text
257+ physicsnemo/nn/functional/interpolation/grid_to_point_interpolation/
258+ grid_to_point_interpolation.py
259+ _torch_impl.py
260+ _warp_impl/
261+ __init__.py
262+ op.py
263+ launch_forward.py
264+ launch_backward.py
265+ _kernels/
266+ forward_3d_stride2.py
267+ backward_3d_stride2.py
268+ utils.py
269+ ```
270+
231271** Anti-pattern:**
232272
233273``` text
@@ -308,11 +348,20 @@ import missing_dep # raises at import time
308348
309349** Description:**
310350
311- Implement ` make_inputs ` and ` compare ` for every functional. ` make_inputs ` should
312- yield labeled inputs ordered from smaller to larger cases. Labels do not have to
313- be exactly "small/medium/large", and you can provide more than three cases.
314- ` compare ` should validate output consistency. Labels are used for benchmark
315- plots and summaries.
351+ Implement ` make_inputs_forward ` for every functional so it can be benchmarked.
352+ Implement ` compare_forward ` when a functional has multiple implementations and
353+ needs cross-backend parity checks in tests.
354+
355+ Implement ` make_inputs_backward ` only for functionals with a meaningful
356+ backward pass (for example differentiable functionals). Implement
357+ ` compare_backward ` when a functional has backward support and multiple
358+ implementations that need backward parity checks.
359+
360+ Input generators should yield labeled inputs ordered from smaller to larger
361+ cases. Labels do not have to be exactly "small/medium/large", and you can
362+ provide more than three cases. Compare hooks should validate output
363+ consistency where implemented. Labels are used for benchmark plots and
364+ summaries.
316365
317366** Rationale:**
318367
@@ -323,17 +372,30 @@ backends.
323372
324373``` python
325374@ classmethod
326- def make_inputs (cls , device = " cpu" ):
375+ def make_inputs_forward (cls , device = " cpu" ):
327376 yield (" small" , (torch.randn(1024 , device = device),), {})
328377 yield (" medium" , (torch.randn(4096 , device = device),), {})
329378 yield (" large" , (torch.randn(16384 , device = device),), {})
379+
380+ @ classmethod
381+ def make_inputs_backward (cls , device = " cpu" ):
382+ x = torch.randn(4096 , device = device, requires_grad = True )
383+ yield (" medium" , (x,), {})
384+
385+ @ classmethod
386+ def compare_forward (cls , output , reference ):
387+ torch.testing.assert_close(output, reference)
388+
389+ @ classmethod
390+ def compare_backward (cls , output , reference ):
391+ torch.testing.assert_close(output, reference)
330392```
331393
332394** Anti-pattern:**
333395
334396``` python
335397@ classmethod
336- def make_inputs (cls , device = " cpu" ):
398+ def make_inputs_forward (cls , device = " cpu" ):
337399 pass
338400```
339401
@@ -346,6 +408,37 @@ def make_inputs(cls, device="cpu"):
346408Add tests under ` test/nn/functional/ ` to validate selection, optional
347409dependencies, and output correctness.
348410
411+ Use a consistent test layout when possible. This is ** highly recommended** for
412+ readability and review speed, but it is ** not strictly required** when a
413+ functional needs a different shape.
414+
415+ Baseline spec-contract tests (expected for every functional):
416+
417+ 1 . Backend and reference correctness:
418+ - ` test_<functional_name>_<implementation_name> `
419+ 2 . Dispatch behavior (only when custom dispatch behavior exists):
420+ - ` test_<functional_name>_dispatch_* `
421+ 3 . Benchmark-input contract:
422+ - ` test_<functional_name>_make_inputs_forward `
423+ - ` test_<functional_name>_make_inputs_backward ` (only when backward is meaningful)
424+ 4 . Validation/deprecation path coverage:
425+ - ` test_<functional_name>_error_handling ` (when validation branches exist)
426+
427+ Cross-backend parity tests and compare-hook tests
428+ (required only when multiple implementations exist):
429+
430+ 1 . Forward parity:
431+ - ` test_<functional_name>_backend_forward_parity `
432+ - ` test_<functional_name>_compare_forward_contract `
433+ 2 . Backward parity:
434+ - ` test_<functional_name>_backend_backward_parity ` (only for differentiable ops)
435+ - ` test_<functional_name>_compare_backward_contract ` (only when backward is meaningful)
436+
437+ Where possible, keep all backend parity checks in one functional test file and
438+ use the functional's ` compare_forward ` /` compare_backward ` hooks for consistency.
439+ For single-implementation functionals, ` compare_forward ` /` compare_backward `
440+ overrides and compare-hook contract tests are not required.
441+
349442** Rationale:**
350443
351444Functional APIs are public entry points and need coverage for both the API and
@@ -354,8 +447,20 @@ backend behavior.
354447** Example:**
355448
356449``` python
357- def test_knn_cpu ():
358- indices, distances = knn(points, queries, k = 4 )
450+ def test_grid_to_point_interpolation_torch ():
451+ ...
452+
453+ def test_grid_to_point_interpolation_warp ():
454+ ...
455+
456+ def test_grid_to_point_interpolation_backend_forward_parity ():
457+ ...
458+
459+ def test_grid_to_point_interpolation_backend_backward_parity ():
460+ ...
461+
462+ def test_grid_to_point_interpolation_error_handling ():
463+ ...
359464```
360465
361466** Anti-pattern:**
@@ -372,7 +477,8 @@ def test_knn_cpu():
372477
373478Functionals that should be benchmarked must be added to
374479` benchmarks/physicsnemo/nn/functional/registry.py ` . Only add a functional once
375- its ` make_inputs ` implementation yields labeled inputs.
480+ its input generators (` make_inputs_forward ` , and optionally
481+ ` make_inputs_backward ` ) yield labeled inputs.
376482
377483** Rationale:**
378484
@@ -392,6 +498,6 @@ FUNCTIONAL_SPECS = (KNN, RadiusSearch)
392498** Anti-pattern:**
393499
394500``` python
395- # Adding a functional before make_inputs is implemented.
501+ # Adding a functional before input generators are implemented.
396502FUNCTIONAL_SPECS = (MyFunctionalWithoutInputs,)
397503```
0 commit comments