Skip to content

Commit d1dd22b

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Internal testing.
PiperOrigin-RevId: 856420208
1 parent 77afe1f commit d1dd22b

6 files changed

Lines changed: 252 additions & 13 deletions

File tree

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/configs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ class MeshConfig:
5454
degree *across* slices (Data Center Network). This typically contains a
5555
single entry for the data-parallel axis.
5656
Example: {'data': 2}
57+
If None, an ordinary device mesh will be used, rather than a hybrid
58+
device mesh (intended for multi-replica workloads)
5759
allow_split_physical_axes: If True, we will split physical axes if
5860
necessary to produce the desired device mesh.
5961
process_is_granule: If True, treat processes as the units of the
6062
slower/outer network.
6163
"""
6264
mesh_axes: list[str]
6365
ici_parallelism: dict[str, int] = dataclasses.field(default_factory=dict)
64-
dcn_parallelism: dict[str, int] = dataclasses.field(default_factory=dict)
66+
dcn_parallelism: dict[str, int] | None = None
6567
allow_split_physical_axes: bool = False
6668
process_is_granule: bool = False

checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/device_mesh.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,19 @@ def create_mesh(config: configs.MeshConfig) -> jax.sharding.Mesh:
4242
num_devices = len(devices)
4343
# Convert the user-friendly dict maps into ordered lists based on mesh_axes
4444
ici_shape = [config.ici_parallelism.get(axis, 1) for axis in config.mesh_axes]
45-
dcn_shape = [config.dcn_parallelism.get(axis, 1) for axis in config.mesh_axes]
45+
46+
dcn_parallelism = config.dcn_parallelism
47+
if dcn_parallelism is None:
48+
logging.info('Creating ICI-only mesh.')
49+
devices_array = mesh_utils.create_device_mesh(ici_shape, devices)
50+
logging.info(
51+
'Creating mesh with axes: %s',
52+
{axis: dim for axis, dim in zip(config.mesh_axes, devices_array.shape)},
53+
)
54+
return jax.sharding.Mesh(devices_array, config.mesh_axes)
55+
else:
56+
logging.info('Creating hybrid mesh.')
57+
dcn_shape = [dcn_parallelism.get(axis, 1) for axis in config.mesh_axes]
4658

4759
# --- Validation ---
4860
if config.process_is_granule:
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Benchmarks for orbax.checkpoint.PyTreeCheckpointHandler."""
16+
17+
from __future__ import annotations
18+
19+
import dataclasses
20+
import pprint
21+
import time
22+
from typing import Any
23+
24+
from absl import logging
25+
import jax
26+
from jax.experimental import multihost_utils
27+
import numpy as np
28+
import orbax.checkpoint as ocp_v0 # pylint: disable=unused-import
29+
from orbax.checkpoint import v1 as ocp
30+
from orbax.checkpoint._src.testing.benchmarks.core import core as benchmarks_core
31+
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib
32+
import requests
33+
34+
35+
SERVICE_URL = "http://service-dns/"
36+
37+
38+
def _metrics_to_measure(options: LustreBenchmarkOptions) -> list[str]:
39+
"""Returns the list of metrics to measure."""
40+
del options
41+
metrics = ["time", "rss", "io"]
42+
return metrics
43+
44+
45+
# ==============================================================================
46+
# 1. Define the Options Dataclass for this specific benchmark
47+
# ==============================================================================
48+
@dataclasses.dataclass(frozen=True)
49+
class LustreBenchmarkOptions(benchmarks_core.BenchmarkOptions):
50+
"""Configuration options for benchmarks targeting PyTreeCheckpointHandler.
51+
52+
Each attribute can be a single value or a list of values to create
53+
a parameter sweep.
54+
55+
Attributes:
56+
use_ocdbt: Whether to use OCDBT for checkpointing.
57+
steps: Number of steps to run the benchmark for.
58+
"""
59+
60+
use_ocdbt: bool = True
61+
steps: int = 1
62+
63+
def is_valid(self):
64+
return True
65+
66+
67+
class StorageServiceClient:
68+
"""Docstring."""
69+
70+
def __init__(self, service_url: str | None = None):
71+
self._service_url = service_url or SERVICE_URL
72+
73+
def resolve(self, execution_id: int, step: int) -> str:
74+
"""Resolves an asset path from the service."""
75+
start = time.time()
76+
logging.info("Resolving ID-step: %s-%s.", execution_id, step)
77+
payload = {"execution_id": execution_id, "step": step}
78+
response = requests.post(f"{self._service_url}/resolve", json=payload)
79+
logging.info("Response: %s", response.json())
80+
response.raise_for_status()
81+
result = response.json()["path"]
82+
end = time.time()
83+
logging.info("Resolved %s in %s seconds.", result, end - start)
84+
return result
85+
86+
def finalize(self, execution_id: int, step: int) -> None:
87+
"""Finalizes an asset in the service."""
88+
start = time.time()
89+
payload = {"execution_id": execution_id, "step": step}
90+
response = requests.post(f"{self._service_url}/finalize", json=payload)
91+
response.raise_for_status()
92+
logging.info(response)
93+
# assert response.json()["status"] == "ok"
94+
end = time.time()
95+
logging.info(
96+
"Finalized %s %s in %s seconds.", execution_id, step, end - start
97+
)
98+
99+
100+
def _get_xid() -> int:
101+
"""Returns the XID for this run."""
102+
xid = multihost_utils.broadcast_one_to_all(
103+
np.asarray(int(time.time()))
104+
).item()
105+
logging.info("XID: %s", xid)
106+
return xid
107+
108+
109+
# ==============================================================================
110+
# 2. Implement the Benchmark Generator
111+
# ==============================================================================
112+
@benchmarks_core.benchmark_options(LustreBenchmarkOptions)
113+
class LustreBenchmark(benchmarks_core.BenchmarksGenerator):
114+
"""Docstring."""
115+
116+
def __init__(self, *args, **kwargs):
117+
super().__init__(*args, **kwargs)
118+
self._client = StorageServiceClient()
119+
self._xid = _get_xid()
120+
121+
def _clear_pytree(self, pytree: Any) -> Any:
122+
"""Clears the pytree to free up memory."""
123+
return jax.tree.map(
124+
lambda x: x.delete() if isinstance(x, jax.Array) else None, pytree
125+
)
126+
127+
def test_fn(
128+
self, context: benchmarks_core.TestContext
129+
) -> benchmarks_core.TestResult:
130+
"""The core test logic for a single save/restore cycle.
131+
132+
This function is called for each combination of options generated by the
133+
framework. It uses the `context.options` to configure the handler
134+
dynamically for each run.
135+
136+
Args:
137+
context: The test context containing the pytree, path, and options.
138+
139+
Returns:
140+
The test result containing the metrics.
141+
"""
142+
logging.info(
143+
"JAX info: %s processes, %s devices, %s process index",
144+
jax.process_count(),
145+
jax.device_count(),
146+
jax.process_index(),
147+
)
148+
metrics = metric_lib.Metrics()
149+
pytree = context.pytree
150+
options = context.options
151+
assert isinstance(options, LustreBenchmarkOptions)
152+
153+
logging.info("Benchmark options: %s", pprint.pformat(options))
154+
logging.info("Benchmark context: %s", pprint.pformat(context))
155+
metrics_to_measure = _metrics_to_measure(options)
156+
157+
for step in range(options.steps):
158+
logging.info("Benchmark step %d", step)
159+
160+
with metrics.measure("resolve_cache", metrics_to_measure):
161+
resolved_path = self._client.resolve(self._xid, step)
162+
with metrics.measure("save_cache", metrics_to_measure):
163+
ocp.save_pytree(resolved_path, pytree)
164+
with metrics.measure("finalize_cache", metrics_to_measure):
165+
self._client.finalize(self._xid, step)
166+
with metrics.measure("restore_cache", metrics_to_measure):
167+
restored_pytree = ocp.load_pytree(resolved_path, pytree)
168+
self._clear_pytree(restored_pytree)
169+
170+
with metrics.measure("save", metrics_to_measure):
171+
ocp.save_pytree(context.path / str(step), pytree)
172+
with metrics.measure("restore", metrics_to_measure):
173+
restored_pytree = ocp.load_pytree(context.path / str(step), pytree)
174+
self._clear_pytree(restored_pytree)
175+
176+
return benchmarks_core.TestResult(metrics=metrics)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
17+
from orbax.checkpoint._src.testing import multiprocess_test
18+
from orbax.checkpoint._src.testing.benchmarks import lustre_benchmark
19+
20+
21+
class LustreBenchmarkTest(multiprocess_test.MultiProcessTest):
22+
23+
def test_xid(self):
24+
xid = lustre_benchmark._get_xid()
25+
self.assertIsInstance(xid, int)
26+
logging.info('XID: %s', xid)
27+
28+
29+
if __name__ == '__main__':
30+
multiprocess_test.main()

checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ JAX on TPU.
160160
| `--jax-version` | `newest` | `newest`, `nightly`, or `x.y.z`. | **Debugging**. Use `nightly` to test bleeding-edge JAX features. |
161161
| `--device` | `tpu` | `tpu`, `gpu`, `cpu`. | **Multi-Hardware**. When testing on GPU or CP/Local validation. |
162162
| `--base-image` | `python:3.11...` | Base Docker Image. | **Advanced**. If you need custom drivers or non-standard OS libs. |
163+
| `--no-cache` | `N/A` | Disable Docker build cache for all layers. | **Debugging**. Force rebuild of all layers from scratch. |
163164

164165
---
165166
<!-- LINT.ThenChange(build_image.sh:build_image_flags) -->

checkpoint/orbax/checkpoint/_src/testing/benchmarks/xpk/build_image.sh

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,22 @@ BRANCH="main"
1212
JAX_VERSION="newest"
1313
DEVICE="tpu"
1414
BASE_IMAGE=""
15+
DOCKERFILE_PATH=""
16+
NO_CACHE_FLAG=""
1517

1618
function print_usage() {
1719
echo "Usage: $0 [OPTIONS]"
18-
echo "Options:"
20+
echo "Options:"+
1921
echo " --project PROJECT_ID GCP Project ID"
2022
echo " --pr PR_NUMBER GitHub PR number"
23+
echo " --image IMAGE_NAME Image name (default: orbax-benchmarks)"
2124
echo " --branch BRANCH GitHub branch (default: main)"
2225
echo " --jax-version VERSION JAX version: newest, nightly, or X.Y.Z (default: newest)"
2326
echo " --device DEVICE Device type: tpu, gpu, cpu (default: tpu)"
2427
echo " --base-image IMAGE Base Docker image (optional)"
28+
echo " --dockerfile FILE Dockerfile path (optional)"
2529
echo " --tag TAG Image tag"
30+
echo " --no-cache Disable Docker build cache"
2631
echo " --help Show this help"
2732
}
2833

@@ -32,11 +37,14 @@ while [[ $# -gt 0 ]]; do
3237
case $1 in
3338
--project) PROJECT_ID="$2"; shift 2 ;;
3439
--pr) PR_NUMBER="$2"; shift 2 ;;
40+
--image) IMAGE_NAME="$2"; shift 2 ;;
3541
--branch) BRANCH="$2"; shift 2 ;;
3642
--jax-version) JAX_VERSION="$2"; shift 2 ;;
3743
--device) DEVICE="$2"; shift 2 ;;
3844
--base-image) BASE_IMAGE="$2"; shift 2 ;;
45+
--dockerfile) DOCKERFILE_PATH="$2"; shift 2 ;;
3946
--tag) USER_TAG_FLAG="$2"; shift 2 ;;
47+
--no-cache) NO_CACHE_FLAG="--no-cache"; shift 1 ;;
4048
--help) print_usage; exit 0 ;;
4149
*) echo "Unknown argument: $1"; print_usage; exit 1 ;;
4250
esac
@@ -54,7 +62,9 @@ if [[ -z "$BASE_IMAGE" ]]; then
5462
fi
5563

5664
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
57-
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile"
65+
if [[ -z "$DOCKERFILE_PATH" ]]; then
66+
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile"
67+
fi
5868

5969
if [[ ! -f "$DOCKERFILE_PATH" ]]; then
6070
# Fallback: check if we are running in the source dir
@@ -110,15 +120,23 @@ done
110120

111121
# Build with local Docker
112122
echo "Building with previously installed Docker..."
113-
docker build \
114-
--build-arg BASE_IMAGE="${BASE_IMAGE}" \
115-
--build-arg BRANCH="${BRANCH}" \
116-
--build-arg JAX_VERSION="${JAX_VERSION}" \
117-
--build-arg DEVICE="${DEVICE}" \
118-
--build-arg PR_NUMBER="${PR_NUMBER}" \
119-
"${build_tag_args[@]}" \
120-
-f "${DOCKERFILE_PATH}" \
121-
.
123+
declare -a build_args=()
124+
if [[ -n "${NO_CACHE_FLAG}" ]]; then
125+
build_args+=("${NO_CACHE_FLAG}")
126+
fi
127+
build_args+=(
128+
"--build-arg" "BASE_IMAGE=${BASE_IMAGE}"
129+
"--build-arg" "BRANCH=${BRANCH}"
130+
"--build-arg" "JAX_VERSION=${JAX_VERSION}"
131+
"--build-arg" "DEVICE=${DEVICE}"
132+
"--build-arg" "PR_NUMBER=${PR_NUMBER}"
133+
)
134+
build_args+=("${build_tag_args[@]}")
135+
build_args+=(
136+
"-f" "${DOCKERFILE_PATH}"
137+
"."
138+
)
139+
docker build "${build_args[@]}"
122140

123141
echo "Pushing image to registry..."
124142
for t in "${tags[@]}"; do

0 commit comments

Comments
 (0)