diff --git a/.github/workflows/actions-hygiene.yml b/.github/workflows/actions-hygiene.yml new file mode 100644 index 0000000000..e34a26af97 --- /dev/null +++ b/.github/workflows/actions-hygiene.yml @@ -0,0 +1,60 @@ +name: GitHub Actions hygiene + +on: + push: + branches: [main, next] + paths: + - ".github/workflows/**" + - ".github/actions/**" + - ".github/actionlint.yaml" + pull_request: + types: [opened, reopened, synchronize] + paths: + - ".github/workflows/**" + - ".github/actions/**" + - ".github/actionlint.yaml" + workflow_dispatch: + +permissions: {} + +jobs: + actionlint: + name: Run actionlint + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # pin@v4 + with: + persist-credentials: false + + - name: Set up Go + uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # pin@v6 + with: + go-version: "1.25.x" + + - name: Install actionlint + run: go install github.com/rhysd/actionlint/cmd/actionlint@v1.7.12 + + - name: Run actionlint + run: | + "$(go env GOPATH)/bin/actionlint" -config-file .github/actionlint.yaml + + zizmor: + name: Run zizmor + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + steps: + - name: Checkout repository + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # pin@v4 + with: + persist-credentials: false + + - name: Run zizmor + uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # pin@v0.5.2 + with: + advanced-security: false + version: 1.23.1 diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 97aa8b6135..8d67c8264a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -56,3 +56,28 @@ jobs: rustup default ${{ matrix.toolchain }} rustup target add wasm32-wasip2 make build-target-miden + + fuzz-check: + name: Check fuzz crates + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # pin@v4 + with: + persist-credentials: false + - name: Cleanup large tools for build space + uses: ./.github/actions/cleanup-runner + - name: Install Rust toolchain + run: | + rustup toolchain install --no-self-update + rustup toolchain install --no-self-update nightly + - uses: WarpBuilds/rust-cache@9d0cc3090d9c87de74ea67617b246e978735b1a1 # pin@v2 + with: + # Fuzz crates are separate workspaces, so cache them explicitly. + workspaces: | + . + miden-crypto-fuzz + miden-serde-utils/fuzz + - name: Install cargo-fuzz + run: cargo +nightly install cargo-fuzz --version 0.13.1 + - name: Check fuzz crates + run: make check-fuzz diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml index eaf1a6f177..88097bc587 100644 --- a/.github/workflows/fuzz.yml +++ b/.github/workflows/fuzz.yml @@ -33,12 +33,11 @@ jobs: persist-credentials: false - name: Cleanup large tools for build space uses: ./.github/actions/cleanup-runner - - uses: dtolnay/rust-toolchain@5b842231ba77f5c045dba54ac5560fed2db780e2 # pin@nightly - with: - toolchain: nightly - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - name: Install Rust toolchain + run: rustup update --no-self-update nightly + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 - name: Install cargo-fuzz - run: cargo +nightly install cargo-fuzz --locked + run: cargo +nightly install cargo-fuzz --version 0.13.1 - name: Run fuzz target (smoke test) working-directory: miden-serde-utils run: | @@ -59,12 +58,11 @@ jobs: persist-credentials: false - name: Cleanup large tools for build space uses: ./.github/actions/cleanup-runner - - uses: dtolnay/rust-toolchain@5b842231ba77f5c045dba54ac5560fed2db780e2 # pin@nightly - with: - toolchain: nightly - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - name: Install Rust toolchain + run: rustup update --no-self-update nightly + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 - name: Install cargo-fuzz - run: cargo +nightly install cargo-fuzz --locked + run: cargo +nightly install cargo-fuzz --version 0.13.1 - name: Run fuzz target (smoke test) run: | # Build the fuzz target first diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0ba7864f8d..2cac387fce 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,7 +35,7 @@ jobs: run: | rustup update --no-self-update nightly rustup +nightly component add rustfmt - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 with: save-if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next') }} - name: Fmt @@ -50,7 +50,7 @@ jobs: persist-credentials: false - name: Cleanup large tools for build space uses: ./.github/actions/cleanup-runner - # Added: LLVM/Clang for RocksDB/bindgen + # `make doc` uses --all-features, which enables RocksDB with bindgen. - name: Install LLVM/Clang uses: ./.github/actions/install-llvm with: @@ -59,7 +59,7 @@ jobs: run: | rustup update --no-self-update rustup component add clippy - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 with: save-if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next') }} - name: Clippy @@ -106,7 +106,7 @@ jobs: version: "17" - name: Rustup run: rustup update --no-self-update - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 with: save-if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next') }} - name: Build docs @@ -150,15 +150,6 @@ jobs: - name: Zeroize audit run: make zeroize-audit - fuzz-check: - name: check fuzz crate - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # pin@v4 - with: - persist-credentials: false - - run: make check-fuzz - check-features: name: check all feature combinations runs-on: ubuntu-latest @@ -178,7 +169,7 @@ jobs: - uses: taiki-e/install-action@055f5df8c3f65ea01cd41e9dc855becd88953486 # pin@v2.75.18 with: tool: cargo-hack - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 with: save-if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/next') }} - name: Check all feature combinations diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ac947767ed..7d563d83bc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -90,7 +90,7 @@ jobs: persist-credentials: false - name: Cleanup large tools for build space uses: ./.github/actions/cleanup-runner - - uses: Swatinem/rust-cache@42dc69e1aa15d09112580998cf2ef0119e2e91ae # pin@v2 + - uses: Swatinem/rust-cache@c19371144df3bb44fab255c43d04cbc2ab54d1c4 # pin@v2.9.1 - name: Install LLVM/Clang uses: ./.github/actions/install-llvm with: @@ -101,26 +101,6 @@ jobs: rustup default ${{ matrix.toolchain }} make test-docs - doc-build: - name: doc-build - runs-on: ubuntu-latest - if: ${{ github.event_name == 'pull_request' && (github.base_ref == 'main' || github.base_ref == 'next') }} - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # pin@v4 - with: - persist-credentials: false - - name: Cleanup large tools for build space - uses: ./.github/actions/cleanup-runner - - name: Install LLVM/Clang - uses: ./.github/actions/install-llvm - with: - version: "17" - - name: Build docs - run: | - rustup update --no-self-update nightly - rustup default nightly - make doc - test-p3-parallel: name: test Miden STARK crates parallel ${{ matrix.toolchain }} runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e019bf967..4632b2fc4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,20 @@ +## 0.26.0 (TBD) + +- [BREAKING] Added reusable preprocessed trace setup artifacts for Lifted STARKs: AIRs can declare fixed preprocessed columns, provers build and reuse a `Preprocessed` commitment bundle, and verifier instances receive the trusted preprocessed commitment ([#1021](https://github.com/0xMiden/crypto/pull/1021)). +- [BREAKING] Extracted `BackendReader`, allowing `LargeSmtForest` to work with read-only storage backends ([#986](https://github.com/0xMiden/crypto/pull/986)). +- Optimized prover quotient evaluation by evaluating each AIR's quotient on its native coset (size `n_j · D_j`) and lifting per-AIR, instead of always on the global maximum coset; constraint division is fused into the constraint evaluation loop ([#991](https://github.com/0xMiden/crypto/pull/991)). +- [BREAKING] Replaced the per-AIR witness/aux-builder proving model (`AirInstance`, `AirWitness`, `AuxBuilder`, `prove_multi` / `verify_multi`) with a `MultiAir` trait that owns its AIRs (each builds its own aux trace via `LiftedAir::build_aux_trace`), plus validated `Statement` / `ProverStatement` structs carried by `ProverInstance` / `VerifierInstance`. `LiftedAir::reduced_aux_values` and `num_var_len_public_inputs` are replaced by `MultiAir::eval_external`, which returns the cross-AIR external assertions as a flat list of extension-field values that must equal zero, fed by an `aux_inputs` slice whose schema each `MultiAir` owns and validates ([#992](https://github.com/0xMiden/crypto/pull/992)). +- [BREAKING] Refactored `miden-lifted-stark::domain` around a uniform `Coset` trait shared by `TwoAdicSubgroup` and `TwoAdicCoset`, slimmed the `LiftedDomain` surface (drops dead getters, removes silently-dispatched `points`/`bit_reversed_points`/`vanishing_at` in favour of explicit `trace_subgroup()` / `lde_coset()` access), made `LiftedDomain` constructors fallible, moved selector logic onto `LiftedDomain`, and changed `log_blowup` to return `u8` ([#993](https://github.com/0xMiden/crypto/pull/993)). +- [BREAKING] Upgraded direct `rand` dependencies to 0.10, updating RNG trait bounds and removing direct `rand_hc` usage ([#995](https://github.com/0xMiden/crypto/pull/995)). +- [BREAKING] Reorganized `miden-lifted-stark` internals: consolidated `align`, `bitrev`, `horner`, and `packing` helpers under a new `util` module; removed the legacy `fri::*` re-export facade ([#1000](https://github.com/0xMiden/crypto/pull/1000)). +- perf: fuse per-group accumulator and defer allocations ([#1008](https://github.com/0xMiden/crypto/pull/1008)). +- [BREAKING] Reduced `LargeSmt` cache depth from 24 to 16 levels ([#1011](https://github.com/0xMiden/crypto/pull/1011)). +- [BREAKING] Implemented two-phase commit_mutations() / apply_mutations()-style API for `LargeSmtForest` ([#1018](https://github.com/0xMiden/crypto/pull/1018)). +- [BREAKING] Tightened the `miden-lifted-stark` public API surface: dropped the wide crate-root re-export list (callers now import from `miden_lifted_stark::air` and `miden_lifted_stark::{lmcs, pcs, proof, prover, verifier}` directly), demoted internal submodules to `pub(crate)`/`pub(super)`, and folded the `transcript` module into `proof` (`TranscriptChallenger` / `TranscriptData` / `TranscriptError` are re-exported there). Renamed the proof artifact types — `StarkProof` → `StarkProofData` (wire artifact) and `StarkTranscript` → `StarkProof` (parsed view, built via `StarkProof::from_data`) — and `*::from_verifier_channel` → `*::read_from_channel` on the PCS sub-proofs. Dropped the panicking domain constructors (`TwoAdicCoset::unshifted`, `LiftedDomain::{canonical, sub_domain}`) in favour of the fallible `try_*` variants ([#1020](https://github.com/0xMiden/crypto/pull/1020)). +- [BREAKING] Fixed RocksDB CLI safety, non-canonical serde input handling, and qualified `WordWrapper` derive paths ([#1022](https://github.com/0xMiden/crypto/pull/1022)). +- [BREAKING] Simplify `LargeSmtForest` backend API ([#1030](https://github.com/0xMiden/crypto/pull/1030)). +- [BREAKING] Made `LargeSmt` leaf/entry/inner node iterators fallible. + ## 0.25.1 (2026-05-21) - Fixed `miden-lifted-stark` builds when `p3-maybe-rayon/parallel` is enabled without `miden-lifted-stark/parallel` ([#1023](https://github.com/0xMiden/crypto/pull/1023)). diff --git a/Cargo.lock b/Cargo.lock index 64c98f4c13..fdf989c45d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -250,6 +250,17 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + [[package]] name = "chacha20poly1305" version = "0.10.1" @@ -257,7 +268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -667,9 +678,6 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ - "futures-core", - "futures-sink", - "nanorand", "spin 0.9.8", ] @@ -702,12 +710,6 @@ dependencies = [ "syn", ] -[[package]] -name = "futures-sink" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" - [[package]] name = "futures-task" version = "0.3.32" @@ -744,19 +746,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "getrandom" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" -dependencies = [ - "cfg-if", - "js-sys", - "libc", - "wasi", - "wasm-bindgen", -] - [[package]] name = "getrandom" version = "0.3.4" @@ -778,6 +767,7 @@ dependencies = [ "cfg-if", "libc", "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", "wasip3", ] @@ -954,9 +944,7 @@ dependencies = [ "cfg-if", "ecdsa", "elliptic-curve", - "once_cell", "sha2", - "signature", ] [[package]] @@ -1075,7 +1063,7 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "miden-bench" -version = "0.25.1" +version = "0.26.0" dependencies = [ "clap", "miden-lifted-stark", @@ -1105,7 +1093,7 @@ dependencies = [ [[package]] name = "miden-crypto" -version = "0.25.1" +version = "0.26.0" dependencies = [ "assert_matches", "blake3", @@ -1138,10 +1126,8 @@ dependencies = [ "p3-symmetric", "p3-util", "proptest", - "rand 0.9.4", - "rand_chacha", - "rand_core 0.9.5", - "rand_hc", + "rand 0.10.1", + "rand_chacha 0.10.0", "rayon", "rocksdb", "seq-macro", @@ -1156,15 +1142,16 @@ dependencies = [ [[package]] name = "miden-crypto-derive" -version = "0.25.1" +version = "0.26.0" dependencies = [ + "miden-field", "quote", "syn", ] [[package]] name = "miden-field" -version = "0.25.1" +version = "0.26.0" dependencies = [ "miden-serde-utils", "num-bigint", @@ -1183,9 +1170,10 @@ dependencies = [ [[package]] name = "miden-lifted-air" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-air", + "p3-challenger", "p3-field", "p3-matrix", "p3-util", @@ -1194,7 +1182,7 @@ dependencies = [ [[package]] name = "miden-lifted-stark" -version = "0.25.1" +version = "0.26.0" dependencies = [ "criterion", "miden-lifted-air", @@ -1226,7 +1214,7 @@ dependencies = [ [[package]] name = "miden-serde-utils" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-field", "p3-goldilocks", @@ -1234,7 +1222,7 @@ dependencies = [ [[package]] name = "miden-stark-transcript" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-challenger", "p3-field", @@ -1244,7 +1232,7 @@ dependencies = [ [[package]] name = "miden-stateful-hasher" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-bn254", "p3-field", @@ -1259,15 +1247,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom 0.2.17", -] - [[package]] name = "nom" version = "7.1.3" @@ -1941,7 +1920,7 @@ dependencies = [ "bitflags", "num-traits", "rand 0.9.4", - "rand_chacha", + "rand_chacha 0.9.0", "rand_xorshift", "regex-syntax", "unarray", @@ -1974,7 +1953,7 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" dependencies = [ - "rand_chacha", + "rand_chacha 0.9.0", "rand_core 0.9.5", ] @@ -1984,6 +1963,8 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.2", "rand_core 0.10.1", ] @@ -1997,14 +1978,21 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand_chacha" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e6af7f3e25ded52c41df4e0b1af2d047e45896c2f3281792ed68a1c243daedb" +dependencies = [ + "ppv-lite86", + "rand_core 0.10.1", +] + [[package]] name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.17", -] [[package]] name = "rand_core" @@ -2021,15 +2009,6 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" -[[package]] -name = "rand_hc" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b363d4f6370f88d62bf586c80405657bde0f0e1b8945d47d2ad59b906cb4f54" -dependencies = [ - "rand_core 0.6.4", -] - [[package]] name = "rand_xorshift" version = "0.4.0" @@ -2639,12 +2618,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - [[package]] name = "wasip2" version = "1.0.3+wasi-0.2.9" diff --git a/Cargo.toml b/Cargo.toml index 884c294f5b..00a612b54a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,16 +21,16 @@ keywords = ["crypto", "hash", "merkle", "miden"] license = "MIT OR Apache-2.0" repository = "https://github.com/0xMiden/crypto" rust-version = "1.90" -version = "0.25.1" +version = "0.26.0" [workspace.dependencies] -miden-crypto-derive = { path = "miden-crypto-derive", version = "0.25" } -miden-field = { path = "miden-field", version = "0.25" } -miden-lifted-air = { default-features = false, path = "stark/miden-lifted-air", version = "0.25" } -miden-lifted-stark = { default-features = false, path = "stark/miden-lifted-stark", version = "0.25" } -miden-serde-utils = { path = "miden-serde-utils", version = "0.25" } -miden-stark-transcript = { default-features = false, path = "stark/miden-stark-transcript", version = "0.25" } -miden-stateful-hasher = { default-features = false, path = "stark/miden-stateful-hasher", version = "0.25" } +miden-crypto-derive = { path = "miden-crypto-derive", version = "0.26" } +miden-field = { path = "miden-field", version = "0.26" } +miden-lifted-air = { default-features = false, path = "stark/miden-lifted-air", version = "0.26" } +miden-lifted-stark = { default-features = false, path = "stark/miden-lifted-stark", version = "0.26" } +miden-serde-utils = { path = "miden-serde-utils", version = "0.26" } +miden-stark-transcript = { default-features = false, path = "stark/miden-stark-transcript", version = "0.26" } +miden-stateful-hasher = { default-features = false, path = "stark/miden-stateful-hasher", version = "0.26" } # Plonky3 p3-air = { default-features = false, version = "0.5" } @@ -61,23 +61,21 @@ p3-util = { default-features = false, version = "0.5" } assert_matches = { default-features = false, version = "1.5" } blake3 = { default-features = false, version = "1.8" } cc = "1.2" -chacha20poly1305 = "0.10" +chacha20poly1305 = { default-features = false, version = "0.10" } curve25519-dalek = { default-features = false, version = "4" } der = { default-features = false, version = "0.7" } -ed25519-dalek = "2" -flume = "0.11.1" +ed25519-dalek = { default-features = false, version = "2" } +flume = { default-features = false, version = "0.11.1" } hex = { default-features = false, version = "0.4" } hkdf = { default-features = false, version = "0.12" } itertools = "0.14" -k256 = "0.13" +k256 = { default-features = false, version = "0.13" } num = { default-features = false, version = "0.4" } num-complex = { default-features = false, version = "0.4" } once_cell = { default-features = false, version = "1.21" } proptest = { default-features = false, version = "1.7" } quote = "1.0" -rand_chacha = { default-features = false, version = "0.9" } -rand_core = { default-features = false, version = "0.9" } -rand_hc = "0.3" +rand_chacha = { default-features = false, version = "0.10" } rayon = "1.10" rocksdb = { default-features = false, version = "0.24" } rstest = "0.26" diff --git a/Makefile b/Makefile index 0cfdf1ef97..f642a6115f 100644 --- a/Makefile +++ b/Makefile @@ -8,6 +8,8 @@ help: ALL_FEATURES_EXCEPT_ROCKSDB="concurrent executable internal serde std" MIDEN_STARK_TEST_PACKAGES=-p miden-lifted-air -p miden-lifted-stark -p miden-stateful-hasher -p miden-stark-transcript +MIDEN_CRYPTO_FUZZ_TARGETS=word merkle merkle_store smt_serde partial_smt mmr crypto aead signatures +MIDEN_SERDE_UTILS_FUZZ_TARGETS=primitives collections string vint64 goldilocks budgeted WARNINGS=RUSTDOCFLAGS="-D warnings" # -- linting -------------------------------------------------------------------------------------- @@ -76,7 +78,7 @@ lint: clippy fix format toml typos-check shear cargo-deny ## Run all linting tas .PHONY: doc doc: ## Generate and check documentation for workspace crates only - rm -rf "${CARGO_TARGET_DIR:-target}/doc" + rm -rf "$${CARGO_TARGET_DIR:-target}/doc" RUSTDOCFLAGS="--enable-index-page -Zunstable-options -D warnings" cargo +nightly doc --all-features --keep-going --release --no-deps # --- testing ------------------------------------------------------------------------------------- @@ -103,7 +105,7 @@ test-p3-parallel: ## Run Miden STARK crate tests with the parallel feature enabl .PHONY: test-large-smt test-large-smt: ## Run large SMT unit tests and RocksDB integration tests - cargo nextest run --success-output immediate --profile large-smt --cargo-profile test-release --features rocksdb + cargo nextest run --success-output immediate --profile large-smt --cargo-profile test-release --features persistent-forest .PHONY: test test: test-default test-no-std test-docs test-large-smt ## Run all tests except concurrent SMT tests @@ -119,8 +121,15 @@ check-features: ## Check curated feature combinations across the integrated work ./scripts/check-features.sh .PHONY: check-fuzz -check-fuzz: ## Check miden-crypto-fuzz compilation - cd miden-crypto-fuzz && cargo check +check-fuzz: ## Check and link fuzz targets + cd miden-crypto-fuzz && cargo check --locked + cd miden-serde-utils/fuzz && cargo check --locked + for target in $(MIDEN_CRYPTO_FUZZ_TARGETS); do \ + cargo +nightly fuzz build --fuzz-dir miden-crypto-fuzz $$target; \ + done + for target in $(MIDEN_SERDE_UTILS_FUZZ_TARGETS); do \ + (cd miden-serde-utils && cargo +nightly fuzz build $$target); \ + done # --- building ------------------------------------------------------------------------------------ @@ -218,10 +227,10 @@ fuzz-signatures: ## Run fuzzing for DSA signature deserialization check-tools: ## Checks if development tools are installed @echo "Checking development tools..." @command -v typos >/dev/null 2>&1 && echo "[OK] typos is installed" || echo "[MISSING] typos is not installed (run: make install-tools)" - @command -v cargo nextest >/dev/null 2>&1 && echo "[OK] nextest is installed" || echo "[MISSING] nextest is not installed (run: make install-tools)" + @command -v cargo-nextest >/dev/null 2>&1 && echo "[OK] nextest is installed" || echo "[MISSING] nextest is not installed (run: make install-tools)" @command -v taplo >/dev/null 2>&1 && echo "[OK] taplo is installed" || echo "[MISSING] taplo is not installed (run: make install-tools)" @command -v cargo-shear >/dev/null 2>&1 && echo "[OK] cargo-shear is installed" || echo "[MISSING] cargo-shear is not installed (run: make install-tools)" - @command -v cargo deny >/dev/null 2>&1 && echo "[OK] cargo-deny is installed" || echo "[MISSING] cargo-deny is not installed (run: make install-tools)" + @command -v cargo-deny >/dev/null 2>&1 && echo "[OK] cargo-deny is installed" || echo "[MISSING] cargo-deny is not installed (run: make install-tools)" .PHONY: install-tools install-tools: ## Installs development tools required by the Makefile (typos, nextest, taplo, cargo-shear, cargo-deny) diff --git a/README.md b/README.md index 0b475c76c1..3ba97f74b6 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ Messages sealed as one type must be unsealed using the corresponding method, oth - AIR traits and builders for defining algebraic constraints. - A Lifted Merkle commitment scheme and FRI-based polynomial commitments. -- Prover and verifier for single and multi-trace STARKs (`prove_single`/`prove_multi`, `verify_single`/`verify_multi`). +- Prover and verifier for multi-trace STARKs (`prove` / `verify`). - Fiat-Shamir transcript and challenge generation. - A debug constraint checker for development. diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..98934e5dbe --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,83 @@ +# Security Policy + +Miden Crypto is currently pre-1.0, has not been independently audited as a whole, +and should be evaluated carefully before production use. Security reports are +valuable, especially when they identify issues that could affect cryptographic +assumptions, proof soundness, verification correctness, serialization +canonicality, release integrity, or unsafe use of the crates in this workspace. + +## Supported Versions + +Security fixes are prioritized for the current active development branch and the +most recent release line. + +| Version | Supported | +| ------- | --------- | +| `next` | Yes | +| Latest published `0.x` release | Best effort | +| Older releases | No guarantee | + +Because the project is pre-1.0, APIs and internals may change between releases. +Users should plan to upgrade to the latest release after a security fix is +published. + +## Reporting a Vulnerability + +Please do not open a public issue, discussion, or pull request for an +undisclosed vulnerability. Report vulnerabilities through GitHub's private +vulnerability reporting flow: + + + +Use a public GitHub issue only for ordinary bugs that do not create a security +risk. + +## What to Include + +Please set a high bar for security submissions. Reports are most actionable when +they include: + +- A clear description of the vulnerability and its security impact. +- The affected crate, component, version, commit, or branch. +- A minimal proof of concept or reproducible test case. +- A proposed fix patch, where possible. +- A regression test that fails before the fix and passes after it, where + possible. +- Any relevant environment details, configuration, feature flags, inputs, or + assumptions. +- Whether the issue has been disclosed anywhere else. + +Incomplete reports may still be useful, but maintainers may ask for a proof of +concept, a patch, or a regression test before triage can be completed. + +## Scope + +Security-relevant reports include, but are not limited to: + +- Cryptographic assumption, domain-separation, transcript, or randomness issues. +- Bugs that allow invalid signatures, proofs, openings, ciphertexts, or package + artifacts to be accepted as valid. +- Vulnerabilities in hashing, authenticated encryption, key exchange, digital + signatures, Merkle structures, serialization, deserialization, or package + handling. +- Proof soundness, verifier acceptance, or constraint-system issues in the STARK + crates. +- Memory-safety issues, panics, timing leaks, or resource exhaustion with a + plausible security impact. +- Supply-chain, release, CI, or signing issues that could affect published + artifacts. + +General correctness bugs, documentation issues, feature requests, and +performance problems without a plausible security impact should be reported +through the normal public issue tracker. + +## Disclosure + +Maintainers will use GitHub security advisories to coordinate investigation, +fixes, credits, and public disclosure. Please give maintainers a reasonable +opportunity to investigate and release a fix before disclosing the issue +publicly. + +When testing, act in good faith: do not access data that is not yours, do not +disrupt services, and do not use the vulnerability beyond what is necessary to +demonstrate impact. diff --git a/deny.toml b/deny.toml index 5a7bd7db0e..7eafd2df75 100644 --- a/deny.toml +++ b/deny.toml @@ -44,12 +44,14 @@ skip = [ { name = "thiserror-impl" }, ] skip-tree = [ - # miden-crypto uses rand 0.9 while Plonky3 0.5 uses rand 0.10, - # causing duplicate rand / rand_core / getrandom trees + # Stable crypto crates still depend on rand_core 0.6 and chacha20 0.9. + { name = "chacha20", version = "=0.9.1" }, { name = "cpufeatures", version = "0.2.17" }, - { name = "getrandom", version = "=0.2.17" }, - { name = "rand", version = "=0.9.4" }, { name = "rand_core", version = "=0.6.4" }, + + # Dev/test-only proptest still depends on rand 0.9. + { name = "rand", version = "=0.9.4" }, + { name = "rand_chacha", version = "=0.9.0" }, { name = "rand_core", version = "=0.9.5" }, ] wildcards = "allow" diff --git a/miden-bench/src/lifted.rs b/miden-bench/src/lifted.rs index 77915221af..ffbbd312ed 100644 --- a/miden-bench/src/lifted.rs +++ b/miden-bench/src/lifted.rs @@ -3,17 +3,15 @@ use std::fmt; use miden_lifted_stark::{ - AirInstance, AirWitness, StarkConfig, - air::{BaseAir, LiftedAir, LiftedAirBuilder}, - prove_multi, + ProverInstance, StarkConfig, VerifierInstance, + air::{BaseAir, LiftedAir, LiftedAirBuilder, MultiAir, ProverStatement, Statement}, testing::airs::{ - ZeroAuxBuilder, blake3::LiftedBlake3Air, keccak::LiftedKeccakAir, miden::DummyMidenAir, + blake3::LiftedBlake3Air, keccak::LiftedKeccakAir, miden::DummyMidenAir, poseidon2::LiftedPoseidon2Air, }, - verify_multi, }; use p3_field::Field; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; use tracing::info_span; use crate::{ @@ -65,8 +63,18 @@ impl LiftedAir for LiftedBenchAir { } } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + _challenges: &[EF], + ) -> (RowMajorMatrix, Vec) { + // All-zero aux trace of the AIR's declared width. + let aux_width = LiftedAir::::aux_width(self); + let num_aux_values = LiftedAir::::num_aux_values(self); + let aux = RowMajorMatrix::new(EF::zero_vec(main.height() * aux_width), aux_width); + (aux, EF::zero_vec(num_aux_values)) } fn eval>(&self, builder: &mut AB) { @@ -79,6 +87,22 @@ impl LiftedAir for LiftedBenchAir { } } +// ═══════════════════════════════════════════════════════════════════════════════ +// MultiAir: empty public inputs; each AIR builds its own all-zero aux trace. +// ═══════════════════════════════════════════════════════════════════════════════ + +struct BenchMultiAir { + airs: Vec, +} + +impl MultiAir for BenchMultiAir { + type Air = LiftedBenchAir; + + fn airs(&self) -> &[Self::Air] { + &self.airs + } +} + // ═══════════════════════════════════════════════════════════════════════════════ // Runner // ═══════════════════════════════════════════════════════════════════════════════ @@ -86,13 +110,13 @@ impl LiftedAir for LiftedBenchAir { pub(crate) fn run_lifted( config: &SC, specs: &[TraceSpec], - traces: &[RowMajorMatrix], + traces: Vec>, constants: &Option, cli: &Cli, ) -> RunResult where SC: StarkConfig, - miden_lifted_stark::StarkDigest: PartialEq + fmt::Debug, + miden_lifted_stark::proof::StarkDigest: PartialEq + fmt::Debug, { let airs: Vec = specs .iter() @@ -109,26 +133,14 @@ where }) .collect(); - let aux_builders: Vec = specs - .iter() - .map(|spec| match spec.air_type { - AirType::Miden => ZeroAuxBuilder { - num_aux_cols: spec.num_aux_cols, - num_aux_values: spec.num_aux_cols, - }, - _ => ZeroAuxBuilder::dummy(), - }) - .collect(); - - let instances: Vec<_> = airs - .iter() - .zip(traces) - .zip(&aux_builders) - .map(|((air, trace), aux)| (air, AirWitness::new(trace, &[], &[]), aux)) - .collect(); + let statement = + Statement::new(BenchMultiAir { airs }, Vec::new(), Vec::new()).expect("statement"); + let prover_statement = ProverStatement::new(statement, traces).expect("prover statement"); + let prover_instance = + ProverInstance::new(config, &prover_statement, None).expect("no preprocessed columns"); let output = info_span!("prove") - .in_scope(|| prove_multi(config, &instances, config.challenger()).expect("proving failed")); + .in_scope(|| prover_instance.prove(config.challenger()).expect("proving failed")); let result = RunResult { proof_size_bytes: output.proof.size_in_bytes(), @@ -138,21 +150,10 @@ where if !cli.no_verify { info_span!("verify").in_scope(|| { - let verifier_instances: Vec<_> = airs - .iter() - .map(|air| { - ( - air, - AirInstance { - public_values: &[], - var_len_public_inputs: &[], - }, - ) - }) - .collect(); - let digest = - verify_multi(config, &verifier_instances, &output.proof, config.challenger()) - .expect("verification failed"); + let digest = VerifierInstance::new(config, prover_statement.statement(), None) + .expect("no preprocessed columns") + .verify(&output.proof, config.challenger()) + .expect("verification failed"); assert_eq!(output.digest, digest); }); } diff --git a/miden-bench/src/main.rs b/miden-bench/src/main.rs index 94c0f55d6a..e4adb035de 100644 --- a/miden-bench/src/main.rs +++ b/miden-bench/src/main.rs @@ -36,7 +36,8 @@ use std::{fmt, time::Instant}; use clap::Parser; use miden_lifted_stark::{ - GenericStarkConfig, PcsParams, + GenericStarkConfig, + pcs::PcsParams, testing::{ airs::{ blake3::generate_blake3_trace, @@ -253,7 +254,7 @@ fn main() { Dft::default(), gl::test_challenger(), ); - lifted::run_lifted(&config, &specs, &traces, &poseidon2_constants, &cli) + lifted::run_lifted(&config, &specs, traces, &poseidon2_constants, &cli) }, HashFn::Keccak => { let config = GenericStarkConfig::new( @@ -262,7 +263,7 @@ fn main() { Dft::default(), keccak::test_challenger(), ); - lifted::run_lifted(&config, &specs, &traces, &poseidon2_constants, &cli) + lifted::run_lifted(&config, &specs, traces, &poseidon2_constants, &cli) }, HashFn::Blake3 => { let config = GenericStarkConfig::new( @@ -271,7 +272,7 @@ fn main() { Dft::default(), blake3::test_challenger(), ); - lifted::run_lifted(&config, &specs, &traces, &poseidon2_constants, &cli) + lifted::run_lifted(&config, &specs, traces, &poseidon2_constants, &cli) }, HashFn::Blake3_192 => { let config = GenericStarkConfig::new( @@ -280,7 +281,7 @@ fn main() { Dft::default(), blake3_192::test_challenger(), ); - lifted::run_lifted(&config, &specs, &traces, &poseidon2_constants, &cli) + lifted::run_lifted(&config, &specs, traces, &poseidon2_constants, &cli) }, }, Mode::Batch => match cli.hash { diff --git a/miden-crypto-derive/Cargo.toml b/miden-crypto-derive/Cargo.toml index b31cecd6f8..a09b85b827 100644 --- a/miden-crypto-derive/Cargo.toml +++ b/miden-crypto-derive/Cargo.toml @@ -20,5 +20,8 @@ test = false quote = { workspace = true } syn = { features = ["full"], workspace = true } +[dev-dependencies] +miden-field = { workspace = true } + [lints] workspace = true diff --git a/miden-crypto-derive/src/lib.rs b/miden-crypto-derive/src/lib.rs index 9c5b080ed0..ccae8dcb5e 100644 --- a/miden-crypto-derive/src/lib.rs +++ b/miden-crypto-derive/src/lib.rs @@ -1,6 +1,6 @@ use proc_macro::TokenStream; use quote::quote; -use syn::{Data, DeriveInput, Fields, Type, parse_macro_input}; +use syn::{Data, DeriveInput, Fields, PathArguments, Type, parse_macro_input}; // SILENT DEBUG MACRO // ================================================================================================ @@ -169,19 +169,24 @@ pub fn word_wrapper_derive(input: TokenStream) -> TokenStream { }, }; - // Verify that the field type is 'Word' (or a path ending in 'Word') - if let Type::Path(type_path) = field_type { - let last_segment = type_path.path.segments.last(); - if let Some(segment) = last_segment { - if segment.ident != "Word" { - return syn::Error::new_spanned( - field_type, - "WordWrapper can only be derived for types wrapping a 'Word' field", - ) - .to_compile_error() - .into(); - } - } else { + let word_type = if let Type::Path(type_path) = field_type { + let Some(segment) = type_path.path.segments.last() else { + return syn::Error::new_spanned( + field_type, + "WordWrapper can only be derived for types wrapping a 'Word' field", + ) + .to_compile_error() + .into(); + }; + if segment.ident != "Word" { + return syn::Error::new_spanned( + field_type, + "WordWrapper can only be derived for types wrapping a 'Word' field", + ) + .to_compile_error() + .into(); + } + if !matches!(segment.arguments, PathArguments::None) { return syn::Error::new_spanned( field_type, "WordWrapper can only be derived for types wrapping a 'Word' field", @@ -189,6 +194,8 @@ pub fn word_wrapper_derive(input: TokenStream) -> TokenStream { .to_compile_error() .into(); } + + field_type } else { return syn::Error::new_spanned( field_type, @@ -196,7 +203,7 @@ pub fn word_wrapper_derive(input: TokenStream) -> TokenStream { ) .to_compile_error() .into(); - } + }; let expanded = quote! { impl #impl_generics #name #ty_generics #where_clause { @@ -206,12 +213,12 @@ pub fn word_wrapper_derive(input: TokenStream) -> TokenStream { /// /// This requires the caller to uphold the guarantees/invariants of this type (if any). /// Check the type-level documentation for guarantees/invariants. - pub fn from_raw(word: Word) -> Self { + pub fn from_raw(word: #word_type) -> Self { Self(word) } /// Returns the elements representation of this value. - pub fn as_elements(&self) -> &[Felt] { + pub fn as_elements(&self) -> &[<#word_type as ::core::ops::Index>::Output] { self.0.as_elements() } @@ -226,7 +233,7 @@ pub fn word_wrapper_derive(input: TokenStream) -> TokenStream { } /// Returns the underlying word of this value. - pub fn as_word(&self) -> Word { + pub fn as_word(&self) -> #word_type { self.0 } } diff --git a/miden-crypto-derive/tests/word_wrapper.rs b/miden-crypto-derive/tests/word_wrapper.rs new file mode 100644 index 0000000000..76ec864039 --- /dev/null +++ b/miden-crypto-derive/tests/word_wrapper.rs @@ -0,0 +1,67 @@ +#![allow(unused_qualifications)] + +use miden_crypto_derive::WordWrapper; + +mod qualified { + mod field { + pub use miden_field::Word; + } + + #[derive(super::WordWrapper)] + pub struct QualifiedWord(miden_field::Word); + + #[derive(super::WordWrapper)] + pub struct ModuleQualifiedWord(miden_field::word::Word); + + #[derive(super::WordWrapper)] + pub struct ReexportedWord(field::Word); + + #[test] + fn derives_accessors_for_qualified_word_path() { + let word = miden_field::Word::default(); + let wrapper = QualifiedWord::from_raw(word); + + let elements: &[miden_field::Felt] = wrapper.as_elements(); + assert_eq!(elements, word.as_elements()); + assert_eq!(wrapper.as_word(), word); + } + + #[test] + fn derives_accessors_for_module_qualified_word_path() { + let word = miden_field::word::Word::default(); + let wrapper = ModuleQualifiedWord::from_raw(word); + + let elements: &[miden_field::Felt] = wrapper.as_elements(); + assert_eq!(elements, word.as_elements()); + assert_eq!(wrapper.as_word(), word); + } + + #[test] + fn derives_accessors_for_reexported_word_path() { + let word = field::Word::default(); + let wrapper = ReexportedWord::from_raw(word); + + let elements: &[miden_field::Felt] = wrapper.as_elements(); + assert_eq!(elements, word.as_elements()); + assert_eq!(wrapper.as_word(), word); + } +} + +mod unqualified { + use miden_field::{Felt, Word}; + + use super::WordWrapper; + + #[derive(WordWrapper)] + pub struct UnqualifiedWord(Word); + + #[test] + fn derives_accessors_for_unqualified_word_path() { + let word = Word::default(); + let wrapper = UnqualifiedWord::from_raw(word); + + let elements: &[Felt] = wrapper.as_elements(); + assert_eq!(elements, word.as_elements()); + assert_eq!(wrapper.as_word(), word); + } +} diff --git a/miden-crypto-fuzz/Cargo.lock b/miden-crypto-fuzz/Cargo.lock index 953501a2e1..1ddc7a4a54 100644 --- a/miden-crypto-fuzz/Cargo.lock +++ b/miden-crypto-fuzz/Cargo.lock @@ -12,6 +12,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + [[package]] name = "arbitrary" version = "1.4.2" @@ -48,6 +54,12 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + [[package]] name = "blake3" version = "1.8.5" @@ -71,12 +83,6 @@ dependencies = [ "generic-array", ] -[[package]] -name = "bumpalo" -version = "3.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" - [[package]] name = "cc" version = "1.2.62" @@ -106,6 +112,17 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.1", +] + [[package]] name = "chacha20poly1305" version = "0.10.1" @@ -113,7 +130,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -333,6 +350,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + [[package]] name = "ff" version = "0.13.1" @@ -361,41 +384,14 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ - "futures-core", - "futures-sink", - "nanorand", "spin 0.9.8", ] [[package]] -name = "futures-core" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" - -[[package]] -name = "futures-sink" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" - -[[package]] -name = "futures-task" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" - -[[package]] -name = "futures-util" -version = "0.3.32" +name = "foldhash" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" -dependencies = [ - "futures-core", - "futures-task", - "pin-project-lite", - "slab", -] +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "generic-array" @@ -410,27 +406,28 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.17" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", - "js-sys", "libc", - "wasi", - "wasm-bindgen", + "r-efi 5.3.0", + "wasip2", ] [[package]] name = "getrandom" -version = "0.3.4" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 6.0.0", + "rand_core 0.10.1", "wasip2", + "wasip3", ] [[package]] @@ -444,6 +441,27 @@ dependencies = [ "subtle", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hkdf" version = "0.12.4" @@ -462,6 +480,24 @@ dependencies = [ "digest", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.1", + "serde", + "serde_core", +] + [[package]] name = "inout" version = "0.1.4" @@ -480,6 +516,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + [[package]] name = "jobserver" version = "0.1.34" @@ -490,18 +532,6 @@ dependencies = [ "libc", ] -[[package]] -name = "js-sys" -version = "0.3.98" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" -dependencies = [ - "cfg-if", - "futures-util", - "once_cell", - "wasm-bindgen", -] - [[package]] name = "k256" version = "0.13.4" @@ -511,9 +541,7 @@ dependencies = [ "cfg-if", "ecdsa", "elliptic-curve", - "once_cell", "sha2", - "signature", ] [[package]] @@ -525,6 +553,12 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.186" @@ -556,9 +590,21 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + [[package]] name = "miden-crypto" -version = "0.25.1" +version = "0.26.0" dependencies = [ "blake3", "cc", @@ -585,10 +631,8 @@ dependencies = [ "p3-maybe-rayon", "p3-symmetric", "p3-util", - "rand 0.9.4", + "rand", "rand_chacha", - "rand_core 0.9.5", - "rand_hc", "rayon", "serde", "sha2", @@ -600,7 +644,7 @@ dependencies = [ [[package]] name = "miden-crypto-derive" -version = "0.25.1" +version = "0.26.0" dependencies = [ "quote", "syn", @@ -612,12 +656,12 @@ version = "0.0.0" dependencies = [ "libfuzzer-sys", "miden-crypto", - "rand 0.9.4", + "rand", ] [[package]] name = "miden-field" -version = "0.25.1" +version = "0.26.0" dependencies = [ "miden-serde-utils", "num-bigint", @@ -626,7 +670,7 @@ dependencies = [ "p3-goldilocks", "p3-util", "paste", - "rand 0.10.1", + "rand", "serde", "subtle", "thiserror", @@ -634,9 +678,10 @@ dependencies = [ [[package]] name = "miden-lifted-air" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-air", + "p3-challenger", "p3-field", "p3-matrix", "p3-util", @@ -645,7 +690,7 @@ dependencies = [ [[package]] name = "miden-lifted-stark" -version = "0.25.1" +version = "0.26.0" dependencies = [ "miden-lifted-air", "miden-stark-transcript", @@ -658,7 +703,7 @@ dependencies = [ "p3-maybe-rayon", "p3-symmetric", "p3-util", - "rand 0.10.1", + "rand", "serde", "thiserror", "tracing", @@ -666,7 +711,7 @@ dependencies = [ [[package]] name = "miden-serde-utils" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-field", "p3-goldilocks", @@ -674,7 +719,7 @@ dependencies = [ [[package]] name = "miden-stark-transcript" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-challenger", "p3-field", @@ -684,21 +729,12 @@ dependencies = [ [[package]] name = "miden-stateful-hasher" -version = "0.25.1" +version = "0.26.0" dependencies = [ "p3-field", "p3-symmetric", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom 0.2.17", -] - [[package]] name = "num" version = "0.4.3" @@ -851,7 +887,7 @@ dependencies = [ "p3-maybe-rayon", "p3-util", "paste", - "rand 0.10.1", + "rand", "serde", "tracing", ] @@ -872,7 +908,7 @@ dependencies = [ "p3-symmetric", "p3-util", "paste", - "rand 0.10.1", + "rand", "serde", ] @@ -897,7 +933,7 @@ dependencies = [ "p3-field", "p3-maybe-rayon", "p3-util", - "rand 0.10.1", + "rand", "serde", "tracing", ] @@ -921,7 +957,7 @@ dependencies = [ "p3-field", "p3-symmetric", "p3-util", - "rand 0.10.1", + "rand", ] [[package]] @@ -942,7 +978,7 @@ dependencies = [ "p3-symmetric", "p3-util", "paste", - "rand 0.10.1", + "rand", "serde", "spin 0.10.0", "tracing", @@ -956,7 +992,7 @@ checksum = "04e2a562fea210baae390a32f9ecf0dd8724ae3f4352d1c8e413077b6f00a162" dependencies = [ "p3-field", "p3-symmetric", - "rand 0.10.1", + "rand", ] [[package]] @@ -969,7 +1005,7 @@ dependencies = [ "p3-mds", "p3-symmetric", "p3-util", - "rand 0.10.1", + "rand", ] [[package]] @@ -1043,6 +1079,16 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.106" @@ -1068,14 +1114,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] -name = "rand" -version = "0.9.4" +name = "r-efi" +version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" -dependencies = [ - "rand_chacha", - "rand_core 0.9.5", -] +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" [[package]] name = "rand" @@ -1083,17 +1125,19 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2e8e8bcc7961af1fdac401278c6a831614941f6164ee3bf4ce61b7edb162207" dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.2", "rand_core 0.10.1", ] [[package]] name = "rand_chacha" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +checksum = "3e6af7f3e25ded52c41df4e0b1af2d047e45896c2f3281792ed68a1c243daedb" dependencies = [ "ppv-lite86", - "rand_core 0.9.5", + "rand_core 0.10.1", ] [[package]] @@ -1101,18 +1145,6 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.17", -] - -[[package]] -name = "rand_core" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" -dependencies = [ - "getrandom 0.3.4", -] [[package]] name = "rand_core" @@ -1120,15 +1152,6 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63b8176103e19a2643978565ca18b50549f6101881c443590420e4dc998a3c69" -[[package]] -name = "rand_hc" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b363d4f6370f88d62bf586c80405657bde0f0e1b8945d47d2ad59b906cb4f54" -dependencies = [ - "rand_core 0.6.4", -] - [[package]] name = "rayon" version = "1.12.0" @@ -1168,12 +1191,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustversion" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" - [[package]] name = "scopeguard" version = "1.2.0" @@ -1230,6 +1247,19 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1267,12 +1297,6 @@ dependencies = [ "rand_core 0.6.4", ] -[[package]] -name = "slab" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" - [[package]] name = "spin" version = "0.9.8" @@ -1403,6 +1427,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "universal-hash" version = "0.5.1" @@ -1419,64 +1449,65 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - [[package]] name = "wasip2" version = "1.0.3+wasi-0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" dependencies = [ - "wit-bindgen", + "wit-bindgen 0.57.1", ] [[package]] -name = "wasm-bindgen" -version = "0.2.121" +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" dependencies = [ - "cfg-if", - "once_cell", - "rustversion", - "wasm-bindgen-macro", - "wasm-bindgen-shared", + "wit-bindgen 0.51.0", ] [[package]] -name = "wasm-bindgen-macro" -version = "0.2.121" +name = "wasm-encoder" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" dependencies = [ - "quote", - "wasm-bindgen-macro-support", + "leb128fmt", + "wasmparser", ] [[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.121" +name = "wasm-metadata" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" dependencies = [ - "bumpalo", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", ] [[package]] -name = "wasm-bindgen-shared" -version = "0.2.121" +name = "wasmparser" +version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ - "unicode-ident", + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", ] [[package]] @@ -1485,6 +1516,85 @@ version = "0.57.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + [[package]] name = "x25519-dalek" version = "2.0.1" @@ -1520,3 +1630,9 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/miden-crypto-fuzz/Cargo.toml b/miden-crypto-fuzz/Cargo.toml index 02a6d6b0c6..58103b93cf 100644 --- a/miden-crypto-fuzz/Cargo.toml +++ b/miden-crypto-fuzz/Cargo.toml @@ -10,7 +10,7 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = "0.4" miden-crypto = { features = ["concurrent", "fuzzing"], path = "../miden-crypto" } -rand = { default-features = false, version = "0.9" } +rand = { default-features = false, version = "0.10" } # Existing SMT fuzz target (for sequential vs parallel consistency) [[bin]] diff --git a/miden-crypto-fuzz/fuzz_targets/smt.rs b/miden-crypto-fuzz/fuzz_targets/smt.rs index b62439c2cd..973093f1b3 100644 --- a/miden-crypto-fuzz/fuzz_targets/smt.rs +++ b/miden-crypto-fuzz/fuzz_targets/smt.rs @@ -2,7 +2,7 @@ use libfuzzer_sys::fuzz_target; use miden_crypto::{merkle::smt::Smt, Felt, Word, ONE}; -use rand::Rng; // Needed for randomizing the split percentage +use rand::RngExt; // Needed for randomizing the split percentage struct FuzzInput { entries: Vec<(Word, Word)>, diff --git a/miden-crypto/Cargo.toml b/miden-crypto/Cargo.toml index 6eddb6955f..a9201f2b25 100644 --- a/miden-crypto/Cargo.toml +++ b/miden-crypto/Cargo.toml @@ -88,7 +88,7 @@ internal = ["concurrent"] persistent-forest = ["rocksdb", "serde"] rocksdb = ["concurrent", "dep:rocksdb"] serde = ["dep:serde", "serde?/alloc"] -std = ["blake3/std", "dep:cc", "miden-serde-utils/std", "rand/std", "rand/thread_rng", "serde?/std"] +std = ["blake3/std", "dep:cc", "miden-serde-utils/std", "once_cell/std", "rand/std", "rand/thread_rng", "serde?/std"] testing = ["dep:proptest", "miden-field/testing"] [dependencies] @@ -100,22 +100,20 @@ miden-serde-utils = { workspace = true } # External dependencies blake3 = { workspace = true } -chacha20poly1305 = { features = ["alloc", "stream"], workspace = true } +chacha20poly1305 = { features = ["alloc", "rand_core", "stream"], workspace = true } clap = { features = ["derive"], optional = true, workspace = true } curve25519-dalek = { workspace = true } der = { workspace = true } -ed25519-dalek = { features = ["zeroize"], workspace = true } +ed25519-dalek = { features = ["alloc", "zeroize"], workspace = true } flume = { workspace = true } hkdf = { workspace = true } -k256 = { features = ["ecdh", "ecdsa", "pkcs8"], workspace = true } +k256 = { features = ["alloc", "ecdh", "ecdsa", "pkcs8"], workspace = true } num = { features = ["alloc", "libm"], workspace = true } num-complex = { workspace = true } once_cell = { features = ["alloc", "critical-section"], workspace = true } proptest = { features = ["alloc"], optional = true, workspace = true } -rand = { default-features = false, version = "0.9" } +rand = { workspace = true } rand_chacha = { workspace = true } -rand_core = { workspace = true } -rand_hc = { workspace = true } rayon = { optional = true, workspace = true } rocksdb = { features = ["bindgen-runtime", "lz4"], optional = true, workspace = true } serde = { features = ["derive"], optional = true, workspace = true } diff --git a/miden-crypto/benches/common/macros.rs b/miden-crypto/benches/common/macros.rs index 130eb97cc5..3464ca61ce 100644 --- a/miden-crypto/benches/common/macros.rs +++ b/miden-crypto/benches/common/macros.rs @@ -961,8 +961,8 @@ macro_rules! benchmark_aead_bytes { /// Benchmark AEAD operations on byte arrays fn $bytes_fn(c: &mut Criterion) { use miden_crypto::aead::$aead_module::{Nonce, SecretKey}; + use rand::SeedableRng; use rand_chacha::ChaCha20Rng; - use rand_core::SeedableRng; let group_name = format!("{} - Byte Arrays", $group_prefix); let mut group = c.benchmark_group(&group_name); @@ -1044,8 +1044,8 @@ macro_rules! benchmark_aead_field { /// Benchmark AEAD operations on field elements fn $felts_fn(c: &mut Criterion) { use miden_crypto::aead::$aead_module::{Nonce, SecretKey}; + use rand::SeedableRng; use rand_chacha::ChaCha20Rng; - use rand_core::SeedableRng; let group_name = format!("{} - Field Elements", $group_prefix); let mut group = c.benchmark_group(&group_name); diff --git a/miden-crypto/benches/large_smt.rs b/miden-crypto/benches/large_smt.rs index 397f2e497b..1199b98df2 100644 --- a/miden-crypto/benches/large_smt.rs +++ b/miden-crypto/benches/large_smt.rs @@ -164,6 +164,25 @@ benchmark_with_setup_data! { }, } +benchmark_with_setup_data! { + large_smt_clone, + DEFAULT_MEASUREMENT_TIME, + DEFAULT_SAMPLE_SIZE, + "rocksdb_smt_clone", + || { + let entries = generate_smt_entries_sequential(10_000); + let temp_dir = tempfile::TempDir::new().unwrap(); + let storage = RocksDbStorage::open(RocksDbConfig::new(temp_dir.path())).unwrap(); + let smt = LargeSmt::with_entries(storage, entries).unwrap(); + (smt, temp_dir) + }, + |b: &mut criterion::Bencher, (smt, _temp_dir): &(LargeSmt, tempfile::TempDir)| { + // iter_batched drops the returned clone after the timed section, keeping the + // RocksDbStorage::drop flush out of the measurement. + b.iter_batched(|| (), |_| hint::black_box(smt.clone()), BatchSize::SmallInput) + }, +} + benchmark_with_setup_data! { large_smt_compute_mutations, DEFAULT_MEASUREMENT_TIME, @@ -199,6 +218,7 @@ benchmark_batch! { b.iter_batched( || { + let _ = std::fs::remove_dir_all(&bench_dir); std::fs::create_dir_all(&bench_dir).unwrap(); let storage = RocksDbStorage::open(RocksDbConfig::new(&bench_dir)).unwrap(); let smt = LargeSmt::with_entries(storage, base_entries.clone()).unwrap(); @@ -232,6 +252,7 @@ benchmark_batch! { b.iter_batched( || { + let _ = std::fs::remove_dir_all(&bench_dir); std::fs::create_dir_all(&bench_dir).unwrap(); let storage = RocksDbStorage::open(RocksDbConfig::new(&bench_dir)).unwrap(); let smt = LargeSmt::with_entries(storage, base_entries.clone()).unwrap(); @@ -406,6 +427,7 @@ criterion_group!( large_smt_benchmark_group, large_smt_open, large_smt_open_in_large_tree, + large_smt_clone, large_smt_compute_mutations, large_smt_apply_mutations, large_smt_apply_mutations_with_reversion, diff --git a/miden-crypto/src/aead/aead_poseidon2/mod.rs b/miden-crypto/src/aead/aead_poseidon2/mod.rs index 70e443a2e1..8f1546f059 100644 --- a/miden-crypto/src/aead/aead_poseidon2/mod.rs +++ b/miden-crypto/src/aead/aead_poseidon2/mod.rs @@ -13,7 +13,7 @@ use core::ops::Range; use miden_crypto_derive::{SilentDebug, SilentDisplay}; use num::Integer; use rand::{ - Rng, + Rng, RngExt, distr::{Distribution, StandardUniform, Uniform}, }; use subtle::ConstantTimeEq; @@ -772,7 +772,7 @@ impl AeadScheme for AeadPoseidon2 { .map_err(|_| EncryptionError::FailedOperation) } - fn encrypt_bytes( + fn encrypt_bytes( key: &Self::Key, rng: &mut R, plaintext: &[u8], @@ -801,7 +801,7 @@ impl AeadScheme for AeadPoseidon2 { // OPTIMIZED FELT METHODS // -------------------------------------------------------------------------------------------- - fn encrypt_elements( + fn encrypt_elements( key: &Self::Key, rng: &mut R, plaintext: &[Felt], diff --git a/miden-crypto/src/aead/aead_poseidon2/test.rs b/miden-crypto/src/aead/aead_poseidon2/test.rs index e0b33dc723..919f90cbe8 100644 --- a/miden-crypto/src/aead/aead_poseidon2/test.rs +++ b/miden-crypto/src/aead/aead_poseidon2/test.rs @@ -2,7 +2,7 @@ use proptest::{ prelude::{any, prop}, prop_assert_eq, prop_assert_ne, prop_assume, proptest, }; -use rand::{RngCore, SeedableRng}; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use super::*; diff --git a/miden-crypto/src/aead/mod.rs b/miden-crypto/src/aead/mod.rs index af9735f3ea..090e8ca587 100644 --- a/miden-crypto/src/aead/mod.rs +++ b/miden-crypto/src/aead/mod.rs @@ -52,7 +52,7 @@ pub(crate) trait AeadScheme { // BYTE METHODS // ================================================================================================ - fn encrypt_bytes( + fn encrypt_bytes( key: &Self::Key, rng: &mut R, plaintext: &[u8], @@ -69,7 +69,7 @@ pub(crate) trait AeadScheme { // ================================================================================================ /// Encrypts field elements with associated data. Default implementation converts to bytes. - fn encrypt_elements( + fn encrypt_elements( key: &Self::Key, rng: &mut R, plaintext: &[Felt], diff --git a/miden-crypto/src/aead/xchacha/mod.rs b/miden-crypto/src/aead/xchacha/mod.rs index 1d3b58af33..318cf6d687 100644 --- a/miden-crypto/src/aead/xchacha/mod.rs +++ b/miden-crypto/src/aead/xchacha/mod.rs @@ -16,13 +16,14 @@ use chacha20poly1305::{ XChaCha20Poly1305, aead::{Aead, AeadCore, KeyInit}, }; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; #[cfg(any(test, feature = "testing"))] use subtle::ConstantTimeEq; use crate::{ Felt, aead::{AeadScheme, DataType, EncryptionError}, + rand::compat::RandCore06, utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, bytes_to_elements_exact, elements_to_bytes, @@ -65,19 +66,10 @@ pub struct Nonce { impl Nonce { /// Creates a new random nonce using the provided random number generator - pub fn with_rng(rng: &mut R) -> Self { - // we use a seedable CSPRNG and seed it with `rng` - // this is a work around the fact that the version of the `rand` dependency in our crate - // is different than the one used in the `chacha20poly1305`. This solution will - // no longer be needed once `chacha20poly1305` gets a new release with a version of - // the `rand` dependency matching ours - use chacha20poly1305::aead::rand_core::SeedableRng; - let mut seed = [0_u8; 32]; - RngCore::fill_bytes(rng, &mut seed); - let rng = rand_hc::Hc128Rng::from_seed(seed); - + pub fn with_rng(rng: &mut R) -> Self { + let mut compat_rng = RandCore06::new(rng); Nonce { - inner: XChaCha20Poly1305::generate_nonce(rng), + inner: XChaCha20Poly1305::generate_nonce(&mut compat_rng), } } @@ -115,18 +107,9 @@ impl SecretKey { } /// Creates a new random secret key using the provided random number generator - pub fn with_rng(rng: &mut R) -> Self { - // we use a seedable CSPRNG and seed it with `rng` - // this is a work around the fact that the version of the `rand` dependency in our crate - // is different than the one used in the `chacha20poly1305`. This solution will - // no longer be needed once `chacha20poly1305` gets a new release with a version of - // the `rand` dependency matching ours - use chacha20poly1305::aead::rand_core::SeedableRng; - let mut seed = [0_u8; 32]; - RngCore::fill_bytes(rng, &mut seed); - let rng = rand_hc::Hc128Rng::from_seed(seed); - - let key = XChaCha20Poly1305::generate_key(rng); + pub fn with_rng(rng: &mut R) -> Self { + let mut compat_rng = RandCore06::new(rng); + let key = XChaCha20Poly1305::generate_key(&mut compat_rng); Self(key.into()) } @@ -348,7 +331,7 @@ impl AeadScheme for XChaCha { .map_err(|_| EncryptionError::FailedOperation) } - fn encrypt_bytes( + fn encrypt_bytes( key: &Self::Key, rng: &mut R, plaintext: &[u8], diff --git a/miden-crypto/src/aead/xchacha/test.rs b/miden-crypto/src/aead/xchacha/test.rs index 0e768c10e3..f1a50d2e91 100644 --- a/miden-crypto/src/aead/xchacha/test.rs +++ b/miden-crypto/src/aead/xchacha/test.rs @@ -2,7 +2,7 @@ use proptest::{ prelude::{any, prop}, prop_assert_eq, prop_assert_ne, proptest, }; -use rand::{Rng, SeedableRng}; +use rand::{RngExt, SeedableRng}; use rand_chacha::ChaCha20Rng; use super::*; diff --git a/miden-crypto/src/boxed_storage.rs b/miden-crypto/src/boxed_storage.rs index fc17bdb808..8464c912a5 100644 --- a/miden-crypto/src/boxed_storage.rs +++ b/miden-crypto/src/boxed_storage.rs @@ -55,17 +55,22 @@ impl SmtStorageReader for BoxedStorage { ) -> Result, StorageError> { self.0.get_inner_node(index) } - fn iter_leaves(&self) -> Result + '_>, StorageError> { + fn iter_leaves( + &self, + ) -> Result> + '_>, StorageError> + { self.0.iter_leaves() } fn iter_subtrees( &self, - ) -> Result + '_>, StorageError> - { + ) -> Result< + Box> + '_>, + StorageError, + > { self.0.iter_subtrees() } - fn get_depth24(&self) -> Result, StorageError> { - self.0.get_depth24() + fn get_top_subtree_roots(&self) -> Result, StorageError> { + self.0.get_top_subtree_roots() } } diff --git a/miden-crypto/src/dsa/ecdsa_k256_keccak/mod.rs b/miden-crypto/src/dsa/ecdsa_k256_keccak/mod.rs index e0ec2469e9..c67b92b465 100644 --- a/miden-crypto/src/dsa/ecdsa_k256_keccak/mod.rs +++ b/miden-crypto/src/dsa/ecdsa_k256_keccak/mod.rs @@ -10,12 +10,13 @@ use k256::{ pkcs8::DecodePublicKey, }; use miden_crypto_derive::{SilentDebug, SilentDisplay}; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use thiserror::Error; use crate::{ Felt, SequentialCommit, Word, ecdh::k256::{EphemeralPublicKey, SharedSecret}, + rand::compat::RandCore06, utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, bytes_to_packed_u32_elements, @@ -50,20 +51,9 @@ struct SecretKey { impl SecretKey { /// Generates a new secret key using the provided random number generator. - fn with_rng(rng: &mut R) -> Self { - // we use a seedable CSPRNG and seed it with `rng` - // this is a work around the fact that the version of the `rand` dependency in our crate - // is different than the one used in the `k256` one. This solution will no longer be needed - // once `k256` gets a new release with a version of the `rand` dependency matching ours - use k256::elliptic_curve::rand_core::SeedableRng; - let mut seed = [0_u8; 32]; - RngCore::fill_bytes(rng, &mut seed); - let mut rng = rand_hc::Hc128Rng::from_seed(seed); - - let signing_key = ecdsa::SigningKey::random(&mut rng); - - // Zeroize the seed to prevent leaking secret material - seed.zeroize(); + fn with_rng(rng: &mut R) -> Self { + let mut compat_rng = RandCore06::new(rng); + let signing_key = ecdsa::SigningKey::random(&mut compat_rng); Self { inner: signing_key } } @@ -137,7 +127,7 @@ impl SigningKey { } /// Generates a new signing key using the provided random number generator. - pub fn with_rng(rng: &mut R) -> Self { + pub fn with_rng(rng: &mut R) -> Self { Self(SecretKey::with_rng(rng)) } @@ -198,7 +188,7 @@ impl KeyExchangeKey { } /// Generates a new signing key using the provided random number generator. - pub fn with_rng(rng: &mut R) -> Self { + pub fn with_rng(rng: &mut R) -> Self { Self(SecretKey::with_rng(rng)) } diff --git a/miden-crypto/src/dsa/eddsa_25519_sha512/mod.rs b/miden-crypto/src/dsa/eddsa_25519_sha512/mod.rs index 30139ddb0e..ca99434c73 100644 --- a/miden-crypto/src/dsa/eddsa_25519_sha512/mod.rs +++ b/miden-crypto/src/dsa/eddsa_25519_sha512/mod.rs @@ -6,7 +6,7 @@ use alloc::{string::ToString, vec::Vec}; use der::{Decode, asn1::BitStringRef}; use ed25519_dalek::{Signer, Verifier}; use miden_crypto_derive::{SilentDebug, SilentDisplay}; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use thiserror::Error; use crate::{ @@ -42,9 +42,9 @@ struct SecretKey { impl SecretKey { /// Generates a new secret key using RNG. - fn with_rng(rng: &mut R) -> Self { + fn with_rng(rng: &mut R) -> Self { let mut seed = [0u8; SECRET_KEY_BYTES]; - RngCore::fill_bytes(rng, &mut seed); + rng.fill_bytes(&mut seed); let inner = ed25519_dalek::SigningKey::from_bytes(&seed); @@ -122,7 +122,7 @@ impl SigningKey { } /// Generates a new secret key using RNG. - pub fn with_rng(rng: &mut R) -> Self { + pub fn with_rng(rng: &mut R) -> Self { Self(SecretKey::with_rng(rng)) } @@ -178,7 +178,7 @@ impl KeyExchangeKey { } /// Generates a new secret key using RNG. - pub fn with_rng(rng: &mut R) -> Self { + pub fn with_rng(rng: &mut R) -> Self { Self(SecretKey::with_rng(rng)) } diff --git a/miden-crypto/src/dsa/falcon512_poseidon2/math/ffsampling.rs b/miden-crypto/src/dsa/falcon512_poseidon2/math/ffsampling.rs index 38ab228cd1..5eaf4284c5 100644 --- a/miden-crypto/src/dsa/falcon512_poseidon2/math/ffsampling.rs +++ b/miden-crypto/src/dsa/falcon512_poseidon2/math/ffsampling.rs @@ -199,7 +199,7 @@ pub fn ffsampling( #[cfg(test)] mod tests { use num_complex::Complex64; - use rand::{Rng, SeedableRng}; + use rand::{Rng, RngExt, SeedableRng}; use rand_chacha::ChaCha20Rng; use super::*; diff --git a/miden-crypto/src/dsa/falcon512_poseidon2/tests/mod.rs b/miden-crypto/src/dsa/falcon512_poseidon2/tests/mod.rs index aa902c6827..63db86caa9 100644 --- a/miden-crypto/src/dsa/falcon512_poseidon2/tests/mod.rs +++ b/miden-crypto/src/dsa/falcon512_poseidon2/tests/mod.rs @@ -3,7 +3,7 @@ use data::{ SYNC_DATA_FOR_TEST_VECTOR, }; use prng::Shake256Testing; -use rand::{RngCore, SeedableRng}; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use super::{Serializable, math::Polynomial}; diff --git a/miden-crypto/src/dsa/falcon512_poseidon2/tests/prng.rs b/miden-crypto/src/dsa/falcon512_poseidon2/tests/prng.rs index 0202676bb0..3633b38a06 100644 --- a/miden-crypto/src/dsa/falcon512_poseidon2/tests/prng.rs +++ b/miden-crypto/src/dsa/falcon512_poseidon2/tests/prng.rs @@ -1,7 +1,9 @@ use alloc::vec::Vec; -use rand::{Rng, RngCore}; -use rand_core::impls; +use rand::{ + Rng, + rand_core::{Infallible, TryRng, utils}, +}; use sha3::{ Shake256, Shake256ReaderCore, digest::{ExtendableOutput, Update, XofReader, core_api::XofReaderCoreWrapper}, @@ -49,17 +51,20 @@ impl Shake256Testing { } } -impl RngCore for Shake256Testing { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) +impl TryRng for Shake256Testing { + type Error = Infallible; + + fn try_next_u32(&mut self) -> Result { + utils::next_word_via_fill::(self) } - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_u32(self) + fn try_next_u64(&mut self) -> Result { + utils::next_u64_via_u32(self) } - fn fill_bytes(&mut self, dest: &mut [u8]) { - self.fill_bytes(dest) + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + Shake256Testing::fill_bytes(self, dest); + Ok(()) } } @@ -176,18 +181,21 @@ impl ChaCha { } } -impl RngCore for ChaCha { - fn next_u32(&mut self) -> u32 { - impls::next_u32_via_fill(self) +impl TryRng for ChaCha { + type Error = Infallible; + + fn try_next_u32(&mut self) -> Result { + utils::next_word_via_fill::(self) } - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_u32(self) + fn try_next_u64(&mut self) -> Result { + utils::next_u64_via_u32(self) } - fn fill_bytes(&mut self, dest: &mut [u8]) { + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { let len = dest.len(); let buffer = self.random_bytes(len); - dest.iter_mut().enumerate().for_each(|(i, d)| *d = buffer[i]) + dest.iter_mut().enumerate().for_each(|(i, d)| *d = buffer[i]); + Ok(()) } } diff --git a/miden-crypto/src/ecdh/k256.rs b/miden-crypto/src/ecdh/k256.rs index 0771cd6a2a..d921db093d 100644 --- a/miden-crypto/src/ecdh/k256.rs +++ b/miden-crypto/src/ecdh/k256.rs @@ -16,14 +16,15 @@ use alloc::{string::ToString, vec::Vec}; use hkdf::{Hkdf, hmac::SimpleHmac}; use k256::{AffinePoint, elliptic_curve::sec1::ToEncodedPoint, sha2::Sha256}; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use crate::{ dsa::ecdsa_k256_keccak::{KeyExchangeKey, PUBLIC_KEY_BYTES, PublicKey}, ecdh::KeyAgreementScheme, + rand::compat::RandCore06, utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, - zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing}, + zeroize::{Zeroize, ZeroizeOnDrop}, }, }; // SHARED SECRET @@ -109,17 +110,9 @@ impl EphemeralSecretKey { } /// Generates a new ephemeral secret key using the provided random number generator. - pub fn with_rng(rng: &mut R) -> Self { - // we use a seedable CSPRNG and seed it with `rng` - // this is a work around the fact that the version of the `rand` dependency in our crate - // is different than the one used in the `k256` one. This solution will no longer be needed - // once `k256` gets a new release with a version of the `rand` dependency matching ours - use k256::elliptic_curve::rand_core::SeedableRng; - let mut seed = Zeroizing::new([0_u8; 32]); - RngCore::fill_bytes(rng, &mut *seed); - let mut rng = rand_hc::Hc128Rng::from_seed(*seed); - - let sk_e = k256::ecdh::EphemeralSecret::random(&mut rng); + pub fn with_rng(rng: &mut R) -> Self { + let mut compat_rng = RandCore06::new(rng); + let sk_e = k256::ecdh::EphemeralSecret::random(&mut compat_rng); Self { inner: sk_e } } @@ -191,7 +184,7 @@ impl KeyAgreementScheme for K256 { type SharedSecret = SharedSecret; - fn generate_ephemeral_keypair( + fn generate_ephemeral_keypair( rng: &mut R, ) -> (Self::EphemeralSecretKey, Self::EphemeralPublicKey) { let sk = EphemeralSecretKey::with_rng(rng); diff --git a/miden-crypto/src/ecdh/mod.rs b/miden-crypto/src/ecdh/mod.rs index 8bcfac6a2f..6119f900d7 100644 --- a/miden-crypto/src/ecdh/mod.rs +++ b/miden-crypto/src/ecdh/mod.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use thiserror::Error; use crate::utils::{ @@ -26,7 +26,7 @@ pub(crate) trait KeyAgreementScheme { type SharedSecret: AsRef<[u8]> + Zeroize + ZeroizeOnDrop; /// Returns an ephemeral key pair generated from the provided RNG. - fn generate_ephemeral_keypair( + fn generate_ephemeral_keypair( rng: &mut R, ) -> (Self::EphemeralSecretKey, Self::EphemeralPublicKey); diff --git a/miden-crypto/src/ecdh/x25519.rs b/miden-crypto/src/ecdh/x25519.rs index f522ac1bde..df51f6d630 100644 --- a/miden-crypto/src/ecdh/x25519.rs +++ b/miden-crypto/src/ecdh/x25519.rs @@ -16,15 +16,16 @@ use alloc::vec::Vec; use hkdf::{Hkdf, hmac::SimpleHmac}; use k256::sha2::Sha256; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use subtle::ConstantTimeEq; use crate::{ dsa::eddsa_25519_sha512::{KeyExchangeKey, PublicKey}, ecdh::KeyAgreementScheme, + rand::compat::RandCore06, utils::{ ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, - zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing}, + zeroize::{Zeroize, ZeroizeOnDrop}, }, }; // SHARED SECRETE @@ -107,18 +108,9 @@ impl EphemeralSecretKey { } /// Generates a new random ephemeral secret key using the provided RNG. - pub fn with_rng(rng: &mut R) -> Self { - // we use a seedable CSPRNG and seed it with `rng` - // this is a work around the fact that the version of the `rand` dependency in our crate - // is different than the one used in the `x25519_dalek` one. This solution will no longer be - // needed once `x25519_dalek` gets a new release with a version of the `rand` - // dependency matching ours - use k256::elliptic_curve::rand_core::SeedableRng; - let mut seed = Zeroizing::new([0_u8; 32]); - RngCore::fill_bytes(rng, &mut *seed); - let rng = rand_hc::Hc128Rng::from_seed(*seed); - - let sk = x25519_dalek::EphemeralSecret::random_from_rng(rng); + pub fn with_rng(rng: &mut R) -> Self { + let mut compat_rng = RandCore06::new(rng); + let sk = x25519_dalek::EphemeralSecret::random_from_rng(&mut compat_rng); Self { inner: sk } } @@ -186,7 +178,7 @@ impl KeyAgreementScheme for X25519 { type SharedSecret = SharedSecret; - fn generate_ephemeral_keypair( + fn generate_ephemeral_keypair( rng: &mut R, ) -> (Self::EphemeralSecretKey, Self::EphemeralPublicKey) { let sk = EphemeralSecretKey::with_rng(rng); diff --git a/miden-crypto/src/ies/crypto_box.rs b/miden-crypto/src/ies/crypto_box.rs index 03477e8dcb..0b15083e5e 100644 --- a/miden-crypto/src/ies/crypto_box.rs +++ b/miden-crypto/src/ies/crypto_box.rs @@ -6,7 +6,7 @@ use alloc::vec::Vec; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use super::{IesError, IesScheme}; use crate::{ @@ -42,7 +42,7 @@ impl CryptoBox { // BYTE-SPECIFIC METHODS // -------------------------------------------------------------------------------------------- - pub fn seal_bytes_with_associated_data( + pub fn seal_bytes_with_associated_data( rng: &mut R, recipient_public_key: &K::PublicKey, scheme: IesScheme, @@ -103,7 +103,7 @@ impl CryptoBox { // ELEMENT-SPECIFIC METHODS // -------------------------------------------------------------------------------------------- - pub fn seal_elements_with_associated_data( + pub fn seal_elements_with_associated_data( rng: &mut R, recipient_public_key: &K::PublicKey, scheme: IesScheme, diff --git a/miden-crypto/src/ies/keys.rs b/miden-crypto/src/ies/keys.rs index 21b4afff36..b6dc0527b6 100644 --- a/miden-crypto/src/ies/keys.rs +++ b/miden-crypto/src/ies/keys.rs @@ -1,7 +1,7 @@ use alloc::vec::Vec; use core::fmt; -use rand::{CryptoRng, RngCore}; +use rand::CryptoRng; use super::{IesError, IesScheme, crypto_box::CryptoBox, message::SealedMessage}; use crate::{ @@ -38,7 +38,7 @@ macro_rules! impl_seal_bytes_with_associated_data { /// /// The returned message can be unsealed with the [UnsealingKey] associated with this /// sealing key. - pub fn seal_bytes_with_associated_data( + pub fn seal_bytes_with_associated_data( &self, rng: &mut R, plaintext: &[u8], @@ -75,7 +75,7 @@ macro_rules! impl_seal_elements_with_associated_data { /// /// The returned message can be unsealed with the [UnsealingKey] associated with this /// sealing key. - pub fn seal_elements_with_associated_data( + pub fn seal_elements_with_associated_data( &self, rng: &mut R, plaintext: &[Felt], @@ -221,7 +221,7 @@ impl SealingKey { /// /// The returned message can be unsealed with the [UnsealingKey] associated with this sealing /// key. - pub fn seal_bytes( + pub fn seal_bytes( &self, rng: &mut R, plaintext: &[u8], @@ -240,7 +240,7 @@ impl SealingKey { /// /// The returned message can be unsealed with the [UnsealingKey] associated with this sealing /// key. - pub fn seal_elements( + pub fn seal_elements( &self, rng: &mut R, plaintext: &[Felt], diff --git a/miden-crypto/src/ies/tests.rs b/miden-crypto/src/ies/tests.rs index b3a89ac895..ef9d9ac014 100644 --- a/miden-crypto/src/ies/tests.rs +++ b/miden-crypto/src/ies/tests.rs @@ -3,7 +3,7 @@ use alloc::vec::Vec; use proptest::prelude::*; -use rand::{RngCore, SeedableRng}; +use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha20Rng; use crate::{ diff --git a/miden-crypto/src/lib.rs b/miden-crypto/src/lib.rs index cb10e6fb53..2d55f07985 100644 --- a/miden-crypto/src/lib.rs +++ b/miden-crypto/src/lib.rs @@ -46,15 +46,15 @@ pub mod stark { //! Lifted STARK proving system based on Plonky3. //! //! Sub-modules from `miden-lifted-stark`: - //! - [`proof`] — [`proof::StarkProof`], [`proof::StarkDigest`], [`proof::StarkOutput`], - //! [`proof::StarkTranscript`] + //! - [`proof`] — [`proof::StarkProofData`] (wire artifact), [`proof::StarkProof`] (structured + //! view), [`proof::StarkDigest`], [`proof::StarkOutput`], [`proof::TranscriptChallenger`], + //! [`proof::TranscriptData`] //! - [`air`] — AIR traits, builders, symbolic types (includes all of `p3-air`) - //! - [`fri`] — PCS parameters, DEEP + FRI types + //! - [`pcs`] — PCS parameters, DEEP + FRI sub-proofs //! - [`lmcs`] — Lifted Merkle commitment scheme - //! - [`transcript`] — Fiat-Shamir channels and transcript data //! - [`hasher`] — Stateful hasher primitives - //! - [`prover`] — `prove_single` / `prove_multi` - //! - [`verifier`] — `verify_single` / `verify_multi` + //! - [`prover`] — [`ProverInstance::prove`] + //! - [`verifier`] — [`VerifierInstance::verify`] //! - [`debug`] — Debug constraint checker for lifted AIRs //! //! Sub-modules from upstream Plonky3: @@ -64,11 +64,12 @@ pub mod stark { //! - [`symmetric`] — Symmetric cryptographic primitives // Top-level types from lifted-stark - pub use miden_lifted_stark::{GenericStarkConfig, StarkConfig}; - // Lifted-stark sub-modules (re-exported as-is) pub use miden_lifted_stark::{ - air, debug, fri, hasher, lmcs, proof, prover, transcript, verifier, + GenericStarkConfig, Preprocessed, PreprocessedValidationError, ProverInstance, StarkConfig, + VerifierInstance, }; + // Lifted-stark sub-modules (re-exported as-is) + pub use miden_lifted_stark::{air, debug, hasher, lmcs, pcs, proof, prover, verifier}; // Upstream Plonky3: challenger pub mod challenger { diff --git a/miden-crypto/src/main.rs b/miden-crypto/src/main.rs index 37dfa3e806..4b3f8bee1c 100644 --- a/miden-crypto/src/main.rs +++ b/miden-crypto/src/main.rs @@ -1,3 +1,5 @@ +#[cfg(any(test, feature = "rocksdb"))] +use std::path::Path; use std::{path::PathBuf, time::Instant}; use clap::{Parser, ValueEnum}; @@ -6,10 +8,10 @@ use miden_crypto::merkle::smt::{RocksDbConfig, RocksDbStorage}; use miden_crypto::{ EMPTY_WORD, Felt, ONE, Word, hash::poseidon2::Poseidon2, - merkle::smt::{LargeSmt, LargeSmtError, MemoryStorage}, + merkle::smt::{LargeSmt, LargeSmtError, MemoryStorage, StorageError}, rand::test_utils::rand_value, }; -use rand::{Rng, prelude::IteratorRandom, rng}; +use rand::{RngExt, prelude::IteratorRandom, rng}; #[cfg(feature = "executable")] mod boxed_storage; @@ -33,11 +35,14 @@ pub struct BenchmarkCmd { /// Open existing database and skip construction #[clap(short = 'o', long = "open", default_value = "false")] open: bool, + /// Delete an existing benchmark database path before creating a new one + #[clap(long = "reset", default_value = "false")] + reset: bool, /// Number of batch operations #[clap(short = 'b', long = "batches", default_value = "1")] batches: usize, /// Storage backend to use at runtime: memory or rocksdb - #[arg(short = 's', long = "storage", value_enum, default_value = "memory")] + #[arg(long = "storage", value_enum, default_value = "memory")] storage: StorageKind, } @@ -47,19 +52,21 @@ pub enum StorageKind { Rocksdb, } -fn main() { - benchmark_smt(); +fn main() -> Result<(), LargeSmtError> { + benchmark_smt()?; println!("Benchmark completed successfully"); + Ok(()) } /// Run a benchmark for [`Smt`]. -pub fn benchmark_smt() { +pub fn benchmark_smt() -> Result<(), LargeSmtError> { let args = BenchmarkCmd::parse(); let tree_size = args.size; let insertions = args.insertions; let updates = args.updates; let storage_path = args.storage_path; let batches = args.batches; + let reset = args.reset; println!( "Running benchmark with {} storage", @@ -78,16 +85,18 @@ pub fn benchmark_smt() { } let mut tree = if args.open { - open_existing(storage_path, args.storage).unwrap() + open_existing(storage_path, args.storage)? } else { - construction(entries.clone(), tree_size, storage_path, args.storage).unwrap() + construction(entries.clone(), tree_size, storage_path, args.storage, reset)? }; - insertion(&mut tree, insertions).unwrap(); + insertion(&mut tree, insertions)?; for _ in 0..batches { - batched_insertion(&mut tree, insertions).unwrap(); - batched_update(&mut tree, &entries, updates).unwrap(); + batched_insertion(&mut tree, insertions)?; + batched_update(&mut tree, &entries, updates)?; } - proof_generation(&mut tree).unwrap(); + proof_generation(&mut tree)?; + + Ok(()) } /// Runs the construction benchmark for [`Smt`], returning the constructed tree. @@ -96,10 +105,11 @@ pub fn construction( size: usize, database_path: Option, storage: StorageKind, + reset: bool, ) -> Result, LargeSmtError> { println!("Running a construction benchmark:"); let now = Instant::now(); - let storage = get_storage(database_path, false, storage); + let storage = get_storage(database_path, false, reset, storage)?; let tree = LargeSmt::with_entries(storage, entries)?; let elapsed = now.elapsed().as_secs_f32(); println!("Constructed an SMT with {size} key-value pairs in {elapsed:.1} seconds"); @@ -114,7 +124,7 @@ pub fn open_existing( ) -> Result, LargeSmtError> { println!("Opening an existing database:"); let now = Instant::now(); - let storage = get_storage(storage_path, true, storage); + let storage = get_storage(storage_path, true, false, storage)?; let tree = LargeSmt::load(storage)?; let elapsed = now.elapsed().as_secs_f32(); println!("Opened an existing database in {elapsed:.1} seconds"); @@ -204,16 +214,15 @@ pub fn batched_update( let size = tree.num_leaves(); let mut rng = rng(); - let new_pairs = - entries.iter().choose_multiple(&mut rng, updates).into_iter().map(|&(key, _)| { - let value = if rng.random_bool(REMOVAL_PROBABILITY) { - EMPTY_WORD - } else { - Word::new([ONE, ONE, ONE, Felt::new_unchecked(rng.random())]) - }; + let new_pairs = entries.iter().sample(&mut rng, updates).into_iter().map(|&(key, _)| { + let value = if rng.random_bool(REMOVAL_PROBABILITY) { + EMPTY_WORD + } else { + Word::new([ONE, ONE, ONE, Felt::new_unchecked(rng.random())]) + }; - (key, value) - }); + (key, value) + }); assert_eq!(new_pairs.len(), updates); @@ -260,8 +269,8 @@ pub fn proof_generation(tree: &mut LargeSmt) -> Result<(), LargeSmtErro let keys = tree .leaves()? .take(NUM_PROOFS) - .map(|(_, leaf)| leaf.entries()[0].0) - .collect::>(); + .map(|result| result.map(|(_, leaf)| leaf.entries()[0].0)) + .collect::, _>>()?; for key in keys { let now = Instant::now(); @@ -279,9 +288,14 @@ pub fn proof_generation(tree: &mut LargeSmt) -> Result<(), LargeSmtErro } #[allow(unused_variables)] -fn get_storage(database_path: Option, open: bool, kind: StorageKind) -> Storage { +fn get_storage( + database_path: Option, + open: bool, + reset: bool, + kind: StorageKind, +) -> Result { match kind { - StorageKind::Memory => Box::new(BoxedStorage(MemoryStorage::new())), + StorageKind::Memory => Ok(Box::new(BoxedStorage(MemoryStorage::new()))), StorageKind::Rocksdb => { #[cfg(feature = "rocksdb")] { @@ -289,24 +303,131 @@ fn get_storage(database_path: Option, open: bool, kind: StorageKind) -> .unwrap_or_else(|| std::env::temp_dir().join("miden_crypto_benchmark")); println!("Using database path: {}", path.display()); if !open { - // delete the folder if it exists as we are creating a new database - if path.exists() { - std::fs::remove_dir_all(path.clone()).unwrap(); - } - std::fs::create_dir_all(path.clone()) - .expect("Failed to create database directory"); + prepare_database_directory(&path, reset)?; } let db = RocksDbStorage::open( RocksDbConfig::new(path).with_cache_size(1 << 30).with_max_open_files(2048), - ) - .expect("Failed to open database"); - Box::new(BoxedStorage(db)) + )?; + Ok(Box::new(BoxedStorage(db))) } #[cfg(not(feature = "rocksdb"))] { - eprintln!("rocksdb feature not enabled; falling back to memory storage"); - Box::new(BoxedStorage(MemoryStorage::new())) + Err(StorageError::Unsupported( + "rocksdb storage was requested, but the rocksdb feature is not enabled".into(), + ) + .into()) } }, } } + +#[cfg(any(test, feature = "rocksdb"))] +fn prepare_database_directory(path: &Path, reset: bool) -> Result<(), LargeSmtError> { + if path.exists() { + if !reset { + return Err(StorageError::Unsupported(format!( + "database path already exists: {}; pass --reset to delete it before creating a new benchmark database", + path.display() + )) + .into()); + } + + std::fs::remove_dir_all(path).map_err(|err| { + storage_io_error(format!("failed to reset database path {}", path.display()), err) + })?; + } + + std::fs::create_dir_all(path).map_err(|err| { + storage_io_error(format!("failed to create database path {}", path.display()), err) + })?; + + Ok(()) +} + +#[cfg(any(test, feature = "rocksdb"))] +fn storage_io_error(message: String, err: std::io::Error) -> LargeSmtError { + StorageError::Backend(Box::new(std::io::Error::new(err.kind(), format!("{message}: {err}")))) + .into() +} + +#[cfg(test)] +mod tests { + use clap::{CommandFactory, Parser, ValueEnum}; + + use super::*; + + #[test] + fn storage_value_parser_accepts_memory() { + assert_eq!(StorageKind::from_str("memory", true).unwrap(), StorageKind::Memory); + } + + #[test] + fn clap_command_definition_is_valid() { + BenchmarkCmd::command().debug_assert(); + } + + #[test] + fn parses_size_short_and_memory_storage() { + let args = BenchmarkCmd::parse_from(["miden-crypto", "-s", "10", "--storage", "memory"]); + + assert_eq!(args.size, 10); + assert_eq!(args.storage, StorageKind::Memory); + } + + #[cfg(not(feature = "rocksdb"))] + #[test] + fn rejects_explicit_rocksdb_storage_without_feature() { + let err = get_storage(None, false, false, StorageKind::Rocksdb).unwrap_err(); + match err { + LargeSmtError::Storage(StorageError::Unsupported(msg)) => { + assert!(msg.contains("rocksdb feature")); + }, + other => panic!("expected unsupported rocksdb storage error, got {other:?}"), + } + } + + #[cfg(feature = "rocksdb")] + #[test] + fn storage_value_parser_accepts_rocksdb_with_feature() { + assert_eq!(StorageKind::from_str("rocksdb", true).unwrap(), StorageKind::Rocksdb); + } + + #[cfg(feature = "rocksdb")] + #[test] + fn parses_explicit_rocksdb_storage_with_feature() { + let args = + BenchmarkCmd::parse_from(["miden-crypto", "--size", "10", "--storage", "rocksdb"]); + + assert_eq!(args.size, 10); + assert_eq!(args.storage, StorageKind::Rocksdb); + } + + #[test] + fn existing_database_path_requires_reset_and_preserves_contents() { + let temp_dir = tempfile::tempdir().unwrap(); + let sentinel = temp_dir.path().join("sentinel.txt"); + std::fs::write(&sentinel, "keep").unwrap(); + + let err = prepare_database_directory(temp_dir.path(), false).unwrap_err(); + match err { + LargeSmtError::Storage(StorageError::Unsupported(msg)) => { + assert!(msg.contains("--reset")); + }, + other => panic!("expected reset-required error, got {other:?}"), + } + + assert_eq!(std::fs::read_to_string(&sentinel).unwrap(), "keep"); + } + + #[test] + fn reset_database_path_removes_existing_contents() { + let temp_dir = tempfile::tempdir().unwrap(); + let sentinel = temp_dir.path().join("sentinel.txt"); + std::fs::write(&sentinel, "delete").unwrap(); + + prepare_database_directory(temp_dir.path(), true).unwrap(); + + assert!(temp_dir.path().is_dir()); + assert!(!sentinel.exists()); + } +} diff --git a/miden-crypto/src/merkle/mmr/partial.rs b/miden-crypto/src/merkle/mmr/partial.rs index 244a222f88..2cb5ea52a8 100644 --- a/miden-crypto/src/merkle/mmr/partial.rs +++ b/miden-crypto/src/merkle/mmr/partial.rs @@ -766,7 +766,14 @@ impl Deserializable for PartialMmr { )); } let tracked: Vec = Vec::read_from(source)?; - let tracked_leaves: BTreeSet = tracked.into_iter().collect(); + let mut tracked_leaves = BTreeSet::new(); + for leaf_pos in tracked { + if !tracked_leaves.insert(leaf_pos) { + return Err(DeserializationError::InvalidValue( + "duplicate tracked leaf in partial mmr encoding".to_string(), + )); + } + } // Construct MmrPeaks to validate forest/peaks consistency let peaks = MmrPeaks::new(forest, peaks_vec).map_err(|e| { @@ -1066,6 +1073,27 @@ mod tests { assert_eq!(partial_mmr, decoded); } + #[test] + fn test_partial_mmr_deserialization_rejects_duplicate_tracked_leaves() { + let mmr = Mmr::try_from_iter(LEAVES.iter().copied()).unwrap(); + let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks()); + let leaf_pos = 1usize; + let node = mmr.get(leaf_pos).unwrap(); + let proof = mmr.open(leaf_pos).unwrap(); + partial_mmr.track(leaf_pos, node, proof.path().merkle_path()).unwrap(); + + let mut bytes = Vec::new(); + partial_mmr.forest.num_leaves().write_into(&mut bytes); + partial_mmr.peaks.write_into(&mut bytes); + partial_mmr.nodes.write_into(&mut bytes); + bytes.write_u8(PartialMmr::TRACKED_LEAVES_MARKER); + vec![leaf_pos, leaf_pos].write_into(&mut bytes); + + let result = PartialMmr::read_from_bytes(&bytes); + + assert!(matches!(result, Err(DeserializationError::InvalidValue(_)))); + } + #[test] fn test_partial_mmr_deserialization_rejects_large_forest() { let mut bytes = (Forest::MAX_LEAVES + 1).to_bytes(); @@ -1434,6 +1462,7 @@ mod tests { bad_bytes.extend_from_slice(&1usize.to_bytes()); // BTreeMap length bad_bytes.extend_from_slice(&0usize.to_bytes()); // invalid index 0 bad_bytes.extend_from_slice(&int_to_node(0).to_bytes()); // value + bad_bytes.push(PartialMmr::TRACKED_LEAVES_MARKER); // tracked_leaves: empty vec bad_bytes.extend_from_slice(&0usize.to_bytes()); diff --git a/miden-crypto/src/merkle/smt/full/concurrent/tests.rs b/miden-crypto/src/merkle/smt/full/concurrent/tests.rs index b34fa85f0c..a925783b5f 100644 --- a/miden-crypto/src/merkle/smt/full/concurrent/tests.rs +++ b/miden-crypto/src/merkle/smt/full/concurrent/tests.rs @@ -7,7 +7,7 @@ use alloc::{ use assert_matches::assert_matches; use proptest::prelude::*; -use rand::{Rng, SeedableRng, prelude::IteratorRandom}; +use rand::{RngExt, SeedableRng, prelude::IteratorRandom}; use rand_chacha::ChaCha20Rng; use super::{ @@ -132,7 +132,7 @@ fn generate_updates(entries: Vec<(Word, Word)>, updates: usize) -> Vec<(Word, Wo ); let mut sorted_entries: Vec<(Word, Word)> = entries .into_iter() - .choose_multiple(&mut rng, updates) + .sample(&mut rng, updates) .into_iter() .map(|(key, _)| { let value = if rng.random_bool(REMOVAL_PROBABILITY) { diff --git a/miden-crypto/src/merkle/smt/large/batch_ops.rs b/miden-crypto/src/merkle/smt/large/batch_ops.rs index 05d1b4f5a5..d3daf4affc 100644 --- a/miden-crypto/src/merkle/smt/large/batch_ops.rs +++ b/miden-crypto/src/merkle/smt/large/batch_ops.rs @@ -5,8 +5,8 @@ use num::Integer; use p3_maybe_rayon::prelude::*; use super::{ - IN_MEMORY_DEPTH, LargeSmt, LargeSmtError, LoadedLeaves, MutatedLeaves, ROOT_MEMORY_INDEX, - SMT_DEPTH, SmtStorage, StorageUpdates, Subtree, SubtreeUpdate, + IN_MEMORY_DEPTH, LargeSmt, LargeSmtError, LargeSmtResult, LoadedLeaves, MutatedLeaves, + ROOT_MEMORY_INDEX, SMT_DEPTH, SmtStorage, StorageUpdates, Subtree, SubtreeUpdate, }; use crate::{ Word, @@ -103,7 +103,7 @@ impl LargeSmt { fn load_leaves_for_pairs( &self, sorted_kv_pairs: &[(Word, Word)], - ) -> Result { + ) -> LargeSmtResult { // Collect the unique leaf indices. If the input is truly sorted, then we can dedup // directly. let mut leaf_indices: Vec = sorted_kv_pairs @@ -359,7 +359,7 @@ impl LargeSmt { pub fn insert_batch( &mut self, kv_pairs: impl IntoIterator, - ) -> Result + ) -> LargeSmtResult where Self: Sized + Sync, { @@ -483,7 +483,7 @@ impl LargeSmt { fn prepare_mutations( &self, mutations: MutationSet, - ) -> Result { + ) -> LargeSmtResult { let MutationSet { old_root, node_mutations, @@ -562,10 +562,7 @@ impl LargeSmt { /// /// Note: This and [`insert_batch()`](Self::insert_batch) are the only two methods that /// persist changes to storage. - fn apply_prepared_mutations( - &mut self, - prepared: PreparedMutations, - ) -> Result<(), LargeSmtError> { + fn apply_prepared_mutations(&mut self, prepared: PreparedMutations) -> LargeSmtResult<()> { use NodeMutation::*; let PreparedMutations { @@ -719,7 +716,7 @@ impl LargeSmt { pub fn compute_mutations( &self, kv_pairs: impl IntoIterator, - ) -> Result, LargeSmtError> + ) -> LargeSmtResult> where Self: Sized + Sync, { @@ -818,7 +815,7 @@ impl LargeSmt { pub fn apply_mutations( &mut self, mutations: MutationSet, - ) -> Result<(), LargeSmtError> { + ) -> LargeSmtResult<()> { let prepared = self.prepare_mutations(mutations)?; self.apply_prepared_mutations(prepared)?; Ok(()) @@ -837,7 +834,7 @@ impl LargeSmt { pub fn apply_mutations_with_reversion( &mut self, mutations: MutationSet, - ) -> Result, LargeSmtError> + ) -> LargeSmtResult> where Self: Sized, { diff --git a/miden-crypto/src/merkle/smt/large/construction.rs b/miden-crypto/src/merkle/smt/large/construction.rs index c6c5b3e02a..680d8b535a 100644 --- a/miden-crypto/src/merkle/smt/large/construction.rs +++ b/miden-crypto/src/merkle/smt/large/construction.rs @@ -16,7 +16,7 @@ use crate::{ full::concurrent::{ PairComputations, SUBTREE_DEPTH, SubtreeLeaf, SubtreeLeavesIter, build_subtree, }, - large::to_memory_index, + large::{LargeSmtResult, to_memory_index}, }, }; @@ -40,7 +40,7 @@ impl LargeSmt { /// let storage = MemoryStorage::new(); /// let smt = LargeSmt::new(storage).expect("Failed to create SMT"); /// ``` - pub fn new(storage: S) -> Result { + pub fn new(storage: S) -> LargeSmtResult { if storage.has_leaves()? { return Err(LargeSmtError::StorageNotEmpty); } @@ -50,7 +50,7 @@ impl LargeSmt { /// Loads an existing [LargeSmt] from storage without validating the root. /// /// If the storage is empty, the SMT is initialized with the root of an empty tree. - /// Otherwise, the in-memory top of the tree is reconstructed from the cached depth-24 + /// Otherwise, the in-memory top of the tree is reconstructed from the cached in-memory-depth /// subtree hashes stored in the backend. /// /// **Note:** This method does not validate the reconstructed root. Use this only when @@ -69,13 +69,13 @@ impl LargeSmt { /// let smt = LargeSmt::load(storage).expect("Failed to load SMT"); /// # } /// ``` - pub fn load(storage: S) -> Result { + pub fn load(storage: S) -> LargeSmtResult { Self::initialize_from_storage(storage) } /// Loads an existing [LargeSmt] from storage and validates it against the expected root. /// - /// This method reconstructs the in-memory top of the tree from the cached depth-24 + /// This method reconstructs the in-memory top of the tree from the cached in-memory-depth /// subtree hashes, computes the root, and validates it against `expected_root`. /// /// Use this method when reloading a tree to ensure the storage contains the expected @@ -100,7 +100,7 @@ impl LargeSmt { /// .expect("Failed to load SMT with expected root"); /// # } /// ``` - pub fn load_with_root(storage: S, expected_root: Word) -> Result { + pub fn load_with_root(storage: S, expected_root: Word) -> LargeSmtResult { let smt = Self::load(storage)?; let actual_root = smt.root(); @@ -117,8 +117,8 @@ impl LargeSmt { /// Internal method that initializes the in-memory tree from storage. /// /// For empty storage, returns an empty tree. For non-empty storage, - /// rebuilds the in-memory top from cached depth-24 hashes. - fn initialize_from_storage(storage: S) -> Result { + /// rebuilds the in-memory top from cached in-memory-depth hashes. + fn initialize_from_storage(storage: S) -> LargeSmtResult { // Initialize in-memory nodes let mut in_memory_nodes: Vec = vec![EMPTY_WORD; NUM_IN_MEMORY_NODES]; @@ -147,8 +147,8 @@ impl LargeSmt { let leaf_count = storage.leaf_count()?; let entry_count = storage.entry_count()?; - // Get the in-memory top of tree leaves from storage - let in_memory_tree_leaves = storage.get_depth24()?; + // Get the tops of subtrees from storage; these become the leaves of the in-memory tree + let in_memory_tree_leaves = storage.get_top_subtree_roots()?; // Convert in-memory top of tree leaves to SubtreeLeaf let mut leaf_subtrees: Vec = in_memory_tree_leaves @@ -212,7 +212,7 @@ impl LargeSmt { pub fn with_entries( storage: S, entries: impl IntoIterator, - ) -> Result { + ) -> LargeSmtResult { let entries: Vec<(Word, Word)> = entries.into_iter().collect(); if storage.has_leaves()? { diff --git a/miden-crypto/src/merkle/smt/large/error.rs b/miden-crypto/src/merkle/smt/large/error.rs index d279d809a9..ee81701b1c 100644 --- a/miden-crypto/src/merkle/smt/large/error.rs +++ b/miden-crypto/src/merkle/smt/large/error.rs @@ -34,6 +34,9 @@ pub enum LargeSmtError { StorageNotEmpty, } +/// The result type for use within the large SMT portion of the library. +pub type LargeSmtResult = Result; + #[cfg(test)] // Compile-time assertion that LargeSmtError implements the required traits const _: fn() = || { diff --git a/miden-crypto/src/merkle/smt/large/iter.rs b/miden-crypto/src/merkle/smt/large/iter.rs index 18a24accbc..8f5a6db1aa 100644 --- a/miden-crypto/src/merkle/smt/large/iter.rs +++ b/miden-crypto/src/merkle/smt/large/iter.rs @@ -1,10 +1,13 @@ use alloc::{boxed::Box, vec::Vec}; -use super::{IN_MEMORY_DEPTH, LargeSmt, SmtStorageReader, is_empty_parent}; +use super::{IN_MEMORY_DEPTH, LargeSmtResult, StorageResult, is_empty_parent}; use crate::{ Word, hash::poseidon2::Poseidon2, - merkle::{InnerNodeInfo, smt::large::subtree::Subtree}, + merkle::{ + InnerNodeInfo, + smt::{LargeSmt, SmtStorageReader, large::subtree::Subtree}, + }, }; // ITERATORS @@ -16,7 +19,7 @@ enum InnerNodeIteratorState<'a> { large_smt_in_memory_nodes: &'a [Word], }, Subtree { - subtree_iter: Box + 'a>, + subtree_iter: Box> + 'a>, current_subtree_node_iter: Option + 'a>>, }, Done, @@ -41,7 +44,7 @@ impl<'a, S: SmtStorageReader> LargeSmtInnerNodeIterator<'a, S> { } impl Iterator for LargeSmtInnerNodeIterator<'_, S> { - type Item = InnerNodeInfo; + type Item = LargeSmtResult; /// Returns the next inner node info in the tree. /// @@ -75,11 +78,11 @@ impl Iterator for LargeSmtInnerNodeIterator<'_, S> { let child_depth = depth + 1; if !is_empty_parent(left, right, child_depth) { - return Some(InnerNodeInfo { + return Some(Ok(InnerNodeInfo { value: Poseidon2::merge(&[left, right]), left, right, - }); + })); } } @@ -92,11 +95,11 @@ impl Iterator for LargeSmtInnerNodeIterator<'_, S> { }; continue; // Start processing subtrees immediately }, - Err(_e) => { - // Storage error occurred - we should propagate this properly - // For now, transition to Done state to avoid infinite loops + Err(e) => { + // Storage error occurred - we should propagate this error. + // We also transition to Done state to avoid infinite loops. self.state = InnerNodeIteratorState::Done; - return None; + return Some(LargeSmtResult::Err(e.into())); }, } }, @@ -107,12 +110,12 @@ impl Iterator for LargeSmtInnerNodeIterator<'_, S> { if let Some(node_iter) = current_subtree_node_iter && let Some(info) = node_iter.as_mut().next() { - return Some(info); + return Some(Ok(info)); } // Current subtree exhausted, move to next subtree match subtree_iter.next() { - Some(next_subtree) => { + Some(Ok(next_subtree)) => { // Collect is necessary here because iter_inner_node_info returns // an iterator borrowing from next_subtree, which would outlive // the subtree itself. We need to eagerly evaluate to owned data. @@ -121,6 +124,7 @@ impl Iterator for LargeSmtInnerNodeIterator<'_, S> { next_subtree.iter_inner_node_info().collect(); *current_subtree_node_iter = Some(Box::new(infos.into_iter())); }, + Some(Err(err)) => return Some(Err(err.into())), None => { self.state = InnerNodeIteratorState::Done; return None; // All subtrees processed diff --git a/miden-crypto/src/merkle/smt/large/mod.rs b/miden-crypto/src/merkle/smt/large/mod.rs index cda428cc11..d765d74155 100644 --- a/miden-crypto/src/merkle/smt/large/mod.rs +++ b/miden-crypto/src/merkle/smt/large/mod.rs @@ -1,10 +1,10 @@ //! Large-scale Sparse Merkle Tree backed by pluggable storage. //! -//! `LargeSmt` stores the top of the tree (depths 0–23) in memory and persists the lower -//! depths (24–64) in storage as fixed-size subtrees. This hybrid layout scales beyond RAM +//! `LargeSmt` stores the top of the tree (depths 0–`IN_MEMORY_DEPTH`-1) in memory and persists +//! the lower depths in storage as fixed-size subtrees. This hybrid layout scales beyond RAM //! while keeping common operations fast. With the `rocksdb` feature enabled, the lower //! subtrees and leaves are stored in RocksDB. On reload, the in-memory top is reconstructed -//! from cached depth-24 subtree roots. +//! from cached in-memory-depth subtree roots. //! //! Examples below require the `rocksdb` feature. //! @@ -249,7 +249,7 @@ use crate::{ }; mod error; -pub use error::LargeSmtError; +pub use error::{LargeSmtError, LargeSmtResult}; #[cfg(test)] mod property_tests; @@ -262,7 +262,7 @@ pub use subtree::{Subtree, SubtreeError}; mod storage; pub use storage::{ MemoryStorage, MemoryStorageSnapshot, SmtStorage, SmtStorageReader, StorageError, - StorageUpdateParts, StorageUpdates, SubtreeUpdate, + StorageResult, StorageUpdateParts, StorageUpdates, SubtreeUpdate, }; #[cfg(feature = "rocksdb")] pub use storage::{RocksDbConfig, RocksDbSnapshotStorage, RocksDbStorage}; @@ -277,17 +277,17 @@ mod smt_trait; // CONSTANTS // ================================================================================================ -/// Number of levels of the tree that are stored in memory -const IN_MEMORY_DEPTH: u8 = 24; +/// Number of levels of the tree that are stored in memory. +pub(super) const IN_MEMORY_DEPTH: u8 = 16; -/// Number of nodes that are stored in memory (including the unused index 0) +/// Number of nodes that are stored in memory (including the unused index 0). const NUM_IN_MEMORY_NODES: usize = 1 << (IN_MEMORY_DEPTH + 1); /// Index of the root node inside `in_memory_nodes`. pub(super) const ROOT_MEMORY_INDEX: usize = 1; -/// Number of subtree levels below in-memory depth (24-64 in steps of 8) -const NUM_SUBTREE_LEVELS: usize = 5; +/// Number of subtree levels below in-memory depth (16-64 in steps of 8). +const NUM_SUBTREE_LEVELS: usize = 6; /// How many subtrees we buffer before flushing them to storage **during the /// SMT construction phase**. @@ -320,8 +320,8 @@ type MutatedLeaves = (MutatedSubtreeLeaves, Map, Map, /// /// Unlike the regular `Smt`, this implementation is designed for very large trees by using external /// storage (such as RocksDB) for the bulk of the tree data, while keeping only the upper levels (up -/// to depth 24) in memory. This hybrid approach allows the tree to scale beyond memory limitations -/// while maintaining good performance for common operations. +/// to `IN_MEMORY_DEPTH`) in memory. This hybrid approach allows the tree to scale beyond memory +/// limitations while maintaining good performance for common operations. /// /// All leaves sit at depth 64. The most significant element of the key is used to identify the leaf /// to which the key maps. @@ -336,7 +336,7 @@ type MutatedLeaves = (MutatedSubtreeLeaves, Map, Map, /// /// `LargeSmt` implements [`Clone`] when its storage is cloneable. The in-memory top is shared and /// detaches on mutation. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LargeSmt { storage: S, /// Shared flat array representation of in-memory nodes. @@ -351,17 +351,6 @@ pub struct LargeSmt { entry_count: usize, } -impl Clone for LargeSmt { - fn clone(&self) -> Self { - Self { - storage: self.storage.clone(), - in_memory_nodes: self.in_memory_nodes.clone(), - leaf_count: self.leaf_count, - entry_count: self.entry_count, - } - } -} - impl LargeSmt { // CONSTANTS // -------------------------------------------------------------------------------------------- @@ -372,7 +361,7 @@ impl LargeSmt { pub const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); /// Subtree depths for the subtrees stored in storage. - pub const SUBTREE_DEPTHS: [u8; 5] = [56, 48, 40, 32, 24]; + pub const SUBTREE_DEPTHS: [u8; 6] = [56, 48, 40, 32, 24, 16]; // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -430,40 +419,61 @@ impl LargeSmt { // -------------------------------------------------------------------------------------------- /// Returns an iterator over the leaves of this [`LargeSmt`]. - /// Note: This iterator returns owned SmtLeaf values. + /// + /// The returned iterator is fallible: each item is a [`LargeSmtResult`] so storage errors + /// encountered while advancing the iterator are surfaced instead of being skipped. + /// + /// Note: This iterator returns owned [`SmtLeaf`] values. /// /// # Errors /// Returns an error if the storage backend fails to create the iterator. pub fn leaves( &self, - ) -> Result, SmtLeaf)>, LargeSmtError> { + ) -> LargeSmtResult, SmtLeaf)>> + '_> + { let iter = self.storage.iter_leaves()?; - Ok(iter.map(|(idx, leaf)| (LeafIndex::new_max_depth(idx), leaf))) + Ok(iter.map(|result| { + result + .map(|(idx, leaf)| (LeafIndex::new_max_depth(idx), leaf)) + .map_err(Into::into) + })) } /// Returns an iterator over the key-value pairs of this [`LargeSmt`]. - /// Note: This iterator returns owned (Word, Word) tuples. + /// + /// The returned iterator is fallible: each item is a [`LargeSmtResult`] so storage errors + /// from the underlying leaf iterator are propagated while flattening leaf entries. + /// + /// Note: This iterator returns owned `(Word, Word)` tuples. /// /// # Errors /// Returns an error if the storage backend fails to create the iterator. - pub fn entries(&self) -> Result, LargeSmtError> { + pub fn entries( + &self, + ) -> LargeSmtResult> + '_> { let leaves_iter = self.leaves()?; - Ok(leaves_iter.flat_map(|(_, leaf)| { - // Collect the (Word, Word) tuples into an owned Vec - // This ensures they outlive the 'leaf' from which they are derived. - let owned_entries: Vec<(Word, Word)> = leaf.entries().to_vec(); - // Return an iterator over this owned Vec + Ok(leaves_iter.flat_map(|result| { + let mut owned_entries = Vec::new(); + match result { + Ok((_, leaf)) => { + owned_entries.extend(leaf.entries().iter().copied().map(Ok)); + }, + Err(err) => owned_entries.push(Err(err)), + } owned_entries.into_iter() })) } /// Returns an iterator over the inner nodes of this [`LargeSmt`]. /// + /// The returned iterator is fallible: each item is a [`LargeSmtResult`] so storage errors + /// from the underlying inner node iterator are propagated while flattening leaf entries. + /// /// # Errors /// Returns an error if the storage backend fails during iteration setup. - pub fn inner_nodes(&self) -> Result + '_, LargeSmtError> { - // Pre-validate that storage is accessible - let _ = self.storage.iter_subtrees()?; + pub fn inner_nodes( + &self, + ) -> LargeSmtResult> + '_> { Ok(LargeSmtInnerNodeIterator::new(self)) } @@ -478,6 +488,7 @@ impl LargeSmt { >::get_inner_node(self, index) } + // Triggers copy-on-write: clones the shared node array only if other references exist. pub(crate) fn in_memory_nodes_mut(&mut self) -> &mut [Word] { Arc::make_mut(&mut self.in_memory_nodes) } @@ -518,7 +529,7 @@ impl LargeSmt { /// The new tree shares the same root, leaf count, and entry count as `self`, and its storage /// is a point-in-time snapshot produced by [`SmtStorage::reader`]. The returned tree's storage /// type is `S::Reader: SmtStorageReader`, so it cannot be used for mutations. - pub fn reader(&self) -> Result, LargeSmtError> { + pub fn reader(&self) -> LargeSmtResult> { Ok(LargeSmt { storage: self.storage.reader()?, in_memory_nodes: self.in_memory_nodes.clone(), @@ -526,11 +537,6 @@ impl LargeSmt { entry_count: self.entry_count, }) } -} - -impl LargeSmt { - // STATE MUTATORS - // -------------------------------------------------------------------------------------------- /// Inserts a value at the specified key, returning the previous value associated with that key. /// Recall that by definition, any key that hasn't been updated is associated with diff --git a/miden-crypto/src/merkle/smt/large/storage/error.rs b/miden-crypto/src/merkle/smt/large/storage/error.rs index efbd4d1ce2..14cfd813bd 100644 --- a/miden-crypto/src/merkle/smt/large/storage/error.rs +++ b/miden-crypto/src/merkle/smt/large/storage/error.rs @@ -40,3 +40,6 @@ pub enum StorageError { #[error("failed to decode value bytes")] Value(#[from] crate::utils::DeserializationError), } + +/// The result type for use with backends. +pub type StorageResult = Result; diff --git a/miden-crypto/src/merkle/smt/large/storage/memory.rs b/miden-crypto/src/merkle/smt/large/storage/memory.rs index 4cfbb922f2..ca4b17a348 100644 --- a/miden-crypto/src/merkle/smt/large/storage/memory.rs +++ b/miden-crypto/src/merkle/smt/large/storage/memory.rs @@ -1,7 +1,8 @@ use alloc::{boxed::Box, vec::Vec}; use super::{ - SmtStorage, SmtStorageReader, StorageError, StorageUpdateParts, StorageUpdates, SubtreeUpdate, + SmtStorage, SmtStorageReader, StorageError, StorageResult, StorageUpdateParts, StorageUpdates, + SubtreeUpdate, }; use crate::{ EMPTY_WORD, Map, MapEntry, Word, @@ -55,41 +56,41 @@ impl Default for MemoryStorage { impl SmtStorageReader for MemoryStorage { /// Gets the total number of non-empty leaves currently stored. - fn leaf_count(&self) -> Result { + fn leaf_count(&self) -> StorageResult { Ok(self.leaves.len()) } /// Gets the total number of key-value entries currently stored. - fn entry_count(&self) -> Result { + fn entry_count(&self) -> StorageResult { Ok(self.leaves.values().map(SmtLeaf::num_entries).sum()) } /// Retrieves a single leaf node. - fn get_leaf(&self, index: u64) -> Result, StorageError> { + fn get_leaf(&self, index: u64) -> StorageResult> { Ok(self.leaves.get(&index).cloned()) } /// Retrieves multiple leaf nodes. Returns Ok(None) for indices not found. - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError> { + fn get_leaves(&self, indices: &[u64]) -> StorageResult>> { let leaves = indices.iter().map(|idx| self.leaves.get(idx).cloned()).collect(); Ok(leaves) } /// Returns true if the storage has any leaves. - fn has_leaves(&self) -> Result { + fn has_leaves(&self) -> StorageResult { Ok(!self.leaves.is_empty()) } /// Retrieves a single Subtree (representing deep nodes) by its root NodeIndex. /// Assumes index.depth() >= IN_MEMORY_DEPTH. Returns Ok(None) if not found. - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError> { + fn get_subtree(&self, index: NodeIndex) -> StorageResult> { Ok(self.subtrees.get(&index).cloned()) } /// Retrieves multiple Subtrees. /// Assumes index.depth() >= IN_MEMORY_DEPTH for all indices. Returns Ok(None) for indices not /// found. - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError> { + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>> { let subtrees: Vec<_> = indices.iter().map(|idx| self.subtrees.get(idx).cloned()).collect(); Ok(subtrees) } @@ -103,7 +104,7 @@ impl SmtStorageReader for MemoryStorage { /// - `StorageError::Unsupported`: If `index.depth() < IN_MEMORY_DEPTH`. /// /// Returns `Ok(None)` if the subtree or the specific inner node within it is not found. - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError> { + fn get_inner_node(&self, index: NodeIndex) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot get inner node from upper part of the tree".into(), @@ -119,24 +120,28 @@ impl SmtStorageReader for MemoryStorage { /// Returns an iterator over all (index, SmtLeaf) pairs in the storage. /// /// The iterator provides access to the current state of the leaves. - fn iter_leaves(&self) -> Result + '_>, StorageError> { - Ok(Box::new(self.leaves.iter().map(|(&k, v)| (k, v.clone())))) + fn iter_leaves( + &self, + ) -> StorageResult> + '_>> { + Ok(Box::new(self.leaves.iter().map(|(&k, v)| Ok((k, v.clone()))))) } /// Returns an iterator over all Subtrees in the storage. /// /// The iterator provides access to the current subtrees from storage. - fn iter_subtrees(&self) -> Result + '_>, StorageError> { - Ok(Box::new(self.subtrees.values().cloned())) + fn iter_subtrees( + &self, + ) -> StorageResult> + '_>> { + Ok(Box::new(self.subtrees.values().cloned().map(Ok))) } - /// Retrieves all depth 24 roots for fast tree rebuilding. + /// Retrieves roots of all subtrees at `IN_MEMORY_DEPTH` depth. /// /// Derived from the subtrees already in memory: for each subtree whose root sits at - /// `IN_MEMORY_DEPTH`, the root node's hash is the depth-24 entry that `initialize_from_storage` + /// `IN_MEMORY_DEPTH`, the root node's hash is the entry that `initialize_from_storage` /// needs to reconstruct the in-memory top of the tree. - fn get_depth24(&self) -> Result, StorageError> { - let depth24 = self + fn get_top_subtree_roots(&self) -> StorageResult> { + let in_mem_roots = self .subtrees .values() .filter(|subtree| subtree.root_index().depth() == IN_MEMORY_DEPTH) @@ -146,7 +151,7 @@ impl SmtStorageReader for MemoryStorage { .map(|node| (subtree.root_index().position(), node.hash())) }) .collect(); - Ok(depth24) + Ok(in_mem_roots) } } @@ -154,7 +159,7 @@ impl SmtStorage for MemoryStorage { type Reader = MemoryStorageSnapshot; /// Returns a read-only snapshot of this in-memory storage by cloning it. - fn reader(&self) -> Result { + fn reader(&self) -> StorageResult { Ok(self.clone().into_snapshot()) } @@ -166,12 +171,7 @@ impl SmtStorage for MemoryStorage { /// /// # Panics /// Panics in debug builds if `value` is `EMPTY_WORD`. - fn insert_value( - &mut self, - index: u64, - key: Word, - value: Word, - ) -> Result, StorageError> { + fn insert_value(&mut self, index: u64, key: Word, value: Word) -> StorageResult> { debug_assert_ne!(value, EMPTY_WORD); match self.leaves.get_mut(&index) { @@ -191,7 +191,7 @@ impl SmtStorage for MemoryStorage { /// returned (as `leaf.get_value(&key)` would be `None`). /// - If the leaf at `index` does not exist, `Ok(None)` is returned, as no value could be /// removed. - fn remove_value(&mut self, index: u64, key: Word) -> Result, StorageError> { + fn remove_value(&mut self, index: u64, key: Word) -> StorageResult> { let old_value = match self.leaves.entry(index) { MapEntry::Occupied(mut entry) => { let (old_value, is_empty) = entry.get_mut().remove(key); @@ -209,13 +209,13 @@ impl SmtStorage for MemoryStorage { /// Sets multiple leaf nodes in storage. /// /// If a leaf at a given index already exists, it is overwritten. - fn set_leaves(&mut self, leaves_map: Map) -> Result<(), StorageError> { + fn set_leaves(&mut self, leaves_map: Map) -> StorageResult<()> { self.leaves.extend(leaves_map); Ok(()) } /// Removes a single leaf node. - fn remove_leaf(&mut self, index: u64) -> Result, StorageError> { + fn remove_leaf(&mut self, index: u64) -> StorageResult> { Ok(self.leaves.remove(&index)) } @@ -223,7 +223,7 @@ impl SmtStorage for MemoryStorage { /// /// If a subtree with the same root NodeIndex already exists, it is overwritten. /// Assumes `subtree.root_index().depth() >= IN_MEMORY_DEPTH`. - fn set_subtree(&mut self, subtree: &Subtree) -> Result<(), StorageError> { + fn set_subtree(&mut self, subtree: &Subtree) -> StorageResult<()> { self.subtrees.insert(subtree.root_index(), subtree.clone()); Ok(()) } @@ -232,14 +232,14 @@ impl SmtStorage for MemoryStorage { /// /// If a subtree with a given root NodeIndex already exists, it is overwritten. /// Assumes `subtree.root_index().depth() >= IN_MEMORY_DEPTH` for all subtrees in the vector. - fn set_subtrees(&mut self, subtrees_vec: Vec) -> Result<(), StorageError> { + fn set_subtrees(&mut self, subtrees_vec: Vec) -> StorageResult<()> { self.subtrees .extend(subtrees_vec.into_iter().map(|subtree| (subtree.root_index(), subtree))); Ok(()) } /// Removes a single Subtree (representing deep nodes) by its root NodeIndex. - fn remove_subtree(&mut self, index: NodeIndex) -> Result<(), StorageError> { + fn remove_subtree(&mut self, index: NodeIndex) -> StorageResult<()> { self.subtrees.remove(&index); Ok(()) } @@ -258,7 +258,7 @@ impl SmtStorage for MemoryStorage { &mut self, index: NodeIndex, node: InnerNode, - ) -> Result, StorageError> { + ) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot set inner node in upper part of the tree".into(), @@ -284,7 +284,7 @@ impl SmtStorage for MemoryStorage { /// /// # Errors /// - `StorageError::Unsupported`: If `index.depth() < IN_MEMORY_DEPTH`. - fn remove_inner_node(&mut self, index: NodeIndex) -> Result, StorageError> { + fn remove_inner_node(&mut self, index: NodeIndex) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot remove inner node from upper part of the tree".into(), @@ -308,7 +308,7 @@ impl SmtStorage for MemoryStorage { /// This method handles updates to: /// - Leaves: Inserts new or updated leaves, removes specified leaves. /// - Subtrees: Inserts new or updated subtrees, removes specified subtrees. - fn apply(&mut self, updates: StorageUpdates) -> Result<(), StorageError> { + fn apply(&mut self, updates: StorageUpdates) -> StorageResult<()> { let StorageUpdateParts { leaf_updates, subtree_updates, @@ -349,31 +349,31 @@ impl SmtStorage for MemoryStorage { pub struct MemoryStorageSnapshot(MemoryStorage); impl SmtStorageReader for MemoryStorageSnapshot { - fn leaf_count(&self) -> Result { + fn leaf_count(&self) -> StorageResult { self.0.leaf_count() } - fn entry_count(&self) -> Result { + fn entry_count(&self) -> StorageResult { self.0.entry_count() } - fn get_leaf(&self, index: u64) -> Result, StorageError> { + fn get_leaf(&self, index: u64) -> StorageResult> { self.0.get_leaf(index) } - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError> { + fn get_leaves(&self, indices: &[u64]) -> StorageResult>> { self.0.get_leaves(indices) } - fn has_leaves(&self) -> Result { + fn has_leaves(&self) -> StorageResult { self.0.has_leaves() } - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError> { + fn get_subtree(&self, index: NodeIndex) -> StorageResult> { self.0.get_subtree(index) } - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError> { + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>> { self.0.get_subtrees(indices) } @@ -381,23 +381,27 @@ impl SmtStorageReader for MemoryStorageSnapshot { &self, leaf_index: u64, subtree_indices: &[NodeIndex], - ) -> Result<(Option, Vec>), StorageError> { + ) -> StorageResult<(Option, Vec>)> { self.0.get_leaf_and_subtrees(leaf_index, subtree_indices) } - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError> { + fn get_inner_node(&self, index: NodeIndex) -> StorageResult> { self.0.get_inner_node(index) } - fn iter_leaves(&self) -> Result + '_>, StorageError> { + fn iter_leaves( + &self, + ) -> StorageResult> + '_>> { self.0.iter_leaves() } - fn iter_subtrees(&self) -> Result + '_>, StorageError> { + fn iter_subtrees( + &self, + ) -> StorageResult> + '_>> { self.0.iter_subtrees() } - fn get_depth24(&self) -> Result, StorageError> { - self.0.get_depth24() + fn get_top_subtree_roots(&self) -> StorageResult> { + self.0.get_top_subtree_roots() } } diff --git a/miden-crypto/src/merkle/smt/large/storage/mod.rs b/miden-crypto/src/merkle/smt/large/storage/mod.rs index 6a7cab0c39..c9829a6d8e 100644 --- a/miden-crypto/src/merkle/smt/large/storage/mod.rs +++ b/miden-crypto/src/merkle/smt/large/storage/mod.rs @@ -13,7 +13,7 @@ use crate::{ }; mod error; -pub use error::StorageError; +pub use error::{StorageError, StorageResult}; #[cfg(feature = "rocksdb")] mod rocksdb; @@ -26,6 +26,9 @@ pub use memory::{MemoryStorage, MemoryStorageSnapshot}; mod updates; pub use updates::{StorageUpdateParts, StorageUpdates, SubtreeUpdate}; +pub type BoxedFallibleLeafIterator<'a> = + Box> + 'a>; + // SMT STORAGE READER // ================================================================================================ @@ -46,43 +49,43 @@ pub trait SmtStorageReader: 'static + fmt::Debug + Send + Sync { /// /// # Errors /// Returns `StorageError` if the storage read operation fails. - fn leaf_count(&self) -> Result; + fn leaf_count(&self) -> StorageResult; /// Retrieves the total number of unique key-value entries across all leaf nodes. /// /// # Errors /// Returns `StorageError` if the storage read operation fails. - fn entry_count(&self) -> Result; + fn entry_count(&self) -> StorageResult; /// Retrieves a single SMT leaf node by its logical `index`. /// Returns `Ok(None)` if no leaf exists at the given `index`. - fn get_leaf(&self, index: u64) -> Result, StorageError>; + fn get_leaf(&self, index: u64) -> StorageResult>; /// Retrieves multiple SMT leaf nodes by their logical `indices`. /// /// The returned `Vec` will have the same length as the input `indices` slice. /// For each `index` in the input, the corresponding element in the output `Vec` /// will be `Some(SmtLeaf)` if found, or `None` if not found. - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError>; + fn get_leaves(&self, indices: &[u64]) -> StorageResult>>; /// Returns true if the storage has any leaves. /// /// # Errors /// Returns `StorageError` if the storage read operation fails. - fn has_leaves(&self) -> Result; + fn has_leaves(&self) -> StorageResult; /// Retrieves a single SMT Subtree by its root `NodeIndex`. /// /// Subtrees typically represent deeper, compacted parts of the SMT. /// Returns `Ok(None)` if no subtree is found for the given `index`. - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError>; + fn get_subtree(&self, index: NodeIndex) -> StorageResult>; /// Retrieves multiple Subtrees by their root `NodeIndex` values. /// /// The returned `Vec` will have the same length as the input `indices` slice. /// For each `index` in the input, the corresponding element in the output `Vec` /// will be `Some(Subtree)` if found, or `None` if not found. - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError>; + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>>; /// Retrieves a single leaf and multiple subtrees in one call. /// @@ -98,7 +101,7 @@ pub trait SmtStorageReader: 'static + fmt::Debug + Send + Sync { &self, leaf_index: u64, subtree_indices: &[NodeIndex], - ) -> Result<(Option, Vec>), StorageError> { + ) -> StorageResult<(Option, Vec>)> { let leaf = self.get_leaf(leaf_index)?; // We explicitly do NOT want to delegate to `get_subtrees` here as it can be a very heavy @@ -115,63 +118,71 @@ pub trait SmtStorageReader: 'static + fmt::Debug + Send + Sync { /// /// This method is intended for accessing nodes at depths greater than the in-memory horizon. /// Returns `Ok(None)` if the containing Subtree or the specific inner node is not found. - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError>; + fn get_inner_node(&self, index: NodeIndex) -> StorageResult>; - /// Returns an iterator over all (logical_index, SmtLeaf) pairs currently in storage. + /// Returns an iterator over all `(logical_index, SmtLeaf)` pairs currently in storage. + /// + /// The returned iterator is fallible: each item is a + /// [`crate::merkle::smt::StorageResult`] so backends can report per-element read or + /// deserialization failures encountered after iterator creation. /// /// The order of iteration is not guaranteed unless specified by the implementation. - fn iter_leaves(&self) -> Result + '_>, StorageError>; + fn iter_leaves(&self) -> StorageResult>; /// Returns an iterator over all `Subtree` instances currently in storage. /// + /// The returned iterator is fallible: each item is a + /// [`crate::merkle::smt::StorageResult`] so backends can report per-element read or + /// deserialization failures encountered after iterator creation. + /// /// The order of iteration is not guaranteed unless specified by the implementation. - fn iter_subtrees(&self) -> Result + '_>, StorageError>; + fn iter_subtrees(&self) + -> StorageResult> + '_>>; - /// Retrieves all depth 24 hashes from storage for efficient startup reconstruction. + /// Retrieves roots of all top level subtrees for efficient startup reconstruction. /// - /// Returns a vector of `(node_index_value, InnerNode)` tuples representing - /// the cached roots of nodes at depth 24 (the in-memory/storage boundary). - /// These roots enable fast reconstruction of the upper tree without loading - /// entire subtrees. + /// Returns a vector of `(node_index_value, Word)` tuples representing the roots of nodes at + /// `IN_MEMORY_DEPTH` (the in-memory/storage boundary). These roots enable fast reconstruction + /// of the upper tree without loading entire subtrees. /// - /// The hash cache is automatically maintained by subtree operations - no manual - /// cache management is required. - fn get_depth24(&self) -> Result, StorageError>; + /// The hash cache is automatically maintained by subtree operations - no manual cache + /// management is required. + fn get_top_subtree_roots(&self) -> StorageResult>; } impl SmtStorageReader for Box { #[inline] - fn leaf_count(&self) -> Result { + fn leaf_count(&self) -> StorageResult { self.deref().leaf_count() } #[inline] - fn entry_count(&self) -> Result { + fn entry_count(&self) -> StorageResult { self.deref().entry_count() } #[inline] - fn get_leaf(&self, index: u64) -> Result, StorageError> { + fn get_leaf(&self, index: u64) -> StorageResult> { self.deref().get_leaf(index) } #[inline] - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError> { + fn get_leaves(&self, indices: &[u64]) -> StorageResult>> { self.deref().get_leaves(indices) } #[inline] - fn has_leaves(&self) -> Result { + fn has_leaves(&self) -> StorageResult { self.deref().has_leaves() } #[inline] - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError> { + fn get_subtree(&self, index: NodeIndex) -> StorageResult> { self.deref().get_subtree(index) } #[inline] - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError> { + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>> { self.deref().get_subtrees(indices) } @@ -180,28 +191,30 @@ impl SmtStorageReader for Box { &self, leaf_index: u64, subtree_indices: &[NodeIndex], - ) -> Result<(Option, Vec>), StorageError> { + ) -> StorageResult<(Option, Vec>)> { self.deref().get_leaf_and_subtrees(leaf_index, subtree_indices) } #[inline] - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError> { + fn get_inner_node(&self, index: NodeIndex) -> StorageResult> { self.deref().get_inner_node(index) } #[inline] - fn iter_leaves(&self) -> Result + '_>, StorageError> { + fn iter_leaves(&self) -> StorageResult> { self.deref().iter_leaves() } #[inline] - fn iter_subtrees(&self) -> Result + '_>, StorageError> { + fn iter_subtrees( + &self, + ) -> StorageResult> + '_>> { self.deref().iter_subtrees() } #[inline] - fn get_depth24(&self) -> Result, StorageError> { - self.deref().get_depth24() + fn get_top_subtree_roots(&self) -> StorageResult> { + self.deref().get_top_subtree_roots() } } @@ -227,7 +240,7 @@ pub trait SmtStorage: SmtStorageReader { /// /// Implementations must return a point-in-time snapshot. Later writes through `self` must not /// affect the returned reader. Holding the reader must not block writes in any way. - fn reader(&self) -> Result; + fn reader(&self) -> StorageResult; /// Inserts a key-value pair into the SMT leaf at the specified logical `index`. /// @@ -243,12 +256,7 @@ pub trait SmtStorage: SmtStorageReader { /// # Errors /// Returns `StorageError` if the storage operation fails (e.g., backend database error, /// insufficient space, serialization failures). - fn insert_value( - &mut self, - index: u64, - key: Word, - value: Word, - ) -> Result, StorageError>; + fn insert_value(&mut self, index: u64, key: Word, value: Word) -> StorageResult>; /// Removes a key-value pair from the SMT leaf at the specified logical `index`. /// @@ -268,7 +276,7 @@ pub trait SmtStorage: SmtStorageReader { /// # Errors /// Returns `StorageError` if the storage operation fails (e.g., backend database error, /// write permission issues, serialization failures). - fn remove_value(&mut self, index: u64, key: Word) -> Result, StorageError>; + fn remove_value(&mut self, index: u64, key: Word) -> StorageResult>; /// Sets or updates multiple SMT leaf nodes in storage. /// @@ -281,7 +289,7 @@ pub trait SmtStorage: SmtStorageReader { /// /// # Errors /// Returns `StorageError` if any storage operation fails during the batch update. - fn set_leaves(&mut self, leaves: Map) -> Result<(), StorageError>; + fn set_leaves(&mut self, leaves: Map) -> StorageResult<()>; /// Removes a single SMT leaf node entirely from storage by its logical `index`. /// @@ -291,23 +299,23 @@ pub trait SmtStorage: SmtStorageReader { /// Returns the `SmtLeaf` that was removed, or `Ok(None)` if no leaf existed at `index`. /// Implementations should ensure that removing a leaf also correctly updates /// the overall leaf and entry counts. - fn remove_leaf(&mut self, index: u64) -> Result, StorageError>; + fn remove_leaf(&mut self, index: u64) -> StorageResult>; /// Sets or updates a single SMT Subtree in storage, identified by its root `NodeIndex`. /// /// If a subtree with the same root `NodeIndex` already exists, it is overwritten. - fn set_subtree(&mut self, subtree: &Subtree) -> Result<(), StorageError>; + fn set_subtree(&mut self, subtree: &Subtree) -> StorageResult<()>; /// Sets or updates multiple SMT Subtrees in storage. /// /// For each `Subtree` in the `subtrees` vector, if a subtree with the same root `NodeIndex` /// already exists, it is overwritten. - fn set_subtrees(&mut self, subtrees: Vec) -> Result<(), StorageError>; + fn set_subtrees(&mut self, subtrees: Vec) -> StorageResult<()>; /// Removes a single SMT Subtree from storage, identified by its root `NodeIndex`. /// /// Returns `Ok(())` on successful removal or if the subtree did not exist. - fn remove_subtree(&mut self, index: NodeIndex) -> Result<(), StorageError>; + fn remove_subtree(&mut self, index: NodeIndex) -> StorageResult<()>; /// Sets or updates a single inner node (non-leaf node) within a Subtree. /// @@ -317,14 +325,14 @@ pub trait SmtStorage: SmtStorageReader { &mut self, index: NodeIndex, node: InnerNode, - ) -> Result, StorageError>; + ) -> StorageResult>; /// Removes a single inner node (non-leaf node) from within a Subtree. /// /// - If the Subtree becomes empty after removing the node, the Subtree itself might be removed /// by the storage implementation. /// - Returns the `InnerNode` that was removed, if any. - fn remove_inner_node(&mut self, index: NodeIndex) -> Result, StorageError>; + fn remove_inner_node(&mut self, index: NodeIndex) -> StorageResult>; /// Applies a batch of `StorageUpdates` atomically to the storage backend. /// @@ -333,54 +341,49 @@ pub trait SmtStorage: SmtStorageReader { /// new root hash, and count deltas) are applied as a single, indivisible operation. /// If any part of the update fails, the entire transaction should be rolled back, leaving /// the storage in its previous state. - fn apply(&mut self, updates: StorageUpdates) -> Result<(), StorageError>; + fn apply(&mut self, updates: StorageUpdates) -> StorageResult<()>; } impl SmtStorage for Box { type Reader = T::Reader; #[inline] - fn reader(&self) -> Result { + fn reader(&self) -> StorageResult { self.deref().reader() } #[inline] - fn insert_value( - &mut self, - index: u64, - key: Word, - value: Word, - ) -> Result, StorageError> { + fn insert_value(&mut self, index: u64, key: Word, value: Word) -> StorageResult> { self.deref_mut().insert_value(index, key, value) } #[inline] - fn remove_value(&mut self, index: u64, key: Word) -> Result, StorageError> { + fn remove_value(&mut self, index: u64, key: Word) -> StorageResult> { self.deref_mut().remove_value(index, key) } #[inline] - fn set_leaves(&mut self, leaves: Map) -> Result<(), StorageError> { + fn set_leaves(&mut self, leaves: Map) -> StorageResult<()> { self.deref_mut().set_leaves(leaves) } #[inline] - fn remove_leaf(&mut self, index: u64) -> Result, StorageError> { + fn remove_leaf(&mut self, index: u64) -> StorageResult> { self.deref_mut().remove_leaf(index) } #[inline] - fn set_subtree(&mut self, subtree: &Subtree) -> Result<(), StorageError> { + fn set_subtree(&mut self, subtree: &Subtree) -> StorageResult<()> { self.deref_mut().set_subtree(subtree) } #[inline] - fn set_subtrees(&mut self, subtrees: Vec) -> Result<(), StorageError> { + fn set_subtrees(&mut self, subtrees: Vec) -> StorageResult<()> { self.deref_mut().set_subtrees(subtrees) } #[inline] - fn remove_subtree(&mut self, index: NodeIndex) -> Result<(), StorageError> { + fn remove_subtree(&mut self, index: NodeIndex) -> StorageResult<()> { self.deref_mut().remove_subtree(index) } @@ -389,17 +392,17 @@ impl SmtStorage for Box { &mut self, index: NodeIndex, node: InnerNode, - ) -> Result, StorageError> { + ) -> StorageResult> { self.deref_mut().set_inner_node(index, node) } #[inline] - fn remove_inner_node(&mut self, index: NodeIndex) -> Result, StorageError> { + fn remove_inner_node(&mut self, index: NodeIndex) -> StorageResult> { self.deref_mut().remove_inner_node(index) } #[inline] - fn apply(&mut self, updates: StorageUpdates) -> Result<(), StorageError> { + fn apply(&mut self, updates: StorageUpdates) -> StorageResult<()> { self.deref_mut().apply(updates) } } diff --git a/miden-crypto/src/merkle/smt/large/storage/rocksdb.rs b/miden-crypto/src/merkle/smt/large/storage/rocksdb.rs index 17fe04f069..22f4f387e0 100644 --- a/miden-crypto/src/merkle/smt/large/storage/rocksdb.rs +++ b/miden-crypto/src/merkle/smt/large/storage/rocksdb.rs @@ -8,7 +8,8 @@ use rocksdb::{ }; use super::{ - SmtStorage, SmtStorageReader, StorageError, StorageUpdateParts, StorageUpdates, SubtreeUpdate, + SmtStorage, SmtStorageReader, StorageError, StorageResult, StorageUpdateParts, StorageUpdates, + SubtreeUpdate, }; use crate::{ EMPTY_WORD, Word, @@ -25,6 +26,7 @@ use crate::{ /// The name of the RocksDB column family used for storing SMT leaves. const LEAVES_CF: &str = "leaves"; /// The names of the RocksDB column families used for storing SMT subtrees (deep nodes). +const SUBTREE_16_CF: &str = "st16"; const SUBTREE_24_CF: &str = "st24"; const SUBTREE_32_CF: &str = "st32"; const SUBTREE_40_CF: &str = "st40"; @@ -33,8 +35,9 @@ const SUBTREE_56_CF: &str = "st56"; /// The name of the RocksDB column family used for storing metadata (e.g., counts). const METADATA_CF: &str = "metadata"; -/// The name of the RocksDB column family used for storing level 24 hashes for fast tree rebuilding. -const DEPTH_24_CF: &str = "depth24"; +/// The name of the RocksDB column family used for storing in-memory-depth hashes for fast tree +/// rebuilding. +const IN_MEM_DEPTH_CF: &str = "in_mem_depth"; /// The key used in the `METADATA_CF` column family to store the total count of non-empty leaves. const LEAF_COUNT_KEY: &[u8] = b"leaf_count"; @@ -50,6 +53,8 @@ const ENTRY_COUNT_KEY: &[u8] = b"entry_count"; /// including leaves, subtrees (for deeper parts of the tree), and metadata like the SMT root /// and counts. It leverages RocksDB column families to organize data: /// - `LEAVES_CF` ("leaves"): Stores `SmtLeaf` data, keyed by their logical u64 index. +/// - `SUBTREE_16_CF` ("st16"): Stores serialized `Subtree` data at depth 16, keyed by their root +/// `NodeIndex`. /// - `SUBTREE_24_CF` ("st24"): Stores serialized `Subtree` data at depth 24, keyed by their root /// `NodeIndex`. /// - `SUBTREE_32_CF` ("st32"): Stores serialized `Subtree` data at depth 32, keyed by their root @@ -78,7 +83,7 @@ impl RocksDbStorage { /// # Errors /// Returns `StorageError::Backend` if the database cannot be opened or configured, /// for example, due to path issues, permissions, or RocksDB internal errors. - pub fn open(config: RocksDbConfig) -> Result { + pub fn open(config: RocksDbConfig) -> StorageResult { // Base DB options let mut db_opts = Options::default(); // Create DB if it doesn't exist @@ -161,9 +166,9 @@ impl RocksDbStorage { opts } - let mut depth24_opts = Options::default(); - depth24_opts.set_compression_type(DBCompressionType::Lz4); - depth24_opts.set_block_based_table_factory(&table_opts); + let mut in_mem_depth_opts = Options::default(); + in_mem_depth_opts.set_compression_type(DBCompressionType::Lz4); + in_mem_depth_opts.set_block_based_table_factory(&table_opts); // Metadata CF with no compression let mut metadata_opts = Options::default(); @@ -172,13 +177,14 @@ impl RocksDbStorage { // Define column families with tailored options let cfs = vec![ ColumnFamilyDescriptor::new(LEAVES_CF, leaves_opts), + ColumnFamilyDescriptor::new(SUBTREE_16_CF, subtree_cf(&cache, 8.0)), ColumnFamilyDescriptor::new(SUBTREE_24_CF, subtree_cf(&cache, 8.0)), ColumnFamilyDescriptor::new(SUBTREE_32_CF, subtree_cf(&cache, 10.0)), ColumnFamilyDescriptor::new(SUBTREE_40_CF, subtree_cf(&cache, 10.0)), ColumnFamilyDescriptor::new(SUBTREE_48_CF, subtree_cf(&cache, 12.0)), ColumnFamilyDescriptor::new(SUBTREE_56_CF, subtree_cf(&cache, 12.0)), ColumnFamilyDescriptor::new(METADATA_CF, metadata_opts), - ColumnFamilyDescriptor::new(DEPTH_24_CF, depth24_opts), + ColumnFamilyDescriptor::new(IN_MEM_DEPTH_CF, in_mem_depth_opts), ]; // Open the database with our tuned CFs @@ -193,19 +199,20 @@ impl RocksDbStorage { /// /// # Errors /// - Returns `StorageError::Backend` if the flush operation fails. - fn sync(&self) -> Result<(), StorageError> { + fn sync(&self) -> StorageResult<()> { let mut fopts = FlushOptions::default(); fopts.set_wait(true); for name in [ LEAVES_CF, + SUBTREE_16_CF, SUBTREE_24_CF, SUBTREE_32_CF, SUBTREE_40_CF, SUBTREE_48_CF, SUBTREE_56_CF, METADATA_CF, - DEPTH_24_CF, + IN_MEM_DEPTH_CF, ] { let cf = self.cf_handle(name)?; self.db.flush_cf_opt(cf, &fopts)?; @@ -226,6 +233,7 @@ impl RocksDbStorage { #[inline(always)] fn subtree_db_key(index: NodeIndex) -> KeyBytes { let keep = match index.depth() { + 16 => 2, 24 => 3, 32 => 4, 40 => 5, @@ -241,7 +249,7 @@ impl RocksDbStorage { /// # Errors /// Returns `StorageError::Backend` if the column family with the given `name` does not /// exist. - fn cf_handle(&self, name: &str) -> Result<&rocksdb::ColumnFamily, StorageError> { + fn cf_handle(&self, name: &str) -> StorageResult<&rocksdb::ColumnFamily> { self.db .cf_handle(name) .ok_or_else(|| StorageError::Unsupported(format!("unknown column family `{name}`"))) @@ -263,7 +271,7 @@ impl SmtStorageReader for RocksDbStorage { /// - `StorageError::Backend`: If the metadata column family is missing or a RocksDB error /// occurs. /// - `StorageError::BadValueLen`: If the retrieved count bytes are invalid. - fn leaf_count(&self) -> Result { + fn leaf_count(&self) -> StorageResult { let cf = self.cf_handle(METADATA_CF)?; self.db.get_cf(cf, LEAF_COUNT_KEY)?.map_or(Ok(0), |bytes| { let arr: [u8; 8] = @@ -283,7 +291,7 @@ impl SmtStorageReader for RocksDbStorage { /// - `StorageError::Backend`: If the metadata column family is missing or a RocksDB error /// occurs. /// - `StorageError::BadValueLen`: If the retrieved count bytes are invalid. - fn entry_count(&self) -> Result { + fn entry_count(&self) -> StorageResult { let cf = self.cf_handle(METADATA_CF)?; self.db.get_cf(cf, ENTRY_COUNT_KEY)?.map_or(Ok(0), |bytes| { let arr: [u8; 8] = @@ -301,7 +309,7 @@ impl SmtStorageReader for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If the leaves column family is missing or a RocksDB error occurs. /// - `StorageError::DeserializationError`: If the retrieved leaf data is corrupt. - fn get_leaf(&self, index: u64) -> Result, StorageError> { + fn get_leaf(&self, index: u64) -> StorageResult> { let cf = self.cf_handle(LEAVES_CF)?; let key = Self::index_db_key(index); match self.db.get_cf(cf, key)? { @@ -318,7 +326,7 @@ impl SmtStorageReader for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If the leaves column family is missing or a RocksDB error occurs. /// - `StorageError::DeserializationError`: If any retrieved leaf data is corrupt. - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError> { + fn get_leaves(&self, indices: &[u64]) -> StorageResult>> { let cf = self.cf_handle(LEAVES_CF)?; let db_keys: Vec<[u8; 8]> = indices.iter().map(|&idx| Self::index_db_key(idx)).collect(); let results = self.db.multi_get_cf(db_keys.iter().map(|k| (cf, k.as_ref()))); @@ -339,7 +347,7 @@ impl SmtStorageReader for RocksDbStorage { /// /// # Errors /// Returns `StorageError` if the storage read operation fails. - fn has_leaves(&self) -> Result { + fn has_leaves(&self) -> StorageResult { Ok(self.leaf_count()? > 0) } @@ -361,7 +369,7 @@ impl SmtStorageReader for RocksDbStorage { /// - A `Vec>` where each index corresponds to the original input. /// - `Ok(...)` if all fetches succeed. /// - `Err(StorageError)` if any RocksDB access or deserialization fails. - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError> { + fn get_subtree(&self, index: NodeIndex) -> StorageResult> { let cf = self.subtree_cf(index); let key = Self::subtree_db_key(index); match self.db.get_cf(cf, key)? { @@ -387,10 +395,10 @@ impl SmtStorageReader for RocksDbStorage { /// - A `Vec>` where each index corresponds to the original input. /// - `Ok(...)` if all fetches succeed. /// - `Err(StorageError)` if any RocksDB access or deserialization fails. - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError> { + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>> { use p3_maybe_rayon::prelude::*; - let mut depth_buckets: [Vec<(usize, NodeIndex)>; 5] = Default::default(); + let mut depth_buckets: [Vec<(usize, NodeIndex)>; 6] = Default::default(); for (original_index, &node_index) in indices.iter().enumerate() { let depth = node_index.depth(); @@ -400,6 +408,7 @@ impl SmtStorageReader for RocksDbStorage { 40 => 2, 32 => 3, 24 => 4, + 16 => 5, _ => { return Err(StorageError::Unsupported(format!( "unsupported subtree depth {depth}" @@ -411,34 +420,32 @@ impl SmtStorageReader for RocksDbStorage { let mut results = vec![None; indices.len()]; // Process depth buckets in parallel - let bucket_results: Result, StorageError> = depth_buckets + let bucket_results: StorageResult> = depth_buckets .into_par_iter() .enumerate() .filter(|(_, bucket)| !bucket.is_empty()) - .map( - |(bucket_index, bucket)| -> Result)>, StorageError> { - let depth = LargeSmt::::SUBTREE_DEPTHS[bucket_index]; - let cf = self.cf_handle(cf_for_depth(depth))?; - let keys: Vec<_> = - bucket.iter().map(|(_, idx)| Self::subtree_db_key(*idx)).collect(); - - let db_results = self.db.multi_get_cf(keys.iter().map(|k| (cf, k.as_ref()))); - - // Process results for this bucket - bucket - .into_iter() - .zip(db_results) - .map(|((original_index, node_index), db_result)| { - let subtree = match db_result { - Ok(Some(bytes)) => Some(Subtree::from_vec(node_index, &bytes)?), - Ok(None) => None, - Err(e) => return Err(e.into()), - }; - Ok((original_index, subtree)) - }) - .collect() - }, - ) + .map(|(bucket_index, bucket)| -> StorageResult)>> { + let depth = LargeSmt::::SUBTREE_DEPTHS[bucket_index]; + let cf = self.cf_handle(cf_for_depth(depth))?; + let keys: Vec<_> = + bucket.iter().map(|(_, idx)| Self::subtree_db_key(*idx)).collect(); + + let db_results = self.db.multi_get_cf(keys.iter().map(|k| (cf, k.as_ref()))); + + // Process results for this bucket + bucket + .into_iter() + .zip(db_results) + .map(|((original_index, node_index), db_result)| { + let subtree = match db_result { + Ok(Some(bytes)) => Some(Subtree::from_vec(node_index, &bytes)?), + Ok(None) => None, + Err(e) => return Err(e.into()), + }; + Ok((original_index, subtree)) + }) + .collect() + }) .collect(); // Flatten results and place them in correct positions @@ -460,7 +467,7 @@ impl SmtStorageReader for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If `index.depth() < IN_MEMORY_DEPTH`, or if RocksDB errors occur. /// - `StorageError::Value`: If the containing Subtree data is corrupt. - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError> { + fn get_inner_node(&self, index: NodeIndex) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot get inner node from upper part of the tree".into(), @@ -476,12 +483,14 @@ impl SmtStorageReader for RocksDbStorage { /// /// The iterator uses a RocksDB snapshot for consistency and iterates in lexicographical /// order of the keys (leaf indices). Errors during iteration (e.g., deserialization issues) - /// cause the iterator to skip the problematic item and attempt to continue. + /// are returned as iterator items. /// /// # Errors /// - `StorageError::Backend`: If the leaves column family is missing or a RocksDB error occurs /// during iterator creation. - fn iter_leaves(&self) -> Result + '_>, StorageError> { + fn iter_leaves( + &self, + ) -> StorageResult> + '_>> { let cf = self.cf_handle(LEAVES_CF)?; let mut read_opts = ReadOptions::default(); read_opts.set_total_order_seek(true); @@ -494,16 +503,23 @@ impl SmtStorageReader for RocksDbStorage { /// /// The iterator uses a RocksDB snapshot and iterates in lexicographical order of keys /// (subtree root NodeIndex) across all depth column families (24, 32, 40, 48, 56). - /// Errors during iteration (e.g., deserialization issues) cause the iterator to skip - /// the problematic item and attempt to continue. + /// Errors during iteration (e.g., deserialization issues) are returned as iterator items. /// /// # Errors /// - `StorageError::Backend`: If any subtree column family is missing or a RocksDB error occurs /// during iterator creation. - fn iter_subtrees(&self) -> Result + '_>, StorageError> { + fn iter_subtrees( + &self, + ) -> StorageResult> + '_>> { // All subtree column family names in order - const SUBTREE_CFS: [&str; 5] = - [SUBTREE_24_CF, SUBTREE_32_CF, SUBTREE_40_CF, SUBTREE_48_CF, SUBTREE_56_CF]; + const SUBTREE_CFS: [&str; 6] = [ + SUBTREE_16_CF, + SUBTREE_24_CF, + SUBTREE_32_CF, + SUBTREE_40_CF, + SUBTREE_48_CF, + SUBTREE_56_CF, + ]; let mut cf_handles = Vec::new(); for cf_name in SUBTREE_CFS { @@ -513,14 +529,14 @@ impl SmtStorageReader for RocksDbStorage { Ok(Box::new(RocksDbSubtreeIterator::new(&self.db, cf_handles))) } - /// Retrieves all depth 24 hashes for fast tree rebuilding. + /// Retrieves roots of all top level subtrees for efficient startup reconstruction. /// /// # Errors - /// - `StorageError::Backend`: If the depth24 column family is missing or a RocksDB error - /// occurs. + /// - `StorageError::Backend`: If the in-memory-depth column family is missing or a RocksDB + /// error occurs. /// - `StorageError::Value`: If any hash bytes are corrupt. - fn get_depth24(&self) -> Result, StorageError> { - let cf = self.cf_handle(DEPTH_24_CF)?; + fn get_top_subtree_roots(&self) -> StorageResult> { + let cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let iter = self.db.iterator_cf(cf, IteratorMode::Start); let mut hashes = Vec::new(); @@ -541,7 +557,7 @@ impl SmtStorage for RocksDbStorage { type Reader = RocksDbSnapshotStorage; /// Returns a detached read-only snapshot of the current RocksDB-backed storage. - fn reader(&self) -> Result { + fn reader(&self) -> StorageResult { Ok(RocksDbSnapshotStorage::new(Arc::clone(&self.db))) } @@ -559,12 +575,7 @@ impl SmtStorage for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If column families are missing or a RocksDB error occurs. /// - `StorageError::DeserializationError`: If existing leaf data is corrupt. - fn insert_value( - &mut self, - index: u64, - key: Word, - value: Word, - ) -> Result, StorageError> { + fn insert_value(&mut self, index: u64, key: Word, value: Word) -> StorageResult> { debug_assert_ne!(value, EMPTY_WORD); let mut batch = WriteBatch::default(); @@ -633,7 +644,7 @@ impl SmtStorage for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If column families are missing or a RocksDB error occurs. /// - `StorageError::DeserializationError`: If existing leaf data is corrupt. - fn remove_value(&mut self, index: u64, key: Word) -> Result, StorageError> { + fn remove_value(&mut self, index: u64, key: Word) -> StorageResult> { let Some(mut leaf) = self.get_leaf(index)? else { return Ok(None); }; @@ -676,7 +687,7 @@ impl SmtStorage for RocksDbStorage { /// /// # Errors /// - `StorageError::Backend`: If column families are missing or a RocksDB error occurs. - fn set_leaves(&mut self, leaves: Map) -> Result<(), StorageError> { + fn set_leaves(&mut self, leaves: Map) -> StorageResult<()> { let cf = self.cf_handle(LEAVES_CF)?; let leaf_count: usize = leaves.len(); let entry_count: usize = leaves.values().map(|leaf| leaf.entries().len()).sum(); @@ -707,7 +718,7 @@ impl SmtStorage for RocksDbStorage { /// - `StorageError::Backend`: If the leaves column family is missing or a RocksDB error occurs. /// - `StorageError::DeserializationError`: If the retrieved (to be returned) leaf data is /// corrupt. - fn remove_leaf(&mut self, index: u64) -> Result, StorageError> { + fn remove_leaf(&mut self, index: u64) -> StorageResult> { let key = Self::index_db_key(index); let cf = self.cf_handle(LEAVES_CF)?; let old_bytes = self.db.get_cf(cf, key)?; @@ -718,11 +729,11 @@ impl SmtStorage for RocksDbStorage { })) } - /// Stores a single subtree in RocksDB and optionally updates the depth-24 root cache. + /// Stores a single subtree in RocksDB and optionally updates the in-memory-depth root cache. /// /// The subtree is serialized and written to its corresponding column family. - /// If it's a depth-24 subtree, the root node’s hash is also stored in the - /// dedicated `DEPTH_24_CF` cache to support top-level reconstruction. + /// If it’s an in-memory-depth subtree, the root node’s hash is also stored in the + /// dedicated `IN_MEM_DEPTH_CF` cache to support top-level reconstruction. /// /// # Parameters /// - `subtree`: A reference to the subtree to be stored. @@ -730,7 +741,7 @@ impl SmtStorage for RocksDbStorage { /// # Errors /// - Returns `StorageError` if column family lookup, serialization, or the write operation /// fails. - fn set_subtree(&mut self, subtree: &Subtree) -> Result<(), StorageError> { + fn set_subtree(&mut self, subtree: &Subtree) -> StorageResult<()> { let subtrees_cf = self.subtree_cf(subtree.root_index()); let mut batch = WriteBatch::default(); @@ -738,16 +749,16 @@ impl SmtStorage for RocksDbStorage { let value = subtree.to_vec(); batch.put_cf(subtrees_cf, key, value); - // Also update level 24 hash cache if this is a level 24 subtree + // Also update in-memory-depth hash cache if this is an in-memory-depth subtree if subtree.root_index().depth() == IN_MEMORY_DEPTH { let root_hash = subtree .get_inner_node(subtree.root_index()) .ok_or_else(|| StorageError::Unsupported("Subtree root node not found".into()))? .hash(); - let depth24_cf = self.cf_handle(DEPTH_24_CF)?; + let in_mem_depth_cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let hash_key = Self::index_db_key(subtree.root_index().position()); - batch.put_cf(depth24_cf, hash_key, root_hash.to_bytes()); + batch.put_cf(in_mem_depth_cf, hash_key, root_hash.to_bytes()); } self.db.write(batch)?; @@ -768,8 +779,8 @@ impl SmtStorage for RocksDbStorage { /// /// # Errors /// - Returns `StorageError::Backend` if any column family lookup or RocksDB write fails. - fn set_subtrees(&mut self, subtrees: Vec) -> Result<(), StorageError> { - let depth24_cf = self.cf_handle(DEPTH_24_CF)?; + fn set_subtrees(&mut self, subtrees: Vec) -> StorageResult<()> { + let in_mem_depth_cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let mut batch = WriteBatch::default(); for subtree in subtrees { @@ -782,7 +793,7 @@ impl SmtStorage for RocksDbStorage { && let Some(root_node) = subtree.get_inner_node(subtree.root_index()) { let hash_key = Self::index_db_key(subtree.root_index().position()); - batch.put_cf(depth24_cf, hash_key, root_node.hash().to_bytes()); + batch.put_cf(in_mem_depth_cf, hash_key, root_node.hash().to_bytes()); } } @@ -795,18 +806,18 @@ impl SmtStorage for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If the subtrees column family is missing or a RocksDB error /// occurs. - fn remove_subtree(&mut self, index: NodeIndex) -> Result<(), StorageError> { + fn remove_subtree(&mut self, index: NodeIndex) -> StorageResult<()> { let subtrees_cf = self.subtree_cf(index); let mut batch = WriteBatch::default(); let key = Self::subtree_db_key(index); batch.delete_cf(subtrees_cf, key); - // Also remove level 24 hash cache if this is a level 24 subtree + // Also remove in-memory-depth hash cache if this is an in-memory-depth subtree if index.depth() == IN_MEMORY_DEPTH { - let depth24_cf = self.cf_handle(DEPTH_24_CF)?; + let in_mem_depth_cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let hash_key = Self::index_db_key(index.position()); - batch.delete_cf(depth24_cf, hash_key); + batch.delete_cf(in_mem_depth_cf, hash_key); } self.db.write(batch)?; @@ -826,7 +837,7 @@ impl SmtStorage for RocksDbStorage { &mut self, index: NodeIndex, node: InnerNode, - ) -> Result, StorageError> { + ) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot set inner node in upper part of the tree".into(), @@ -851,7 +862,7 @@ impl SmtStorage for RocksDbStorage { /// # Errors /// - `StorageError::Backend`: If `index.depth() < IN_MEMORY_DEPTH`, or if RocksDB errors occur. /// - `StorageError::Value`: If existing Subtree data is corrupt. - fn remove_inner_node(&mut self, index: NodeIndex) -> Result, StorageError> { + fn remove_inner_node(&mut self, index: NodeIndex) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot remove inner node from upper part of the tree".into(), @@ -888,14 +899,14 @@ impl SmtStorage for RocksDbStorage { /// /// # Errors /// - `StorageError::Backend`: If any column family is missing or a RocksDB write error occurs. - fn apply(&mut self, updates: StorageUpdates) -> Result<(), StorageError> { + fn apply(&mut self, updates: StorageUpdates) -> StorageResult<()> { use p3_maybe_rayon::prelude::*; let mut batch = WriteBatch::default(); let leaves_cf = self.cf_handle(LEAVES_CF)?; let metadata_cf = self.cf_handle(METADATA_CF)?; - let depth24_cf = self.cf_handle(DEPTH_24_CF)?; + let in_mem_depth_cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let StorageUpdateParts { leaf_updates, @@ -913,52 +924,52 @@ impl SmtStorage for RocksDbStorage { } } - // Helper for depth 24 operations - let is_depth_24 = |index: NodeIndex| index.depth() == IN_MEMORY_DEPTH; + // Helper for in-memory-depth operations + let is_in_mem_depth = |index: NodeIndex| index.depth() == IN_MEMORY_DEPTH; // Parallel preparation of subtree operations - let subtree_ops: Result, StorageError> = subtree_updates + let subtree_ops: StorageResult> = subtree_updates .into_par_iter() - .map(|update| -> Result<_, StorageError> { - let (index, maybe_bytes, depth24_op) = match update { + .map(|update| -> StorageResult<_> { + let (index, maybe_bytes, in_mem_depth_op) = match update { SubtreeUpdate::Store { index, subtree } => { let bytes = subtree.to_vec(); - let depth24_op = is_depth_24(index) + let in_mem_depth_op = is_in_mem_depth(index) .then(|| subtree.get_inner_node(index)) .flatten() .map(|root_node| { let hash_key = Self::index_db_key(index.position()); (hash_key, Some(root_node.hash().to_bytes())) }); - (index, Some(bytes), depth24_op) + (index, Some(bytes), in_mem_depth_op) }, SubtreeUpdate::Delete { index } => { - let depth24_op = is_depth_24(index).then(|| { + let in_mem_depth_op = is_in_mem_depth(index).then(|| { let hash_key = Self::index_db_key(index.position()); (hash_key, None) }); - (index, None, depth24_op) + (index, None, in_mem_depth_op) }, }; let key = Self::subtree_db_key(index); let subtrees_cf = self.subtree_cf(index); - Ok((subtrees_cf, key, maybe_bytes, depth24_op)) + Ok((subtrees_cf, key, maybe_bytes, in_mem_depth_op)) }) .collect(); // Sequential batch building - for (subtrees_cf, key, maybe_bytes, depth24_op) in subtree_ops? { + for (subtrees_cf, key, maybe_bytes, in_mem_depth_op) in subtree_ops? { match maybe_bytes { Some(bytes) => batch.put_cf(subtrees_cf, key, bytes), None => batch.delete_cf(subtrees_cf, key), } - if let Some((hash_key, maybe_hash_bytes)) = depth24_op { + if let Some((hash_key, maybe_hash_bytes)) = in_mem_depth_op { match maybe_hash_bytes { - Some(hash_bytes) => batch.put_cf(depth24_cf, hash_key, hash_bytes), - None => batch.delete_cf(depth24_cf, hash_key), + Some(hash_bytes) => batch.put_cf(in_mem_depth_cf, hash_key, hash_bytes), + None => batch.delete_cf(in_mem_depth_cf, hash_key), } } } @@ -1003,22 +1014,21 @@ impl Drop for RocksDbStorage { /// An iterator over leaves directly from RocksDB. /// /// Wraps a `DBIteratorWithThreadMode` and handles deserialization of keys to `u64` (leaf index) -/// and values to `SmtLeaf`. Skips items that fail to deserialize or if a RocksDB error occurs -/// for an item, attempting to continue iteration. +/// and values to `SmtLeaf`. struct RocksDbDirectLeafIterator<'a> { iter: DBIteratorWithThreadMode<'a, DB>, } impl Iterator for RocksDbDirectLeafIterator<'_> { - type Item = (u64, SmtLeaf); + type Item = StorageResult<(u64, SmtLeaf)>; fn next(&mut self) -> Option { - self.iter.find_map(|result| { - let (key_bytes, value_bytes) = result.ok()?; - let leaf_idx = index_from_key_bytes(&key_bytes).ok()?; - let leaf = - SmtLeaf::read_from_bytes_with_budget(&value_bytes, value_bytes.len()).ok()?; - Some((leaf_idx, leaf)) + self.iter.next().map(|result| { + result.map_err(StorageError::from).and_then(|(key_bytes, value_bytes)| { + let leaf_idx = index_from_key_bytes(&key_bytes)?; + let leaf = SmtLeaf::read_from_bytes_with_budget(&value_bytes, value_bytes.len())?; + Ok((leaf_idx, leaf)) + }) }) } } @@ -1057,31 +1067,32 @@ impl<'a> RocksDbSubtreeIterator<'a> { } } - fn try_next_from_iter( + fn next_from_iter( iter: &mut DBIteratorWithThreadMode, cf_index: usize, - ) -> Option { - iter.find_map(|result| { - let (key_bytes, value_bytes) = result.ok()?; - let depth = 24 + (cf_index * 8) as u8; - - let node_idx = subtree_root_from_key_bytes(&key_bytes, depth).ok()?; - let value_vec = value_bytes.into_vec(); - Subtree::from_vec(node_idx, &value_vec).ok() + ) -> Option> { + iter.next().map(|result| { + result.map_err(StorageError::from).and_then(|(key_bytes, value_bytes)| { + let depth = IN_MEMORY_DEPTH + (cf_index * 8) as u8; + + let node_idx = subtree_root_from_key_bytes(&key_bytes, depth)?; + let value_vec = value_bytes.into_vec(); + Ok(Subtree::from_vec(node_idx, &value_vec)?) + }) }) } } impl Iterator for RocksDbSubtreeIterator<'_> { - type Item = Subtree; + type Item = StorageResult; fn next(&mut self) -> Option { loop { let iter = self.current_iter.as_mut()?; // Try to get the next valid subtree from current iterator - if let Some(subtree) = Self::try_next_from_iter(iter, self.current_cf_index) { - return Some(subtree); + if let Some(result) = Self::next_from_iter(iter, self.current_cf_index) { + return Some(result); } // Current CF exhausted, advance to next @@ -1210,7 +1221,7 @@ pub(crate) struct KeyBytes { impl KeyBytes { #[inline(always)] pub fn new(value: u64, keep: usize) -> Self { - debug_assert!((3..=7).contains(&keep)); + debug_assert!((2..=7).contains(&keep)); let bytes = value.to_be_bytes(); debug_assert!(bytes[..8 - keep].iter().all(|&b| b == 0)); Self { bytes, len: keep as u8 } @@ -1237,7 +1248,7 @@ impl AsRef<[u8]> for KeyBytes { /// /// # Errors /// - `StorageError::BadKeyLen`: If `key_bytes` is not 8 bytes long or conversion fails. -fn index_from_key_bytes(key_bytes: &[u8]) -> Result { +fn index_from_key_bytes(key_bytes: &[u8]) -> StorageResult { if key_bytes.len() != 8 { return Err(StorageError::BadKeyLen { expected: 8, found: key_bytes.len() }); } @@ -1246,7 +1257,7 @@ fn index_from_key_bytes(key_bytes: &[u8]) -> Result { Ok(u64::from_be_bytes(arr)) } -fn read_count(what: &'static str, bytes: &[u8]) -> Result { +fn read_count(what: &'static str, bytes: &[u8]) -> StorageResult { let arr: [u8; 8] = bytes.try_into().map_err(|_| StorageError::BadValueLen { what, expected: 8, @@ -1255,9 +1266,9 @@ fn read_count(what: &'static str, bytes: &[u8]) -> Result { Ok(usize::from_be_bytes(arr)) } -fn collect_depth24( +fn collect_to_subtree_roots( iter: DBIteratorWithThreadMode<'_, DB>, -) -> Result, StorageError> { +) -> StorageResult> { let mut hashes = Vec::new(); for item in iter { @@ -1280,14 +1291,16 @@ fn collect_depth24( /// - depth 40 → 5 bytes /// - depth 32 → 4 bytes /// - depth 24 → 3 bytes +/// - depth 16 → 2 bytes /// /// # Errors /// * `StorageError::Unsupported` - `depth` is not one of 24/32/40/48/56. /// * `StorageError::DeserializationError` - `key_bytes.len()` does not match the length required by /// `depth`. #[inline(always)] -fn subtree_root_from_key_bytes(key_bytes: &[u8], depth: u8) -> Result { +fn subtree_root_from_key_bytes(key_bytes: &[u8], depth: u8) -> StorageResult { let expected = match depth { + 16 => 2, 24 => 3, 32 => 4, 40 => 5, @@ -1309,6 +1322,7 @@ fn subtree_root_from_key_bytes(key_bytes: &[u8], depth: u8) -> Result &'static str { match depth { + 16 => SUBTREE_16_CF, 24 => SUBTREE_24_CF, 32 => SUBTREE_32_CF, 40 => SUBTREE_40_CF, @@ -1384,7 +1398,7 @@ impl RocksDbSnapshotStorage { } /// Retrieves a handle to a RocksDB column family by its name. - fn cf_handle(&self, name: &str) -> Result<&rocksdb::ColumnFamily, StorageError> { + fn cf_handle(&self, name: &str) -> StorageResult<&rocksdb::ColumnFamily> { self.inner .db .cf_handle(name) @@ -1400,7 +1414,7 @@ impl RocksDbSnapshotStorage { impl SmtStorageReader for RocksDbSnapshotStorage { /// Retrieves the total count of non-empty leaves from the snapshot. - fn leaf_count(&self) -> Result { + fn leaf_count(&self) -> StorageResult { let cf = self.cf_handle(METADATA_CF)?; self.inner .snapshot @@ -1409,7 +1423,7 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Retrieves the total count of key-value entries from the snapshot. - fn entry_count(&self) -> Result { + fn entry_count(&self) -> StorageResult { let cf = self.cf_handle(METADATA_CF)?; self.inner .snapshot @@ -1418,7 +1432,7 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Retrieves a single SMT leaf node by its logical `index` from the snapshot. - fn get_leaf(&self, index: u64) -> Result, StorageError> { + fn get_leaf(&self, index: u64) -> StorageResult> { let cf = self.cf_handle(LEAVES_CF)?; let key = RocksDbStorage::index_db_key(index); match self.inner.snapshot.get_cf(cf, key)? { @@ -1431,7 +1445,7 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Retrieves multiple SMT leaf nodes by their logical `indices` from the snapshot. - fn get_leaves(&self, indices: &[u64]) -> Result>, StorageError> { + fn get_leaves(&self, indices: &[u64]) -> StorageResult>> { let cf = self.cf_handle(LEAVES_CF)?; let db_keys: Vec<[u8; 8]> = indices.iter().map(|&idx| RocksDbStorage::index_db_key(idx)).collect(); @@ -1450,12 +1464,12 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Returns true if the snapshot has any leaves. - fn has_leaves(&self) -> Result { + fn has_leaves(&self) -> StorageResult { Ok(self.leaf_count()? > 0) } /// Retrieves a single SMT Subtree by its root `NodeIndex` from the snapshot. - fn get_subtree(&self, index: NodeIndex) -> Result, StorageError> { + fn get_subtree(&self, index: NodeIndex) -> StorageResult> { let cf = self.subtree_cf(index); let key = RocksDbStorage::subtree_db_key(index); match self.inner.snapshot.get_cf(cf, key)? { @@ -1468,10 +1482,10 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Retrieves multiple subtrees from the snapshot. - fn get_subtrees(&self, indices: &[NodeIndex]) -> Result>, StorageError> { + fn get_subtrees(&self, indices: &[NodeIndex]) -> StorageResult>> { use p3_maybe_rayon::prelude::*; - let mut depth_buckets: [Vec<(usize, NodeIndex)>; 5] = Default::default(); + let mut depth_buckets: [Vec<(usize, NodeIndex)>; 6] = Default::default(); for (original_index, &node_index) in indices.iter().enumerate() { let depth = node_index.depth(); @@ -1481,6 +1495,7 @@ impl SmtStorageReader for RocksDbSnapshotStorage { 40 => 2, 32 => 3, 24 => 4, + 16 => 5, _ => { return Err(StorageError::Unsupported(format!( "unsupported subtree depth {depth}" @@ -1491,36 +1506,32 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } let mut results = vec![None; indices.len()]; - let bucket_results: Result, StorageError> = depth_buckets + let bucket_results: StorageResult> = depth_buckets .into_par_iter() .enumerate() .filter(|(_, bucket)| !bucket.is_empty()) - .map( - |(bucket_index, bucket)| -> Result)>, StorageError> { - let depth = LargeSmt::::SUBTREE_DEPTHS[bucket_index]; - let cf = self.cf_handle(cf_for_depth(depth))?; - let keys: Vec<_> = bucket - .iter() - .map(|(_, idx)| RocksDbStorage::subtree_db_key(*idx)) - .collect(); - - let db_results = - self.inner.snapshot.multi_get_cf(keys.iter().map(|k| (cf, k.as_ref()))); - - bucket - .into_iter() - .zip(db_results) - .map(|((original_index, node_index), db_result)| { - let subtree = match db_result { - Ok(Some(bytes)) => Some(Subtree::from_vec(node_index, &bytes)?), - Ok(None) => None, - Err(e) => return Err(e.into()), - }; - Ok((original_index, subtree)) - }) - .collect() - }, - ) + .map(|(bucket_index, bucket)| -> StorageResult)>> { + let depth = LargeSmt::::SUBTREE_DEPTHS[bucket_index]; + let cf = self.cf_handle(cf_for_depth(depth))?; + let keys: Vec<_> = + bucket.iter().map(|(_, idx)| RocksDbStorage::subtree_db_key(*idx)).collect(); + + let db_results = + self.inner.snapshot.multi_get_cf(keys.iter().map(|k| (cf, k.as_ref()))); + + bucket + .into_iter() + .zip(db_results) + .map(|((original_index, node_index), db_result)| { + let subtree = match db_result { + Ok(Some(bytes)) => Some(Subtree::from_vec(node_index, &bytes)?), + Ok(None) => None, + Err(e) => return Err(e.into()), + }; + Ok((original_index, subtree)) + }) + .collect() + }) .collect(); for bucket_result in bucket_results? { @@ -1533,7 +1544,7 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Retrieves a single inner node from within a snapshot subtree. - fn get_inner_node(&self, index: NodeIndex) -> Result, StorageError> { + fn get_inner_node(&self, index: NodeIndex) -> StorageResult> { if index.depth() < IN_MEMORY_DEPTH { return Err(StorageError::Unsupported( "Cannot get inner node from upper part of the tree".into(), @@ -1546,7 +1557,9 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Returns an iterator over all leaves in this snapshot. - fn iter_leaves(&self) -> Result + '_>, StorageError> { + fn iter_leaves( + &self, + ) -> StorageResult> + '_>> { let cf = self.cf_handle(LEAVES_CF)?; let mut read_opts = ReadOptions::default(); read_opts.set_total_order_seek(true); @@ -1556,9 +1569,17 @@ impl SmtStorageReader for RocksDbSnapshotStorage { } /// Returns an iterator over all subtrees in this snapshot. - fn iter_subtrees(&self) -> Result + '_>, StorageError> { - const SUBTREE_CFS: [&str; 5] = - [SUBTREE_24_CF, SUBTREE_32_CF, SUBTREE_40_CF, SUBTREE_48_CF, SUBTREE_56_CF]; + fn iter_subtrees( + &self, + ) -> StorageResult> + '_>> { + const SUBTREE_CFS: [&str; 6] = [ + SUBTREE_16_CF, + SUBTREE_24_CF, + SUBTREE_32_CF, + SUBTREE_40_CF, + SUBTREE_48_CF, + SUBTREE_56_CF, + ]; let mut cf_handles = Vec::new(); for cf_name in SUBTREE_CFS { @@ -1568,11 +1589,11 @@ impl SmtStorageReader for RocksDbSnapshotStorage { Ok(Box::new(RocksDbSnapshotSubtreeIterator::new(&self.inner.snapshot, cf_handles))) } - /// Retrieves all depth 24 hashes from this snapshot. - fn get_depth24(&self) -> Result, StorageError> { - let cf = self.cf_handle(DEPTH_24_CF)?; + /// Retrieves roots of all top level subtrees for efficient startup reconstruction. + fn get_top_subtree_roots(&self) -> StorageResult> { + let cf = self.cf_handle(IN_MEM_DEPTH_CF)?; let iter = self.inner.snapshot.iterator_cf(cf, IteratorMode::Start); - collect_depth24(iter) + collect_to_subtree_roots(iter) } } @@ -1613,16 +1634,16 @@ impl<'a> RocksDbSnapshotSubtreeIterator<'a> { } impl Iterator for RocksDbSnapshotSubtreeIterator<'_> { - type Item = Subtree; + type Item = StorageResult; fn next(&mut self) -> Option { loop { let iter = self.current_iter.as_mut()?; - if let Some(subtree) = - RocksDbSubtreeIterator::try_next_from_iter(iter, self.current_cf_index) + if let Some(result) = + RocksDbSubtreeIterator::next_from_iter(iter, self.current_cf_index) { - return Some(subtree); + return Some(result); } self.current_cf_index += 1; diff --git a/miden-crypto/src/merkle/smt/large/tests.rs b/miden-crypto/src/merkle/smt/large/tests.rs index 442416d290..198d1ebb67 100644 --- a/miden-crypto/src/merkle/smt/large/tests.rs +++ b/miden-crypto/src/merkle/smt/large/tests.rs @@ -1,6 +1,6 @@ use alloc::{collections::BTreeSet, vec::Vec}; -use rand::{Rng, prelude::IteratorRandom, rng}; +use rand::{RngExt, prelude::IteratorRandom, rng}; use super::{IN_MEMORY_DEPTH, MemoryStorage, SmtStorage, SmtStorageReader}; use crate::{ @@ -35,7 +35,7 @@ fn generate_updates(entries: Vec<(Word, Word)>, updates: usize) -> Vec<(Word, Wo ); let mut sorted_entries: Vec<(Word, Word)> = entries .into_iter() - .choose_multiple(&mut rng, updates) + .sample(&mut rng, updates) .into_iter() .map(|(key, _)| { let value = if rng.random_bool(REMOVAL_PROBABILITY) { @@ -105,7 +105,8 @@ fn test_equivalent_entry_sets() { let (control_smt, large_smt) = create_equivalent_smts_for_testing(storage, entries); let mut entries_control_smt_owned: Vec<(Word, Word)> = control_smt.entries().copied().collect(); - let mut entries_large_smt: Vec<(Word, Word)> = large_smt.entries().unwrap().collect(); + let mut entries_large_smt: Vec<(Word, Word)> = + large_smt.entries().unwrap().collect::, _>>().unwrap(); entries_control_smt_owned.sort_by_key(|k| k.0); entries_large_smt.sort_by_key(|k| k.0); @@ -124,7 +125,7 @@ fn test_equivalent_leaf_sets() { let mut leaves_control_smt: Vec<(LeafIndex, SmtLeaf)> = control_smt.leaves().map(|(idx, leaf_ref)| (idx, leaf_ref.clone())).collect(); let mut leaves_large_smt: Vec<(LeafIndex, SmtLeaf)> = - large_smt.leaves().unwrap().collect(); + large_smt.leaves().unwrap().collect::, _>>().unwrap(); leaves_control_smt.sort_by_key(|k| k.0); leaves_large_smt.sort_by_key(|k| k.0); @@ -142,7 +143,8 @@ fn test_equivalent_inner_nodes() { let (control_smt, large_smt) = create_equivalent_smts_for_testing(storage, entries); let mut control_smt_inner_nodes: Vec = control_smt.inner_nodes().collect(); - let mut large_smt_inner_nodes: Vec = large_smt.inner_nodes().unwrap().collect(); + let mut large_smt_inner_nodes: Vec = + large_smt.inner_nodes().unwrap().collect::, _>>().unwrap(); control_smt_inner_nodes.sort_by_key(|info| info.value); large_smt_inner_nodes.sort_by_key(|info| info.value); @@ -185,10 +187,18 @@ fn test_empty_smt() { "get_value on empty SMT should return EMPTY_WORD" ); - assert_eq!(large_smt.entries().unwrap().count(), 0, "Empty SMT should have no entries"); - assert_eq!(large_smt.leaves().unwrap().count(), 0, "Empty SMT should have no leaves"); assert_eq!( - large_smt.inner_nodes().unwrap().count(), + large_smt.entries().unwrap().collect::, _>>().unwrap().len(), + 0, + "Empty SMT should have no entries" + ); + assert_eq!( + large_smt.leaves().unwrap().collect::, _>>().unwrap().len(), + 0, + "Empty SMT should have no leaves" + ); + assert_eq!( + large_smt.inner_nodes().unwrap().collect::, _>>().unwrap().len(), 0, "Empty SMT should have no inner nodes" ); @@ -210,7 +220,7 @@ fn test_single_entry_smt() { let other_key = Word::from([2_u32, 2_u32, 2_u32, 2_u32]); assert_eq!(smt.get_value(&other_key), EMPTY_WORD, "get_value for non-existing key failed"); - let entries: Vec<_> = smt.entries().unwrap().collect(); + let entries: Vec<_> = smt.entries().unwrap().collect::, _>>().unwrap(); assert_eq!(entries.len(), 1, "Single entry SMT should have one entry"); assert_eq!(entries[0], (key, value), "Single entry SMT entry mismatch"); @@ -241,7 +251,11 @@ fn test_single_entry_smt() { let empty_control_smt = Smt::new(); assert_eq!(smt.root(), empty_control_smt.root(), "SMT root after deletion mismatch"); assert_eq!(smt.get_value(&key), EMPTY_WORD, "get_value after deletion failed"); - assert_eq!(smt.entries().unwrap().count(), 0, "SMT should have no entries after deletion"); + assert_eq!( + smt.entries().unwrap().collect::, _>>().unwrap().len(), + 0, + "SMT should have no entries after deletion" + ); } #[test] @@ -337,7 +351,7 @@ fn test_delete_entry() { "get_value for deleted key should be EMPTY_WORD" ); - let current_entries: Vec<_> = smt.entries().unwrap().collect(); + let current_entries: Vec<_> = smt.entries().unwrap().collect::, _>>().unwrap(); assert!( !current_entries.iter().any(|(k, _v)| k == &key2), "Deleted key should not be in entries" @@ -731,11 +745,11 @@ fn test_flat_layout_children_relationship() { } } -/// Verifies that a snapshot produced by `MemoryStorage::reader()` returns correct depth-24 roots -/// from `get_depth24()`, and that loading a `LargeSmt` from that snapshot reconstructs the same -/// root as the original tree. +/// Verifies that a snapshot produced by `MemoryStorage::reader()` returns correct top subtree +/// roots from `get_top_subtree_roots()`, and that loading a `LargeSmt` from that snapshot +/// reconstructs the same root as the original tree. #[test] -fn test_memory_storage_snapshot_depth24() { +fn test_memory_storage_snapshot_in_mem_depth() { use crate::merkle::NodeIndex; let entries = generate_entries(50); @@ -745,17 +759,20 @@ fn test_memory_storage_snapshot_depth24() { let snapshot = smt.storage.reader().unwrap(); - // The depth-24 entries must be non-empty for a non-empty tree. - let depth24 = snapshot.get_depth24().unwrap(); - assert!(!depth24.is_empty(), "snapshot must expose depth-24 roots for a non-empty tree"); + // The in-memory-depth entries must be non-empty for a non-empty tree. + let in_mem_roots = snapshot.get_top_subtree_roots().unwrap(); + assert!( + !in_mem_roots.is_empty(), + "snapshot must expose in-memory-depth roots for a non-empty tree" + ); // Every returned entry must sit exactly at IN_MEMORY_DEPTH. - for (position, _hash) in &depth24 { + for (position, _hash) in &in_mem_roots { let index = NodeIndex::new(IN_MEMORY_DEPTH, *position).unwrap(); assert_eq!( index.depth(), IN_MEMORY_DEPTH, - "depth-24 entry at position {position} has wrong depth" + "in-memory-depth entry at position {position} has wrong depth" ); } diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/memory/mod.rs b/miden-crypto/src/merkle/smt/large_forest/backend/memory/mod.rs index 0b6aa1a475..ab7cb18140 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/memory/mod.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/memory/mod.rs @@ -7,19 +7,72 @@ mod tests; use alloc::vec::Vec; +#[cfg(test)] +use crate::merkle::smt::large_forest::operation::SmtUpdateBatch; use crate::{ EMPTY_WORD, Map, Word, - merkle::smt::{ - LeafIndex, SMT_DEPTH, Smt, SmtLeaf, SmtProof, VersionId, - large_forest::{ - Backend, - backend::{BackendError, MutationSet, Result}, - operation::{SmtForestUpdateBatch, SmtUpdateBatch}, - root::{LineageId, TreeEntry, TreeWithRoot}, + merkle::{ + MerkleError, + smt::{ + LeafIndex, SMT_DEPTH, Smt, SmtLeaf, SmtProof, VersionId, + large_forest::{ + Backend, BackendReader, + backend::{BackendError, Result}, + operation::SmtForestUpdateBatch, + root::{LineageId, TreeEntry, TreeWithRoot}, + utils::{ + AppliedLineageMutation, LineageMutation, LineageMutationKind, MutationSet, + }, + }, }, }, }; +// IN-MEMORY BACKEND SNAPSHOT +// ================================================================================================ + +/// A read-only, point-in-time snapshot of an [`InMemoryBackend`]. +/// +/// This type intentionally implements only [`BackendReader`], not [`Backend`]. It is returned by +/// [`InMemoryBackend::reader`] to hand out a detached copy of the backend state without exposing +/// any mutation capabilities. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct InMemoryBackendSnapshot(InMemoryBackend); + +impl BackendReader for InMemoryBackendSnapshot { + fn open(&self, lineage: LineageId, key: Word) -> Result { + self.0.open(lineage, key) + } + + fn get_leaf(&self, lineage: LineageId, leaf_index: LeafIndex) -> Result { + self.0.get_leaf(lineage, leaf_index) + } + + fn get(&self, lineage: LineageId, key: Word) -> Result> { + self.0.get(lineage, key) + } + + fn version(&self, lineage: LineageId) -> Result { + self.0.version(lineage) + } + + fn lineages(&self) -> Result> { + self.0.lineages() + } + + fn trees(&self) -> Result> { + self.0.trees() + } + + fn entry_count(&self, lineage: LineageId) -> Result { + self.0.entry_count(lineage) + } + + fn entries(&self, lineage: LineageId) -> Result>> { + self.0.entries(lineage) + } +} + // IN-MEMORY BACKEND // ================================================================================================ @@ -31,18 +84,69 @@ pub struct InMemoryBackend { trees: Map, } +/// Prepared mutations for [`InMemoryBackend`]. +/// +/// This is the in-memory backend's concrete [`Backend::PreparedMutations`] type. It stores the +/// forward SMT mutation sets that were computed during the first phase of a forest update. Applying +/// it mutates the in-memory trees directly without recomputing the update batches. +/// +/// The fields are private because callers should treat prepared mutation data as opaque and pass it +/// back through +/// [`LargeSmtForest::apply_mutations`](crate::merkle::smt::LargeSmtForest::apply_mutations). +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct InMemoryPreparedMutations { + entries: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct InMemoryPreparedLineageMutation { + lineage: LineageId, + old_version: Option, + version: VersionId, + forward: MutationSet, + kind: LineageMutationKind, +} + impl InMemoryBackend { /// Constructs a new instance of the in-memory backend. pub fn new() -> Self { let trees = Map::default(); Self { trees } } + + /// Converts this backend into a read-only snapshot. + pub fn into_snapshot(self) -> InMemoryBackendSnapshot { + InMemoryBackendSnapshot(self) + } + + fn mutation_from_tree( + lineage: LineageId, + old_version: Option, + new_version: VersionId, + kind: LineageMutationKind, + forward: MutationSet, + ) -> (LineageMutation, InMemoryPreparedLineageMutation) { + let old_root = forward.old_root(); + let new_root = forward.root(); + + let mutation = + LineageMutation::new(lineage, old_version, new_version, old_root, new_root, kind); + let prepared = InMemoryPreparedLineageMutation { + lineage, + old_version, + version: new_version, + forward, + kind, + }; + + (mutation, prepared) + } } -// BACKEND TRAIT +// BACKEND READER TRAIT // ================================================================================================ -impl Backend for InMemoryBackend { +impl BackendReader for InMemoryBackend { /// Returns an opening for the specified `key` in the SMT with the specified `lineage`. /// /// # Errors @@ -134,7 +238,163 @@ impl Backend for InMemoryBackend { let tree = self.trees.get(&lineage).ok_or(BackendError::UnknownLineage(lineage))?; Ok(tree.tree.entries().map(|(k, v)| Ok(TreeEntry { key: *k, value: *v }))) } +} + +// BACKEND TRAIT +// ================================================================================================ + +impl Backend for InMemoryBackend { + type Reader = InMemoryBackendSnapshot; + type PreparedMutations = InMemoryPreparedMutations; + + fn reader(&self) -> Result { + Ok(self.clone().into_snapshot()) + } + + /// Computes the mutations required to apply the provided `updates` on the forest. + fn compute_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> Result<(Vec, Self::PreparedMutations)> { + let updates = updates.into_iter().collect::>(); + let mut mutations = Vec::with_capacity(updates.len()); + let mut prepared = Vec::with_capacity(updates.len()); + for (lineage, ops) in updates { + let (old_version, kind, forward) = if let Some(tree_data) = self.trees.get(&lineage) { + ( + Some(tree_data.version), + LineageMutationKind::UpdateTree, + tree_data.tree.compute_mutations(ops.into_iter().map(Into::into))?, + ) + } else { + ( + None, + LineageMutationKind::AddLineage, + Smt::new().compute_mutations(ops.into_iter().map(Into::into))?, + ) + }; + let (mutation, prepared_entry) = + Self::mutation_from_tree(lineage, old_version, new_version, kind, forward); + mutations.push(mutation); + prepared.push(prepared_entry); + } + + Ok((mutations, InMemoryPreparedMutations { entries: prepared })) + } + + /// Apply a mutation set to the entire forest, returning the mutation sets that would reverse + /// the changes to each lineage in the forest. + /// + /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. + /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by this backend. + fn apply_mutations( + &mut self, + mutations: Self::PreparedMutations, + ) -> Result> { + // We start by checking that all lineages referred to in the `mutations` are valid, + // failing early with an error if need be. + for mutation in &mutations.entries { + match mutation.kind { + LineageMutationKind::AddLineage => { + if self.trees.contains_key(&mutation.lineage) { + return Err(BackendError::DuplicateLineage(mutation.lineage)); + } + }, + LineageMutationKind::UpdateTree => { + let tree_data = self + .trees + .get(&mutation.lineage) + .ok_or(BackendError::UnknownLineage(mutation.lineage))?; + + if Some(tree_data.version) != mutation.old_version { + return Err(BackendError::BadVersion { + provided: mutation.old_version.unwrap_or_default(), + latest: tree_data.version, + }); + } + + let old_root = mutation.forward.old_root(); + let latest_root = tree_data.tree.root(); + if latest_root != old_root { + return Err(MerkleError::ConflictingRoots { + expected_root: old_root, + actual_root: latest_root, + } + .into()); + } + }, + } + } + + let mut applied = Vec::with_capacity(mutations.entries.len()); + + // Apply mutations to each lineage. + for mutation in mutations.entries { + let old_root = mutation.forward.old_root(); + let new_root = mutation.forward.root(); + match mutation.kind { + LineageMutationKind::AddLineage => { + let mut tree = Smt::new(); + let reverse = MutationSet::default(); + + if !mutation.forward.is_empty() { + tree.apply_mutations(mutation.forward) + .map_err(BackendError::internal_from)?; + } + + applied.push(AppliedLineageMutation::new( + mutation.lineage, + mutation.old_version, + mutation.version, + old_root, + new_root, + 0, + reverse, + mutation.kind, + )); + self.trees + .insert(mutation.lineage, TreeData { version: mutation.version, tree }); + }, + LineageMutationKind::UpdateTree => { + let tree_data = self + .trees + .get_mut(&mutation.lineage) + .ok_or(BackendError::UnknownLineage(mutation.lineage))?; + let old_entry_count = tree_data.tree.num_entries(); + + let reverse = if mutation.forward.is_empty() { + mutation.forward + } else { + let reverse = tree_data + .tree + .apply_mutations_with_reversion(mutation.forward) + .map_err(BackendError::internal_from)?; + tree_data.version = mutation.version; + reverse + }; + + applied.push(AppliedLineageMutation::new( + mutation.lineage, + mutation.old_version, + mutation.version, + old_root, + new_root, + old_entry_count, + reverse, + mutation.kind, + )); + }, + } + } + + Ok(applied) + } +} +// These are the implementations of helper methods used by the backend tests. +#[cfg(test)] +impl InMemoryBackend { /// Adds the provided `lineage` to the forest. /// /// # Errors @@ -142,35 +402,27 @@ impl Backend for InMemoryBackend { /// - [`BackendError::DuplicateLineage`] if the provided `lineage` is the same as an /// already-known lineage. No data is changed in this case. /// - [`BackendError::Merkle`] if the provided `updates` cannot be applied to the empty tree. - fn add_lineage( + pub(crate) fn add_lineage( &mut self, lineage: LineageId, version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // Returning this in the case of a duplicate lineage is required by the method contract on - // the `Backend` trait. if self.trees.contains_key(&lineage) { return Err(BackendError::DuplicateLineage(lineage)); } - let mut tree = Smt::new(); - - // A failure to compute mutations is a failure derived from user input, so we forward it as - // appropriate. - let mutations = tree.compute_mutations(updates.into_iter().map(Into::into))?; + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (_mutations, persistent_mutations) = self.compute_mutations(version, batch)?; - // If computation of the mutations has succeeded but the application fails, then this should - // be reported as an internal error, not a merkle error, to allow the caller to decide what - // to do. - tree.apply_mutations(mutations).map_err(BackendError::internal_from)?; + let mut applied_mutations = self.apply_mutations(persistent_mutations)?; + let applied_mutation = applied_mutations + .pop() + .expect("should have applied exactly one lineage mutation"); - // The following has had its preconditions checked, so we can change the state without - // worrying about consistency. - let tree_data = TreeData { version, tree }; - let root = tree_data.tree.root(); - self.trees.insert(lineage, tree_data); - Ok(TreeWithRoot::new(lineage, version, root)) + // Finally we just return the necessary metadata. + Ok(TreeWithRoot::new(lineage, version, applied_mutation.new_root())) } /// Performs the provided `updates` on the tree with the specified `lineage`, returning the @@ -183,43 +435,27 @@ impl Backend for InMemoryBackend { /// - [`BackendError::Merkle`] if the application of `updates` to the tree fails for any reason. /// - [`BackendError::UnknownLineage`] If the provided `lineage` is one not known by this /// backend. - fn update_tree( + pub(crate) fn update_tree( &mut self, lineage: LineageId, new_version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // The method contract requires raising this error in the case that `lineage` is unknown to - // the backend. - let tree_data = - self.trees.get_mut(&lineage).ok_or(BackendError::UnknownLineage(lineage))?; - let tree = &mut tree_data.tree; - - // We compute the mutations as a precondition check, which will leave the underlying tree in - // the same state if anything errors. Any error this yields is considered to be derived from - // user-input and hence is forwarded as-is. - let mutations = tree.compute_mutations(updates.into_iter().map(Into::into))?; - - // The invariants on this method given by the `Backend` trait states that no new allocations - // should be performed if the updates do not change the tree. As a result, we can - // short-circuit even trying. - if mutations.is_empty() { - // As the reverse of an empty mutations is also empty mutations, we can just return - // that. - return Ok(mutations); + if !self.trees.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); } - // Any failure to apply the mutations here is considered an internal error, so we transform - // it as such. - let reversion_set = tree - .apply_mutations_with_reversion(mutations) - .map_err(BackendError::internal_from)?; + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (_mutations, persistent_mutations) = self.compute_mutations(new_version, batch)?; - // With preconditions checked, we can actually perform our modifications as it should yield - // a consistent state. - tree_data.version = new_version; + let mut applied_mutations = self.apply_mutations(persistent_mutations)?; + let applied_mutation = applied_mutations + .pop() + .expect("should have applied exactly one lineage mutation"); - Ok(reversion_set) + // We then just return the reversion set for the operations in question. + Ok(applied_mutation.into_reverse()) } /// Adds multiple new `lineages` to the tree, creating an empty tree for each and applying the @@ -234,54 +470,26 @@ impl Backend for InMemoryBackend { /// lineage. No data is changed in this case. /// - [`BackendError::Merkle`] if any of the provided updates cannot be applied on top of the /// empty tree. - fn add_lineages( + pub(crate) fn add_lineages( &mut self, version: VersionId, lineages: SmtForestUpdateBatch, ) -> Result> { - // We start by checking that all lineages referred to in the batch of `updates` are valid, - // failing early with an error if need be. - let updates = lineages - .into_iter() - .map(|(lineage, ops)| { - if self.trees.contains_key(&lineage) { - return Err(BackendError::DuplicateLineage(lineage)); - } - Ok((lineage, ops)) - }) - .collect::>>()?; - - // Next, we compute all the relevant mutations to each tree, also failing with an error - // where relevant. - let mutations = updates - .into_iter() - .map(|(lineage, ops)| { - let tree = Smt::new(); - let mutations = tree.compute_mutations(ops.into_iter().map(Into::into))?; - Ok((lineage, tree, mutations)) - }) - .collect::>>()?; - - // With the preconditions checked, we can unconditionally perform the changes on all trees. - // We apply all mutations first without modifying self.trees, so that if any mutation - // fails, no data is changed. - let applied = mutations - .into_iter() - .map(|(lineage, mut tree, mutations)| { - tree.apply_mutations(mutations).map_err(BackendError::internal_from)?; - Ok((lineage, tree)) - }) - .collect::>>()?; - - // Once they have all succeeded, we then modify the data in memory. - let results = applied + for lineage in lineages.lineages() { + if self.trees.contains_key(lineage) { + return Err(BackendError::DuplicateLineage(*lineage)); + } + } + + let (_mutations, persistent_mutations) = self.compute_mutations(version, lineages)?; + + let applied_mutations = self.apply_mutations(persistent_mutations)?; + + // Build the return value from the applied mutations. + let results = applied_mutations .into_iter() - .map(|(lineage, tree)| { - let root = tree.root(); - self.trees.insert(lineage, TreeData { version, tree }); - (lineage, TreeWithRoot::new(lineage, version, root)) - }) - .collect::>(); + .map(|applied_mutation| (applied_mutation.lineage(), applied_mutation.result())) + .collect(); Ok(results) } @@ -301,54 +509,26 @@ impl Backend for InMemoryBackend { /// # Panics /// /// - If a tree that has been checked to be present is not present upon later access. - fn update_forest( + pub(crate) fn update_forest( &mut self, new_version: VersionId, updates: SmtForestUpdateBatch, ) -> Result> { - // We start by checking that all lineages referred to in the batch of `updates` are valid, - // failing early with an error if need be. - let updates = updates - .into_iter() - .map(|(lineage, ops)| { - if !self.trees.contains_key(&lineage) { - return Err(BackendError::UnknownLineage(lineage)); - } - - Ok((lineage, ops)) - }) - .collect::>>()?; - - // Next, we compute all the relevant mutations to each tree, also failing with an error - // where relevant. - let mutations = updates - .into_iter() - .map(|(lineage, ops)| { - let tree = self.trees.get(&lineage).expect("Tree known to be present was not"); - let mutations = tree.tree.compute_mutations(ops.into_iter().map(Into::into))?; - Ok((lineage, mutations)) - }) - .collect::>>()?; - - // With the preconditions checked, we can unconditionally perform the changes on all trees. - let reversion_sets = mutations + for lineage in updates.lineages() { + if !self.trees.contains_key(lineage) { + return Err(BackendError::UnknownLineage(*lineage)); + } + } + + let (_mutations, persistent_mutations) = self.compute_mutations(new_version, updates)?; + + let applied_mutations = self.apply_mutations(persistent_mutations)?; + + // Build the return value from the applied mutations. + let reversion_sets = applied_mutations .into_iter() - .map(|(lineage, mutations)| { - if mutations.is_empty() { - // The inverse of empty mutations is empty mutations. - Ok((lineage, mutations)) - } else { - let tree = - self.trees.get_mut(&lineage).expect("Tree known to be present was not"); - let reversion = tree - .tree - .apply_mutations_with_reversion(mutations) - .map_err(BackendError::internal_from)?; - tree.version = new_version; - Ok((lineage, reversion)) - } - }) - .collect::>>()?; + .map(|applied_mutation| (applied_mutation.lineage(), applied_mutation.into_reverse())) + .collect(); Ok(reversion_sets) } diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/memory/property_tests.rs b/miden-crypto/src/merkle/smt/large_forest/backend/memory/property_tests.rs index 466ab48294..efc68cafe9 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/memory/property_tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/memory/property_tests.rs @@ -9,7 +9,7 @@ use proptest::prelude::*; use crate::{ EMPTY_WORD, merkle::smt::{ - Backend, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeWithRoot, + BackendReader, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeWithRoot, large_forest::{ InMemoryBackend, test_utils::{ diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/memory/tests.rs b/miden-crypto/src/merkle/smt/large_forest/backend/memory/tests.rs index 3268043ae8..f9baed08ef 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/memory/tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/memory/tests.rs @@ -13,7 +13,7 @@ use itertools::Itertools; use crate::{ EMPTY_WORD, Word, merkle::smt::{ - Backend, BackendError, Smt, SmtForestUpdateBatch, SmtUpdateBatch, VersionId, + Backend, BackendError, BackendReader, Smt, SmtForestUpdateBatch, SmtUpdateBatch, VersionId, large_forest::{ InMemoryBackend, backend::Result, @@ -622,6 +622,85 @@ fn update_tree() -> Result<()> { Ok(()) } +#[test] +fn apply_mutations_returns_reversion_data() -> Result<()> { + let mut backend = InMemoryBackend::new(); + let mut rng = ContinuousRng::new([0x77; 32]); + + let lineage: LineageId = rng.value(); + let version_1: VersionId = rng.value(); + let version_2: VersionId = rng.value(); + let key_1: Word = rng.value(); + let value_1: Word = rng.value(); + let key_2: Word = rng.value(); + let value_2: Word = rng.value(); + let value_3: Word = rng.value(); + + let mut initial = SmtUpdateBatch::default(); + initial.add_insert(key_1, value_1); + initial.add_insert(key_2, value_2); + backend.add_lineage(lineage, version_1, initial)?; + + let mut reference = Smt::new(); + reference.insert(key_1, value_1)?; + reference.insert(key_2, value_2)?; + let old_entry_count = reference.num_entries(); + + let mut updates = SmtUpdateBatch::default(); + updates.add_insert(key_2, value_3); + let expected_forward = reference.compute_mutations([(key_2, value_3)])?; + let expected_reverse = reference.apply_mutations_with_reversion(expected_forward)?; + + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (_, prepared) = backend.compute_mutations(version_2, batch)?; + let applied = backend.apply_mutations(prepared)?; + + assert_eq!(applied.len(), 1); + assert_eq!(applied[0].lineage(), lineage); + assert_eq!(applied[0].old_entry_count(), old_entry_count); + assert_eq!(applied[0].reverse(), &expected_reverse); + assert_eq!(backend.version(lineage)?, version_2); + assert!(backend.trees()?.any(|tree| tree.root() == reference.root())); + + Ok(()) +} + +#[test] +fn apply_mutations_rejects_stale_prepared_update() -> Result<()> { + let mut backend = InMemoryBackend::new(); + let mut rng = ContinuousRng::new([0xa5; 32]); + + let lineage: LineageId = rng.value(); + let key_1: Word = rng.value(); + let value_1: Word = rng.value(); + let key_2: Word = rng.value(); + let value_2: Word = rng.value(); + let key_3: Word = rng.value(); + let value_3: Word = rng.value(); + + let mut initial = SmtUpdateBatch::default(); + initial.add_insert(key_1, value_1); + backend.add_lineage(lineage, 1, initial)?; + + let mut stale_updates = SmtUpdateBatch::default(); + stale_updates.add_insert(key_2, value_2); + let mut stale_batch = SmtForestUpdateBatch::empty(); + stale_batch.operations(lineage).add_operations(stale_updates.into_iter()); + let (_visible, stale_prepared) = backend.compute_mutations(2, stale_batch)?; + + let mut intervening_updates = SmtUpdateBatch::default(); + intervening_updates.add_insert(key_3, value_3); + backend.update_tree(lineage, 2, intervening_updates)?; + + assert!( + backend.apply_mutations(stale_prepared).is_err(), + "stale prepared mutations must not apply after the lineage root changes" + ); + + Ok(()) +} + #[test] fn update_forest() -> Result<()> { let mut backend = InMemoryBackend::new(); diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/mod.rs b/miden-crypto/src/merkle/smt/large_forest/backend/mod.rs index fc6076fcc4..7536f517f6 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/mod.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/mod.rs @@ -17,25 +17,26 @@ use crate::{ smt::{ LeafIndex, SMT_DEPTH, SmtLeaf, SmtProof, large_forest::{ - operation::{SmtForestUpdateBatch, SmtUpdateBatch}, + operation::SmtForestUpdateBatch, root::{LineageId, TreeEntry, TreeWithRoot, VersionId}, - utils::MutationSet, + utils::{AppliedLineageMutation, LineageMutation}, }, }, }, }; -// BACKEND +// BACKEND READER // ================================================================================================ -/// The backing storage for the SMT forest, providing the necessary high-level methods for -/// performing operations on the full trees that make up the forest, while allowing the forest -/// itself to be storage agnostic. +/// The read-only interface for the SMT forest storage backend. +/// +/// This trait provides the query operations necessary to read the full trees that make up the +/// forest. It is a supertrait of [`Backend`], which extends it with write operations. /// /// # Backend Data Storage /// -/// Having a generic [`Backend`] provides no guarantees to the user about how it stores data and -/// what patterns are used for data access under the hood. It is, however, guaranteed to store +/// Having a generic [`BackendReader`] provides no guarantees to the user about how it stores data +/// and what patterns are used for data access under the hood. It is, however, guaranteed to store /// _only_ the data necessary to describe the latest state of each tree in the forest. /// /// # Error Handling @@ -56,12 +57,6 @@ use crate::{ /// /// # Expected Behavior /// -/// Certain methods on this trait (e.g. [`Backend::update_tree`]) provide behaviors expected for -/// that method. These combine with the following trait-level behavior requirements to become part -/// of the contract of the method, but a portion that cannot be encoded in the type system. Any -/// failure to conform to these expected behaviors is **considered a bug in the implementation** of -/// the backend, and must be rectified. -/// /// The following behavior is expected of all methods in implementations of this trait: /// /// - For any failure derived from user input (see _User-Derived Errors_ above), the data and the @@ -70,13 +65,10 @@ use crate::{ /// caller by returning a variant of [`BackendError`] that is **not [`BackendError::Internal`]**. /// Methods may place additional constraints on which errors are used to signal certain failures. /// Such failures should not lead to data corruption of any persistent data. -pub trait Backend +pub trait BackendReader where Self: Debug, { - // QUERIES - // ============================================================================================ - /// Returns an opening for the specified `key` in the SMT with the specified `lineage`. /// /// It is the responsibility of the forest to ensure lineage existence before querying the @@ -148,80 +140,93 @@ where /// - `None` will be returned upon successful completion, or at any time after an error has been /// returned. fn entries(&self, lineage: LineageId) -> Result>>; +} - // SINGLE-TREE MODIFIERS - // ============================================================================================ +// BACKEND +// ================================================================================================ - /// Adds a new `lineage` to the forest with the provided `version` and sets the associated SMT - /// to have the value created by applying `updates` to the empty tree, returning the new root of - /// that tree. - /// - /// # Expected Behavior - /// - /// Implementations must guarantee the following behavior in addition to the global invariants: +/// The full read-write interface for the SMT forest storage backend. +/// +/// This trait extends [`BackendReader`] with mutation operations, allowing the forest to add new +/// lineages and update existing ones. +/// +/// # Implementation Contract +/// +/// Method-level doc comments describe invariants that cannot be encoded in the type system. +/// Implementations are responsible for upholding them. +pub trait Backend: BackendReader { + /// The read-only view type returned by [`Self::reader`]. /// - /// - If the provided `lineage` conflicts with an already-existing lineage in the backend, it - /// must return [`BackendError::DuplicateLineage`]. - fn add_lineage( - &mut self, - lineage: LineageId, - version: VersionId, - updates: SmtUpdateBatch, - ) -> Result; + /// The returned type implements [`BackendReader`] but not [`Backend`], providing a read-only + /// guarantee. Implementations may return either a point-in-time snapshot or a live view, but + /// the view must always reflect a consistent committed state (not partial writes). Holding the + /// reader must not block writes in any way. + type Reader: BackendReader; - /// Performs the provided `updates` on the tree with the specified `lineage`, returning the - /// mutation set that will revert the changes made to the tree. - /// - /// # Expected Behavior + /// Backend-specific data prepared during mutation computation and consumed during application. /// - /// Implementations must guarantee the following behavior in addition to the global invariants: + /// This type is intentionally opaque to forest users. Implementations should store enough + /// information here to apply the already-computed mutations without repeating the expensive + /// tree update computation. /// - /// - At most one new root must be added to the forest for the entire batch. - /// - If applying the provided `updates` results in no changes to the tree, no new tree must be - /// allocated. - fn update_tree( - &mut self, - lineage: LineageId, - new_version: VersionId, - updates: SmtUpdateBatch, - ) -> Result; + /// The prepared value must represent only prospective changes. Computing it must not change + /// the backend's committed state. It may contain ordinary SMT mutation sets, storage-level + /// updates, serialized values, or any other implementation-specific data needed to apply the + /// mutation efficiently later. + type PreparedMutations; + + /// Returns a read-only view of this backend that observes its current state. + fn reader(&self) -> Result; - // MULTI-TREE MODIFIERS + // TWO-PHASE MODIFIERS // ============================================================================================ - /// Adds multiple new `lineages` to the backend with the provided `version` and sets the - /// associated SMTs to have the value created by applying the provided updates to the empty - /// tree, returning the new root of that tree. + /// Computes the backend data required to mutate lineages, without applying it. /// /// # Expected Behavior /// /// Implementations must guarantee the following behavior in addition to the global invariants: /// - /// - If any provided lineage conflicts with an already-existing lineage in the backend, it must - /// return [`BackendError::DuplicateLineage`]. - fn add_lineages( - &mut self, - version: VersionId, - lineages: SmtForestUpdateBatch, - ) -> Result>; + /// - The backend's committed state must not change. + /// - Each unknown lineage in `updates` is treated as an addition from the empty tree. + /// - Each known lineage in `updates` is treated as an update to its latest tree. + /// - Each lineage in `updates` must produce at most one [`LineageMutation`]. + /// - No-op lineage updates must not allocate new backend tree versions when applied. + /// - The prepared mutations must be applicable atomically by [`Self::apply_mutations`] where + /// the backend supports atomic writes. + fn compute_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> Result<(Vec, Self::PreparedMutations)>; - /// Performs the provided `updates` on the forest, setting all new tree states to have the - /// provided `new_version` and returning a vector of the mutation sets that reverse the changes - /// to each changed tree. + /// Applies previously-computed backend mutations. + /// + /// This method consumes the opaque prepared data returned by one of the backend compute + /// methods. It commits the backend's latest-tree state and returns the applied lineage data + /// needed by [`crate::merkle::smt::LargeSmtForest`] to update forest-level lineage metadata and + /// history. /// /// # Expected Behavior /// /// Implementations must guarantee the following behavior in addition to the global invariants: /// - /// - At most one new root must be added to the forest for each target root in the provided - /// `updates`. - /// - If applying the provided `updates` results in no changes to a given lineage of trees in - /// the forest, then no new tree must be allocated in that lineage. - fn update_forest( + /// - The prepared mutation data must still be applicable to the current backend state before + /// any mutation is written. For updates, the current version and root must match the + /// version/root captured during the compute phase. For additions, the lineage must still be + /// absent. + /// - User-derived errors must leave the backend in a consistent committed state. + /// - If the prepared data contains multiple lineage updates, they should be committed + /// atomically when the backend's storage engine supports atomic batched writes. + /// - The method must not recompute Merkle mutations from the original user updates; that work + /// belongs to the compute methods. + /// - On success, the returned [`AppliedLineageMutation`] values must correspond to the applied + /// prepared mutations in the same lineage set, including reverse mutations and old entry + /// counts for update history. + fn apply_mutations( &mut self, - new_version: VersionId, - updates: SmtForestUpdateBatch, - ) -> Result>; + mutations: Self::PreparedMutations, + ) -> Result>; } // BACKEND ERROR @@ -230,6 +235,10 @@ where /// The error type for use within Backends. #[derive(Debug, Error)] pub enum BackendError { + /// Raised when an update was prepared against a version that is no longer current. + #[error("Version {provided} is not current backend version {latest}")] + BadVersion { provided: VersionId, latest: VersionId }, + /// Raised when corrupted data is encountered in the backend. /// /// It exists as a separate error variant to allow the forest itself to handle it better if diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/mod.rs b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/mod.rs index b03e1b7ce3..1fdc5f4973 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/mod.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/mod.rs @@ -32,28 +32,31 @@ mod internal; mod iterator; mod keys; mod property_tests; +mod snapshot; mod tests; mod tree_metadata; use alloc::{string::ToString, sync::Arc, vec::Vec}; use core::ffi::c_int; -use std::{collections::HashMap, iter::once, mem}; +use std::{collections::HashMap, mem}; use miden_serde_utils::{Deserializable, DeserializationError, Serializable}; use num::Integer; use rayon::prelude::*; use rocksdb as db; +pub use snapshot::PersistentBackendReader; use super::{BackendError, Result}; +#[cfg(test)] +use crate::merkle::smt::SmtUpdateBatch; use crate::{ EMPTY_WORD, Map, Word, merkle::{ - EmptySubtreeRoots, MerkleError, NodeIndex, SparseMerklePath, + EmptySubtreeRoots, MerkleError, NodeIndex, smt::{ - Backend, InnerNode, LeafIndex, LineageId, NodeMutation, NodeMutations, SMT_DEPTH, - SmtForestUpdateBatch, SmtLeaf, SmtLeafError, SmtProof, SmtUpdateBatch, - StorageUpdateParts, StorageUpdates, Subtree, SubtreeError, TreeEntry, TreeWithRoot, - VersionId, + Backend, BackendReader, LeafIndex, LineageId, NodeMutation, NodeMutations, SMT_DEPTH, + SmtForestUpdateBatch, SmtLeaf, SmtLeafError, SmtProof, StorageUpdateParts, + StorageUpdates, Subtree, SubtreeError, TreeEntry, TreeWithRoot, VersionId, full::concurrent::{ MutatedSubtreeLeaves, SUBTREE_DEPTH, SubtreeLeaf, SubtreeLeavesIter, fetch_sibling_pair, process_sorted_pairs_to_leaves, @@ -67,7 +70,9 @@ use crate::{ keys::{LeafKey, SubtreeKey}, tree_metadata::TreeMetadata, }, - utils::MutationSet, + utils::{ + AppliedLineageMutation, LineageMutation, LineageMutationKind, MutationSet, + }, }, }, }, @@ -82,6 +87,34 @@ type DB = db::DB; /// The type of a write batch in the database associated with a transaction. type WriteBatch = db::WriteBatch; +/// Prepared mutations for [`PersistentBackend`]. +/// +/// This is the persistent backend's concrete [`Backend::PreparedMutations`] type. It stores +/// storage-level updates and the resulting metadata computed during the first phase of a forest +/// update. Applying it builds and commits a RocksDB [`WriteBatch`] without recomputing the Merkle +/// update batches. +/// +/// The fields are private because callers should treat prepared mutation data as opaque and pass it +/// back through +/// [`LargeSmtForest::apply_mutations`](crate::merkle::smt::LargeSmtForest::apply_mutations). +#[derive(Debug)] +pub struct PersistentPreparedMutations { + entries: Vec, +} + +#[derive(Debug)] +struct PersistentPreparedLineageMutation { + lineage: LineageId, + old_version: Option, + new_version: VersionId, + old_root: Word, + old_entry_count: usize, + reverse: MutationSet, + metadata: TreeMetadata, + storage_updates: StorageUpdates, + kind: LineageMutationKind, +} + // CONSTANTS / COLUMN FAMILY NAMES // ================================================================================================ @@ -145,69 +178,10 @@ const MIN_LINEAGES_IN_BATCH_TO_PARALLELIZE: usize = 5; /// The minimum number of items per rayon chunk when parallelizing deserialization and extraction. const CHUNKING_UNIT: usize = 100; -// PERSISTENT BACKEND -// ================================================================================================ - -/// The persistent backend for the SMT forest, providing durable storage for the latest tree in each -/// lineage in the forest. -#[derive(Debug)] -pub struct PersistentBackend { - /// The underlying database. - /// - /// # Layout - /// - /// The data on each tree is stored across a series of RocksDB column families, along with - /// additional metadata. The layout is fixed (for the moment), and has the following column - /// families. - /// - /// - [`LEAVES_CF`]: Stores the [`SmtLeaf`] data, keyed by a [`LeafKey`] instance. - /// - [`METADATA_CF`]: Stores a [`TreeMetadata`] instance for each tree, keyed by - /// [`LineageId`]. This acts like a mirror of the in-memory `lineages` data, which exists to - /// speed up common queries. - /// - `SUBTREE_XX_CF`: Stores the [`Subtree`]s with their root at level `XX` in the backend, - /// keyed on the [`SubtreeKey`]. - db: Arc, - - /// An in-memory cache of the tree metadata enabling the more rapid servicing of certain kinds - /// of queries. - /// - /// Care must be taken that this is _always_ kept in sync with the on-disk copy in the - /// [`METADATA_CF`] column. - lineages: HashMap, - - /// Whether writes should be synchronously flushed to disk. - /// - /// Setting this to true will result in reduced throughput but may result in higher durability - /// in the presence of crashes. - sync_writes: bool, -} - -// CONSTRUCTION -// ================================================================================================ - -/// This block contains functions for the construction of the persistent backend. -impl PersistentBackend { - /// Constructs an instance of the persistent backend, either opening or creating the data store - /// at the location specified in the `config`. - /// - /// # Errors - /// - /// - [`BackendError::CorruptedData`] if data corruption is encountered when loading the forest - /// from disk. - /// - [`BackendError::Internal`] if the backend cannot be started up properly. - pub fn load(config: Config) -> Result { - let db = Arc::new(Self::build_db_with_options(&config)?); - let lineages = Self::read_all_metadata(db.clone())?; - let sync_writes = config.sync_writes; - - Ok(Self { db, lineages, sync_writes }) - } -} - -// BACKEND TRAIT +// BACKEND READER TRAIT // ================================================================================================ -impl Backend for PersistentBackend { +impl BackendReader for PersistentBackend { /// Returns an opening for the specified `key` in the SMT with the specified `lineage`. /// /// # Errors @@ -215,44 +189,13 @@ impl Backend for PersistentBackend { /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by the backend. /// - [`BackendError::Internal`] if the backing database cannot be accessed for some reason. fn open(&self, lineage: LineageId, key: Word) -> Result { - // We fail early if we don't know about the lineage in question, as querying further could - // cause very strange behavior. - if !self.lineages.contains_key(&lineage) { - return Err(BackendError::UnknownLineage(lineage)); - } - - // We get our leaf first. - let leaf = self - .load_leaf_for(lineage, key)? - .unwrap_or_else(|| SmtLeaf::new_empty(LeafIndex::from(key))); - - // We then have to load both the corresponding leaf, and the siblings for its path out of - // storage. - let leaf_index: NodeIndex = LeafIndex::from(key).into(); - - // We calculate the roots of the subtrees in order to know their keys for loading. As an - // opening only ever needs to retrieve 8 subtrees we just do this sequentially. - let subtree_roots = (0..SMT_DEPTH / SUBTREE_DEPTH) - .scan(leaf_index.parent(), |cursor, _| { - let subtree_root = Subtree::find_subtree_root(*cursor); - *cursor = subtree_root.parent(); - Some(subtree_root) - }) - .collect::>(); - - // Doing this as a separate step exhibits better performance than loading these subtrees - // inline in the path creation. This appears to be due to better pipelining and - // branch-predictor behavior. - let mut subtree_cache = HashMap::::new(); - for root in subtree_roots { - let maybe_tree = self.load_subtree(SubtreeKey { lineage, index: root })?; - subtree_cache.insert(root, maybe_tree.unwrap_or_else(|| Subtree::new(root))); - } - - let merkle_path = self.compute_path(leaf_index, &subtree_cache); - - // This is safe to do unchecked as we ensure that the path is valid by construction. - Ok(SmtProof::new_unchecked(merkle_path, leaf)) + snapshot::open_proof( + &self.lineages, + lineage, + key, + |l, k| self.load_leaf_for(l, k), + |k| self.load_subtree(k), + ) } /// Returns the leaf stored at `leaf_index` in the SMT with the specified `lineage`. @@ -367,7 +310,188 @@ impl Backend for PersistentBackend { // its type, so we delegate to our custom entries iterator impl. Ok(PersistentBackendEntriesIterator::new(lineage, pfx_iterator)) } +} +// BACKEND TRAIT +// ================================================================================================ + +impl Backend for PersistentBackend { + type Reader = PersistentBackendReader; + type PreparedMutations = PersistentPreparedMutations; + + fn reader(&self) -> Result { + let snapshot = self.db.snapshot(); + // SAFETY: `SnapshotInner` holds both the snapshot and `Arc`, and its `Drop` impl + // drops the snapshot before decrementing the Arc. This guarantees the DB outlives the + // snapshot, making the 'static transmute sound. + let snapshot: db::Snapshot<'static> = unsafe { mem::transmute(snapshot) }; + Ok(PersistentBackendReader::new( + Arc::clone(&self.db), + snapshot, + Arc::clone(&self.lineages), + )) + } + + /// Computes the mutations required to apply the provided `updates` on the forest. + /// + /// The order of application of these mutations is unspecified, but is guaranteed to produce no + /// more than one new root for each operated-upon lineage. All operations are performed as part + /// of one atomic update, leaving the data on disk in a consistent state even if failures occur. + /// + /// # Errors + /// + /// - [`BackendError::Internal`] if the database cannot be accessed at any point. + /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. + /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by this backend. + fn compute_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> Result<(Vec, Self::PreparedMutations)> { + let updates = updates + .into_iter() + .map(|(lineage, ops)| { + let (metadata, kind) = if let Some(metadata) = self.lineages.get(&lineage) { + (metadata.clone(), LineageMutationKind::UpdateTree) + } else { + ( + TreeMetadata { + version: new_version, + root_value: *EmptySubtreeRoots::entry(SMT_DEPTH, 0), + entry_count: 0, + }, + LineageMutationKind::AddLineage, + ) + }; + Ok((lineage, ops, metadata, kind)) + }) + .collect::>>()?; + + // Now we can simply issue the work in parallel. + let lineage_data = updates + .into_par_iter() + .map(|(lineage, ops, metadata, kind)| { + self.prepare_tree_update( + lineage, + metadata, + new_version, + ops.into_iter().map(Into::into).collect(), + kind, + ) + }) + .collect::>>()?; + + let (mutations, prepared): (Vec<_>, Vec<_>) = lineage_data.into_iter().unzip(); + + Ok((mutations, PersistentPreparedMutations { entries: prepared })) + } + + /// Apply a mutation set to the entire forest, returning the mutation sets that would reverse + /// the changes to each lineage in the forest. + /// + /// All operations are performed as part of one atomic update, leaving the data on disk in a + /// consistent state even if failures occur. + /// + /// - [`BackendError::Internal`] if the database cannot be accessed at any point. + /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. + /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by this backend. + fn apply_mutations( + &mut self, + mutations: Self::PreparedMutations, + ) -> Result> { + // We first have to check our precondition that all lineages are valid. + for entry in &mutations.entries { + match entry.kind { + LineageMutationKind::AddLineage => { + if self.lineages.contains_key(&entry.lineage) { + return Err(BackendError::DuplicateLineage(entry.lineage)); + } + }, + LineageMutationKind::UpdateTree => { + let metadata = self + .lineages + .get(&entry.lineage) + .ok_or(BackendError::UnknownLineage(entry.lineage))?; + + if Some(metadata.version) != entry.old_version { + return Err(BackendError::BadVersion { + provided: entry.old_version.unwrap_or_default(), + latest: metadata.version, + }); + } + + if metadata.root_value != entry.old_root { + return Err(MerkleError::ConflictingRoots { + expected_root: entry.old_root, + actual_root: metadata.root_value, + } + .into()); + } + }, + } + } + + let lineage_count = mutations.entries.len(); + + // We want to update all trees as part of an atomic update to the backing database, but we + // also want to do this in parallel. As we cannot share a transaction directly, we instead + // share a write-batch per tree. + let mutations_with_batch = mutations + .entries + .into_iter() + .map(|mutation| { + let batch = WriteBatch::default(); + (mutation, batch) + }) + .collect::>(); + + let lineage_data = mutations_with_batch + .into_par_iter() + .map(|(entry, batch)| { + let applied_entry = AppliedLineageMutation::new( + entry.lineage, + entry.old_version, + entry.new_version, + entry.old_root, + entry.metadata.root_value, + entry.old_entry_count, + entry.reverse, + entry.kind, + ); + let batch = + self.apply_updates_to_lineage(batch, entry.lineage, entry.storage_updates)?; + let batch = self.write_metadata(batch, entry.lineage, &entry.metadata)?; + + Ok((batch, (applied_entry, (entry.lineage, entry.metadata)))) + }) + .collect::>>()?; + let (batches, (applied_entries, metadata_updates)): (Vec<_>, (Vec<_>, Vec<_>)) = + lineage_data.into_iter().unzip(); + + // We construct our final WriteBatch in parallel if we have enough of them, otherwise we + // just do it in serial. + let final_batch = if lineage_count > MIN_LINEAGES_IN_BATCH_TO_PARALLELIZE { + batches + .into_par_iter() + .fold(WriteBatch::new, |l, r| merge_batches(l, &r)) + .reduce(WriteBatch::new, |l, r| merge_batches(l, &r)) + } else { + batches.into_iter().fold(WriteBatch::new(), |l, r| merge_batches(l, &r)) + }; + + // We first write the full atomic update to disk. If it errors, we bail. + self.write(final_batch)?; + + // If it hasn't errored, we can now safely update the in-memory metadata cache. + self.lineages_mut().extend(metadata_updates); + + Ok(applied_entries) + } +} + +// These are the implementations of helper methods used by the backend tests. +#[cfg(test)] +impl PersistentBackend { /// Adds the provided `lineage` to the forest with the provided `version` and sets the /// associated tree to have the value created by applying `updates` to the empty tree, returning /// the root of this new tree. @@ -378,51 +502,27 @@ impl Backend for PersistentBackend { /// backend. /// - [`BackendError::Internal`] if the database cannot be accessed at any point. /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. - fn add_lineage( + pub(crate) fn add_lineage( &mut self, lineage: LineageId, version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // We start by checking if the lineage already exists, as we are expected by contract to - // error if it is a duplicate. if self.lineages.contains_key(&lineage) { return Err(BackendError::DuplicateLineage(lineage)); } - // We now build our new tree metadata, which begins as a fully-empty tree with the default - // root, no leaves, and no entries. - let new_lineage_meta = TreeMetadata { - version, - root_value: *EmptySubtreeRoots::entry(SMT_DEPTH, 0), - entry_count: 0, - }; - - // We add and update the tree all in one go in the backend, using a single batch to ensure - // internal consistency. - let batch = WriteBatch::default(); - - // We perform the update. If this fails due to an error, the batch will get dropped at the - // end of the scope, and any staged mutations will be forgotten without being applied. - let (batch, reversion_set, tree_metadata) = self.update_tree_in_write_batch( - batch, - lineage, - new_lineage_meta, - version, - updates.into(), - )?; - - // Upon returning successfully, we need to update the metadata on disk as part of this - // operation. - let batch = self.write_metadata(batch, lineage, &tree_metadata)?; + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (_mutations, persistent_mutations) = self.compute_mutations(version, batch)?; - // Only when the batch has been successfully written to disk do we write to the in-memory - // metadata, ensuring that the state remains consistent. - let new_root = tree_metadata.root_value; - self.finalize_update(batch, once((lineage, tree_metadata, reversion_set)))?; + let mut applied_mutations = self.apply_mutations(persistent_mutations)?; + let applied_mutation = applied_mutations + .pop() + .expect("should have applied exactly one lineage mutation"); // Finally we just return the necessary metadata. - Ok(TreeWithRoot::new(lineage, version, new_root)) + Ok(TreeWithRoot::new(lineage, version, applied_mutation.new_root())) } /// Performs the provided `updates` on the tree with the specified `lineage`, returning the @@ -435,47 +535,27 @@ impl Backend for PersistentBackend { /// - [`BackendError::Internal`] if the database cannot be accessed at any point. /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by this backend. - fn update_tree( + pub(crate) fn update_tree( &mut self, lineage: LineageId, new_version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // We check our lineage existence at the start just as a sanity check. - let tree_metadata = self - .lineages - .get(&lineage) - .ok_or(BackendError::UnknownLineage(lineage))? - .clone(); - - // All the work needs to happen atomically, so we construct a new write batch in which we - // stage our operations. - let batch = WriteBatch::default(); - - // We perform the update. If this fails with an error, the write batch will be dropped - // during error unwind, and any staged mutations will be forgotten about without being - // applied. - let (batch, reversion_set, tree_metadata) = self.update_tree_in_write_batch( - batch, - lineage, - tree_metadata, - new_version, - updates.into(), - )?; - - // At this point the batch contains the updates to the tree, but we still need to handle - // updating the metadata. - let batch = self.write_metadata(batch, lineage, &tree_metadata)?; - - // Writing the batch may fail, so we only write to the in-memory metadata once it is - // successful to ensure the in-memory state cache remains consistent with the database. - let mut res = self.finalize_update(batch, once((lineage, tree_metadata, reversion_set)))?; - let Some((_, reversion_set)) = res.pop() else { - unreachable!("finalize_update did not return the same number of output elements") - }; + if !self.lineages.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); + } + + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (_mutations, persistent_mutations) = self.compute_mutations(new_version, batch)?; + + let mut applied_mutations = self.apply_mutations(persistent_mutations)?; + let applied_mutation = applied_mutations + .pop() + .expect("should have applied exactly one lineage mutation"); // We then just return the reversion set for the operations in question. - Ok(reversion_set) + Ok(applied_mutation.into_reverse()) } /// Adds multiple new `lineages` to the tree, creating an empty tree for each and applying the @@ -490,88 +570,25 @@ impl Backend for PersistentBackend { /// backend. /// - [`BackendError::Internal`] if the database cannot be accessed at any point. /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. - fn add_lineages( + pub(crate) fn add_lineages( &mut self, version: VersionId, lineages: SmtForestUpdateBatch, ) -> Result> { - // We start by checking that none of the lineages already exist, as we are expected by - // contract to error if any is a duplicate. - let updates = lineages - .into_iter() - .map(|(lineage, ops)| { - if self.lineages.contains_key(&lineage) { - return Err(BackendError::DuplicateLineage(lineage)); - } - Ok((lineage, ops)) - }) - .collect::>>()?; - let lineage_count = updates.len(); - - // If we have no lineages, then we can exit early. - if updates.is_empty() { - return Ok(Vec::new()); + for lineage in lineages.lineages() { + if self.lineages.contains_key(lineage) { + return Err(BackendError::DuplicateLineage(*lineage)); + } } - // Build the initial metadata and set up empty write batches for each new lineage - let updates_with_batch = updates - .into_iter() - .map(|(lineage, ops)| { - let new_meta = TreeMetadata { - version, - root_value: *EmptySubtreeRoots::entry(SMT_DEPTH, 0), - entry_count: 0, - }; - let batch = WriteBatch::default(); - (lineage, ops, new_meta, batch) - }) - .collect::>(); + let (_mutations, persistent_mutations) = self.compute_mutations(version, lineages)?; - // Process lineages in parallel, updating each tree in its own write batch - let lineage_data = updates_with_batch - .into_par_iter() - .map(|(lineage, ops, new_meta, batch)| { - let ops = ops.into_iter().map(Into::into).collect(); - let (batch, reversion, tree_data) = - self.update_tree_in_write_batch(batch, lineage, new_meta, version, ops)?; - let batch = self.write_metadata(batch, lineage, &tree_data)?; - let root = tree_data.root_value; - - Ok((batch, (lineage, tree_data, reversion), (lineage, root))) - }) - .collect::>>()?; - let (batches, mutation_sets, roots): (Vec<_>, Vec<_>, Vec<_>) = - lineage_data.into_iter().fold( - ( - Vec::with_capacity(lineage_count), - Vec::with_capacity(lineage_count), - Vec::with_capacity(lineage_count), - ), - |(mut bs, mut ms, mut rs), (b, m, r)| { - bs.push(b); - ms.push(m); - rs.push(r); - (bs, ms, rs) - }, - ); + let applied_mutations = self.apply_mutations(persistent_mutations)?; - // Merge all the write batches into one atomic batch - let final_batch = if lineage_count > MIN_LINEAGES_IN_BATCH_TO_PARALLELIZE { - batches - .into_par_iter() - .fold(WriteBatch::new, |l, r| merge_batches(l, &r)) - .reduce(WriteBatch::new, |l, r| merge_batches(l, &r)) - } else { - batches.into_iter().fold(WriteBatch::new(), |l, r| merge_batches(l, &r)) - }; - - // Atomically write to disk and update in-memory cache. - self.finalize_update(final_batch, mutation_sets.into_iter())?; - - // Build the return value from the captured roots. - let results = roots + // Build the return value from the applied mutations. + let results = applied_mutations .into_iter() - .map(|(lineage, root)| (lineage, TreeWithRoot::new(lineage, version, root))) + .map(|applied_mutation| (applied_mutation.lineage(), applied_mutation.result())) .collect(); Ok(results) @@ -589,120 +606,114 @@ impl Backend for PersistentBackend { /// - [`BackendError::Internal`] if the database cannot be accessed at any point. /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics. /// - [`BackendError::UnknownLineage`] if the provided `lineage` is not known by this backend. - fn update_forest( + pub(crate) fn update_forest( &mut self, new_version: VersionId, updates: SmtForestUpdateBatch, ) -> Result> { - // We first have to check our precondition that all lineages are valid, returning an error - // as required by our contract if any lineage is unknown to the backend. - let updates: Vec<_> = updates - .into_iter() - .map(|(lineage, ops)| { - if !self.lineages.contains_key(&lineage) { - return Err(BackendError::UnknownLineage(lineage)); - } - let tree_data = self.lineages.get(&lineage).expect("Known to exist").clone(); - Ok((lineage, ops, tree_data)) - }) - .collect::>>()?; - let lineage_count = updates.len(); - - // We want to update all trees as part of an atomic update to the backing database, but we - // also want to do this in parallel. As we cannot share a transaction directly, we instead - // share a write-batch per tree. - let updates_with_batch = updates - .into_iter() - .map(|(lineage, ops, tree_data)| { - let batch = WriteBatch::default(); - (lineage, ops, tree_data, batch) - }) - .collect::>(); - - // Now we can simply issue the work in parallel. - let lineage_data = updates_with_batch - .into_par_iter() - .map(|(lineage, ops, tree_data, batch)| { - let ops = ops.into_iter().map(Into::into).collect(); - let (batch, reversion, tree_data) = - self.update_tree_in_write_batch(batch, lineage, tree_data, new_version, ops)?; - let batch = self.write_metadata(batch, lineage, &tree_data)?; + for lineage in updates.lineages() { + if !self.lineages.contains_key(lineage) { + return Err(BackendError::UnknownLineage(*lineage)); + } + } - Ok((batch, (lineage, tree_data, reversion))) - }) - .collect::>>()?; - let (batches, mutation_sets): (Vec<_>, Vec<_>) = lineage_data.into_iter().unzip(); + let (_mutations, persistent_mutations) = self.compute_mutations(new_version, updates)?; - // We construct our final WriteBatch in parallel if we have enough of them, otherwise we - // just do it in serial. - let final_batch = if lineage_count > MIN_LINEAGES_IN_BATCH_TO_PARALLELIZE { - batches - .into_par_iter() - .fold(WriteBatch::new, |l, r| merge_batches(l, &r)) - .reduce(WriteBatch::new, |l, r| merge_batches(l, &r)) - } else { - batches.into_iter().fold(WriteBatch::new(), |l, r| merge_batches(l, &r)) - }; + let applied_mutations = self.apply_mutations(persistent_mutations)?; - // Only at this point do we write to the in-memory metadata, ensuring that the state remains - // consistent. - let result = self.finalize_update(final_batch, mutation_sets.into_iter())?; + // Build the return value from the applied mutations. + let reversion_sets = applied_mutations + .into_iter() + .map(|applied_mutation| (applied_mutation.lineage(), applied_mutation.into_reverse())) + .collect(); - Ok(result) + Ok(reversion_sets) } } -// INTERNAL / UTILITY +// PERSISTENT BACKEND // ================================================================================================ -/// This block contains methods for internal use only that provide useful functionality for the -/// implementation of the backend. +/// The persistent backend for the SMT forest, providing durable storage for the latest tree in each +/// lineage in the forest. +#[derive(Debug)] +pub struct PersistentBackend { + /// The underlying database. + /// + /// # Layout + /// + /// The data on each tree is stored across a series of RocksDB column families, along with + /// additional metadata. The layout is fixed (for the moment), and has the following column + /// families. + /// + /// - [`LEAVES_CF`]: Stores the [`SmtLeaf`] data, keyed by a [`LeafKey`] instance. + /// - [`METADATA_CF`]: Stores a [`TreeMetadata`] instance for each tree, keyed by + /// [`LineageId`]. This acts like a mirror of the in-memory `lineages` data, which exists to + /// speed up common queries. + /// - `SUBTREE_XX_CF`: Stores the [`Subtree`]s with their root at level `XX` in the backend, + /// keyed on the [`SubtreeKey`]. + db: Arc, + + /// An in-memory cache of the tree metadata enabling the more rapid servicing of certain kinds + /// of queries. + /// + /// Wrapped in an `Arc` for copy-on-write sharing with reader snapshots. Readers clone the + /// `Arc` cheaply; mutations use `Arc::make_mut` to fork a private copy only when needed. + /// + /// Care must be taken that this is _always_ kept in sync with the on-disk copy in the + /// [`METADATA_CF`] column. + lineages: Arc>, + + /// Whether writes should be synchronously flushed to disk. + /// + /// Setting this to true will result in reduced throughput but may result in higher durability + /// in the presence of crashes. + sync_writes: bool, +} + impl PersistentBackend { - /// Computes the merkle path for the provided `lineage` beginning at the provided `leaf_index` - /// using the pre-loaded `subtrees`. - fn compute_path( - &self, - mut leaf_index: NodeIndex, - subtrees: &HashMap, - ) -> SparseMerklePath { - let mut path = Vec::with_capacity(SMT_DEPTH as usize); - - while leaf_index.depth() > 0 { - let is_right = leaf_index.is_position_odd(); - leaf_index = leaf_index.parent(); - - let root = Subtree::find_subtree_root(leaf_index); - let subtree = &subtrees[&root]; // Known to exist by construction. - let InnerNode { left, right } = - subtree.get_inner_node(leaf_index).unwrap_or_else(|| { - EmptySubtreeRoots::get_inner_node(SMT_DEPTH, leaf_index.depth()) - }); + /// Constructs an instance of the persistent backend, either opening or creating the data store + /// at the location specified in the `config`. + /// + /// # Errors + /// + /// - [`BackendError::CorruptedData`] if data corruption is encountered when loading the forest + /// from disk. + /// - [`BackendError::Internal`] if the backend cannot be started up properly. + pub fn load(config: Config) -> Result { + let db = Arc::new(Self::build_db_with_options(&config)?); + let lineages = Arc::new(Self::read_all_metadata(db.clone())?); + let sync_writes = config.sync_writes; - path.push(if is_right { left } else { right }); - } + Ok(Self { db, lineages, sync_writes }) + } - SparseMerklePath::from_sized_iter(path).expect("Always succeeds by construction") + // Triggers copy-on-write: clones the shared lineages map only if other references exist. + pub(crate) fn lineages_mut(&mut self) -> &mut HashMap { + Arc::make_mut(&mut self.lineages) } - /// Performs `updates` on the tree in the specified lineage, assigning the new tree the - /// provided `new_version`. + // INTERNAL / UTILITY + // -------------------------------------------------------------------------------------------- + + /// Computes the mutation set for `updates` on the tree in the specified lineage, assigning the + /// new tree the provided `new_version`. /// - /// All operations in this method take place within the context of the provided `batch`. This - /// method will only stage operations using the transaction associated with that batch, and is - /// guaranteed to not commit those changes. + /// This method will only compute the mutation set required to do the updates but does not + /// update the tree. /// /// # Errors /// /// - [`BackendError::Internal`] if the backend fails to read to or write from storage. /// - [`BackendError::Merkle`] if an error occurs with the merkle tree semantics in the backend. - fn update_tree_in_write_batch( + fn prepare_tree_update( &self, - batch: WriteBatch, lineage: LineageId, mut tree_metadata: TreeMetadata, new_version: VersionId, mut updates: Vec<(Word, Word)>, - ) -> Result<(WriteBatch, MutationSet, TreeMetadata)> { + kind: LineageMutationKind, + ) -> Result<(LineageMutation, PersistentPreparedLineageMutation)> { // We start by ensuring that our updates are sorted, as this is necessary for the efficiency // of various other operations. updates.sort_by_key(|(k, _)| LeafIndex::from(*k).position()); @@ -727,20 +738,42 @@ impl PersistentBackend { reversion_pairs, } = self.sorted_pairs_to_mutated_leaves(updates, &leaf_map)?; + let old_version = tree_metadata.version; + let old_root = tree_metadata.root_value; + let old_entry_count = tree_metadata + .entry_count + .try_into() + .expect("Count of entries should fit into usize"); + // If we have no mutations to perform, we return early for performance and to satisfy the // contract required of `add_lineage`. if leaves.is_empty() { - // As a result, our mutation set is empty. - return Ok(( - batch, - MutationSet { - old_root: tree_metadata.root_value, - node_mutations: NodeMutations::default(), - new_pairs: Map::default(), - new_root: tree_metadata.root_value, - }, - tree_metadata, - )); + let empty = MutationSet { + old_root, + node_mutations: NodeMutations::default(), + new_pairs: Map::default(), + new_root: old_root, + }; + let mutation = LineageMutation::new( + lineage, + (kind == LineageMutationKind::UpdateTree).then_some(old_version), + new_version, + old_root, + old_root, + kind, + ); + let prepared = PersistentPreparedLineageMutation { + lineage, + old_version: (kind == LineageMutationKind::UpdateTree).then_some(old_version), + new_version, + old_root, + old_entry_count, + reverse: empty, + metadata: tree_metadata, + storage_updates: StorageUpdates::default(), + kind, + }; + return Ok((mutation, prepared)); } // We can then preallocate capacity for the subtree updates. @@ -768,14 +801,14 @@ impl PersistentBackend { )) }, |result, processed_tree| match (result, processed_tree) { - (Ok((mut roots, mut subtrees, mut node_reversions)), Ok(tree)) => { + (Ok((mut roots, mut subtrees, mut reversions)), Ok(tree)) => { roots.push(tree.subtree_root); - node_reversions.extend(tree.reversion_nodes); + reversions.extend(tree.reversion_nodes); if let Some(action) = tree.storage_action { subtrees.push(action); } - Ok((roots, subtrees, node_reversions)) + Ok((roots, subtrees, reversions)) }, (Err(e), _) | (_, Err(e)) => Err(e), }, @@ -808,41 +841,62 @@ impl PersistentBackend { for (idx, mutated_leaf) in leaf_updates { let leaf_opt = match mutated_leaf { - SmtLeaf::Empty(_) => None, // Delete from storage + SmtLeaf::Empty(_) => None, _ => Some(mutated_leaf), }; leaf_update_map.insert(idx, leaf_opt); } - let updates = StorageUpdates::from_parts( + let storage_updates = StorageUpdates::from_parts( leaf_update_map, subtree_updates, leaf_count_delta, entry_count_delta, ); - // And we apply the updates to the lineage in the storage under the current transaction. - let batch = self.apply_updates_to_lineage(batch, lineage, updates)?; - // And then compute the new root. - let root_after_modification = leaves[0][0].hash; + let new_root = leaves[0][0].hash; // We then write the node metadata into a copy - let root_before_modification = tree_metadata.root_value; tree_metadata.entry_count = tree_metadata.entry_count.saturating_add_signed( entry_count_delta.try_into().expect("Delta should always fit into i64"), ); - tree_metadata.root_value = root_after_modification; + tree_metadata.root_value = new_root; tree_metadata.version = new_version; - let mutation_set = MutationSet { - old_root: root_after_modification, + // Construct the reverse mutation set. + let reverse = MutationSet { + old_root: new_root, node_mutations: global_node_reversions, new_pairs: reversion_pairs.into_iter().collect(), - new_root: root_before_modification, + new_root: old_root, }; - Ok((batch, mutation_set, tree_metadata)) + // The forward mutation set. + let mutation = LineageMutation::new( + lineage, + (kind == LineageMutationKind::UpdateTree).then_some(old_version), + new_version, + old_root, + new_root, + kind, + ); + + // And the prepared mutation set that contains _all_ information that is required + // to _apply_ these changes in [`apply_mutations`]. + let prepared = PersistentPreparedLineageMutation { + lineage, + old_version: (kind == LineageMutationKind::UpdateTree).then_some(old_version), + new_version, + old_root, + old_entry_count, + reverse, + metadata: tree_metadata, + storage_updates, + kind, + }; + + Ok((mutation, prepared)) } /// Applies the `updates` to the specified `lineage` in the context of the provided `batch`. @@ -1363,14 +1417,10 @@ impl PersistentBackend { Ok(()) } -} -// INTERNAL / STARTUP -// ================================================================================================ + // INTERNAL / STARTUP + // -------------------------------------------------------------------------------------------- -/// This impl block contains internal functionality to do with starting up the backend and -/// performing its initialization work. -impl PersistentBackend { /// Sets up the basic configuration for the underlying RocksDB database. fn build_db_with_options(config: &Config) -> Result { let mut db_opts = db::Options::default(); @@ -1478,29 +1528,6 @@ impl PersistentBackend { Ok(batch) } - /// Finalizes the update by committing the provided `batch` to disk and updating the in-memory - /// cache with the provided `metadata`. - /// - /// # Errors - /// - /// - [`BackendError::Internal`] if the underlying database cannot be written to. - fn finalize_update( - &mut self, - batch: WriteBatch, - metadata: impl Iterator, - ) -> Result> { - // We first write the full atomic update to disk. If it errors, we bail. - self.write(batch)?; - - // If it hasn't errored, we can now safely update the in-memory metadata cache. - Ok(metadata - .map(|(l, d, r)| { - self.lineages.insert(l, d); - (l, r) - }) - .collect()) - } - /// Reads all the lineages and their corresponding metadata out of the on-disk storage as part /// of the startup work. /// diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/property_tests.rs b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/property_tests.rs index 2f30129432..cdaf4cb2a0 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/property_tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/property_tests.rs @@ -10,7 +10,7 @@ use super::tests::default_backend; use crate::{ EMPTY_WORD, merkle::smt::{ - Backend, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeWithRoot, + BackendReader, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeWithRoot, large_forest::test_utils::{ arbitrary_batch, arbitrary_lineage, arbitrary_version, arbitrary_word, }, diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/snapshot.rs b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/snapshot.rs new file mode 100644 index 0000000000..bcef0fb177 --- /dev/null +++ b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/snapshot.rs @@ -0,0 +1,281 @@ +use alloc::{sync::Arc, vec::Vec}; +use core::mem::ManuallyDrop; +use std::collections::HashMap; + +use miden_serde_utils::{Deserializable, Serializable}; +use rocksdb as db; + +use super::{ + super::{BackendError, Result}, + LEAVES_CF, + iterator::PersistentBackendEntriesIterator, + keys::{LeafKey, SubtreeKey}, + subtree_cf_name, + tree_metadata::TreeMetadata, +}; +use crate::{ + Word, + merkle::{ + EmptySubtreeRoots, NodeIndex, SparseMerklePath, + smt::{ + BackendReader, InnerNode, LeafIndex, LineageId, SMT_DEPTH, SmtLeaf, SmtProof, Subtree, + TreeEntry, TreeWithRoot, VersionId, full::concurrent::SUBTREE_DEPTH, + }, + }, +}; + +// PERSISTENT BACKEND SNAPSHOT INNER +// ================================================================================================ + +/// Inner state shared by all clones of a [`PersistentBackendReader`]. +/// +/// Pairs a RocksDB point-in-time snapshot with the `Arc` that owns the database, so that +/// the database is guaranteed to outlive the snapshot. +/// +/// # Safety +/// +/// `snapshot` contains an internal pointer into the `DB` allocation. `db` must not be dropped +/// (i.e. its refcount must not reach zero) while `snapshot` is live. The `Drop` impl enforces +/// this by explicitly dropping `snapshot` before the `Arc` field is automatically decremented. +pub(super) struct SnapshotInner { + /// The RocksDB snapshot providing the consistent read view. + /// + /// The `'static` lifetime is a sound lie: the real lifetime is tied to `db`. The `Drop` impl + /// guarantees we drop this before `db`. + snapshot: ManuallyDrop>, + /// Keeps the database alive for at least as long as `snapshot`. + db: Arc, + /// Point-in-time view of the lineage metadata, shared with the backend via copy-on-write. + lineages: Arc>, +} + +impl Drop for SnapshotInner { + fn drop(&mut self) { + // SAFETY: Drop the snapshot before the Arc refcount is decremented. + unsafe { + ManuallyDrop::drop(&mut self.snapshot); + } + } +} + +impl core::fmt::Debug for SnapshotInner { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("SnapshotInner").finish_non_exhaustive() + } +} + +// PERSISTENT BACKEND READER +// ================================================================================================ + +/// A read-only, point-in-time snapshot of a [`PersistentBackend`](super::PersistentBackend). +/// +/// This type intentionally implements only [`BackendReader`], not +/// [`Backend`](crate::merkle::smt::Backend). It is returned by +/// [`Backend::reader`](crate::merkle::smt::Backend::reader) for +/// [`PersistentBackend`](super::PersistentBackend) to provide read-only access to a consistent +/// snapshot of the backend state without exposing any mutation capabilities. +/// +/// All reads go through a RocksDB snapshot, so the view is frozen at the instant +/// [`Backend::reader`](crate::merkle::smt::Backend::reader) was called; concurrent writes to the +/// underlying database are invisible to this reader. +/// +/// Cloning is O(1): both the snapshot and the lineage metadata are owned by the inner `Arc`. +#[derive(Clone, Debug)] +pub struct PersistentBackendReader { + inner: Arc, +} + +impl PersistentBackendReader { + pub(super) fn new( + db: Arc, + snapshot: db::Snapshot<'static>, + lineages: Arc>, + ) -> Self { + Self { + inner: Arc::new(SnapshotInner { + snapshot: ManuallyDrop::new(snapshot), + db, + lineages, + }), + } + } + + fn load_subtree(&self, tree_key: SubtreeKey) -> Result> { + let cf = self.subtree_cf(tree_key.index)?; + let key_bytes = tree_key.to_bytes(); + let result = match self.inner.snapshot.get_cf(cf, key_bytes) { + Ok(Some(bytes)) => Some(Subtree::from_vec(tree_key.index, &bytes)?), + Ok(None) => None, + Err(e) => return Err(e.into()), + }; + Ok(result) + } + + fn load_leaf_raw(&self, key: &LeafKey) -> Result> { + let col = self.cf(LEAVES_CF)?; + let key_bytes = key.to_bytes(); + let leaf_bytes = self.inner.snapshot.get_cf(col, key_bytes)?; + Ok(match leaf_bytes { + Some(bytes) => Some(SmtLeaf::read_from_bytes_with_budget(&bytes, bytes.len())?), + None => None, + }) + } + + fn load_leaf_for(&self, lineage: LineageId, key: Word) -> Result> { + let key = LeafKey { + lineage, + index: LeafIndex::from(key).position(), + }; + self.load_leaf_raw(&key) + } + + #[inline(always)] + fn subtree_cf(&self, index: NodeIndex) -> Result<&db::ColumnFamily> { + self.subtree_cf_depth(index.depth()) + } + + #[inline(always)] + fn subtree_cf_depth(&self, depth: u8) -> Result<&db::ColumnFamily> { + let cf_name = subtree_cf_name(depth); + self.cf(cf_name) + } + + #[inline(always)] + fn cf(&self, name: &str) -> Result<&db::ColumnFamily> { + self.inner.db.cf_handle(name).ok_or_else(|| { + BackendError::internal_from_message(format!("Could not load column with name {name}")) + }) + } +} + +impl BackendReader for PersistentBackendReader { + fn open(&self, lineage: LineageId, key: Word) -> Result { + open_proof( + &self.inner.lineages, + lineage, + key, + |l, k| self.load_leaf_for(l, k), + |k| self.load_subtree(k), + ) + } + + fn get_leaf(&self, lineage: LineageId, leaf_index: LeafIndex) -> Result { + if !self.inner.lineages.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); + } + let key = LeafKey { lineage, index: leaf_index.position() }; + Ok(self.load_leaf_raw(&key)?.unwrap_or_else(|| SmtLeaf::new_empty(leaf_index))) + } + + fn get(&self, lineage: LineageId, key: Word) -> Result> { + if !self.inner.lineages.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); + } + let leaf = self.load_leaf_for(lineage, key)?; + Ok(leaf.and_then(|l| { + let val = l.get_value(&key); + val.and_then(|e| if e.is_empty() { None } else { Some(e) }) + })) + } + + fn version(&self, lineage: LineageId) -> Result { + let metadata = + self.inner.lineages.get(&lineage).ok_or(BackendError::UnknownLineage(lineage))?; + Ok(metadata.version) + } + + fn lineages(&self) -> Result> { + Ok(self.inner.lineages.keys().copied()) + } + + fn trees(&self) -> Result> { + Ok(self + .inner + .lineages + .iter() + .map(|(l, m)| TreeWithRoot::new(*l, m.version, m.root_value))) + } + + fn entry_count(&self, lineage: LineageId) -> Result { + let metadata = + self.inner.lineages.get(&lineage).ok_or(BackendError::UnknownLineage(lineage))?; + Ok(metadata.entry_count.try_into().expect("Count of entries should fit into usize")) + } + + fn entries(&self, lineage: LineageId) -> Result>> { + if !self.inner.lineages.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); + } + let lineage_bytes = lineage.to_bytes(); + let cf = self.cf(LEAVES_CF)?; + let mut read_opts = db::ReadOptions::default(); + read_opts.set_prefix_same_as_start(true); + let pfx_iterator = self.inner.snapshot.iterator_cf_opt( + cf, + read_opts, + db::IteratorMode::From(&lineage_bytes, db::Direction::Forward), + ); + Ok(PersistentBackendEntriesIterator::new(lineage, pfx_iterator)) + } +} + +// HELPERS +// ================================================================================================ + +fn compute_merkle_path( + mut leaf_index: NodeIndex, + subtrees: &HashMap, +) -> SparseMerklePath { + let mut path = Vec::with_capacity(SMT_DEPTH as usize); + + while leaf_index.depth() > 0 { + let is_right = leaf_index.is_position_odd(); + leaf_index = leaf_index.parent(); + + let root = Subtree::find_subtree_root(leaf_index); + let subtree = &subtrees[&root]; + let InnerNode { left, right } = subtree + .get_inner_node(leaf_index) + .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, leaf_index.depth())); + + path.push(if is_right { left } else { right }); + } + + SparseMerklePath::from_sized_iter(path).expect("Always succeeds by construction") +} + +pub(super) fn open_proof( + lineages: &HashMap, + lineage: LineageId, + key: Word, + load_leaf: impl Fn(LineageId, Word) -> Result>, + load_subtree: impl Fn(SubtreeKey) -> Result>, +) -> Result { + if !lineages.contains_key(&lineage) { + return Err(BackendError::UnknownLineage(lineage)); + } + + let leaf = load_leaf(lineage, key)?.unwrap_or_else(|| SmtLeaf::new_empty(LeafIndex::from(key))); + let leaf_index: NodeIndex = LeafIndex::from(key).into(); + + // An opening needs exactly one subtree per level; collect their roots up front so we can + // load them all before constructing the path. + let subtree_roots = (0..SMT_DEPTH / SUBTREE_DEPTH) + .scan(leaf_index.parent(), |cursor, _| { + let subtree_root = Subtree::find_subtree_root(*cursor); + *cursor = subtree_root.parent(); + Some(subtree_root) + }) + .collect::>(); + + // Loading subtrees as a separate step (rather than inline during path construction) + // exhibits better performance due to improved pipelining and branch-predictor behavior. + let mut subtree_cache = HashMap::::new(); + for root in subtree_roots { + let maybe_tree = load_subtree(SubtreeKey { lineage, index: root })?; + subtree_cache.insert(root, maybe_tree.unwrap_or_else(|| Subtree::new(root))); + } + + let merkle_path = compute_merkle_path(leaf_index, &subtree_cache); + Ok(SmtProof::new_unchecked(merkle_path, leaf)) +} diff --git a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/tests.rs b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/tests.rs index 57645b96a4..306db30b42 100644 --- a/miden-crypto/src/merkle/smt/large_forest/backend/persistent/tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/backend/persistent/tests.rs @@ -15,8 +15,9 @@ use super::{PersistentBackend, Result}; use crate::{ EMPTY_WORD, Word, merkle::smt::{ - Backend, BackendError, LineageId, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeEntry, - TreeWithRoot, VersionId, large_forest::backend::persistent::config::Config, + Backend, BackendError, BackendReader, LargeSmtForest, LineageId, Smt, SmtForestUpdateBatch, + SmtUpdateBatch, TreeEntry, TreeWithRoot, VersionId, + large_forest::backend::persistent::config::Config, }, rand::test_utils::ContinuousRng, }; @@ -822,3 +823,142 @@ fn update_forest() -> Result<()> { Ok(()) } + +#[test] +fn forest_apply_noop_update_tree_does_not_panic() { + let (_dir, backend) = default_backend().unwrap(); + let mut forest = LargeSmtForest::new(backend).unwrap(); + let mut rng = ContinuousRng::new([0x99; 32]); + + let lineage: LineageId = rng.value(); + let key: Word = rng.value(); + let value: Word = rng.value(); + + let mut initial = SmtUpdateBatch::default(); + initial.add_insert(key, value); + forest.add_lineage(lineage, 1, initial).unwrap(); + + let mut noop = SmtUpdateBatch::default(); + noop.add_insert(key, value); + let mutations = forest.compute_update_tree_mutations(lineage, 2, noop).unwrap(); + let roots = forest.apply_mutations(mutations).unwrap(); + + assert_eq!(roots.len(), 1); + assert_eq!(roots[0].version(), 1); +} + +#[test] +fn apply_mutations_rejects_stale_prepared_update() -> Result<()> { + let (_dir, mut backend) = default_backend()?; + let mut rng = ContinuousRng::new([0xa5; 32]); + + let lineage: LineageId = rng.value(); + let key_1: Word = rng.value(); + let value_1: Word = rng.value(); + let key_2: Word = rng.value(); + let value_2: Word = rng.value(); + let key_3: Word = rng.value(); + let value_3: Word = rng.value(); + + let mut initial = SmtUpdateBatch::default(); + initial.add_insert(key_1, value_1); + backend.add_lineage(lineage, 1, initial)?; + + let mut stale_updates = SmtUpdateBatch::default(); + stale_updates.add_insert(key_2, value_2); + let mut stale_batch = SmtForestUpdateBatch::empty(); + stale_batch.operations(lineage).add_operations(stale_updates.into_iter()); + let (_visible, stale_prepared) = backend.compute_mutations(2, stale_batch)?; + + let mut intervening_updates = SmtUpdateBatch::default(); + intervening_updates.add_insert(key_3, value_3); + backend.update_tree(lineage, 2, intervening_updates)?; + + assert!( + backend.apply_mutations(stale_prepared).is_err(), + "stale prepared mutations must not apply after the lineage root changes" + ); + + Ok(()) +} + +#[test] +fn reader_snapshot_isolation() -> Result<()> { + // Writes committed to the backend after the reader is created must be invisible to the reader. + let (_dir, mut backend) = default_backend()?; + let mut rng = ContinuousRng::new([0xc7; 32]); + let version: VersionId = rng.value(); + + // Add lineage_1 and create the reader while lineage_2 does not yet exist. + let lineage_1: LineageId = rng.value(); + let k1: Word = rng.value(); + let v1: Word = rng.value(); + let mut ops = SmtUpdateBatch::default(); + ops.add_insert(k1, v1); + backend.add_lineage(lineage_1, version, ops)?; + + let reader = backend.reader()?; + + // Now add lineage_2 after the reader was created. + let lineage_2: LineageId = rng.value(); + let k2: Word = rng.value(); + let v2: Word = rng.value(); + let mut ops = SmtUpdateBatch::default(); + ops.add_insert(k2, v2); + backend.add_lineage(lineage_2, version, ops)?; + + // Also mutate lineage_1 after the snapshot. + let k3: Word = rng.value(); + let v3: Word = rng.value(); + let mut ops = SmtUpdateBatch::default(); + ops.add_insert(k3, v3); + backend.update_tree(lineage_1, version + 1, ops)?; + + // The reader must not see lineage_2 at all. + assert_eq!(reader.lineages()?.count(), 1); + assert!(!reader.lineages()?.any(|l| l == lineage_2)); + assert_matches!(reader.open(lineage_2, k2).unwrap_err(), BackendError::UnknownLineage(l) if l == lineage_2); + assert_matches!(reader.get(lineage_2, k2).unwrap_err(), BackendError::UnknownLineage(l) if l == lineage_2); + + // The reader must see lineage_1 at the pre-snapshot state (k3 absent, version unchanged). + assert_eq!(reader.version(lineage_1)?, version); + assert_eq!(reader.entry_count(lineage_1)?, 1); + assert!(reader.get(lineage_1, k3)?.is_none()); + assert_eq!(reader.get(lineage_1, k1)?, Some(v1)); + + Ok(()) +} + +#[test] +fn reader_clone() -> Result<()> { + // Cloning a reader must produce an independent handle to the same snapshot. + let (_dir, mut backend) = default_backend()?; + let mut rng = ContinuousRng::new([0xc8; 32]); + let version: VersionId = rng.value(); + + let lineage_1: LineageId = rng.value(); + let k1: Word = rng.value(); + let v1: Word = rng.value(); + let mut ops = SmtUpdateBatch::default(); + ops.add_insert(k1, v1); + backend.add_lineage(lineage_1, version, ops)?; + + let reader = backend.reader()?; + let reader_clone = reader.clone(); + + // Write to the backend after cloning — neither handle should see it. + let lineage_2: LineageId = rng.value(); + let mut ops = SmtUpdateBatch::default(); + ops.add_insert(rng.value(), rng.value()); + backend.add_lineage(lineage_2, version, ops)?; + + // Both handles see exactly lineage_1 and agree on its data. + for r in [&reader, &reader_clone] { + assert_eq!(r.lineages()?.count(), 1); + assert!(r.lineages()?.any(|l| l == lineage_1)); + assert_eq!(r.get(lineage_1, k1)?, Some(v1)); + assert_matches!(r.get(lineage_2, k1).unwrap_err(), BackendError::UnknownLineage(l) if l == lineage_2); + } + + Ok(()) +} diff --git a/miden-crypto/src/merkle/smt/large_forest/error.rs b/miden-crypto/src/merkle/smt/large_forest/error.rs index 21a6ebf2bb..b342b7cfb7 100644 --- a/miden-crypto/src/merkle/smt/large_forest/error.rs +++ b/miden-crypto/src/merkle/smt/large_forest/error.rs @@ -84,6 +84,9 @@ impl LargeSmtForestError { impl From for LargeSmtForestError { fn from(value: BackendError) -> Self { match value { + BackendError::BadVersion { provided, latest } => { + LargeSmtForestError::BadVersion { provided, latest } + }, BackendError::CorruptedData(_) => LargeSmtForestError::fatal_from(value), BackendError::DuplicateLineage(l) => LargeSmtForestError::DuplicateLineage(l), BackendError::Internal(e) => LargeSmtForestError::Fatal(e), diff --git a/miden-crypto/src/merkle/smt/large_forest/mod.rs b/miden-crypto/src/merkle/smt/large_forest/mod.rs index a5fbb88ed3..cc0d3beb9d 100644 --- a/miden-crypto/src/merkle/smt/large_forest/mod.rs +++ b/miden-crypto/src/merkle/smt/large_forest/mod.rs @@ -314,15 +314,22 @@ mod utils; use alloc::vec::Vec; use core::num::NonZeroU8; -pub use backend::{Backend, BackendError, memory::InMemoryBackend}; +pub use backend::{ + Backend, BackendError, BackendReader, + memory::{InMemoryBackend, InMemoryBackendSnapshot}, +}; #[cfg(feature = "persistent-forest")] pub use backend::{ - persistent::PersistentBackend, persistent::config::Config as PersistentBackendConfig, + persistent::config::Config as PersistentBackendConfig, + persistent::{PersistentBackend, PersistentBackendReader}, }; pub use config::{Config, DEFAULT_MAX_HISTORY_VERSIONS, MIN_HISTORY_VERSIONS}; pub use error::{LargeSmtForestError, Result}; -pub use operation::{ForestOperation, SmtForestUpdateBatch, SmtUpdateBatch}; +pub use operation::{SmtForestOperation, SmtForestUpdateBatch, SmtUpdateBatch}; pub use root::{LineageId, RootInfo, TreeEntry, TreeId, TreeWithRoot, VersionId}; +pub use utils::{ + AppliedLineageMutation, LineageMutation, LineageMutationKind, SmtForestMutationSet, +}; use crate::{ EMPTY_WORD, Map, Set, Word, @@ -346,8 +353,8 @@ use crate::{ /// A high-performance forest of sparse merkle trees with pluggable storage backends. /// /// See the module documentation for more information. -#[derive(Debug)] -pub struct LargeSmtForest { +#[derive(Clone, Debug)] +pub struct LargeSmtForest { /// The configuration for how the forest functions. config: Config, @@ -369,17 +376,6 @@ pub struct LargeSmtForest { non_empty_histories: Set, } -impl Clone for LargeSmtForest { - fn clone(&self) -> Self { - Self { - config: self.config.clone(), - backend: self.backend.clone(), - lineage_data: self.lineage_data.clone(), - non_empty_histories: self.non_empty_histories.clone(), - } - } -} - // CONSTRUCTION AND BASIC QUERIES // ================================================================================================ @@ -394,7 +390,7 @@ impl Clone for LargeSmtForest { /// /// Where anything more specific can be said about performance, the method documentation will /// contain more detail. -impl LargeSmtForest { +impl LargeSmtForest { /// Constructs a new forest backed by the provided `backend` using the default [`Config`] for /// the forest's behavior. /// @@ -471,7 +467,7 @@ impl LargeSmtForest { /// All of these methods can be performed fully in-memory, and hence their performance is /// predictable on a given machine regardless of the choice of [`Backend`] instance being used by /// the forest. -impl LargeSmtForest { +impl LargeSmtForest { /// Returns an iterator that yields all the (uniquely identified) roots that the forest knows /// about, including those from historical versions. /// @@ -576,7 +572,7 @@ impl LargeSmtForest { /// /// Where anything more specific can be said about performance, the method documentation will /// contain more detail. -impl LargeSmtForest { +impl LargeSmtForest { /// Returns an opening for the specified `key` in the specified `tree`, regardless of whether /// the `tree` has a value associated with `key` or not. /// @@ -809,9 +805,14 @@ impl LargeSmtForest { /// Where anything more specific can be said about performance, the method documentation will /// contain more detail. impl LargeSmtForest { - /// Adds a new `lineage` to the tree, creating an empty tree and modifying it as specified by + /// Adds a new `lineage` to the forest, creating an empty tree and modifying it as specified by /// `updates`, with the result taking the provided `new_version`. /// + /// This is the one-phase convenience API. It is equivalent to ensuring that the lineage + /// does not exist and then calling [`Self::compute_tree_mutations`] followed immediately + /// by [`Self::apply_mutations`]. Use the two-phase API directly when the proposed root + /// commitment must be inspected before committing the backend changes. + /// /// If the provided `updates` batch is empty, then the **empty tree will be added** as the first /// version in the lineage. /// @@ -819,120 +820,187 @@ impl LargeSmtForest { /// /// - [`LargeSmtForestError::DuplicateLineage`] if the provided `lineage` is the same as an /// already-known lineage. - /// - [`LargeSmtForestError::Fatal`] if the backend fails while being accessed. - /// - [`BackendError::Merkle`] if the provided `updates` cannot be applied to the empty tree. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing or applying the + /// mutation. + /// - [`LargeSmtForestError::Merkle`] if the provided `updates` cannot be applied to the empty + /// tree. pub fn add_lineage( &mut self, lineage: LineageId, new_version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // We can immediately add lineage in the backend, as by its contract it should return - // `DuplicateLineage` if the new lineage is a duplicate. We forward that, and any other - // errors, as this is the correct behavior for correctly-implemented backends. - let tree_info = self.backend.add_lineage(lineage, new_version, updates)?; - - // We then construct the lineage tracking data and shove it into the corresponding map. The - // history is guaranteed to be empty here, so we do not need to put an entry in the - // non-empty histories set. - let lineage_data = LineageData { - history: History::empty(self.config.max_history_versions()), - latest_version: tree_info.version(), - latest_root: tree_info.root(), - }; - self.lineage_data.insert(lineage, lineage_data); + let mutations = self.compute_add_lineage_mutations(lineage, new_version, updates)?; + let mut roots = self.apply_mutations(mutations)?; + Ok(roots.pop().expect("single lineage mutation returns one root")) + } + + /// Computes the mutations required to add a new `lineage`, without applying them. + /// + /// This is the first phase of [`Self::add_lineage`]. It computes the proposed new root for the + /// lineage and the backend-specific data needed to commit the update later via + /// [`Self::apply_mutations`]. Reverse mutations needed for history are returned by the backend + /// during the apply phase. + /// + /// The forest and backend are not modified by this method. Callers can inspect the returned + /// mutation set before committing it, which is useful when root commitments must be published, + /// checked, or combined with other data before the backend write occurs. + /// + /// If the provided `updates` batch is empty, then the **empty tree will be added** as the first + /// version in the lineage. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::DuplicateLineage`] if the provided `lineage` is the same as an + /// already-known lineage. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing the mutation. + /// - [`LargeSmtForestError::Merkle`] if the provided `updates` cannot be applied to the empty + /// tree. + /// + /// # Applying the Result + /// + /// The returned mutation set is valid only while the lineage remains at the same latest + /// version and root. [`Self::apply_mutations`] rejects stale mutation sets if the lineage has + /// changed since this method was called. + fn compute_add_lineage_mutations( + &self, + lineage: LineageId, + new_version: VersionId, + updates: SmtUpdateBatch, + ) -> Result> { + if self.lineage_data.contains_key(&lineage) { + return Err(LargeSmtForestError::DuplicateLineage(lineage)); + } - Ok(tree_info) + self.compute_tree_mutations(lineage, new_version, updates) } /// Performs the provided `updates` on the latest tree in the specified `lineage`, adding a /// single new root to the forest (corresponding to `new_version`) for the entire batch, and /// returning the data for the new root of the tree. /// + /// This is the one-phase convenience API. It is equivalent to ensuring that the lineage exists + /// then calling [`Self::compute_tree_mutations`] followed immediately by + /// [`Self::apply_mutations`]. Use the two-phase API directly when the proposed root commitment + /// must be inspected before committing the backend changes. + /// /// If applying the provided `operations` results in no changes to the tree, then the root data /// will be returned unchanged and no new tree will be allocated. It will retain its original /// version, and not be returned with `new_version`. /// /// # Errors /// - /// - [`LargeSmtForestError::BadVersion`] if the `new_version` is older than the latest version + /// - [`LargeSmtForestError::BadVersion`] if `new_version` is not newer than the latest version /// for the provided `lineage`. - /// - [`LargeSmtForestError::Fatal`] if the backend fails while being accessed. - /// - [`LargeSmtForestError::UnknownLineage`] if the provided `tree` specifies a lineage that is - /// not one known by the forest. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing or applying the + /// mutation. + /// - [`LargeSmtForestError::Merkle`] if the provided `updates` cannot be applied to the latest + /// tree. + /// - [`LargeSmtForestError::UnknownLineage`] if `lineage` is not known by the forest. pub fn update_tree( &mut self, lineage: LineageId, new_version: VersionId, updates: SmtUpdateBatch, ) -> Result { - // We initially check that the lineage is known and that the version is greater than the - // last known version for that lineage. - let lineage_data = if let Some(lineage_data) = self.lineage_data.get_mut(&lineage) { - if lineage_data.latest_version < new_version { - lineage_data - } else { - return Err(LargeSmtForestError::BadVersion { - provided: new_version, - latest: lineage_data.latest_version, - }); - } - } else { + let mutations = self.compute_update_tree_mutations(lineage, new_version, updates)?; + let mut roots = self.apply_mutations(mutations)?; + Ok(roots.pop().expect("single lineage mutation returns one root")) + } + + /// Computes the mutations required to update the latest tree in `lineage`, without applying + /// them. + /// + /// This is the first phase of [`Self::update_tree`]. It computes the proposed new root for the + /// lineage and the backend-specific data needed to commit the update later via + /// [`Self::apply_mutations`]. Reverse mutations needed for history are returned by the backend + /// during the apply phase. + /// + /// The forest and backend are not modified by this method. Callers can inspect the returned + /// mutation set before committing it, which is useful when root commitments must be published, + /// checked, or combined with other data before the backend write occurs. + /// + /// If applying `updates` would not change the tree, the returned mutation set represents a + /// no-op. Applying it will not advance the lineage version or allocate a new tree version, and + /// [`SmtForestMutationSet::roots`] will report the current latest version/root. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::UnknownLineage`] if `lineage` is not known by the forest. + /// - [`LargeSmtForestError::BadVersion`] if `new_version` is not newer than the latest version + /// for `lineage`. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing prepared data. + /// - [`LargeSmtForestError::Merkle`] if the provided `updates` cannot be applied to the latest + /// tree. + /// + /// # Applying the Result + /// + /// The returned mutation set is valid only while the lineage remains at the same latest + /// version and root. [`Self::apply_mutations`] rejects stale mutation sets if the lineage has + /// changed since this method was called. + fn compute_update_tree_mutations( + &self, + lineage: LineageId, + new_version: VersionId, + updates: SmtUpdateBatch, + ) -> Result> { + let Some(lineage_data) = self.lineage_data.get(&lineage) else { return Err(LargeSmtForestError::UnknownLineage(lineage)); }; - // We capture the entry count before the backend update, as this is the count for the - // version being pushed into history. This must precede `update_tree` because the backend - // entry count changes after that call. - // - // By the contract of the `Backend` trait, `entry_count` is expected to be extremely cheap, - // and only fail if the lineage in question is missing. Versions of this method that are - // expensive to call, or that can fail due to I/O or other such reasons, are considered - // non-conformant. - let old_entry_count = self.backend.entry_count(lineage)?; - - // We now know that we have a valid lineage and a valid version, so we perform the update in - // the backend. - let reversion_set = self.backend.update_tree(lineage, new_version, updates)?; - - // We do not want to actually change anything if the tree would not change. - if reversion_set.is_empty() { - return Ok(TreeWithRoot::new( - lineage, - lineage_data.latest_version, - lineage_data.latest_root, - )); - } - - // The new root of the latest tree is actually given by the **old root** in our reverse - // mutation set. - let updated_root = reversion_set.old_root; - - // The call to `add_version_from_mutation_set` should only yield an error if the - // provided version does not pass the version check. This check has already been - // performed as a precondition for reaching this point of the tree update, and - // hence should only ever fail due to a programmer bug so we panic if it does fail. - lineage_data - .history - .add_version_from_mutation_set( - lineage_data.latest_version, - reversion_set, - old_entry_count, - ) - .unwrap_or_else(|_| { - panic!("Unable to add valid version {} to history", lineage_data.latest_version) + if lineage_data.latest_version >= new_version { + return Err(LargeSmtForestError::BadVersion { + provided: new_version, + latest: lineage_data.latest_version, }); + } - // At this point we now have a historical version added, so we track that the lineage has a - // non-empty history. - self.non_empty_histories.insert(lineage); + self.compute_tree_mutations(lineage, new_version, updates) + } - // Now we just have to update the other portions of the lineage data in place... - lineage_data.latest_root = updated_root; - lineage_data.latest_version = new_version; + /// Computes the mutations required to mutate one lineage, without applying them. + /// + /// If the lineage is already present it is updated, while unknown lineages are added from the + /// empty tree. + /// + /// The forest and backend are not modified by this method. Callers can inspect the returned + /// mutation set, especially via [`SmtForestMutationSet::roots`] or + /// [`SmtForestMutationSet::lineage_mutations`], before deciding whether to commit it. + /// + /// If the provided `updates` batch is empty, the returned mutation set still represents adding + /// the empty tree as the first version of the lineage. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::DuplicateLineage`] if `lineage` is already known by the forest. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing prepared data. + /// - [`LargeSmtForestError::Merkle`] if the provided `updates` cannot be applied to the empty + /// tree. + /// + /// # Applying the Result + /// + /// The returned mutation set is valid only for the forest state against which it was computed. + /// [`Self::apply_mutations`] rejects stale mutation sets if the lineage has changed since this + /// method was called. + pub fn compute_tree_mutations( + &self, + lineage: LineageId, + new_version: VersionId, + updates: SmtUpdateBatch, + ) -> Result> { + if let Some(lineage_data) = self.lineage_data.get(&lineage) + && lineage_data.latest_version >= new_version + { + return Err(LargeSmtForestError::BadVersion { + provided: new_version, + latest: lineage_data.latest_version, + }); + } - // ...and return the correct value. - Ok(TreeWithRoot::new(lineage, new_version, updated_root)) + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(lineage).add_operations(updates.into_iter()); + let (entries, prepared) = self.backend.compute_mutations(new_version, batch)?; + Ok(SmtForestMutationSet::new(entries, prepared)) } } @@ -951,10 +1019,15 @@ impl LargeSmtForest { /// Where anything more specific can be said about performance, the method documentation will /// contain more detail. impl LargeSmtForest { - /// Adds multiple new `lineages` to the tree, creating an empty tree for each and applying the + /// Adds multiple new `lineages` to the forest, creating an empty tree for each and applying the /// provided modifications to it, with the result being given the specified `version`. /// - /// If the provide batch of modifications is empty for any given lineage, then the **empty tree + /// This is the one-phase convenience API. It is equivalent to make sure that none of the + /// lineages exists and then calling [`Self::compute_forest_mutations`] followed immediately + /// by [`Self::apply_mutations`]. Use the two-phase API directly when the proposed root + /// commitments must be inspected before committing the backend changes. + /// + /// If the provided batch of modifications is empty for any given lineage, then the **empty tree /// will be added** as the first version in that lineage. /// /// # Performance @@ -964,51 +1037,72 @@ impl LargeSmtForest { /// [`Self::add_lineage`] in a loop, but in some cases it may be significantly more performant. /// /// The exact scope of any speed-up is determined by the backend in use, so it is worth reading - /// the documentation for the Backend's `add_lineages` method. + /// the documentation for the backend's [`Backend::compute_mutations`] method. /// /// # Errors /// /// - [`LargeSmtForestError::DuplicateLineage`] if any of the provided lineages share an ID with /// an already-known lineage. - /// - [`LargeSmtForestError::Fatal`] if the backend fails while being accessed. - /// - [`BackendError::Merkle`] if the provided `updates` cannot be applied to the empty tree. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing or applying the + /// mutations. + /// - [`LargeSmtForestError::Merkle`] if any provided updates cannot be applied to an empty + /// tree. pub fn add_lineages( &mut self, version: VersionId, lineages: SmtForestUpdateBatch, ) -> Result> { - // We start by performing our precondition checks: none of the lineages in the batch should - // already exist in the forest. + let mutations = self.compute_add_lineages_mutations(version, lineages)?; + self.apply_mutations(mutations) + } + + /// Computes mutations that would add multiple new lineages without applying them. + /// + /// This is the first phase of [`Self::add_lineages`]. It returns a [`SmtForestMutationSet`] + /// containing one proposed result per lineage in `lineages`, plus backend-specific prepared + /// data that can later be committed with [`Self::apply_mutations`]. + /// + /// The forest and backend are not modified by this method. Callers can inspect the returned + /// roots before committing the batch. This is useful when a higher-level transaction must know + /// all new root commitments before the forest backend is updated. + /// + /// Empty per-lineage operation batches still produce new empty lineages in the mutation set. + /// + /// # Performance + /// + /// Backends may compute the per-lineage mutations in parallel and may prepare a batched apply + /// representation. This method should generally be preferred over repeated + /// [`Self::compute_add_lineage_mutations`] calls when adding multiple lineages. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::DuplicateLineage`] if any provided lineage is already known by the + /// forest. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing prepared data. + /// - [`LargeSmtForestError::Merkle`] if any provided updates cannot be applied to an empty + /// tree. + fn compute_add_lineages_mutations( + &self, + version: VersionId, + lineages: SmtForestUpdateBatch, + ) -> Result> { for lineage in lineages.lineages() { if self.lineage_data.contains_key(lineage) { return Err(LargeSmtForestError::DuplicateLineage(*lineage)); } } - // With the preconditions checked we can call into the backend to perform the additions, and - // we forward all errors as this will be correct for conformant backend implementations. - let results = self.backend.add_lineages(version, lineages)?; - - // Now we have to update the lineage data for each newly-added lineage. New lineages have - // empty histories, so we do not need to insert into `non_empty_histories`. - results - .into_iter() - .map(|(lineage, tree_info)| { - let lineage_data = LineageData { - history: History::empty(self.config.max_history_versions()), - latest_version: tree_info.version(), - latest_root: tree_info.root(), - }; - self.lineage_data.insert(lineage, lineage_data); - - Ok(tree_info) - }) - .collect() + let (entries, prepared) = self.backend.compute_mutations(version, lineages)?; + Ok(SmtForestMutationSet::new(entries, prepared)) } /// Performs the provided `updates` on the forest, adding at most one new root with version - /// `new_version` to the forest for each target root in `updates` and returning a mapping - /// from old root to the new root data. + /// `new_version` for each targeted lineage and returning the resulting root data. + /// + /// This is the one-phase convenience API. It is equivalent to checking that all of the lineages + /// exist then calling [`Self::compute_forest_mutations`] followed immediately by + /// [`Self::apply_mutations`]. Use the two-phase API directly when the proposed root + /// commitments must be inspected before committing the backend changes. /// /// If applying the associated batch to any given lineage in the forest results in no changes to /// that tree, the initial root for that lineage will be returned and no new tree will be @@ -1016,20 +1110,58 @@ impl LargeSmtForest { /// /// # Errors /// - /// - [`LargeSmtForestError::UnknownLineage`] If any lineage in the batch of modifications is + /// - [`LargeSmtForestError::UnknownLineage`] if any lineage in the batch of modifications is /// one that is not known by the forest. - /// - [`LargeSmtForestError::Fatal`] if any error occurs to leave the forest in an inconsistent - /// state. - /// - [`LargeSmtForestError::BadVersion`] if the `new_version` is older than the latest version - /// for the provided `lineage`. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing or applying the + /// mutations. + /// - [`LargeSmtForestError::BadVersion`] if `new_version` is not newer than the latest version + /// for any targeted lineage. + /// - [`LargeSmtForestError::Merkle`] if any provided updates cannot be applied to the relevant + /// latest tree. pub fn update_forest( &mut self, new_version: VersionId, updates: SmtForestUpdateBatch, ) -> Result> { - // We start by performing our precondition checks on the lineages and versions. We have to - // ensure both that all the lineages exist, and that the specified version transition is - // valid for all of those lineages. + let mutations = self.compute_update_forest_mutations(new_version, updates)?; + self.apply_mutations(mutations) + } + + /// Computes mutations that would update multiple existing lineages without applying them. + /// + /// This is the first phase of [`Self::update_forest`]. It returns an [`SmtForestMutationSet`] + /// containing one proposed result per lineage in `updates`, plus backend-specific prepared data + /// that can later be committed with [`Self::apply_mutations`]. + /// + /// The forest and backend are not modified by this method. Callers can inspect all proposed + /// root commitments before committing the batch. This is useful when the forest update is only + /// one part of a larger transaction or when root commitments must be signed, checked, or stored + /// elsewhere before the backend write occurs. + /// + /// If a per-lineage update would not change the tree, the corresponding mutation is a no-op. + /// Applying the mutation set will not advance that lineage's version or allocate a new tree + /// version for it. + /// + /// # Performance + /// + /// Backends may compute per-lineage mutations in parallel and may prepare a batched apply + /// representation. This method should generally be preferred over repeated + /// [`Self::compute_update_tree_mutations`] calls when updating multiple lineages. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::UnknownLineage`] if any targeted lineage is not known by the + /// forest. + /// - [`LargeSmtForestError::BadVersion`] if `new_version` is not newer than the latest version + /// for any targeted lineage. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing prepared data. + /// - [`LargeSmtForestError::Merkle`] if any provided updates cannot be applied to the relevant + /// latest tree. + fn compute_update_forest_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> Result> { updates .lineages() .map(|lineage| { @@ -1048,77 +1180,159 @@ impl LargeSmtForest { }) .collect::>>()?; - // We capture the entry counts before the backend update, as these are the counts for the - // versions being pushed into history. This must precede `update_forest` because the - // backend entry counts change after that call. We capture for all lineages eagerly - // (including those that may produce no-op updates) to keep the logic simple and consistent - // with `update_tree`, and use a map to eliminate any accidental dependencies on iteration - // order. - // - // By the contract of the `Backend` trait, `entry_count` is expected to be extremely cheap, - // and only fail if the lineage in question is missing. Versions of this method that are - // expensive to call, or that can fail due to I/O or other such reasons, are considered - // non-conformant. - let old_entry_counts: Map = updates + self.compute_forest_mutations(new_version, updates) + } + + /// Computes mutations that would add or update multiple lineages without applying them. + /// + /// Lineages that are already present are updated, while unknown lineages are added from the + /// empty tree. + /// + /// The forest and backend are not modified by this method. Callers can inspect all proposed + /// root commitments before committing the batch. This is useful when the forest update is only + /// one part of a larger transaction or when root commitments must be signed, checked, or stored + /// elsewhere before the backend write occurs. + /// + /// If a per-lineage update would not change the tree, the corresponding mutation is a no-op. + /// Applying the mutation set will not advance that lineage's version or allocate a new tree + /// version for it. + /// + /// # Performance + /// + /// Backends may compute the per-lineage mutations in parallel and may prepare a batched apply + /// representation. This method should generally be preferred over repeated + /// [`Self::compute_tree_mutations`] calls when updating multiple lineages. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::Fatal`] if the backend fails while computing or applying the + /// mutations. + /// - [`LargeSmtForestError::BadVersion`] if `new_version` is not newer than the latest version + /// for any targeted lineage that is already present. + /// - [`LargeSmtForestError::Merkle`] if any provided updates cannot be applied to the relevant + /// latest tree. + pub fn compute_forest_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> Result> { + updates .lineages() - .map(|lineage| Ok((*lineage, self.backend.entry_count(*lineage)?))) - .collect::>()?; - - // With the preconditions checked we can call into the backend to perform the updates, and - // we forward all errors as this will be correct for conformant backend implementations. - let reversion_sets = self.backend.update_forest(new_version, updates)?; - - // Now we have to update the lineage data (including the history) to ensure that the state - // remains consistent, and we build our return values while doing so. - reversion_sets - .into_iter() - .map(|(lineage, reversion)| { - // Known to exist by construction, so the bare index is safe. - let old_entry_count = old_entry_counts[&lineage]; - let lineage_data = self - .lineage_data - .get_mut(&lineage) - .expect("Lineage has been checked to be present"); - - // If the operations change nothing we want to short-circuit for that tree. - if reversion.is_empty() { - return Ok(TreeWithRoot::new( - lineage, - lineage_data.latest_version, - lineage_data.latest_root, - )); - } + .map(|lineage| { + let Some(lineage_data) = self.lineage_data.get(lineage) else { + return Ok(()); + }; - let updated_root = reversion.old_root; - - // The call to `add_version_from_mutation_set` should only yield an error if the - // provided version does not pass the version check. This check has already been - // performed as a precondition for reaching this point of the forest update, and - // hence should only ever fail due to a programmer bug so we panic if it does fail. - lineage_data - .history - .add_version_from_mutation_set( - lineage_data.latest_version, - reversion, - old_entry_count, - ) - .unwrap_or_else(|_| { - panic!( - "Unable to add valid version {} to history", - lineage_data.latest_version - ) - }); - - // At this point we know that we have a historical version for that tree, so we - // should track it as having a non-empty history. - self.non_empty_histories.insert(lineage); - - lineage_data.latest_root = updated_root; - lineage_data.latest_version = new_version; - - Ok(TreeWithRoot::new(lineage, new_version, updated_root)) + if lineage_data.latest_version < new_version { + Ok(()) + } else { + Err(LargeSmtForestError::BadVersion { + provided: new_version, + latest: lineage_data.latest_version, + }) + } }) - .collect::>>() + .collect::>>()?; + + let (entries, prepared) = self.backend.compute_mutations(new_version, updates)?; + Ok(SmtForestMutationSet::new(entries, prepared)) + } + + /// Applies mutations previously computed by one of the forest compute methods. + /// + /// This is the second phase of the two-phase update API. It consumes an + /// [`SmtForestMutationSet`] returned by one of: + /// + /// - [`Self::compute_tree_mutations`], or + /// - [`Self::compute_forest_mutations`]. + /// + /// The backend validates that the mutation set is still applicable before committing anything. + /// For update mutations, the latest version and root of every affected lineage must still match + /// the version/root captured during the compute phase. For new-lineage mutations, the lineage + /// must still be absent. These checks prevent stale mutation sets from being applied after + /// intervening forest changes. + /// + /// If backend validation succeeds, the opaque backend-prepared data is committed first. Only + /// after the backend apply succeeds does the forest update its in-memory lineage metadata and + /// history. This ordering keeps the forest consistent if the backend reports an error. + /// + /// The returned roots match [`SmtForestMutationSet::roots`] for the applied mutation set. No-op + /// update mutations return the existing latest version/root for their lineages. + /// + /// # Errors + /// + /// - [`LargeSmtForestError::DuplicateLineage`] if a mutation set attempts to add a lineage that + /// is now already known. + /// - [`LargeSmtForestError::UnknownLineage`] if a mutation set attempts to update a lineage + /// that is no longer known. + /// - [`LargeSmtForestError::BadVersion`] if an update mutation was computed against a version + /// that is no longer the latest version. + /// - [`LargeSmtForestError::Merkle`] if an update mutation was computed against a root that is + /// no longer the latest root. + /// - [`LargeSmtForestError::Fatal`] if the backend fails while applying its prepared data. + pub fn apply_mutations( + &mut self, + mutations: SmtForestMutationSet, + ) -> Result> { + let (entries, prepared) = mutations.into_parts(); + + let applied_entries = self.backend.apply_mutations(prepared)?; + + assert_eq!( + entries.len(), + applied_entries.len(), + "backend returned an unexpected number of applied mutations" + ); + let mut roots = Vec::with_capacity(applied_entries.len()); + for (entry, applied) in entries.into_iter().zip(applied_entries) { + assert_eq!(entry.lineage(), applied.lineage()); + assert_eq!(entry.old_version(), applied.old_version()); + assert_eq!(entry.new_version(), applied.new_version()); + assert_eq!(entry.old_root(), applied.old_root()); + assert_eq!(entry.new_root(), applied.new_root()); + assert_eq!(entry.kind(), applied.kind()); + + let root = applied.result(); + match entry.kind() { + LineageMutationKind::AddLineage => { + let lineage_data = LineageData { + history: History::empty(self.config.max_history_versions()), + latest_version: root.version(), + latest_root: root.root(), + }; + self.lineage_data.insert(entry.lineage(), lineage_data); + }, + LineageMutationKind::UpdateTree => { + if entry.old_root() != entry.new_root() { + let lineage_data = self + .lineage_data + .get_mut(&entry.lineage()) + .expect("Lineage has been checked to be present"); + + let old_version = + entry.old_version().expect("update mutations have old versions"); + let old_entry_count = applied.old_entry_count(); + lineage_data + .history + .add_version_from_mutation_set( + old_version, + applied.into_reverse(), + old_entry_count, + ) + .unwrap_or_else(|_| { + panic!("Unable to add valid version {} to history", old_version) + }); + + self.non_empty_histories.insert(entry.lineage()); + lineage_data.latest_root = entry.new_root(); + lineage_data.latest_version = entry.new_version(); + } + }, + } + roots.push(root); + } + + Ok(roots) } } @@ -1127,7 +1341,7 @@ impl LargeSmtForest { /// This block contains internal functions that exist to de-duplicate or modularize functionality /// within the forest. These should not be exposed. -impl LargeSmtForest { +impl LargeSmtForest { /// Applies the history delta given by `history_view` on top of the provided `full_tree_leaf` to /// produce the correct leaf for a historical opening. /// @@ -1230,6 +1444,22 @@ impl LargeSmtForest { } } +impl LargeSmtForest { + /// Returns a read-only `LargeSmtForest` backed by a reader view of this forest's backend. + /// + /// The new forest shares the same config, lineage data, and history as `self`, and its backend + /// is a point-in-time snapshot produced by [`Backend::reader`]. The returned forest's backend + /// type is `B::Reader: BackendReader`, so it cannot be used for mutations. + pub fn reader(&self) -> Result> { + Ok(LargeSmtForest { + config: self.config.clone(), + backend: self.backend.reader()?, + lineage_data: self.lineage_data.clone(), + non_empty_histories: self.non_empty_histories.clone(), + }) + } +} + // TESTING FUNCTIONALITY // ================================================================================================ @@ -1237,7 +1467,7 @@ impl LargeSmtForest { /// inspect the internal state of the forest that are unsafe to make part of the forest's public /// API. #[cfg(test)] -impl LargeSmtForest { +impl LargeSmtForest { /// Gets an immutable reference to the underlying backend of the forest. pub fn get_backend(&self) -> &B { &self.backend diff --git a/miden-crypto/src/merkle/smt/large_forest/operation.rs b/miden-crypto/src/merkle/smt/large_forest/operation.rs index f739a9fdc3..1f6457a609 100644 --- a/miden-crypto/src/merkle/smt/large_forest/operation.rs +++ b/miden-crypto/src/merkle/smt/large_forest/operation.rs @@ -12,7 +12,7 @@ use crate::{EMPTY_WORD, Map, Set, Word, merkle::smt::large_forest::root::Lineage /// The operations that can be performed on an arbitrary leaf in a tree in a forest. #[derive(Clone, Debug, Eq, PartialEq)] -pub enum ForestOperation { +pub enum SmtForestOperation { /// An insertion of `value` under `key` into the tree. /// /// If `key` already exists in the tree, the associated value will be replaced with `value` @@ -22,7 +22,7 @@ pub enum ForestOperation { /// The removal of the `key` and its associated value from the tree. Remove { key: Word }, } -impl ForestOperation { +impl SmtForestOperation { /// Insert the provided `value` into a tree under the provided `key`. pub fn insert(key: Word, value: Word) -> Self { Self::Insert { key, value } @@ -36,17 +36,17 @@ impl ForestOperation { /// Retrieves the key from the operation. pub fn key(&self) -> Word { match self { - ForestOperation::Insert { key, .. } => *key, - ForestOperation::Remove { key } => *key, + SmtForestOperation::Insert { key, .. } => *key, + SmtForestOperation::Remove { key } => *key, } } } -impl From for (Word, Word) { - fn from(value: ForestOperation) -> Self { +impl From for (Word, Word) { + fn from(value: SmtForestOperation) -> Self { match value { - ForestOperation::Insert { key, value } => (key, value), - ForestOperation::Remove { key } => (key, EMPTY_WORD), + SmtForestOperation::Insert { key, value } => (key, value), + SmtForestOperation::Remove { key } => (key, EMPTY_WORD), } } } @@ -58,7 +58,7 @@ impl From for (Word, Word) { #[derive(Clone, Debug, Eq, PartialEq)] pub struct SmtUpdateBatch { /// The operations to be performed on a tree. - operations: Vec, + operations: Vec, } impl SmtUpdateBatch { /// Creates an empty batch of operations. @@ -67,33 +67,33 @@ impl SmtUpdateBatch { } /// Creates a batch containing the provided `operations`. - pub fn new(operations: impl Iterator) -> Self { + pub fn new(operations: impl Iterator) -> Self { Self { operations: operations.collect::>(), } } /// Adds the provided `operations` to the batch. - pub fn add_operations(&mut self, operations: impl Iterator) { + pub fn add_operations(&mut self, operations: impl Iterator) { self.operations.extend(operations); } - /// Adds the [`ForestOperation::Insert`] operation for the provided `key` and `value` pair to + /// Adds the [`SmtForestOperation::Insert`] operation for the provided `key` and `value` pair to /// the batch. pub fn add_insert(&mut self, key: Word, value: Word) { - self.operations.push(ForestOperation::insert(key, value)); + self.operations.push(SmtForestOperation::insert(key, value)); } - /// Adds the [`ForestOperation::Remove`] operation for the provided `key` to the batch. + /// Adds the [`SmtForestOperation::Remove`] operation for the provided `key` to the batch. pub fn add_remove(&mut self, key: Word) { - self.operations.push(ForestOperation::remove(key)); + self.operations.push(SmtForestOperation::remove(key)); } /// Consumes the batch as a vector of operations, containing the last operation for any given /// `key` in the case that multiple operations per key are encountered. /// /// This vector is guaranteed to be sorted by the key on which an operation is performed. - pub fn consume(self) -> Vec { + pub fn consume(self) -> Vec { // As we want to keep the LAST operation for each key, rather than the first, we filter in // reverse. let mut seen_keys: Set = Set::new(); @@ -103,13 +103,13 @@ impl SmtUpdateBatch { .rev() .filter(|o| seen_keys.insert(o.key())) .collect::>(); - ops.sort_by_key(ForestOperation::key); + ops.sort_by_key(SmtForestOperation::key); ops } } impl IntoIterator for SmtUpdateBatch { - type Item = ForestOperation; + type Item = SmtForestOperation; type IntoIter = alloc::vec::IntoIter; /// Consumes the batch as an iterator yielding operations while respecting the guarantees given @@ -127,8 +127,8 @@ impl From for Vec<(Word, Word)> { .consume() .into_iter() .map(|op| match op { - ForestOperation::Insert { key, value } => (key, value), - ForestOperation::Remove { key } => (key, EMPTY_WORD), + SmtForestOperation::Insert { key, value } => (key, value), + SmtForestOperation::Remove { key } => (key, EMPTY_WORD), }) .collect() } @@ -141,9 +141,9 @@ where fn from(value: I) -> Self { Self::new(value.map(|(k, v)| { if v.is_empty() { - ForestOperation::Remove { key: k } + SmtForestOperation::Remove { key: k } } else { - ForestOperation::Insert { key: k, value: v } + SmtForestOperation::Insert { key: k, value: v } } })) } @@ -176,7 +176,7 @@ impl SmtForestUpdateBatch { pub fn add_operations( &mut self, lineage: LineageId, - operations: impl Iterator, + operations: impl Iterator, ) { let batch = self.operations.entry(lineage).or_insert_with(SmtUpdateBatch::empty); batch.add_operations(operations); @@ -198,14 +198,14 @@ impl SmtForestUpdateBatch { /// Consumes the batch as a map of batches, with each individual batch guaranteed to be in /// sorted order and contain only the last operation in the batch for any given key. - pub fn consume(self) -> Map> { + pub fn consume(self) -> Map> { self.operations.into_iter().map(|(k, v)| (k, v.consume())).collect() } } impl IntoIterator for SmtForestUpdateBatch { - type Item = (LineageId, Vec); - type IntoIter = crate::MapIntoIter>; + type Item = (LineageId, Vec); + type IntoIter = crate::MapIntoIter>; /// Consumes the batch as an iterator yielding pairs of `(lineage, operations)` while respecting /// the guarantees given by [`Self::consume`]. @@ -240,9 +240,9 @@ mod test { let o3_key: Word = rng.value(); let o3_value: Word = rng.value(); - let o1 = ForestOperation::insert(o1_key, o1_value); - let o2 = ForestOperation::remove(o2_key); - let o3 = ForestOperation::insert(o3_key, o3_value); + let o1 = SmtForestOperation::insert(o1_key, o1_value); + let o2 = SmtForestOperation::remove(o2_key); + let o3 = SmtForestOperation::insert(o3_key, o3_value); // ... and stick them in the batch in various ways batch.add_operations(vec![o1.clone()].into_iter()); @@ -254,7 +254,7 @@ mod test { // If we then consume the batch, we should have the operations ordered by their key. let ops = batch.consume(); - assert!(ops.is_sorted_by_key(ForestOperation::key)); + assert!(ops.is_sorted_by_key(SmtForestOperation::key)); // Let's now make two additional operations with keys that overlay with keys from the first // three... @@ -262,8 +262,8 @@ mod test { let o4_value: Word = rng.value(); let o5_key = o1_key; - let o4 = ForestOperation::insert(o4_key, o4_value); - let o5 = ForestOperation::remove(o5_key); + let o4 = SmtForestOperation::insert(o4_key, o4_value); + let o5 = SmtForestOperation::remove(o5_key); // ... and also stick them into the batch. let mut batch = batch_tmp; @@ -274,7 +274,7 @@ mod test { let ops = batch.consume(); assert_eq!(ops.len(), 3); - assert!(ops.is_sorted_by_key(ForestOperation::key)); + assert!(ops.is_sorted_by_key(SmtForestOperation::key)); assert!(ops.contains(&o3)); assert!(ops.contains(&o4)); @@ -292,14 +292,14 @@ mod test { // Let's start by adding a few operations to a tree. let t1_lineage: LineageId = rng.value(); - let t1_o1 = ForestOperation::insert(rng.value(), rng.value()); - let t1_o2 = ForestOperation::remove(rng.value()); + let t1_o1 = SmtForestOperation::insert(rng.value(), rng.value()); + let t1_o2 = SmtForestOperation::remove(rng.value()); batch.add_operations(t1_lineage, vec![t1_o1, t1_o2].into_iter()); // We can also add them differently. let t2_lineage: LineageId = rng.value(); - let t2_o1 = ForestOperation::remove(rng.value()); - let t2_o2 = ForestOperation::insert(rng.value(), rng.value()); + let t2_o1 = SmtForestOperation::remove(rng.value()); + let t2_o2 = SmtForestOperation::insert(rng.value(), rng.value()); batch.operations(t2_lineage).add_operations(vec![t2_o1, t2_o2].into_iter()); // When we consume the batch, each per-tree batch should be unique by key and sorted. @@ -307,11 +307,11 @@ mod test { assert_eq!(ops.len(), 2); let t1_ops = ops.get(&t1_lineage).unwrap(); - assert!(t1_ops.is_sorted_by_key(ForestOperation::key)); + assert!(t1_ops.is_sorted_by_key(SmtForestOperation::key)); assert_eq!(t1_ops.iter().unique_by(|o| o.key()).count(), 2); let t2_ops = ops.get(&t2_lineage).unwrap(); - assert!(t2_ops.is_sorted_by_key(ForestOperation::key)); + assert!(t2_ops.is_sorted_by_key(SmtForestOperation::key)); assert_eq!(t2_ops.iter().unique_by(|o| o.key()).count(), 2); } } diff --git a/miden-crypto/src/merkle/smt/large_forest/property_tests.rs b/miden-crypto/src/merkle/smt/large_forest/property_tests.rs index d2af7a5807..96f66e310b 100644 --- a/miden-crypto/src/merkle/smt/large_forest/property_tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/property_tests.rs @@ -9,9 +9,8 @@ use proptest::prelude::*; use crate::{ EMPTY_WORD, Word, merkle::smt::{ - Backend, ForestConfig, ForestInMemoryBackend, ForestOperation, LargeSmtForest, - LargeSmtForestError, LineageId, RootInfo, Smt, SmtForestUpdateBatch, SmtUpdateBatch, - TreeId, + ForestConfig, ForestInMemoryBackend, LargeSmtForest, LargeSmtForestError, LineageId, + RootInfo, Smt, SmtForestOperation, SmtForestUpdateBatch, SmtUpdateBatch, TreeId, large_forest::test_utils::{ apply_batch, arbitrary_batch, arbitrary_distinct_lineages, arbitrary_lineage, arbitrary_non_empty_word, arbitrary_version, arbitrary_word, assert_lineage_metadata, @@ -117,50 +116,50 @@ proptest! { .add_lineage( lineage, version, - SmtUpdateBatch::new([ForestOperation::insert(key_1, value_1)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_1, value_1)].into_iter()), ) .map_err(to_fail)?; forest .update_tree( lineage, version + 1, - SmtUpdateBatch::new([ForestOperation::insert(key_2, value_2)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_2, value_2)].into_iter()), ) .map_err(to_fail)?; forest .update_tree( lineage, version + 2, - SmtUpdateBatch::new([ForestOperation::insert(key_3, value_3)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_3, value_3)].into_iter()), ) .map_err(to_fail)?; forest .update_tree( lineage, version + 3, - SmtUpdateBatch::new([ForestOperation::insert(key_4, value_4)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_4, value_4)].into_iter()), ) .map_err(to_fail)?; let mut tree_v1 = Smt::new(); apply_batch( &mut tree_v1, - SmtUpdateBatch::new([ForestOperation::insert(key_1, value_1)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_1, value_1)].into_iter()), )?; let mut tree_v2 = tree_v1.clone(); apply_batch( &mut tree_v2, - SmtUpdateBatch::new([ForestOperation::insert(key_2, value_2)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_2, value_2)].into_iter()), )?; let mut tree_v3 = tree_v2.clone(); apply_batch( &mut tree_v3, - SmtUpdateBatch::new([ForestOperation::insert(key_3, value_3)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_3, value_3)].into_iter()), )?; let mut tree_v4 = tree_v3.clone(); apply_batch( &mut tree_v4, - SmtUpdateBatch::new([ForestOperation::insert(key_4, value_4)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_4, value_4)].into_iter()), )?; let sample_keys = vec![key_1, key_2, key_3, key_4]; @@ -524,7 +523,7 @@ proptest! { let invalid_value = Word::from([1u32, 1, 1, 1]); invalid_updates.add_operations( lineage_1, - SmtUpdateBatch::new([ForestOperation::insert(query_key, invalid_value)].into_iter()) + SmtUpdateBatch::new([SmtForestOperation::insert(query_key, invalid_value)].into_iter()) .consume() .into_iter(), ); diff --git a/miden-crypto/src/merkle/smt/large_forest/test_utils.rs b/miden-crypto/src/merkle/smt/large_forest/test_utils.rs index ea6aac4f05..aab6822432 100644 --- a/miden-crypto/src/merkle/smt/large_forest/test_utils.rs +++ b/miden-crypto/src/merkle/smt/large_forest/test_utils.rs @@ -14,13 +14,13 @@ use proptest::prelude::*; use crate::{ EMPTY_WORD, Map, ONE, ZERO, merkle::smt::{ - Backend, ForestInMemoryBackend, ForestOperation, LargeSmtForest, LeafIndex, LineageId, - MAX_LEAF_ENTRIES, RootInfo, SMT_DEPTH, Smt, SmtForestUpdateBatch, SmtProof, SmtUpdateBatch, - TreeId, VersionId, + Backend, BackendReader, ForestInMemoryBackend, LargeSmtForest, LeafIndex, LineageId, + MAX_LEAF_ENTRIES, RootInfo, SMT_DEPTH, Smt, SmtForestOperation, SmtForestUpdateBatch, + SmtProof, SmtUpdateBatch, TreeId, VersionId, large_forest::{ backend::{BackendError, Result as BackendResult}, root::{TreeEntry, TreeWithRoot}, - utils::MutationSet, + utils::{AppliedLineageMutation, LineageMutation}, }, }, }; @@ -121,9 +121,9 @@ pub fn arbitrary_batch() -> impl Strategy { arbitrary_entries().prop_map(|e| { SmtUpdateBatch::new(e.into_iter().map(|(k, v)| { if v == EMPTY_WORD { - ForestOperation::remove(k) + SmtForestOperation::remove(k) } else { - ForestOperation::insert(k, v) + SmtForestOperation::insert(k, v) } })) }) @@ -162,7 +162,7 @@ pub fn sorted_tree_entries(tree: &Smt) -> Vec { /// Sorts forest entries explicitly by `(key, value)` so tests compare observable contents rather /// than relying on unspecified iterator ordering. -pub fn sorted_forest_entries( +pub fn sorted_forest_entries( forest: &LargeSmtForest, tree: TreeId, ) -> Result, TestCaseError> { @@ -180,7 +180,7 @@ fn word_to_option(value: Word) -> Option { } /// Asserts that the forest and reference tree agree on entries, counts, key lookups, and openings. -pub fn assert_tree_queries_match( +pub fn assert_tree_queries_match( forest: &LargeSmtForest, tree_id: TreeId, reference: &Smt, @@ -207,7 +207,7 @@ pub fn assert_tree_queries_match( } /// Asserts that the forest metadata for `lineage` matches the provided sequence of versions. -pub fn assert_lineage_metadata( +pub fn assert_lineage_metadata( forest: &LargeSmtForest, lineage: LineageId, versions: &[(VersionId, Word)], @@ -275,7 +275,7 @@ impl>> Iterator for FallibleIter } } -impl Backend for FallibleEntriesBackend { +impl BackendReader for FallibleEntriesBackend { fn open(&self, lineage: LineageId, key: Word) -> BackendResult { self.inner.open(lineage, key) } @@ -315,38 +315,28 @@ impl Backend for FallibleEntriesBackend { let inner_iter = self.inner.entries(lineage)?; Ok(FallibleIter { inner: inner_iter, count: 0 }) } +} - fn add_lineage( - &mut self, - lineage: LineageId, - version: VersionId, - updates: SmtUpdateBatch, - ) -> BackendResult { - self.inner.add_lineage(lineage, version, updates) - } +impl Backend for FallibleEntriesBackend { + type Reader = ::Reader; + type PreparedMutations = ::PreparedMutations; - fn update_tree( - &mut self, - lineage: LineageId, - new_version: VersionId, - updates: SmtUpdateBatch, - ) -> BackendResult { - self.inner.update_tree(lineage, new_version, updates) + fn reader(&self) -> BackendResult { + self.inner.reader() } - fn add_lineages( - &mut self, - version: VersionId, - lineages: SmtForestUpdateBatch, - ) -> BackendResult> { - self.inner.add_lineages(version, lineages) + fn compute_mutations( + &self, + new_version: VersionId, + updates: SmtForestUpdateBatch, + ) -> BackendResult<(Vec, Self::PreparedMutations)> { + self.inner.compute_mutations(new_version, updates) } - fn update_forest( + fn apply_mutations( &mut self, - new_version: VersionId, - updates: SmtForestUpdateBatch, - ) -> BackendResult> { - self.inner.update_forest(new_version, updates) + mutations: Self::PreparedMutations, + ) -> BackendResult> { + self.inner.apply_mutations(mutations) } } diff --git a/miden-crypto/src/merkle/smt/large_forest/tests.rs b/miden-crypto/src/merkle/smt/large_forest/tests.rs index ea09019101..110c838a5d 100644 --- a/miden-crypto/src/merkle/smt/large_forest/tests.rs +++ b/miden-crypto/src/merkle/smt/large_forest/tests.rs @@ -18,8 +18,8 @@ use crate::{ merkle::{ EmptySubtreeRoots, smt::{ - Backend, ForestInMemoryBackend, ForestOperation, LargeSmtForest, LargeSmtForestError, - RootInfo, Smt, SmtForestUpdateBatch, SmtUpdateBatch, TreeId, VersionId, + BackendReader, ForestInMemoryBackend, LargeSmtForest, LargeSmtForestError, RootInfo, + Smt, SmtForestOperation, SmtForestUpdateBatch, SmtUpdateBatch, TreeId, VersionId, large_forest::{ LineageData, history::{ChangedKeys, History, NodeChanges}, @@ -320,7 +320,7 @@ fn lineage_count() -> Result<()> { // This should stay the same if we update a tree. let operations = - SmtUpdateBatch::new([ForestOperation::insert(rng.value(), rng.value())].into_iter()); + SmtUpdateBatch::new([SmtForestOperation::insert(rng.value(), rng.value())].into_iter()); forest.update_tree(lineage_1, version + 1, operations)?; assert_eq!(forest.lineage_count(), 3); @@ -337,12 +337,12 @@ fn root_info() -> Result<()> { let lineage_1: LineageId = rng.value(); let version_1: VersionId = rng.value(); let operations = - SmtUpdateBatch::new([ForestOperation::insert(rng.value(), rng.value())].into_iter()); + SmtUpdateBatch::new([SmtForestOperation::insert(rng.value(), rng.value())].into_iter()); let historical_root = forest.add_lineage(lineage_1, version_1, operations)?; let version_2 = version_1 + 1; let operations = - SmtUpdateBatch::new([ForestOperation::insert(rng.value(), rng.value())].into_iter()); + SmtUpdateBatch::new([SmtForestOperation::insert(rng.value(), rng.value())].into_iter()); let current_root = forest.update_tree(lineage_1, version_2, operations)?; // When we query for a root (lineage_1, version_1), we should get back HistoricalVersion. @@ -399,8 +399,8 @@ fn open() -> Result<()> { version_1, SmtUpdateBatch::new( [ - ForestOperation::insert(key_1, value_1_v1), - ForestOperation::insert(key_2, value_2_v1), + SmtForestOperation::insert(key_1, value_1_v1), + SmtForestOperation::insert(key_2, value_2_v1), ] .into_iter(), ), @@ -442,9 +442,9 @@ fn open() -> Result<()> { version_2, SmtUpdateBatch::new( [ - ForestOperation::insert(key_1, value_1_v2), - ForestOperation::insert(key_3, value_3_v1), - ForestOperation::remove(key_2), + SmtForestOperation::insert(key_1, value_1_v2), + SmtForestOperation::insert(key_3, value_3_v1), + SmtForestOperation::remove(key_2), ] .into_iter(), ), @@ -499,8 +499,8 @@ fn get() -> Result<()> { version_1, SmtUpdateBatch::new( [ - ForestOperation::insert(key_1, value_1_v1), - ForestOperation::insert(key_2, value_2_v1), + SmtForestOperation::insert(key_1, value_1_v1), + SmtForestOperation::insert(key_2, value_2_v1), ] .into_iter(), ), @@ -540,8 +540,8 @@ fn get() -> Result<()> { version_2, SmtUpdateBatch::new( [ - ForestOperation::insert(key_1, value_1_v2), - ForestOperation::insert(key_3, value_3_v1), + SmtForestOperation::insert(key_1, value_1_v2), + SmtForestOperation::insert(key_3, value_3_v1), ] .into_iter(), ), @@ -913,8 +913,11 @@ fn entries_never_returns_empty_entry() -> Result<()> { let key_2: Word = rng.value(); let value_2: Word = rng.value(); let operations = SmtUpdateBatch::new( - [ForestOperation::insert(key_1, value_1), ForestOperation::insert(key_2, value_2)] - .into_iter(), + [ + SmtForestOperation::insert(key_1, value_1), + SmtForestOperation::insert(key_2, value_2), + ] + .into_iter(), ); forest.update_tree(lineage_1, version_2, operations)?; @@ -930,7 +933,7 @@ fn entries_never_returns_empty_entry() -> Result<()> { forest.add_lineage( lineage_2, version_1, - SmtUpdateBatch::new([ForestOperation::insert(key_1, value_1)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_1, value_1)].into_iter()), )?; // Now we add an update to a different leaf. @@ -939,7 +942,7 @@ fn entries_never_returns_empty_entry() -> Result<()> { forest.update_tree( lineage_2, version_2, - SmtUpdateBatch::new([ForestOperation::insert(key_2, value_2)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_2, value_2)].into_iter()), )?; // Now, when we query for entries on the historical version, we should only see one entry, and @@ -957,7 +960,7 @@ fn entries_never_returns_empty_entry() -> Result<()> { forest.add_lineage( lineage_3, version_1, - SmtUpdateBatch::new([ForestOperation::insert(key_1, value_1)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_1, value_1)].into_iter()), )?; // We now add an update in the same leaf. @@ -966,7 +969,7 @@ fn entries_never_returns_empty_entry() -> Result<()> { forest.update_tree( lineage_3, version_2, - SmtUpdateBatch::new([ForestOperation::insert(key_2, value_2)].into_iter()), + SmtUpdateBatch::new([SmtForestOperation::insert(key_2, value_2)].into_iter()), )?; // Now when we query the historical version, we should only see one entry, and no reversions. @@ -997,8 +1000,8 @@ fn entries_history_empty_values_do_not_reorder() -> Result<()> { version_1, SmtUpdateBatch::new( [ - ForestOperation::insert(key_a, value_a), - ForestOperation::insert(key_c, value_c_v1), + SmtForestOperation::insert(key_a, value_a), + SmtForestOperation::insert(key_c, value_c_v1), ] .into_iter(), ), @@ -1014,8 +1017,8 @@ fn entries_history_empty_values_do_not_reorder() -> Result<()> { version_2, SmtUpdateBatch::new( [ - ForestOperation::insert(key_b, value_b), - ForestOperation::insert(key_c, value_c_v2), + SmtForestOperation::insert(key_b, value_b), + SmtForestOperation::insert(key_c, value_c_v2), ] .into_iter(), ), @@ -1160,6 +1163,132 @@ fn update_tree() -> Result<()> { Ok(()) } +#[test] +fn compute_and_apply_update_tree_mutations() -> Result<()> { + let backend = ForestInMemoryBackend::new(); + let mut forest = Forest::new(backend)?; + let mut rng = ContinuousRng::new([0x71; 32]); + + let lineage: LineageId = rng.value(); + let version_1: VersionId = 10; + let version_2: VersionId = 11; + let key_1: Word = rng.value(); + let value_1: Word = rng.value(); + let key_2: Word = rng.value(); + let value_2: Word = rng.value(); + + let mut initial = SmtUpdateBatch::default(); + initial.add_insert(key_1, value_1); + let original = forest.add_lineage(lineage, version_1, initial)?; + + let mut updates = SmtUpdateBatch::default(); + updates.add_insert(key_2, value_2); + let mutations = forest.compute_update_tree_mutations(lineage, version_2, updates)?; + let proposed_roots = mutations.roots().collect::>(); + assert_eq!(proposed_roots.len(), 1); + assert_eq!(proposed_roots[0].lineage(), lineage); + assert_eq!(proposed_roots[0].version(), version_2); + assert_ne!(proposed_roots[0].root(), original.root()); + + assert_eq!( + forest.root_info(TreeId::new(lineage, version_1)), + RootInfo::LatestVersion(original.root()) + ); + assert_eq!(forest.get(TreeId::new(lineage, version_1), key_2)?, None); + assert_eq!(forest.get_history(lineage).num_versions(), 0); + + let applied_roots = forest.apply_mutations(mutations)?; + assert_eq!(applied_roots, proposed_roots); + assert_eq!( + forest.root_info(TreeId::new(lineage, version_2)), + RootInfo::LatestVersion(proposed_roots[0].root()) + ); + assert_eq!(forest.get(TreeId::new(lineage, version_2), key_2)?, Some(value_2)); + assert_eq!(forest.get_history(lineage).num_versions(), 1); + + Ok(()) +} + +#[test] +fn compute_and_apply_forest_mutations_can_mix_additions_and_updates() -> Result<()> { + let backend = ForestInMemoryBackend::new(); + let mut forest = Forest::new(backend)?; + let mut rng = ContinuousRng::new([0x74; 32]); + + let existing_lineage: LineageId = rng.value(); + let new_lineage: LineageId = rng.value(); + let version_1: VersionId = 10; + let version_2: VersionId = 11; + + let existing_key: Word = rng.value(); + let existing_value: Word = rng.value(); + let new_key: Word = rng.value(); + let new_value: Word = rng.value(); + + forest.add_lineage(existing_lineage, version_1, SmtUpdateBatch::default())?; + + let mut batch = SmtForestUpdateBatch::empty(); + batch.operations(existing_lineage).add_insert(existing_key, existing_value); + batch.operations(new_lineage).add_insert(new_key, new_value); + + let mutations = forest.compute_forest_mutations(version_2, batch)?; + let proposed_roots = mutations.roots().collect::>(); + assert_eq!(proposed_roots.len(), 2); + assert!(proposed_roots.iter().any(|root| root.lineage() == existing_lineage)); + assert!(proposed_roots.iter().any(|root| root.lineage() == new_lineage)); + + assert_eq!(forest.get(TreeId::new(existing_lineage, version_1), existing_key)?, None); + assert_matches!(forest.root_info(TreeId::new(new_lineage, version_2)), RootInfo::Missing); + + let applied_roots = forest.apply_mutations(mutations)?; + assert_eq!(applied_roots.len(), 2); + assert_eq!( + forest.get(TreeId::new(existing_lineage, version_2), existing_key)?, + Some(existing_value) + ); + assert_eq!(forest.get(TreeId::new(new_lineage, version_2), new_key)?, Some(new_value)); + + Ok(()) +} + +#[test] +fn apply_update_tree_mutations_rejects_stale_state() -> Result<()> { + let backend = ForestInMemoryBackend::new(); + let mut forest = Forest::new(backend)?; + let mut rng = ContinuousRng::new([0x72; 32]); + + let lineage: LineageId = rng.value(); + let version_1: VersionId = 20; + let version_2: VersionId = 21; + let version_3: VersionId = 22; + let key_1: Word = rng.value(); + let value_1: Word = rng.value(); + let key_2: Word = rng.value(); + let value_2: Word = rng.value(); + let key_3: Word = rng.value(); + let value_3: Word = rng.value(); + + forest.add_lineage(lineage, version_1, SmtUpdateBatch::default())?; + + let mut pending_updates = SmtUpdateBatch::default(); + pending_updates.add_insert(key_1, value_1); + let pending = forest.compute_update_tree_mutations(lineage, version_2, pending_updates)?; + + let mut intervening_updates = SmtUpdateBatch::default(); + intervening_updates.add_insert(key_2, value_2); + forest.update_tree(lineage, version_2, intervening_updates)?; + + let stale_result = forest.apply_mutations(pending); + assert!(stale_result.is_err()); + + let mut fresh_updates = SmtUpdateBatch::default(); + fresh_updates.add_insert(key_3, value_3); + let fresh = forest.compute_update_tree_mutations(lineage, version_3, fresh_updates)?; + assert!(forest.apply_mutations(fresh).is_ok()); + + Ok(()) +} + // MULTI-TREE MODIFIER TESTS // ================================================================================================ diff --git a/miden-crypto/src/merkle/smt/large_forest/utils.rs b/miden-crypto/src/merkle/smt/large_forest/utils.rs index ff0a2f4bfe..130a447619 100644 --- a/miden-crypto/src/merkle/smt/large_forest/utils.rs +++ b/miden-crypto/src/merkle/smt/large_forest/utils.rs @@ -1,6 +1,17 @@ -//! Contains utility type aliases and functions for use as part of the SMT forest. +//! Contains utility types, aliases, and functions for use as part of the SMT forest. -use crate::{Word, merkle::smt::full::SMT_DEPTH}; +use alloc::vec::Vec; + +use crate::{ + Word, + merkle::smt::{ + full::SMT_DEPTH, + large_forest::{ + backend::Backend, + root::{LineageId, TreeWithRoot, VersionId}, + }, + }, +}; // TYPE ALIASES // ================================================================================================ @@ -8,3 +19,285 @@ use crate::{Word, merkle::smt::full::SMT_DEPTH}; /// The mutation set used by the forest backends to provide reverse mutations that describe the /// changes necessary to revert the tree to its previous state. pub type MutationSet = crate::merkle::smt::MutationSet; + +// FOREST MUTATIONS +// ================================================================================================ + +/// A prospective set of mutations to a forest. +/// +/// This is the forest-level analogue of [`crate::merkle::smt::MutationSet`]. It represents changes +/// that have already been computed but have not yet been committed to the underlying backend or to +/// the forest's lineage metadata. +/// +/// A mutation set has two parts: +/// +/// - inspectable [`LineageMutation`] entries, which expose the affected lineages, requested +/// versions, old roots, and proposed new roots; and +/// - backend-specific prepared data, which is intentionally opaque and is consumed by +/// [`crate::merkle::smt::LargeSmtForest::apply_mutations`]. +/// +/// The type is parameterized by the backend because different backend implementations may prepare +/// different internal data. For example, an in-memory backend can keep regular SMT mutation sets, +/// while a persistent backend can keep storage-level updates that avoid recomputing the tree walk +/// during application. +/// +/// Values of this type are only valid for the forest state against which they were computed. +/// Applying them after the target lineage has changed will fail during forest-level validation. +pub struct SmtForestMutationSet { + entries: Vec, + prepared: B::PreparedMutations, +} + +impl SmtForestMutationSet { + /// Constructs a forest mutation set from inspectable lineage entries and backend-prepared data. + /// + /// This constructor is crate-private because only forest/backend code can maintain the + /// invariant that the public lineage metadata and opaque prepared data describe the same set of + /// changes. + pub(crate) fn new(entries: Vec, prepared: B::PreparedMutations) -> Self { + Self { entries, prepared } + } + + /// Returns the lineage-level mutations in this set. + /// + /// Callers can use this to inspect the old and new roots for each affected lineage before + /// committing the mutation set. The returned entries are read-only; the opaque backend portion + /// of the mutation set remains unavailable so that callers cannot accidentally break the link + /// between the visible metadata and the prepared backend data. + pub fn lineage_mutations(&self) -> &[LineageMutation] { + &self.entries + } + + /// Returns the roots that would be observed after successfully applying this mutation set. + /// + /// This is a convenience view over [`Self::lineage_mutations`]. For update mutations that do + /// not change the underlying tree, the returned [`TreeWithRoot`] uses the existing latest + /// version rather than the requested new version, matching the behavior of + /// [`crate::merkle::smt::LargeSmtForest::update_tree`] and + /// [`crate::merkle::smt::LargeSmtForest::update_forest`]. + pub fn roots(&self) -> impl Iterator + '_ { + self.entries.iter().map(LineageMutation::result) + } + + /// Consumes this value into its inspectable entries and backend-prepared mutation data. + /// + /// This is crate-private for the same reason as [`Self::new`]: only the forest can safely + /// coordinate applying the backend data and then updating lineage metadata/history from the + /// inspectable entries. + pub(crate) fn into_parts(self) -> (Vec, B::PreparedMutations) { + (self.entries, self.prepared) + } +} + +/// A prospective mutation to one lineage in a forest. +/// +/// This type records only inspectable metadata that callers need before committing a mutation set: +/// affected lineage, version transition, root transition, and mutation kind. Backend-specific +/// mutation data, including forward and reverse SMT mutation sets, stays in opaque backend +/// prepared data until [`crate::merkle::smt::LargeSmtForest::apply_mutations`] commits it. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct LineageMutation { + lineage: LineageId, + old_version: Option, + new_version: VersionId, + old_root: Word, + new_root: Word, + kind: LineageMutationKind, +} + +impl LineageMutation { + /// Constructs a lineage mutation. + /// + /// This constructor is crate-private because callers must not be able to fabricate mutation + /// metadata that is inconsistent with the backend-prepared data in an [`SmtForestMutationSet`]. + pub(crate) fn new( + lineage: LineageId, + old_version: Option, + new_version: VersionId, + old_root: Word, + new_root: Word, + kind: LineageMutationKind, + ) -> Self { + Self { + lineage, + old_version, + new_version, + old_root, + new_root, + kind, + } + } + + /// Returns the affected lineage. + pub fn lineage(&self) -> LineageId { + self.lineage + } + + /// Returns the previous version for update mutations, or `None` for new lineages. + /// + /// This value is used by [`crate::merkle::smt::LargeSmtForest::apply_mutations`] to reject + /// stale mutation sets. For updates, it must still match the latest version in the forest when + /// the mutation set is applied. + pub fn old_version(&self) -> Option { + self.old_version + } + + /// Returns the requested new version. + /// + /// For an update that does not change the tree, this version is the version requested by the + /// compute call, but the lineage will not advance when the mutation set is applied. + pub fn new_version(&self) -> VersionId { + self.new_version + } + + /// Returns the root before this mutation. + /// + /// For updates, this must match the current latest root when the mutation set is applied. For a + /// new lineage, this is the empty SMT root from which the initial tree is computed. + pub fn old_root(&self) -> Word { + self.old_root + } + + /// Returns the root after this mutation. + /// + /// This is the root commitment that callers usually inspect before deciding whether to commit a + /// computed mutation set. + pub fn new_root(&self) -> Word { + self.new_root + } + + /// Returns the mutation kind. + pub fn kind(&self) -> LineageMutationKind { + self.kind + } + + /// Returns the root information this mutation would produce if applied. + /// + /// For no-op update mutations, this returns the old version and old root, since applying such a + /// mutation does not allocate a new tree version. For new-lineage mutations and non-empty + /// update mutations, this returns the requested new version and computed new root. + pub fn result(&self) -> TreeWithRoot { + let version = + if self.kind == LineageMutationKind::UpdateTree && self.old_root == self.new_root { + self.old_version.expect("update tree mutations always have an old version") + } else { + self.new_version + }; + TreeWithRoot::new(self.lineage, version, self.new_root) + } +} + +/// Data returned by a backend after applying prepared mutations. +/// +/// The forest uses this data to update lineage metadata and historical views after the backend has +/// committed its latest-tree state. Unlike [`LineageMutation`], this type includes the reverse SMT +/// mutation set and old entry count because those are only needed after a successful apply. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AppliedLineageMutation { + lineage: LineageId, + old_version: Option, + new_version: VersionId, + old_root: Word, + new_root: Word, + old_entry_count: usize, + reverse: MutationSet, + kind: LineageMutationKind, +} + +impl AppliedLineageMutation { + /// Constructs an applied lineage mutation. + /// + /// This constructor is crate-private because backend implementations must keep the returned + /// history payload consistent with the prepared mutation data they just applied. + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + lineage: LineageId, + old_version: Option, + new_version: VersionId, + old_root: Word, + new_root: Word, + old_entry_count: usize, + reverse: MutationSet, + kind: LineageMutationKind, + ) -> Self { + Self { + lineage, + old_version, + new_version, + old_root, + new_root, + old_entry_count, + reverse, + kind, + } + } + + /// Returns the affected lineage. + pub fn lineage(&self) -> LineageId { + self.lineage + } + + /// Returns the previous version for update mutations, or `None` for new lineages. + pub fn old_version(&self) -> Option { + self.old_version + } + + /// Returns the requested new version. + pub fn new_version(&self) -> VersionId { + self.new_version + } + + /// Returns the root before this mutation. + pub fn old_root(&self) -> Word { + self.old_root + } + + /// Returns the root after this mutation. + pub fn new_root(&self) -> Word { + self.new_root + } + + /// Returns the entry count before this mutation. + pub fn old_entry_count(&self) -> usize { + self.old_entry_count + } + + /// Returns the reverse mutation set for this lineage. + pub fn reverse(&self) -> &MutationSet { + &self.reverse + } + + /// Consumes this mutation and returns the reverse mutation set. + pub(crate) fn into_reverse(self) -> MutationSet { + self.reverse + } + + /// Returns the mutation kind. + pub fn kind(&self) -> LineageMutationKind { + self.kind + } + + /// Returns the root information produced by this applied mutation. + /// + /// For no-op update mutations, this returns the old version, since applying such a + /// mutation does not allocate a new tree version. For new-lineage mutations and non-empty + /// update mutations, this returns the requested new version and computed new root. + pub fn result(&self) -> TreeWithRoot { + let version = + if self.kind == LineageMutationKind::UpdateTree && self.old_root == self.new_root { + self.old_version.expect("update tree mutations always have an old version") + } else { + self.new_version + }; + TreeWithRoot::new(self.lineage, version, self.new_root) + } +} + +/// The operation represented by a lineage mutation. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum LineageMutationKind { + /// A new lineage is being added. + AddLineage, + /// An existing lineage is being updated. + UpdateTree, +} diff --git a/miden-crypto/src/merkle/smt/mod.rs b/miden-crypto/src/merkle/smt/mod.rs index 6a0f6988cb..044e89963b 100644 --- a/miden-crypto/src/merkle/smt/mod.rs +++ b/miden-crypto/src/merkle/smt/mod.rs @@ -22,22 +22,29 @@ mod large; pub use full::concurrent::{SubtreeLeaf, build_subtree_for_bench}; #[cfg(feature = "concurrent")] pub use large::{ - LargeSmt, LargeSmtError, MemoryStorage, MemoryStorageSnapshot, SmtStorage, SmtStorageReader, - StorageError, StorageUpdateParts, StorageUpdates, Subtree, SubtreeError, SubtreeUpdate, + LargeSmt, LargeSmtError, LargeSmtResult, MemoryStorage, MemoryStorageSnapshot, SmtStorage, + SmtStorageReader, StorageError, StorageResult, StorageUpdateParts, StorageUpdates, Subtree, + SubtreeError, SubtreeUpdate, }; #[cfg(feature = "rocksdb")] pub use large::{RocksDbConfig, RocksDbSnapshotStorage, RocksDbStorage}; mod large_forest; pub use large_forest::{ - Backend, BackendError, Config as ForestConfig, - DEFAULT_MAX_HISTORY_VERSIONS as FOREST_DEFAULT_MAX_HISTORY_VERSIONS, ForestOperation, - InMemoryBackend as ForestInMemoryBackend, LargeSmtForest, LargeSmtForestError, LineageId, - MIN_HISTORY_VERSIONS as FOREST_MIN_HISTORY_VERSIONS, RootInfo, SmtForestUpdateBatch, - SmtUpdateBatch, TreeEntry, TreeId, TreeWithRoot, VersionId, + AppliedLineageMutation, Backend, BackendError, BackendReader, Config as ForestConfig, + DEFAULT_MAX_HISTORY_VERSIONS as FOREST_DEFAULT_MAX_HISTORY_VERSIONS, + InMemoryBackend as ForestInMemoryBackend, + InMemoryBackendSnapshot as ForestInMemoryBackendReader, LargeSmtForest, LargeSmtForestError, + LineageId, LineageMutation, LineageMutationKind, + MIN_HISTORY_VERSIONS as FOREST_MIN_HISTORY_VERSIONS, RootInfo, SmtForestMutationSet, + SmtForestOperation, SmtForestUpdateBatch, SmtUpdateBatch, TreeEntry, TreeId, TreeWithRoot, + VersionId, }; #[cfg(feature = "persistent-forest")] -pub use large_forest::{PersistentBackend as ForestPersistentBackend, PersistentBackendConfig}; +pub use large_forest::{ + PersistentBackend as ForestPersistentBackend, PersistentBackendConfig, + PersistentBackendReader as ForestPersistentBackendReader, +}; mod simple; pub use simple::{SimpleSmt, SimpleSmtProof}; diff --git a/miden-crypto/src/merkle/smt/partial/tests.rs b/miden-crypto/src/merkle/smt/partial/tests.rs index 1664044fe6..66f1e7476b 100644 --- a/miden-crypto/src/merkle/smt/partial/tests.rs +++ b/miden-crypto/src/merkle/smt/partial/tests.rs @@ -6,7 +6,7 @@ use alloc::{ use assert_matches::assert_matches; use itertools::Itertools; use proptest::prelude::*; -use rand::{Rng, SeedableRng}; +use rand::{Rng, RngExt, SeedableRng}; use rand_chacha::ChaCha20Rng; use super::{PartialSmt, SMT_DEPTH, serialization::property_tests::arbitrary_valid_word}; diff --git a/miden-crypto/src/rand/coin.rs b/miden-crypto/src/rand/coin.rs index fdcacc9aa9..ef34b70245 100644 --- a/miden-crypto/src/rand/coin.rs +++ b/miden-crypto/src/rand/coin.rs @@ -1,8 +1,11 @@ use alloc::string::ToString; -use rand_core::impls; +use rand::{ + Rng, + rand_core::{Infallible, TryRng, utils}, +}; -use super::{Felt, FeltRng, RngCore}; +use super::{Felt, FeltRng}; use crate::{ Word, ZERO, field::ExtensionField, @@ -68,7 +71,7 @@ impl RandomCoin { /// Fills `dest` with random data. pub fn fill_bytes(&mut self, dest: &mut [u8]) { - ::fill_bytes(self, dest) + ::fill_bytes(self, dest) } /// Draws a random base field element from the random coin. @@ -142,20 +145,22 @@ impl FeltRng for RandomCoin { } } -// RNGCORE IMPLEMENTATION +// RNG IMPLEMENTATION // ------------------------------------------------------------------------------------------------ -impl RngCore for RandomCoin { - fn next_u32(&mut self) -> u32 { - self.draw_basefield().as_canonical_u64() as u32 +impl TryRng for RandomCoin { + type Error = Infallible; + + fn try_next_u32(&mut self) -> Result { + Ok(self.draw_basefield().as_canonical_u64() as u32) } - fn next_u64(&mut self) -> u64 { - impls::next_u64_via_u32(self) + fn try_next_u64(&mut self) -> Result { + utils::next_u64_via_u32(self) } - fn fill_bytes(&mut self, dest: &mut [u8]) { - impls::fill_bytes_via_next(self, dest) + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> { + utils::fill_bytes_via_next_word(dest, || self.try_next_u32()) } } diff --git a/miden-crypto/src/rand/compat.rs b/miden-crypto/src/rand/compat.rs new file mode 100644 index 0000000000..b1090e6e6b --- /dev/null +++ b/miden-crypto/src/rand/compat.rs @@ -0,0 +1,34 @@ +use rand::{CryptoRng, Rng}; + +/// Adapts rand 0.10 RNGs for stable crypto crates that still use rand_core 0.6. +pub(crate) struct RandCore06<'a, R: ?Sized>(&'a mut R); + +impl<'a, R: ?Sized> RandCore06<'a, R> { + pub(crate) fn new(rng: &'a mut R) -> Self { + Self(rng) + } +} + +impl k256::elliptic_curve::rand_core::RngCore for RandCore06<'_, R> { + fn next_u32(&mut self) -> u32 { + self.0.next_u32() + } + + fn next_u64(&mut self) -> u64 { + self.0.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.0.fill_bytes(dest); + } + + fn try_fill_bytes( + &mut self, + dest: &mut [u8], + ) -> Result<(), k256::elliptic_curve::rand_core::Error> { + self.0.fill_bytes(dest); + Ok(()) + } +} + +impl k256::elliptic_curve::rand_core::CryptoRng for RandCore06<'_, R> {} diff --git a/miden-crypto/src/rand/mod.rs b/miden-crypto/src/rand/mod.rs index d4d62ee663..654301d629 100644 --- a/miden-crypto/src/rand/mod.rs +++ b/miden-crypto/src/rand/mod.rs @@ -1,11 +1,12 @@ //! Pseudo-random element generation. -use rand::RngCore; +use rand::Rng; use crate::{Felt, Word}; mod coin; pub use coin::RandomCoin; +pub(crate) mod compat; // Test utilities for generating random data (used in tests and benchmarks) #[cfg(any(test, feature = "std"))] @@ -127,7 +128,7 @@ impl Randomizable for [u8; N] { /// Pseudo-random element generator. /// /// An instance can be used to draw, uniformly at random, base field elements as well as [Word]s. -pub trait FeltRng: RngCore { +pub trait FeltRng: Rng { /// Draw, uniformly at random, a base field element. fn draw_element(&mut self) -> Felt; @@ -143,7 +144,7 @@ pub trait FeltRng: RngCore { /// This function is only available with the `std` feature. #[cfg(feature = "std")] pub fn random_felt() -> Felt { - use rand::Rng; + use rand::RngExt; let mut rng = rand::rng(); // We use the `Felt::new` constructor to do rejection sampling here. It should effectively // never repeat, but nevertheless gives us the correct distribution. diff --git a/miden-crypto/src/rand/test_utils.rs b/miden-crypto/src/rand/test_utils.rs index 22d52176a4..443f81043b 100644 --- a/miden-crypto/src/rand/test_utils.rs +++ b/miden-crypto/src/rand/test_utils.rs @@ -18,7 +18,7 @@ use alloc::{vec, vec::Vec}; -use rand::{Rng, SeedableRng}; +use rand::{Rng, RngExt, SeedableRng}; use rand_chacha::ChaCha20Rng; use crate::rand::Randomizable; @@ -32,7 +32,7 @@ use crate::rand::Randomizable; /// ``` /// # use miden_crypto::rand::test_utils::seeded_rng; /// let mut rng = seeded_rng([0u8; 32]); -/// // Use rng with any function that accepts impl RngCore +/// // Use rng with any function that accepts impl Rng /// ``` pub fn seeded_rng(seed: [u8; 32]) -> ChaCha20Rng { ChaCha20Rng::from_seed(seed) diff --git a/miden-crypto/tests/rocksdb_large_smt.rs b/miden-crypto/tests/rocksdb_large_smt.rs index 781f089837..2367253208 100644 --- a/miden-crypto/tests/rocksdb_large_smt.rs +++ b/miden-crypto/tests/rocksdb_large_smt.rs @@ -2,11 +2,29 @@ use miden_crypto::{ EMPTY_WORD, Felt, ONE, Word, ZERO, merkle::{ InnerNodeInfo, - smt::{LargeSmt, LargeSmtError, RocksDbConfig, RocksDbSnapshotStorage, RocksDbStorage}, + smt::{ + LargeSmt, LargeSmtError, LargeSmtResult, RocksDbConfig, RocksDbSnapshotStorage, + RocksDbStorage, SmtStorageReader, StorageError, + }, }, }; +use rocksdb::{DB, IteratorMode, Options}; use tempfile::TempDir; +const LEAVES_CF: &str = "leaves"; +const SUBTREE_CFS: [&str; 6] = ["st16", "st24", "st32", "st40", "st48", "st56"]; +const ROCKSDB_CFS: [&str; 9] = [ + "in_mem_depth", + "leaves", + "st16", + "st24", + "st32", + "st40", + "st48", + "st56", + "metadata", +]; + fn setup_storage() -> (RocksDbStorage, TempDir) { let temp_dir = tempfile::Builder::new() .prefix("test_smt_rocksdb_") @@ -35,6 +53,34 @@ fn generate_entries(pair_count: usize) -> Vec<(Word, Word)> { .collect() } +fn open_raw_db(path: &std::path::Path) -> DB { + let opts = Options::default(); + DB::open_cf(&opts, path, ROCKSDB_CFS).expect("failed to open raw RocksDB handle") +} + +fn corrupt_leaf_value(path: &std::path::Path, leaf_index: u64) { + let db = open_raw_db(path); + let cf = db.cf_handle(LEAVES_CF).expect("leaves column family missing"); + db.put_cf(cf, leaf_index.to_be_bytes(), b"not a valid leaf") + .expect("failed to corrupt leaf value"); +} + +fn corrupt_first_subtree_value(path: &std::path::Path) { + let db = open_raw_db(path); + + for cf_name in SUBTREE_CFS { + let cf = db.cf_handle(cf_name).expect("subtree column family missing"); + if let Some(result) = db.iterator_cf(cf, IteratorMode::Start).next() { + let (key, _value) = result.expect("failed to read subtree entry"); + db.put_cf(cf, key, b"not a valid subtree") + .expect("failed to corrupt subtree value"); + return; + } + } + + panic!("expected at least one subtree entry"); +} + #[test] fn rocksdb_sanity_insert_and_get() { let (storage, _tmp) = setup_storage(); @@ -91,14 +137,16 @@ fn rocksdb_persistence_reopen() { let smt = LargeSmt::::with_entries(initial_storage, entries).unwrap(); let root = smt.root(); - let mut inner_nodes: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes.sort_by_key(|info| info.value); drop(smt); let reopened_storage = RocksDbStorage::open(RocksDbConfig::new(db_path)).unwrap(); let smt = LargeSmt::::load(reopened_storage).unwrap(); - let mut inner_nodes_2: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes_2: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes_2.sort_by_key(|info| info.value); assert_eq!(inner_nodes.len(), inner_nodes_2.len()); @@ -114,35 +162,61 @@ fn rocksdb_persistence_after_insertion() { let db_path = temp_dir_guard.path().to_path_buf(); let mut smt = LargeSmt::::with_entries(initial_storage, entries).unwrap(); - let key = Word::new([ONE, ONE, ONE, ONE]); + let initial_num_leaves = smt.num_leaves(); + let initial_num_entries = smt.num_entries(); + let key = Word::new([ONE, ONE, Felt::new_unchecked(20_000), Felt::new_unchecked(20_000)]); let new_value = Word::new([ Felt::new_unchecked(2), Felt::new_unchecked(2), Felt::new_unchecked(2), Felt::new_unchecked(2), ]); - smt.insert(key, new_value).unwrap(); + let previous_value = smt.insert(key, new_value).unwrap(); + assert_eq!(previous_value, EMPTY_WORD); + assert_eq!(smt.get_value(&key), new_value); + assert_eq!(smt.num_leaves(), initial_num_leaves + 1); + assert_eq!(smt.num_entries(), initial_num_entries + 1); let root = smt.root(); + let num_leaves = smt.num_leaves(); + let num_entries = smt.num_entries(); - let mut inner_nodes: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes.sort_by_key(|info| info.value); drop(smt); let reopened_storage = RocksDbStorage::open(RocksDbConfig::new(db_path)).unwrap(); let smt = LargeSmt::::load(reopened_storage).unwrap(); - let mut inner_nodes_2: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes_2: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes_2.sort_by_key(|info| info.value); assert_eq!(inner_nodes.len(), inner_nodes_2.len()); assert_eq!(inner_nodes, inner_nodes_2); assert_eq!(smt.root(), root); + assert_eq!(smt.num_leaves(), num_leaves); + assert_eq!(smt.num_entries(), num_entries); + assert_eq!(smt.get_value(&key), new_value); } #[test] fn rocksdb_persistence_after_insert_batch_with_deletions() { // Create a tree with initial entries let entries = generate_entries(10_000); + let unchanged_key = entries[1_500].0; + let unchanged_value = entries[1_500].1; + let updated_key = entries[1_501].0; + let updated_value = Word::new([ + Felt::new_unchecked(3), + Felt::new_unchecked(3), + Felt::new_unchecked(3), + Felt::new_unchecked(3), + ]); + let deleted_key = entries[100].0; + let newly_inserted_key = + Word::new([ONE, ONE, Felt::new_unchecked(20_000), Felt::new_unchecked(20_000 % 1000)]); + let newly_inserted_value = Word::new([ONE, ONE, ONE, Felt::new_unchecked(20_000)]); let (initial_storage, temp_dir_guard) = setup_storage(); let db_path = temp_dir_guard.path().to_path_buf(); @@ -164,6 +238,9 @@ fn rocksdb_persistence_after_insert_batch_with_deletions() { batch_entries.push((key, value)); } + // Update an existing entry that is not deleted below + batch_entries.push((updated_key, updated_value)); + // Delete some existing entries for i in 0..1000 { let key = Word::new([ @@ -178,7 +255,8 @@ fn rocksdb_persistence_after_insert_batch_with_deletions() { smt.insert_batch(batch_entries).unwrap(); let root = smt.root(); - let mut inner_nodes: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes.sort_by_key(|info| info.value); let num_leaves = smt.num_leaves(); let num_entries = smt.num_entries(); @@ -187,7 +265,8 @@ fn rocksdb_persistence_after_insert_batch_with_deletions() { let reopened_storage = RocksDbStorage::open(RocksDbConfig::new(db_path)).unwrap(); let smt = LargeSmt::::load(reopened_storage).unwrap(); - let mut inner_nodes_2: Vec = smt.inner_nodes().unwrap().collect(); + let mut inner_nodes_2: Vec = + smt.inner_nodes().unwrap().collect::, _>>().unwrap(); inner_nodes_2.sort_by_key(|info| info.value); let num_leaves_2 = smt.num_leaves(); let num_entries_2 = smt.num_entries(); @@ -197,6 +276,10 @@ fn rocksdb_persistence_after_insert_batch_with_deletions() { assert_eq!(num_leaves, num_leaves_2); assert_eq!(num_entries, num_entries_2); assert_eq!(smt.root(), root, "Tree reconstruction failed - root mismatch after deletions"); + assert_eq!(smt.get_value(&unchanged_key), unchanged_value); + assert_eq!(smt.get_value(&newly_inserted_key), newly_inserted_value); + assert_eq!(smt.get_value(&updated_key), updated_value); + assert_eq!(smt.get_value(&deleted_key), EMPTY_WORD); } #[test] @@ -265,6 +348,49 @@ fn rocksdb_load_skips_validation() { assert_eq!(smt.root(), expected_root); } +#[test] +fn rocksdb_iter_leaves_returns_error_for_corrupt_leaf() { + let entries = generate_entries(1); + let leaf_index = entries[0].0[3].as_canonical_u64(); + + let (initial_storage, temp_dir_guard) = setup_storage(); + let db_path = temp_dir_guard.path().to_path_buf(); + + let smt = LargeSmt::::with_entries(initial_storage, entries).unwrap(); + drop(smt); + + corrupt_leaf_value(&db_path, leaf_index); + + let storage = RocksDbStorage::open(RocksDbConfig::new(db_path)).unwrap(); + let result = storage.iter_leaves().unwrap().collect::, StorageError>>(); + + assert!( + matches!(result, Err(StorageError::Value(_))), + "expected corrupt leaf deserialization to fail, got {result:?}", + ); +} + +#[test] +fn rocksdb_iter_subtrees_returns_error_for_corrupt_subtree() { + let entries = generate_entries(1000); + + let (initial_storage, temp_dir_guard) = setup_storage(); + let db_path = temp_dir_guard.path().to_path_buf(); + + let smt = LargeSmt::::with_entries(initial_storage, entries).unwrap(); + drop(smt); + + corrupt_first_subtree_value(&db_path); + + let storage = RocksDbStorage::open(RocksDbConfig::new(db_path)).unwrap(); + let result = storage.iter_subtrees().unwrap().collect::, StorageError>>(); + + assert!( + matches!(result, Err(StorageError::Subtree(_))), + "expected corrupt subtree deserialization to fail, got {result:?}", + ); +} + #[test] fn rocksdb_new_fails_on_non_empty_storage() { let entries = generate_entries(1000); @@ -337,3 +463,29 @@ fn rocksdb_entry_count_through_leaf_lifecycle() { assert_eq!(smt.num_entries(), 0, "persisted entry count should be 0"); assert_eq!(smt.num_leaves(), 0, "persisted leaf count should be 0"); } + +#[test] +fn rocksdb_inner_nodes_match_full_smt() { + use miden_crypto::merkle::smt::Smt; + + let entries = generate_entries(1000); + let control_smt = Smt::with_entries(entries.clone()).unwrap(); + + let (storage, _tmp) = setup_storage(); + let large_smt = LargeSmt::::with_entries(storage, entries).unwrap(); + + let mut control_nodes: Vec = control_smt.inner_nodes().collect(); + let mut rocksdb_nodes: Vec = large_smt + .inner_nodes() + .unwrap() + .try_fold(Vec::new(), |mut acc, info| { + acc.push(info?); + LargeSmtResult::Ok(acc) + }) + .unwrap(); + control_nodes.sort_by_key(|info| info.value); + rocksdb_nodes.sort_by_key(|info| info.value); + + assert_eq!(control_nodes.len(), rocksdb_nodes.len()); + assert_eq!(control_nodes, rocksdb_nodes); +} diff --git a/miden-serde-utils/fuzz/Cargo.lock b/miden-serde-utils/fuzz/Cargo.lock index 6104a30fa2..1ba121e81c 100644 --- a/miden-serde-utils/fuzz/Cargo.lock +++ b/miden-serde-utils/fuzz/Cargo.lock @@ -102,7 +102,7 @@ dependencies = [ [[package]] name = "miden-serde-utils" -version = "0.24.0" +version = "0.26.0" dependencies = [ "p3-field", "p3-goldilocks", @@ -147,9 +147,9 @@ dependencies = [ [[package]] name = "p3-challenger" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a0b490c745a7d2adeeafff06411814c8078c432740162332b3cd71be0158a76" +checksum = "8972ccd1d5dc90e46cdb1f2ab4ee2bae49b3917e5e98aa533f0c2b779c010445" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -161,9 +161,9 @@ dependencies = [ [[package]] name = "p3-dft" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55301e91544440254977108b85c32c09d7ea05f2f0dd61092a2825339906a4a7" +checksum = "17771aca44632f9cc11f2718d7ea7ec06794946c4190ef3a985bfc893f14c18a" dependencies = [ "itertools", "p3-field", @@ -176,9 +176,9 @@ dependencies = [ [[package]] name = "p3-field" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85affca7fc983889f260655c4cf74163eebb94605f702e4b6809ead707cba54f" +checksum = "6f3eb24d0591fd4d282d89cbe4e4efba5571c699375006f80b2cbf53ce83461c" dependencies = [ "itertools", "num-bigint", @@ -192,9 +192,9 @@ dependencies = [ [[package]] name = "p3-goldilocks" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca1081f5c47b940f2d75a11c04f62ea1cc58a5d480dd465fef3861c045c63cd" +checksum = "5751c6591a0d2397d726620c2c29a7436ec6c5e19d2ed74ca5d078d4fbb18eb5" dependencies = [ "num-bigint", "p3-challenger", @@ -212,9 +212,9 @@ dependencies = [ [[package]] name = "p3-matrix" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53428126b009071563d1d07305a9de8be0d21de00b57d2475289ee32ffca6577" +checksum = "ea9c94c0714944e7b8a9a62e6340b1e3e1d3f8ecfd3e35c08798360200e73eff" dependencies = [ "itertools", "p3-field", @@ -227,15 +227,15 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "082bf467011c06c768c579ec6eb9accb5e1e62108891634cc770396e917f978a" +checksum = "eebc233a34b1ab0273f35b4052fa2eeb3114b22ba4575bd7da00716e878ffb77" [[package]] name = "p3-mds" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35209e6214102ea6ec6b8cb1b9c15a9b8e597a39f9173597c957f123bced81b3" +checksum = "6b5441fa8116246ec9e6c835f15273cb27777ca572960ec87476b67fef13e01e" dependencies = [ "p3-dft", "p3-field", @@ -246,9 +246,9 @@ dependencies = [ [[package]] name = "p3-monty-31" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffa8c99ec50c035020bbf5457c6a729ba6a975719c1a8dd3f16421081e4f650c" +checksum = "8724f330ea6d19dd4f2436aa0f88b5fcbf88f0f55ca7fccd3fea8b736dbcddad" dependencies = [ "itertools", "num-bigint", @@ -270,9 +270,9 @@ dependencies = [ [[package]] name = "p3-poseidon1" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a018b618e3fa0aec8be933b1d8e404edd23f46991f6bf3f5c2f3f95e9413fe9" +checksum = "04e2a562fea210baae390a32f9ecf0dd8724ae3f4352d1c8e413077b6f00a162" dependencies = [ "p3-field", "p3-symmetric", @@ -281,9 +281,9 @@ dependencies = [ [[package]] name = "p3-poseidon2" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256a668a9ba916f8767552f13d0ba50d18968bc74a623bfdafa41e2970c944d0" +checksum = "06394851c161d17e4aa4ad2aad5557d32f14cadd1dc838f965d8e1821a63b8c5" dependencies = [ "p3-field", "p3-mds", @@ -294,9 +294,9 @@ dependencies = [ [[package]] name = "p3-symmetric" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c60a71a1507c13611b0f2b0b6e83669fd5b76f8e3115bcbced5ccfdf3ca7807" +checksum = "9ac1a276d421f8ef3361bb7d8c39a02c93c6b3f10eeaa559cc4c50222f9a5b82" dependencies = [ "itertools", "p3-field", @@ -306,9 +306,9 @@ dependencies = [ [[package]] name = "p3-util" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8b766b9e9254bf3fa98d76e42cf8a5b30628c182dfd5272d270076ee12f0fc0" +checksum = "d08a58162a4c264269ef454f0b28dcda89939490eecacb2b2cf5b00f719b80f6" dependencies = [ "serde", "transpose", diff --git a/miden-serde-utils/src/lib.rs b/miden-serde-utils/src/lib.rs index 47899d24c6..4bdaf15d43 100644 --- a/miden-serde-utils/src/lib.rs +++ b/miden-serde-utils/src/lib.rs @@ -482,6 +482,10 @@ where let v1 = T1::read_from(source)?; Ok((v1,)) } + + fn min_serialized_size() -> usize { + T1::min_serialized_size() + } } impl Deserializable for (T1, T2) @@ -690,7 +694,16 @@ impl Deserializable for Vec { impl Deserializable for BTreeMap { fn read_from(source: &mut R) -> Result { let len = source.read_usize()?; - source.read_many_iter(len)?.collect() + let mut map = BTreeMap::new(); + for entry in source.read_many_iter(len)? { + let (key, value) = entry?; + if map.insert(key, value).is_some() { + return Err(DeserializationError::InvalidValue(String::from( + "duplicate key in BTreeMap encoding", + ))); + } + } + Ok(map) } fn min_serialized_size() -> usize { @@ -701,7 +714,15 @@ impl Deserializable for BTreeMap Deserializable for BTreeSet { fn read_from(source: &mut R) -> Result { let len = source.read_usize()?; - source.read_many_iter(len)?.collect() + let mut set = BTreeSet::new(); + for item in source.read_many_iter(len)? { + if !set.insert(item?) { + return Err(DeserializationError::InvalidValue(String::from( + "duplicate item in BTreeSet encoding", + ))); + } + } + Ok(set) } fn min_serialized_size() -> usize { @@ -830,4 +851,40 @@ mod tests { let as_arc = Arc::::read_from_bytes(&bytes).unwrap(); assert_eq!(&*as_arc, "other direction"); } + + #[test] + fn btree_map_rejects_duplicate_keys() { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&2usize.to_bytes()); + bytes.extend_from_slice(&7u8.to_bytes()); + bytes.extend_from_slice(&1u8.to_bytes()); + bytes.extend_from_slice(&7u8.to_bytes()); + bytes.extend_from_slice(&2u8.to_bytes()); + + let result = BTreeMap::::read_from_bytes(&bytes); + + assert!(matches!(result, Err(DeserializationError::InvalidValue(_)))); + } + + #[test] + fn btree_set_rejects_duplicate_items() { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&2usize.to_bytes()); + bytes.extend_from_slice(&7u8.to_bytes()); + bytes.extend_from_slice(&7u8.to_bytes()); + + let result = BTreeSet::::read_from_bytes(&bytes); + + assert!(matches!(result, Err(DeserializationError::InvalidValue(_)))); + } + + #[test] + fn budgeted_vec_of_one_element_tuples_accepts_exact_budget() { + let values = Vec::<(usize,)>::from([(0,), (1,), (2,)]); + let bytes = values.to_bytes(); + + let decoded = Vec::<(usize,)>::read_from_bytes_with_budget(&bytes, bytes.len()).unwrap(); + + assert_eq!(decoded, values); + } } diff --git a/stark/CHANGELOG.md b/stark/CHANGELOG.md index 3ec81fb93c..351d019ebe 100644 --- a/stark/CHANGELOG.md +++ b/stark/CHANGELOG.md @@ -1,5 +1,14 @@ ## Unreleased +- [BREAKING] Preprocessed (fixed) traces are now supported. AIRs declare their shape via `LiftedAir::preprocessed_width()` and provide content through `BaseAir::preprocessed_trace()`; `ProverInstance` / `VerifierInstance` carry the `StarkConfig` alongside the statement and optional preprocessed data / commitment. The prover commits the preprocessed LDE tree once with `Preprocessed::build`, observes its commitment before per-proof statement data, and opens it as the first PCS group before main, aux, and quotient traces; preprocessed groups shorter than the max trace height are virtually lifted by the PCS. ([#1021](https://github.com/0xMiden/crypto/pull/1021)) +- [BREAKING] Replaced the `(AIR, AirWitness, AuxBuilder)` proving model with a `MultiAir` trait that owns its AIRs, plus validated `Statement` / `ProverStatement` structs; each AIR builds its own auxiliary trace via the new `LiftedAir::build_aux_trace`. `ProverInstance::prove` / `VerifierInstance::verify` consume those validated statements together with the STARK config; `prove_single` / `verify_single` were removed. Heterogeneous AIRs are expressed via caller-defined enum wrappers (the `LiftedBenchAir` pattern in `miden-bench`). ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] Replaced `LiftedAir::reduced_aux_values` and `num_var_len_public_inputs` with `MultiAir::eval_external`, which returns the cross-AIR external assertions as a flat list of extension-field values that must equal zero. It reads an `aux_inputs: &[F]` slice (budget declared by `MultiAir::max_aux_inputs`, default `0`; the default `eval_external` rejects a non-empty slice) whose schema each `MultiAir` owns and validates. Canonically binding `airs()` and `eval_external` into Fiat-Shamir remains a known soundness gap ([#970](https://github.com/0xMiden/crypto/issues/970)), called out on the trait's `observe` doc. ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] Moved runtime validation onto the validating constructors `Statement::new` / `ProverStatement::new` (returning `InstanceError` — public-values length, `aux_inputs` budget, trace count ≤ 256, per-AIR width, power-of-two height, height ≥ max periodic length) and `TraceOrder` construction (returning `ShapeError`, validating the untrusted proof heights against the AIRs). Removed the `validate` module and its free functions, `LiftedAir::validate`, `LiftedAir::is_valid_builder`, `AirStructureError`, `TracePart`, and `InstanceValidationError`. `miden_lifted_air::debug` collapses to `assert_multi_air_valid` (whole-`MultiAir` structural contract) and `check_builder_shape`; `assert_prover_setup` asserts only that contract and drops its `params` argument plus the thin `assert_*` wrappers. Added overridable `MultiAir::num_air_inputs` and `LiftedAir::max_periodic_length`. ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] `LiftedAir::log_quotient_degree` removed; quotient chunking is now the free function `miden_lifted_stark::log_quotient_degree`. `LiftedAir::constraint_degree` returns `ConstraintDegrees { base, ext }` (the raw base/extension symbolic-degree split, no clamping) instead of a `usize`; `log_quotient_degree` combines them and clamps the quotient degree to `D ≥ 1`, so trivial/degenerate AIRs are now supported rather than rejected (`AirStructureError::TrivialConstraints` is gone). The AIR ↔ PCS compatibility check (`log_quotient_degree ≤ log_blowup`) is inlined into the prover and verifier as `DomainError::ConstraintDegreeTooHigh`. ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] AIR ordering is no longer public: `InstanceShape` / `InstanceShapes` are removed and `StarkProofData` carries a single `log_trace_heights: Vec` (instance order, no public accessor; the prior `air_indices` field is gone). The wire-format ordering is derived deterministically from the heights (stable sort on `(log_h, instance_idx)`); read the heights and derived order via parsed `StarkProof::log_trace_heights()` / `StarkProof::air_order()`. ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] `ProverError` is now `Instance(InstanceError) | Domain(DomainError)`; `VerifierError` gains `Shape` and drops `InvalidAuxShape` (PCS row widths are now validated upstream by `verify_aligned`) and `ConstraintDegreeTooHigh` (now `DomainError::ConstraintDegreeTooHigh`). ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] `check_constraints` no longer takes a `challenges: &[EF]` argument — it takes a challenger and derives aux randomness via `Statement::observe`, mirroring the prover's seeding. ([#992](https://github.com/0xMiden/crypto/pull/992)) +- [BREAKING] Reduced the public API surface to `ProverInstance` / `VerifierInstance` plus a structured proof-inspection view. The wide crate-root re-export list is dropped (callers import from `air` and from `lmcs` / `pcs` / `proof` / `prover` / `verifier`); `domain` and `order` become crate-private (only their `DomainError` / `ShapeError` stay reachable, since they surface through `ProverError` / `VerifierError`); `pcs` is promoted to public for its structured sub-proof types; and the `transcript` module is folded into `proof`. The proof view types are renamed: `StarkProof` → `StarkProofData` (the serialized wire artifact) and `StarkTranscript` → `StarkProof` (the parse-only view, built via `StarkProof::from_data`), with the same renaming applied to the PCS sub-proofs (`PcsProof`, `DeepProof`, `FriProof`, `FriRoundProof`; `from_verifier_channel` → `read_from_channel`). The panicking domain constructors (`TwoAdicCoset::unshifted`, `LiftedDomain::canonical` / `sub_domain`) are removed in favour of the fallible `try_*` variants. ([#1020](https://github.com/0xMiden/crypto/pull/1020)) - Consolidated `p3-miden-lmcs`, `p3-miden-lifted-fri`, `p3-miden-dev-utils`, and `p3-miden-lifted-examples` into `miden-lifted-stark`; extracted profiling binary into `miden-bench` ([#66](https://github.com/0xMiden/p3-miden/pull/66)). - Dropped BabyBear support; simplified tests, benchmarks, and dev-utils to Goldilocks-only ([#52](https://github.com/0xMiden/p3-miden/pull/52)). - Added crate-local `testing` modules to `p3-miden-lmcs` and `p3-miden-lifted-fri` behind a `testing` feature flag ([#52](https://github.com/0xMiden/p3-miden/pull/52)). diff --git a/stark/README.md b/stark/README.md index 029c1de21d..c3d172acd3 100644 --- a/stark/README.md +++ b/stark/README.md @@ -28,8 +28,8 @@ miden-lifted-stark (prover, verifier, PCS, LMCS, shared types) ## Docs -- `docs/faq.md` (architecture Q&A) -- `docs/lifting.md` (math background for lifting) +- `miden-lifted-stark/README.md` (protocol-level overview) +- `miden-lifted-stark/src/prover/README.md`, `src/verifier/README.md` (per-side detail + lifting math) - `SECURITY.md` (audit/review guide; transcript and composition notes) ## Where To Start (Code) @@ -37,7 +37,7 @@ miden-lifted-stark (prover, verifier, PCS, LMCS, shared types) - Protocol flow: `miden-lifted-stark/src/prover/mod.rs` and `miden-lifted-stark/src/verifier/mod.rs` - PCS layer: `miden-lifted-stark/src/pcs/prover.rs` and `miden-lifted-stark/src/pcs/verifier.rs` - Commitment layer: `miden-lifted-stark/src/lmcs/mod.rs` and `miden-lifted-stark/src/lmcs/lifted_tree.rs` -- Math background: `docs/lifting.md` +- Math background: the "Mathematical background" in `miden-lifted-stark/src/prover/README.md` and `miden-lifted-stark/src/verifier/README.md` ## Build / Test diff --git a/stark/SECURITY.md b/stark/SECURITY.md index da57e3a1fa..3b431f861c 100644 --- a/stark/SECURITY.md +++ b/stark/SECURITY.md @@ -64,7 +64,8 @@ These are requirements on *applications* composing these crates. - You MUST ensure evaluation points used by DEEP/PCS lie outside the trace subgroup `H` and outside the LDE coset `gK`. - You MUST only use LMCS lifting with AIRs that are compatible with the lifted - view (see `docs/lifting.md`). + view (see the "Mathematical background" in + `miden-lifted-stark/src/prover/README.md` and `src/verifier/README.md`). Concrete examples of statement data that the application must treat explicitly: @@ -115,7 +116,7 @@ at the outer protocol layer. ## What To Review First (Suggested Order) -1. `miden-lifted-stark/src/verifier/mod.rs` (`verify_multi`) +1. `miden-lifted-stark/src/verifier/mod.rs` (`verify`) 2. `miden-lifted-stark/src/pcs/verifier.rs` (`verify`) 3. `miden-lifted-stark/src/lmcs/mod.rs` (`Lmcs::open_batch`) 4. `miden-lifted-stark/src/pcs/deep/verifier.rs` (DEEP reduction + quotient eval) @@ -224,8 +225,8 @@ If the AIR is *liftable* (roughly: it does not depend on wrap-around "next row" semantics unless explicitly constrained), then proving the lifted identity is as sound as proving the non-lifted identity. -For more detail on liftable AIR conditions and periodicity constraints, see -`docs/lifting.md`. +For more detail on liftable AIR conditions and periodicity constraints, see the +"Mathematical background" in `miden-lifted-stark/src/{prover,verifier}/README.md`. ## Parameter Guidance (Non-Normative) diff --git a/stark/miden-lifted-air/Cargo.toml b/stark/miden-lifted-air/Cargo.toml index 3401252493..ee5751e517 100644 --- a/stark/miden-lifted-air/Cargo.toml +++ b/stark/miden-lifted-air/Cargo.toml @@ -14,10 +14,11 @@ doctest = false test = false [dependencies] -p3-air.workspace = true -p3-field.workspace = true -p3-matrix.workspace = true -p3-util.workspace = true +p3-air.workspace = true +p3-challenger.workspace = true +p3-field.workspace = true +p3-matrix.workspace = true +p3-util.workspace = true thiserror.workspace = true diff --git a/stark/miden-lifted-air/src/air.rs b/stark/miden-lifted-air/src/air.rs index d7e349d4c3..28b18cd317 100644 --- a/stark/miden-lifted-air/src/air.rs +++ b/stark/miden-lifted-air/src/air.rs @@ -1,10 +1,12 @@ -//! The `LiftedAir` super-trait for AIR definitions in the lifted STARK system. +//! AIR traits for the lifted STARK system: [`LiftedAir`] (a single AIR) and +//! [`MultiAir`] (the AIR collection plus the cross-AIR behavior on it). //! //! # Panic safety of `eval()` //! //! [`LiftedAir::eval`] is generic over `AB: LiftedAirBuilder`, so it cannot branch //! on the concrete builder type. All builders expose data through the same trait -//! methods — [`main()`](crate::AirBuilder::main), +//! methods — [`preprocessed()`](crate::AirBuilder::preprocessed), +//! [`main()`](crate::AirBuilder::main), //! [`permutation()`](crate::PermutationAirBuilder::permutation), //! [`public_values()`](crate::AirBuilder::public_values), //! [`permutation_randomness()`](crate::PermutationAirBuilder::permutation_randomness), @@ -12,25 +14,23 @@ //! [`periodic_values()`](crate::PeriodicAirBuilder::periodic_values) — which return //! matrices or slices. //! -//! If the symbolic evaluation in [`LiftedAir::log_quotient_degree`] succeeds (i.e. +//! If the symbolic evaluation in [`LiftedAir::constraint_degree`] succeeds (i.e. //! does not panic), it proves that the AIR's `eval()` only accesses indices within //! the declared dimensions. Any concrete builder constructed with matching dimensions //! is therefore safe from out-of-bounds panics. //! -//! Use [`LiftedAir::is_valid_builder`] to verify that a concrete builder's +//! Use [`crate::debug::check_builder_shape`] to verify a concrete builder's //! dimensions match the AIR before calling `eval()`. -use alloc::vec::Vec; +use alloc::{boxed::Box, vec::Vec}; -use p3_air::{BaseAir, WindowAccess}; +use p3_air::BaseAir; +use p3_challenger::CanObserve; use p3_field::{ExtensionField, Field}; use p3_matrix::dense::RowMajorMatrix; -use p3_util::log2_ceil_usize; -use thiserror::Error; use crate::{ LiftedAirBuilder, - auxiliary::{ReducedAuxValues, ReductionError, VarLenPublicInputs}, symbolic::{AirLayout, SymbolicAirBuilder, SymbolicExpression, SymbolicExpressionExt}, }; @@ -45,6 +45,17 @@ use crate::{ /// - `F`: Base field /// - `EF`: Extension field (for aux trace challenges and aux values) pub trait LiftedAir: Sync + BaseAir { + /// Number of base-field columns in the preprocessed trace. + /// + /// A preprocessed trace is data fixed per AIR — typically a lookup table + /// or selector polynomial — committed once and reused across proofs. AIRs + /// without preprocessed columns return 0 (the default). The content comes + /// from [`BaseAir::preprocessed_trace`]; this is the cheap width the + /// verifier reads without materialising the table. + fn preprocessed_width(&self) -> usize { + 0 + } + /// Return the periodic table data: a list of columns, each a `Vec` of evaluations. /// /// Each inner `Vec` represents one periodic column. Its length is the period of @@ -79,6 +90,28 @@ pub trait LiftedAir: Sync + BaseAir { Some(RowMajorMatrix::new(values, num_cols)) } + /// Maximum periodic-column length, or `0` if there are none. A trace's height + /// must be at least this, so it is the per-AIR lower bound on trace height. + /// + /// The default derives it from [`periodic_columns`](Self::periodic_columns), + /// asserting each column is a non-empty power of two. Override to return it + /// directly when known statically; the override is cross-checked by + /// [`crate::debug::assert_multi_air_valid`]. + fn max_periodic_length(&self) -> usize { + self.periodic_columns() + .iter() + .map(|col| { + assert!( + !col.is_empty() && col.len().is_power_of_two(), + "periodic column length must be a positive power of two, got {}", + col.len(), + ); + col.len() + }) + .max() + .unwrap_or(0) + } + /// Number of extension-field challenges required for the auxiliary trace. fn num_randomness(&self) -> usize; @@ -87,69 +120,32 @@ pub trait LiftedAir: Sync + BaseAir { /// Number of extension-field aux values committed to the Fiat-Shamir transcript. /// - /// These are the values returned by - /// [`AuxBuilder::build_aux_trace`](crate::AuxBuilder::build_aux_trace) alongside the aux - /// trace matrix. Their count may differ from [`aux_width`](Self::aux_width) (the number of - /// aux trace columns). - /// - /// These values are exposed to AIR constraints as *permutation values* via + /// Returned by [`build_aux_trace`](Self::build_aux_trace) alongside the aux trace + /// matrix; the count is independent of [`aux_width`](Self::aux_width) (the number + /// of aux trace columns). Exposed to constraints as *permutation values* via /// [`PermutationAirBuilder::permutation_values`](crate::PermutationAirBuilder::permutation_values). fn num_aux_values(&self) -> usize; - /// Number of variable-length public inputs this AIR expects. - /// - /// Each input is a slice of base-field elements that - /// [`reduced_aux_values`](Self::reduced_aux_values) reduces to a single value. - /// The prover validates that witnesses provide exactly this many slices. - /// - /// Implementors of [`reduced_aux_values`](Self::reduced_aux_values) should verify - /// that `var_len_public_inputs` contains exactly this many slices, returning - /// [`ReductionError`] otherwise. - fn num_var_len_public_inputs(&self) -> usize; - - /// Reduce this AIR's aux values to a [`ReducedAuxValues`] contribution. - /// - /// Called by the verifier (with concrete field values, not symbolic expressions) - /// to compute each AIR's contribution to the global cross-AIR bus identity check. - /// The verifier accumulates contributions across all AIRs and checks that the - /// combined result is identity (prod=1, sum=0). - /// - /// # Arguments - /// - `aux_values`: prover-supplied aux values (from the proof) - /// - `challenges`: extension-field challenges (same as used for aux trace building) - /// - `public_values`: this AIR's public values (base field) - /// - `var_len_public_inputs`: reducible inputs for the cross-AIR identity check + /// Build this AIR's auxiliary trace and aux values. /// - /// # Errors - /// - /// The verifier validates instance dimensions (public values length, - /// var-len public inputs count) before calling this method, so - /// implementations can assume correct input counts. However, the - /// *length of each individual var-len slice* is not validated upfront — - /// implementations that index into these slices must check lengths - /// themselves or use the `Result` return type to report errors. - /// - /// Default: returns identity (correct for AIRs without buses). - fn reduced_aux_values( + /// `challenges` contains exactly [`num_randomness`](Self::num_randomness) + /// extension-field elements for this AIR. The returned `aux_trace` has width + /// [`aux_width`](Self::aux_width) and the same height as `main`; `aux_values` has + /// length [`num_aux_values`](Self::num_aux_values). + fn build_aux_trace( &self, - _aux_values: &[EF], - _challenges: &[EF], - _public_values: &[F], - _var_len_public_inputs: VarLenPublicInputs<'_, F>, - ) -> Result, ReductionError> - where - EF: ExtensionField, - { - Ok(ReducedAuxValues::identity()) - } + main: &RowMajorMatrix, + air_inputs: &[F], + aux_inputs: &[F], + challenges: &[EF], + ) -> (RowMajorMatrix, Vec); /// Return the [`AirLayout`] describing this AIR's dimensions. /// /// This is the single source of truth for building symbolic or layout builders. - /// `preprocessed_width` is always 0 because lifted AIRs forbid preprocessed traces. fn air_layout(&self) -> AirLayout { AirLayout { - preprocessed_width: 0, + preprocessed_width: self.preprocessed_width(), main_width: self.width(), num_public_values: self.num_public_values(), permutation_width: self.aux_width(), @@ -159,178 +155,214 @@ pub trait LiftedAir: Sync + BaseAir { } } - /// Validate that this AIR satisfies the [`LiftedAir`] contract. - /// - /// The lifted STARK protocol relies on several structural properties of the AIR - /// that can be checked statically (i.e. without a witness). This method verifies - /// the subset that is machine-checkable; the full list of trust assumptions is - /// documented in the module docs of `miden-lifted-stark`. Both the prover and - /// verifier call this before proceeding, so a malformed AIR is caught early. - /// - /// # Checked properties - /// - /// - **No preprocessed trace** — the lifted STARK protocol does not support preprocessed - /// (fixed) columns; their presence is an error. - /// - **Positive auxiliary width** — every lifted AIR must declare at least one auxiliary column - /// (`aux_width() > 0`). - /// - **Well-formed periodic columns** — each periodic column must be non-empty and have a - /// power-of-two length. - fn validate(&self) -> Result<(), AirStructureError> { - if self.preprocessed_trace().is_some() { - return Err(AirStructureError::PreprocessedTrace); - } - if self.aux_width() == 0 { - return Err(AirStructureError::ZeroAuxWidth); - } - for (i, col) in self.periodic_columns().iter().enumerate() { - if col.is_empty() || !col.len().is_power_of_two() { - return Err(AirStructureError::InvalidPeriodicColumn { - index: i, - length: col.len(), - }); - } - } - Ok(()) - } - /// Evaluate all AIR constraints using the provided builder. fn eval>(&self, builder: &mut AB); - /// Log₂ of the number of quotient chunks, inferred from symbolic constraint analysis. - /// - /// Evaluates the AIR on a [`SymbolicAirBuilder`](crate::symbolic::SymbolicAirBuilder) to - /// determine the maximum constraint degree M, then returns `log2_ceil(M - 1)` (padded so M - /// ≥ 2). + /// Symbolic constraint degree multiples, split into base-field and + /// extension-field maxima (see [`ConstraintDegrees`]). /// - /// Uses `SymbolicAirBuilder` (i.e. `EF = F`) which is sufficient for degree - /// computation since extension-field operations have the same degree structure. - /// - /// # Why `M − 1` chunks? - /// - /// Let N be the trace height (so trace columns are polynomials of degree < N). - /// Symbolic evaluation assigns each constraint a *degree multiple* M, meaning the - /// resulting numerator polynomial C(X) has degree bounded by roughly M·(N − 1). - /// - /// In a STARK, the constraint numerator is divisible by the trace vanishing - /// polynomial `Z_H(X) = Xᴺ − 1`, so the quotient polynomial - /// `Q(X) = C(X) / Z_H(X)` has - /// - /// `deg(Q) ≤ deg(C) − N ≤ M·(N − 1) − N < (M − 1)·N`. - /// - /// We commit to Q(X) by splitting it into D chunks of degree < N. The bound above - /// shows that D = M − 1 chunks suffice; we then round D up to a power of two and - /// return `log2(D)`. - /// - /// We clamp M ≥ 2 so that D ≥ 1. If M = 1 then `deg(C) < N`, and divisibility by - /// `Z_H` would force C(X) to be the zero polynomial (i.e. the constraint carries no - /// information about the trace). - fn log_quotient_degree(&self) -> usize + /// The split lets callers report base and extension constraint degree maxima + /// separately; STARK prover/verifier code derives quotient degrees from these + /// raw symbolic bounds later. Override this when the split is known statically + /// so a per-AIR bound can be sharp without redoing the symbolic pass. + fn constraint_degree(&self) -> ConstraintDegrees where Self: Sized, { - let mut builder = SymbolicAirBuilder::::new(self.air_layout()); - self.eval(&mut builder); + ConstraintDegrees::from_air::(self) + } +} + +/// Symbolic constraint degree multiples, split by constraint kind. +/// +/// `base` is the maximum degree multiple over the base-field constraints and +/// `ext` over the extension-field constraints (each `0` if the AIR has none of +/// that kind). Consumers that need a single value take [`max`](Self::max). +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ConstraintDegrees { + /// Max degree multiple over base-field constraints (`0` if there are none). + pub base: usize, + /// Max degree multiple over extension-field constraints (`0` if there are none). + pub ext: usize, +} - let base_degree_multiple = - |constraint: &SymbolicExpression| constraint.degree_multiple(); - let ext_degree_multiple = - |constraint: &SymbolicExpressionExt| constraint.degree_multiple(); +impl ConstraintDegrees { + /// Compute symbolic constraint degree multiples from an AIR. + pub fn from_air(air: &A) -> Self + where + F: Field, + A: LiftedAir, + { + let mut builder = SymbolicAirBuilder::::new(air.air_layout()); + air.eval(&mut builder); - let base_degree = - builder.base_constraints().iter().map(base_degree_multiple).max().unwrap_or(0); - let ext_degree = builder + let base = builder + .base_constraints() + .iter() + .map(SymbolicExpression::degree_multiple) + .max() + .unwrap_or(0); + let ext = builder .extension_constraints() .iter() - .map(ext_degree_multiple) + .map(SymbolicExpressionExt::degree_multiple) .max() .unwrap_or(0); - let constraint_degree = base_degree.max(ext_degree).max(2); - - log2_ceil_usize(constraint_degree - 1) + Self { base, ext } } - /// Number of quotient chunks: `2^log_quotient_degree()`. - fn constraint_degree(&self) -> usize - where - Self: Sized, - { - 1 << self.log_quotient_degree() + /// The combined degree multiple: the larger of [`base`](Self::base) and + /// [`ext`](Self::ext). + pub fn max(&self) -> usize { + self.base.max(self.ext) } +} - /// Check that a builder's dimensions match this AIR. - /// - /// Verifies every data-carrying accessor on [`LiftedAirBuilder`]: main trace, - /// preprocessed trace, aux trace, public values, randomness, aux values, and - /// periodic values. - /// - /// This guards the invariant that makes [`eval`](Self::eval) panic-free: if - /// the symbolic evaluation in [`log_quotient_degree`](Self::log_quotient_degree) - /// succeeds and this check passes, then `eval()` cannot panic from - /// out-of-bounds access on the builder's accessors. - fn is_valid_builder>( - &self, - builder: &AB, - ) -> Result<(), AirStructureError> { - let check = - |part: TracePart, expected: usize, actual: usize| -> Result<(), AirStructureError> { - if actual != expected { - return Err(AirStructureError::BuilderMismatch { part, expected, actual }); - } - Ok(()) - }; +/// Boxed error returned by [`MultiAir::eval_external`]. +pub type ReductionError = Box; + +// ============================================================================ +// MultiAir trait +// ============================================================================ - let main = builder.main(); - // Check current and next slices of the main trace. - check(TracePart::Main, self.width(), main.current_slice().len())?; - check(TracePart::Main, self.width(), main.next_slice().len())?; +/// Trusted statement definition for a multi-AIR proof. +/// +/// A `MultiAir` owns the AIR instances and defines the cross-AIR assertions and +/// Fiat-Shamir binding hook for the statement. +/// +/// Methods take `&self` so an impl can carry the AIRs and any protocol-level +/// state (closures, lookup tables, shared parameters). Because the AIRs are +/// owned here, a `MultiAir` is constructed per proof. +/// +/// The framework defines instance order as the position of an AIR within +/// [`Self::airs`]. Every per-AIR slice elsewhere on [`Statement`](crate::Statement) / +/// [`ProverStatement`](crate::ProverStatement) uses the same ordering. +/// +/// # Structural contract +/// +/// Prover/verifier hot paths assume the structural invariants checked by +/// [`crate::debug::assert_multi_air_valid`]. In particular, `airs()` must be +/// non-empty, all AIRs must agree on their public value count, and overridable +/// helpers such as [`Self::num_air_inputs`] must agree with the raw AIR data. +pub trait MultiAir +where + F: Field, + EF: ExtensionField, +{ + /// AIR type. Heterogeneous AIRs are expressed via caller-defined enum wrappers. + type Air: LiftedAir; - // Check current and next slices of the aux trace. - let perm = builder.permutation(); - check(TracePart::Aux, self.aux_width(), perm.current_slice().len())?; - check(TracePart::Aux, self.aux_width(), perm.next_slice().len())?; + /// The AIRs in instance order — the single source of that ordering. + fn airs(&self) -> &[Self::Air]; - check(TracePart::PublicValues, self.num_public_values(), builder.public_values().len())?; - check( - TracePart::Randomness, - self.num_randomness(), - builder.permutation_randomness().len(), - )?; - check(TracePart::AuxValues, self.num_aux_values(), builder.permutation_values().len())?; - check( - TracePart::PeriodicValues, - self.periodic_columns().len(), - builder.periodic_values().len(), - )?; + /// Number of public inputs shared by every AIR. + /// + /// All AIRs read the same `air_inputs`, so they must agree on + /// [`num_public_values`](p3_air::BaseAir::num_public_values). The default + /// returns that shared count. It assumes the structural contract checked by + /// [`crate::debug::assert_multi_air_valid`], including non-empty `airs()`; + /// override to return it directly when known statically. + /// [`Statement::new`](crate::Statement::new) checks `air_inputs.len()` against it. + fn num_air_inputs(&self) -> usize { + let airs = self.airs(); + let n = airs + .first() + .expect("a MultiAir must carry at least one AIR") + .num_public_values(); + assert!( + airs.iter().all(|air| air.num_public_values() == n), + "AIRs disagree on num_public_values", + ); + n + } - Ok(()) + /// Upper bound on the extra statement inputs accepted by + /// [`Self::eval_external`]. + /// + /// These inputs are public/verifier-visible, but are not passed to each AIR + /// as `air_inputs` and are unrelated to aux trace columns. They are + /// available only to the statement-level cross-AIR assertions. Validated by + /// [`Statement::new`](crate::Statement::new) before any cryptographic work. + /// Default `0`; implementations that consume `aux_inputs` must override so + /// the budget matches the schema their `eval_external` decodes. + fn max_aux_inputs(&self) -> usize { + 0 } -} -/// Which part of the trace a builder mismatch refers to. -#[derive(Copy, Clone, Debug)] -pub enum TracePart { - Main, - Aux, - PublicValues, - Randomness, - AuxValues, - PeriodicValues, -} + /// Evaluate statement-level cross-AIR assertions. + /// + /// Returns one value per assertion expression. Each expression is expected + /// to vanish for a valid statement; the prover/verifier accept only if every + /// returned value is zero. Implementations should return `Ok(Vec::new())` + /// when the statement has no cross-AIR assertions. + /// + /// An implementation could perform zero checks internally, but returning + /// assertion expression values keeps this hook close to the AIR constraint + /// model: build the algebraic expression for each assertion, then let the + /// protocol batch or individually assert those expressions equal zero. This + /// also keeps the logic usable by a future symbolic-expression pipeline that + /// extracts the assertion polynomials. + /// + /// # Arguments + /// - `challenges`: shared extension-field challenge pool; each AIR consumes the prefix of + /// length `air.num_randomness()`. + /// - `air_inputs`, `aux_inputs`: the inputs from the [`Statement`](crate::Statement). + /// - `aux_values`, `log_trace_heights`: parallel per-AIR slices in instance order. + /// `aux_values[i]` and `log_trace_heights[i]` both describe `self.airs()[i]`. The protocol + /// derives proof order by stable-sorting instance indices by `(log_trace_height, + /// instance_index)`. + /// + /// Default: refuses to be called with non-empty `aux_inputs`; otherwise + /// emits no assertions. + fn eval_external( + &self, + challenges: &[EF], + air_inputs: &[F], + aux_inputs: &[F], + aux_values: &[&[EF]], + log_trace_heights: &[u8], + ) -> Result, ReductionError> { + if !aux_inputs.is_empty() { + return Err("default `eval_external` received non-empty `aux_inputs` — override \ + `eval_external` to consume them" + .into()); + } + let _ = (challenges, air_inputs, aux_values, log_trace_heights); + Ok(Vec::new()) + } -/// Errors intrinsic to a single AIR definition, independent of any instance -/// data. Returned by [`LiftedAir::validate`] and [`LiftedAir::is_valid_builder`]. -#[derive(Debug, Error)] -pub enum AirStructureError { - #[error("periodic column {index}: length must be positive power of two, got {length}")] - InvalidPeriodicColumn { index: usize, length: usize }, - #[error("preprocessed traces are not supported")] - PreprocessedTrace, - #[error("{part:?} dimension mismatch: expected {expected}, got {actual}")] - BuilderMismatch { - part: TracePart, - expected: usize, - actual: usize, - }, - #[error("aux width must be positive")] - ZeroAuxWidth, + /// Absorb statement-owned public inputs into the Fiat-Shamir challenger. + /// + /// The default order is `air_inputs.len()`, `air_inputs`, + /// `max_aux_inputs()`, `aux_inputs.len()`, then `aux_inputs`. The protocol + /// observes the instance count and `log_trace_heights` separately after this + /// hook, but passes the heights here so custom bindings can include AIR + /// metadata that depends on the prover-chosen trace ordering or heights. + /// + /// # Soundness gap (TODO) + /// + /// The default binds inputs but does NOT canonically bind the `MultiAir` + /// itself — neither its AIR collection nor `eval_external` logic — into + /// Fiat-Shamir. Until the symbolic-graph binding lands (tracked in + /// ), callers MUST observe the + /// `MultiAir`'s AIR configurations into the challenger before calling the + /// prover or verifier. + fn observe>( + &self, + challenger: &mut C, + air_inputs: &[F], + aux_inputs: &[F], + log_trace_heights: &[u8], + ) { + challenger.observe(F::from_usize(air_inputs.len())); + for &v in air_inputs { + challenger.observe(v); + } + challenger.observe(F::from_usize(self.max_aux_inputs())); + challenger.observe(F::from_usize(aux_inputs.len())); + for &v in aux_inputs { + challenger.observe(v); + } + let _ = log_trace_heights; + } } diff --git a/stark/miden-lifted-air/src/auxiliary/builder.rs b/stark/miden-lifted-air/src/auxiliary/builder.rs deleted file mode 100644 index 0846b88b38..0000000000 --- a/stark/miden-lifted-air/src/auxiliary/builder.rs +++ /dev/null @@ -1,36 +0,0 @@ -//! The `AuxBuilder` trait for constructing auxiliary traces. -//! -//! This trait decouples auxiliary trace *building* from the AIR definition, -//! allowing the prover to supply a separate builder per instance. - -use alloc::vec::Vec; - -use p3_field::{ExtensionField, Field}; -use p3_matrix::dense::RowMajorMatrix; - -/// Builder for constructing the auxiliary trace from a main trace and challenges. -/// -/// Decoupled from [`LiftedAir`](crate::LiftedAir) so that prover-side trace -/// construction is not part of the AIR trait. Each prover instance can supply -/// its own `AuxBuilder`. -pub trait AuxBuilder> { - /// Build the auxiliary trace and return aux values. - /// - /// # Arguments - /// - `main`: The main trace matrix - /// - `challenges`: Extension-field challenges for aux trace construction - /// - /// # Returns - /// `(aux_trace, aux_values)` where: - /// - `aux_trace`: The auxiliary trace matrix (EF columns) - /// - `aux_values`: Extension-field scalars committed to the Fiat-Shamir transcript. Their - /// meaning is AIR-defined — typically the aux trace's last row, but the protocol does not - /// require this. The AIR's [`eval`](crate::LiftedAir::eval) should constrain how they relate - /// to the committed trace, and [`reduced_aux_values`](crate::LiftedAir::reduced_aux_values) - /// uses them for cross-AIR bus identity checking. - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - challenges: &[EF], - ) -> (RowMajorMatrix, Vec); -} diff --git a/stark/miden-lifted-air/src/auxiliary/mod.rs b/stark/miden-lifted-air/src/auxiliary/mod.rs deleted file mode 100644 index 89116a68e8..0000000000 --- a/stark/miden-lifted-air/src/auxiliary/mod.rs +++ /dev/null @@ -1,54 +0,0 @@ -//! Auxiliary trace types: builder and cross-AIR identity checking. -//! -//! # Protocol Overview -//! -//! The auxiliary trace enables cross-AIR buses (multiset / logup) in the lifted STARK. -//! -//! ## Prover -//! -//! 1. [`AuxBuilder::build_aux_trace`] constructs the aux trace and returns aux values -//! (extension-field elements whose meaning is AIR-defined). -//! 2. The aux trace is committed (Merkle commitment). -//! 3. Aux values are sent via the Fiat-Shamir transcript. -//! -//! ## AIR constraints ([`eval`](crate::LiftedAir::eval)) -//! -//! 4. The AIR defines how aux values relate to the committed aux trace. A common pattern is to -//! constrain them to equal the aux trace's last row, but the protocol does not impose this — the -//! AIR is free to define whatever relationship it needs. -//! 5. Transition constraints enforce the aux trace's internal logic (e.g. running product -//! accumulation). -//! -//! ## Verifier -//! -//! 6. The verifier receives aux values from the transcript. -//! 7. Constraint evaluation (steps 4–5) is checked at a random point. -//! 8. [`reduced_aux_values`](crate::LiftedAir::reduced_aux_values) computes each AIR's bus -//! contribution from the aux values, challenges, and public inputs. -//! 9. Global check: all contributions combine to identity (prod=1, sum=0). - -mod builder; -mod values; - -pub use builder::AuxBuilder; -pub use values::ReducedAuxValues; - -/// Variable-length public inputs for an AIR instance. -/// -/// A list of *reducible inputs*: each `&[F]` is a slice of base-field elements -/// that [`LiftedAir::reduced_aux_values`](crate::LiftedAir::reduced_aux_values) -/// reduces to a single extension-field value. The AIR defines how to group and -/// interpret them (e.g. which inputs belong to which bus). -/// -/// The number of slices must equal -/// [`LiftedAir::num_var_len_public_inputs`](crate::LiftedAir::num_var_len_public_inputs). -/// -/// **Commitment:** callers **must** bind these inputs to the Fiat-Shamir -/// challenger state, just like the AIR's public values. -pub type VarLenPublicInputs<'a, F> = &'a [&'a [F]]; - -/// Boxed error returned by -/// [`LiftedAir::reduced_aux_values`](crate::LiftedAir::reduced_aux_values). -/// -/// Each AIR defines its own concrete error type and boxes it into this alias. -pub type ReductionError = alloc::boxed::Box; diff --git a/stark/miden-lifted-air/src/auxiliary/values.rs b/stark/miden-lifted-air/src/auxiliary/values.rs deleted file mode 100644 index 98a3375212..0000000000 --- a/stark/miden-lifted-air/src/auxiliary/values.rs +++ /dev/null @@ -1,48 +0,0 @@ -//! Types for auxiliary trace value reduction and cross-AIR identity checking. -//! -//! Each AIR's aux trace has associated aux values (extension field scalars), -//! sent via the transcript by the prover. The verifier calls -//! [`LiftedAir::reduced_aux_values`](crate::LiftedAir::reduced_aux_values) -//! on each AIR to compute a [`ReducedAuxValues`] contribution, then checks that -//! the global combination is identity (prod=1, sum=0). - -use p3_field::{Field, PrimeCharacteristicRing}; - -/// Accumulated contribution from reducing aux values across one or more AIRs. -/// -/// The global identity check requires: -/// - `prod == 1` (multiset buses: all ratios multiply to 1) -/// - `sum == 0` (logup buses: all differences sum to 0) -#[derive(Clone, Debug)] -pub struct ReducedAuxValues { - /// Accumulated product for multiset buses. - pub prod: EF, - /// Accumulated sum for logup buses. - pub sum: EF, -} - -impl ReducedAuxValues { - /// The identity contribution (no buses): prod=1, sum=0. - pub fn identity() -> Self { - Self { prod: EF::ONE, sum: EF::ZERO } - } - - /// Combine another contribution into this one. - pub fn combine_in_place(&mut self, other: &Self) { - self.prod *= other.prod.clone(); - self.sum += other.sum.clone(); - } - - /// Combine two contributions, returning a new one. - pub fn combine(mut self, other: &Self) -> Self { - self.combine_in_place(other); - self - } -} - -impl ReducedAuxValues { - /// Check whether this contribution is the identity (all buses satisfied). - pub fn is_identity(&self) -> bool { - self.prod == EF::ONE && self.sum == EF::ZERO - } -} diff --git a/stark/miden-lifted-air/src/debug.rs b/stark/miden-lifted-air/src/debug.rs new file mode 100644 index 0000000000..2e1f03da4b --- /dev/null +++ b/stark/miden-lifted-air/src/debug.rs @@ -0,0 +1,158 @@ +//! Debug-only structural checks for AIRs. These panic on contract violation and +//! are meant for tests / setup; the prover and verifier hot paths assume AIRs +//! are well-formed. +//! +//! # When to use what +//! +//! - [`assert_multi_air_valid`]: verify a [`MultiAir`] satisfies the structural contract assumed by +//! the rest of the protocol — per AIR positive auxiliary width and power-of-two periodic columns; +//! across AIRs a shared `num_public_values`. This also cross-checks the overridable +//! [`MultiAir::num_air_inputs`] / [`LiftedAir::max_periodic_length`] against the raw AIR data, so +//! an override that lies about either is caught here. +//! - [`check_builder_shape`]: verify a concrete builder's accessor dimensions match an AIR before +//! calling [`LiftedAir::eval`]. Used as a `cfg(debug_assertions)` belt-and-suspenders inside the +//! prover and verifier loops. +//! +//! Runtime checks on caller-supplied data live on the constructors +//! [`Statement::new`](crate::Statement::new) / +//! [`ProverStatement::new`](crate::ProverStatement::new). + +use p3_field::{ExtensionField, Field}; +use p3_matrix::Matrix; + +use crate::{BaseAir, LiftedAir, LiftedAirBuilder, MultiAir, WindowAccess}; + +/// Assert a [`MultiAir`] is structurally well-formed. +/// +/// This checks the structural invariants that prover/verifier hot paths trust: +/// - [`MultiAir::airs`] is non-empty. +/// - All AIRs agree on [`BaseAir::num_public_values`]. +/// - [`MultiAir::num_air_inputs`] agrees with the per-AIR public value count. +/// - Each AIR has positive auxiliary width. +/// - Each AIR's [`LiftedAir::preprocessed_width`] agrees with [`BaseAir::preprocessed_trace`] +/// presence and width. +/// - Each periodic column is non-empty and has power-of-two length. +/// - [`LiftedAir::max_periodic_length`] agrees with the raw periodic columns. +/// +/// Panics on any violation. An empty [`MultiAir`] is a malformed trusted AIR +/// definition, not a typed [`Statement::new`](crate::Statement::new) error. +pub fn assert_multi_air_valid(multi_air: &MA) +where + F: Field, + EF: ExtensionField, + MA: MultiAir, +{ + let airs = multi_air.airs(); + assert!(!airs.is_empty(), "MultiAir::airs() must be non-empty"); + + // Derive the shared count from the raw AIRs and confirm the overridable + // `num_air_inputs` agrees. + let num_air_inputs = airs[0].num_public_values(); + assert!( + airs.iter().all(|air| air.num_public_values() == num_air_inputs), + "AIRs disagree on num_public_values", + ); + assert!( + multi_air.num_air_inputs() == num_air_inputs, + "num_air_inputs() = {} disagrees with per-AIR num_public_values() = {num_air_inputs}", + multi_air.num_air_inputs(), + ); + + for (idx, air) in airs.iter().enumerate() { + check_one_air::(idx, air); + } +} + +/// Assert one AIR satisfies the structural contract. +fn check_one_air(idx: usize, air: &A) +where + F: Field, + A: LiftedAir, +{ + assert!(air.aux_width() > 0, "AIR {idx}: aux_width must be positive"); + + let preprocessed_width = air.preprocessed_width(); + match air.preprocessed_trace() { + Some(trace) => { + assert!( + preprocessed_width > 0, + "AIR {idx}: preprocessed_trace returned Some but preprocessed_width() is 0", + ); + assert_eq!( + trace.width(), + preprocessed_width, + "AIR {idx}: preprocessed_trace width disagrees with preprocessed_width()", + ); + assert!( + trace.height().is_power_of_two(), + "AIR {idx}: preprocessed_trace height must be a positive power of two, got {height}", + height = trace.height(), + ); + }, + None => { + assert_eq!( + preprocessed_width, 0, + "AIR {idx}: preprocessed_width() is {preprocessed_width} but preprocessed_trace returned None", + ); + }, + } + + // Derive the max period from the raw columns (asserting positive-power-of-two) + // and confirm the overridable `max_periodic_length` agrees. + let mut max_period = 0; + for (i, col) in air.periodic_columns().iter().enumerate() { + assert!( + !col.is_empty() && col.len().is_power_of_two(), + "AIR {idx}: periodic column {i}: length must be a positive power of two, \ + got {len}", + len = col.len(), + ); + max_period = max_period.max(col.len()); + } + assert!( + air.max_periodic_length() == max_period, + "AIR {idx}: max_periodic_length() = {} disagrees with periodic_columns() max = {max_period}", + air.max_periodic_length(), + ); +} + +/// Assert a concrete builder's accessor dimensions match `air` — preprocessed, +/// main and aux trace, public values, randomness, aux values, and periodic values. +/// +/// Guards the invariant that makes [`LiftedAir::eval`] panic-free: if symbolic +/// evaluation in `constraint_degree` succeeds and this check passes, `eval()` +/// cannot panic from out-of-bounds accessor access. +pub fn check_builder_shape(air: &A, builder: &AB) +where + F: Field, + A: LiftedAir, + AB: LiftedAirBuilder, +{ + let check = |part: &str, expected: usize, actual: usize| { + assert!( + actual == expected, + "{part} dimension mismatch: expected {expected}, got {actual}" + ); + }; + + let preprocessed = builder.preprocessed(); + check( + "preprocessed (current)", + air.preprocessed_width(), + preprocessed.current_slice().len(), + ); + check("preprocessed (next)", air.preprocessed_width(), preprocessed.next_slice().len()); + + let main = builder.main(); + check("main (current)", air.width(), main.current_slice().len()); + check("main (next)", air.width(), main.next_slice().len()); + + let perm = builder.permutation(); + check("aux (current)", air.aux_width(), perm.current_slice().len()); + check("aux (next)", air.aux_width(), perm.next_slice().len()); + + check("public values", air.num_public_values(), builder.public_values().len()); + check("randomness", air.num_randomness(), builder.permutation_randomness().len()); + check("aux values", air.num_aux_values(), builder.permutation_values().len()); + check("periodic values", air.periodic_columns().len(), builder.periodic_values().len()); +} diff --git a/stark/miden-lifted-air/src/empty_window.rs b/stark/miden-lifted-air/src/empty_window.rs deleted file mode 100644 index 1eaad1135c..0000000000 --- a/stark/miden-lifted-air/src/empty_window.rs +++ /dev/null @@ -1,38 +0,0 @@ -//! A zero-width window that prevents access to preprocessed columns. -//! -//! Lifted AIRs have no preprocessed trace. [`EmptyWindow`] encodes this invariant: -//! AIR validation prevents preprocessed access, and the window methods are -//! unreachable as a defence-in-depth measure. - -use core::marker::PhantomData; - -use p3_air::WindowAccess; - -/// A window type for traces that must never be accessed. -/// -/// Satisfies the `WindowAccess + Clone` bound required by -/// [`AirBuilder::PreprocessedWindow`](p3_air::AirBuilder::PreprocessedWindow). -/// Lifted AIRs have no preprocessed trace, so these methods should never be -/// called; AIR validation prevents this at a higher level. -#[derive(Debug, Clone, Copy)] -pub struct EmptyWindow(PhantomData); - -impl EmptyWindow { - /// Static reference to an empty window. - /// - /// Safe because `EmptyWindow` is a ZST — no actual `T` is stored, - /// so the `'static` lifetime is always valid. - pub fn empty_ref() -> &'static Self { - &EmptyWindow(PhantomData) - } -} - -impl WindowAccess for EmptyWindow { - fn current_slice(&self) -> &[T] { - unreachable!("preprocessed trace does not exist in lifted AIRs") - } - - fn next_slice(&self) -> &[T] { - unreachable!("preprocessed trace does not exist in lifted AIRs") - } -} diff --git a/stark/miden-lifted-air/src/lib.rs b/stark/miden-lifted-air/src/lib.rs index df03316935..bf7ea6a71e 100644 --- a/stark/miden-lifted-air/src/lib.rs +++ b/stark/miden-lifted-air/src/lib.rs @@ -1,34 +1,34 @@ //! AIR traits for the Miden lifted STARK protocol. //! //! This crate provides: -//! - [`LiftedAir`]: Super-trait for AIR definitions (inherits upstream + adds aux trace support and +//! - [`LiftedAir`]: super-trait for AIR definitions (inherits upstream + adds aux trace support and //! periodic column data) -//! - [`LiftedAirBuilder`]: Super-trait for constraint builders -//! - [`auxiliary`]: Auxiliary trace types (builder, cross-AIR identity checking). +//! - [`LiftedAirBuilder`]: super-trait for constraint builders +//! - [`MultiAir`]: trusted statement definition for a multi-AIR proof +//! - [`Statement`]: validated per-proof caller inputs over a `MultiAir` +//! - [`ProverStatement`]: validated proving input — a `Statement` plus per-AIR main witness traces +//! - [`debug`]: panic-based AIR structural checks for tests / setup #![no_std] extern crate alloc; mod air; -pub mod auxiliary; mod builder; +pub mod debug; +mod statement; mod util; -pub use air::{AirStructureError, LiftedAir, TracePart}; -pub use auxiliary::{AuxBuilder, ReducedAuxValues, ReductionError, VarLenPublicInputs}; +pub use air::{ConstraintDegrees, LiftedAir, MultiAir, ReductionError}; pub use builder::LiftedAirBuilder; -pub use util::log2_strict_u8; - -mod empty_window; - -pub use empty_window::EmptyWindow; // Re-export upstream p3-air types so downstream crates never need to depend on p3-air // directly. pub use p3_air::{ Air, AirBuilder, AirBuilderWithContext, BaseAir, ExtensionBuilder, FilteredAirBuilder, PeriodicAirBuilder, PermutationAirBuilder, RowWindow, WindowAccess, }; +pub use statement::{InstanceError, ProverStatement, Statement}; +pub use util::{log2_ceil_u8, log2_strict_u8}; /// Symbolic constraint analysis types from upstream p3-air. pub mod symbolic { diff --git a/stark/miden-lifted-air/src/statement.rs b/stark/miden-lifted-air/src/statement.rs new file mode 100644 index 0000000000..b4fe6e331f --- /dev/null +++ b/stark/miden-lifted-air/src/statement.rs @@ -0,0 +1,251 @@ +//! Validated runtime inputs for trusted multi-AIR definitions: [`Statement`] +//! holds per-proof caller inputs, and [`ProverStatement`] adds per-AIR main +//! traces. + +use alloc::vec::Vec; +use core::marker::PhantomData; + +use p3_challenger::CanObserve; +use p3_field::{ExtensionField, Field}; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; +use thiserror::Error; + +use crate::{ + BaseAir, LiftedAir, + air::{MultiAir, ReductionError}, +}; + +// ============================================================================ +// Statement +// ============================================================================ + +/// Validated per-proof inputs over a [`MultiAir`]: `air_inputs` and `aux_inputs`. +/// +/// Holding one guarantees the caller-input length checks in [`Statement::new`] passed. +/// The `MultiAir` itself is trusted application code; run +/// [`crate::debug::assert_multi_air_valid`] in tests/setup to check structural +/// invariants such as non-empty `airs()`, shared public-value counts, and +/// periodic-column shape. +pub struct Statement +where + F: Field, + EF: ExtensionField, + MA: MultiAir, +{ + multi_air: MA, + air_inputs: Vec, + aux_inputs: Vec, + _ef: PhantomData, +} + +impl Statement +where + F: Field, + EF: ExtensionField, + MA: MultiAir, +{ + /// Construct a [`Statement`], validating caller-supplied input lengths + /// against `multi_air`. + /// + /// This assumes `multi_air` satisfies the structural contract checked by + /// [`crate::debug::assert_multi_air_valid`]; malformed AIR definitions (for + /// example an empty [`MultiAir::airs`] collection) may panic in trusted + /// helper methods such as [`MultiAir::num_air_inputs`]. + pub fn new( + multi_air: MA, + air_inputs: Vec, + aux_inputs: Vec, + ) -> Result { + let expected = multi_air.num_air_inputs(); + if expected != air_inputs.len() { + return Err(InstanceError::PublicValuesLengthMismatch { + expected, + actual: air_inputs.len(), + }); + } + let max = multi_air.max_aux_inputs(); + if aux_inputs.len() > max { + return Err(InstanceError::AuxInputsTooLong { actual: aux_inputs.len(), max }); + } + Ok(Self { + multi_air, + air_inputs, + aux_inputs, + _ef: PhantomData, + }) + } + + pub fn multi_air(&self) -> &MA { + &self.multi_air + } + + pub fn airs(&self) -> &[MA::Air] { + self.multi_air.airs() + } + + pub fn air_inputs(&self) -> &[F] { + &self.air_inputs + } + + pub fn aux_inputs(&self) -> &[F] { + &self.aux_inputs + } + + /// Evaluate cross-AIR assertions via [`MultiAir::eval_external`]. + pub fn eval_external( + &self, + challenges: &[EF], + aux_values: &[&[EF]], + log_trace_heights: &[u8], + ) -> Result, ReductionError> { + self.multi_air.eval_external( + challenges, + &self.air_inputs, + &self.aux_inputs, + aux_values, + log_trace_heights, + ) + } + + /// Absorb statement-owned data into the Fiat-Shamir challenger via [`MultiAir::observe`]. + /// + /// The protocol separately observes the instance count and + /// `log_trace_heights` in instance order after this call. Heights are passed + /// here only so custom `MultiAir` bindings can include height-dependent + /// statement data. + pub fn observe>(&self, challenger: &mut C, log_trace_heights: &[u8]) { + self.multi_air + .observe(challenger, &self.air_inputs, &self.aux_inputs, log_trace_heights); + } +} + +// ============================================================================ +// ProverStatement +// ============================================================================ + +/// A [`Statement`] plus per-AIR main traces. +/// +/// Holding one guarantees the trace-shape checks in [`ProverStatement::new`] passed. +pub struct ProverStatement +where + F: Field, + EF: ExtensionField, + MA: MultiAir, +{ + statement: Statement, + traces: Vec>, +} + +impl ProverStatement +where + F: Field, + EF: ExtensionField, + MA: MultiAir, +{ + /// Construct a [`ProverStatement`], validating each trace's count, height, and + /// width against its AIR. + /// + /// This assumes the underlying [`MultiAir`] satisfies the structural contract + /// checked by [`crate::debug::assert_multi_air_valid`]. + pub fn new( + statement: Statement, + traces: Vec>, + ) -> Result { + // TraceOrder stores instance indices as u8, so it can represent 256 + // instances: indices 0 through u8::MAX. + let max_instances = u8::MAX as usize + 1; + if traces.len() > max_instances { + return Err(InstanceError::TooManyInstances { count: traces.len() }); + } + let airs = statement.airs(); + if airs.len() != traces.len() { + return Err(InstanceError::TraceCountMismatch { + airs: airs.len(), + traces: traces.len(), + }); + } + for (idx, trace) in traces.iter().enumerate() { + let h = trace.height(); + if h < 2 { + return Err(InstanceError::TraceHeightTooSmall { air: idx, height: h }); + } + if !h.is_power_of_two() { + return Err(InstanceError::TraceHeightNotPowerOfTwo { air: idx, height: h }); + } + } + for (idx, (air, trace)) in airs.iter().zip(traces.iter()).enumerate() { + let trace_height = trace.height(); + let max_period = air.max_periodic_length(); + if trace_height < max_period { + return Err(InstanceError::TraceHeightBelowPeriod { + air: idx, + trace_height, + max_period, + }); + } + if trace.width() != air.width() { + return Err(InstanceError::TraceWidthMismatch { + air: idx, + expected: air.width(), + actual: trace.width(), + }); + } + } + Ok(Self { statement, traces }) + } + + pub fn statement(&self) -> &Statement { + &self.statement + } + + pub fn traces(&self) -> &[RowMajorMatrix] { + &self.traces + } +} + +// ============================================================================ +// InstanceError +// ============================================================================ + +/// Errors from constructing a [`Statement`] / [`ProverStatement`] — the runtime +/// trust boundary on caller-supplied inputs. +#[derive(Debug, Error)] +pub enum InstanceError { + #[error("num_air_inputs() = {expected}, but air_inputs().len() = {actual}")] + PublicValuesLengthMismatch { expected: usize, actual: usize }, + + #[error("aux_inputs().len() = {actual} exceeds max_aux_inputs() = {max}")] + AuxInputsTooLong { actual: usize, max: usize }, + + #[error("airs().len() = {airs} does not match traces().len() = {traces}")] + TraceCountMismatch { airs: usize, traces: usize }, + + #[error( + "too many instances ({count}); the per-proof limit is {max} = u8::MAX + 1", + max = u8::MAX as usize + 1 + )] + TooManyInstances { count: usize }, + + #[error("AIR {air}: trace width = {actual}, but air.width() = {expected}")] + TraceWidthMismatch { + air: usize, + expected: usize, + actual: usize, + }, + + #[error("AIR {air}: trace height = {height} is too small; expected at least 2 rows")] + TraceHeightTooSmall { air: usize, height: usize }, + + #[error("AIR {air}: trace height = {height} is not a power of two")] + TraceHeightNotPowerOfTwo { air: usize, height: usize }, + + #[error( + "AIR {air}: trace height = {trace_height} is less than max periodic column \ + length {max_period}" + )] + TraceHeightBelowPeriod { + air: usize, + trace_height: usize, + max_period: usize, + }, +} diff --git a/stark/miden-lifted-air/src/util.rs b/stark/miden-lifted-air/src/util.rs index 8202692694..41a07f0e19 100644 --- a/stark/miden-lifted-air/src/util.rs +++ b/stark/miden-lifted-air/src/util.rs @@ -1,6 +1,6 @@ //! Small utility helpers shared across lifted-STARK crates. -use p3_util::log2_strict_usize; +use p3_util::{log2_ceil_usize, log2_strict_usize}; /// Strict log₂ returning `u8`. /// @@ -10,3 +10,12 @@ use p3_util::log2_strict_usize; pub fn log2_strict_u8(n: usize) -> u8 { log2_strict_usize(n) as u8 } + +/// Ceiling log₂ returning `u8`. +/// +/// Returns `0` for `n = 0` or `n = 1`; otherwise the smallest `k` such that +/// `2^k >= n`. Panics if the result exceeds `u8::MAX` (impossible in practice). +#[inline] +pub fn log2_ceil_u8(n: usize) -> u8 { + log2_ceil_usize(n) as u8 +} diff --git a/stark/miden-lifted-stark/Cargo.toml b/stark/miden-lifted-stark/Cargo.toml index 6d2d8a6485..4f13ed9a1f 100644 --- a/stark/miden-lifted-stark/Cargo.toml +++ b/stark/miden-lifted-stark/Cargo.toml @@ -97,5 +97,10 @@ harness = false name = "plonky3" required-features = ["testing"] +[[bench]] +harness = false +name = "per_air_degree_opt" +required-features = ["testing"] + [lints] workspace = true diff --git a/stark/miden-lifted-stark/README.md b/stark/miden-lifted-stark/README.md index 4035b48e8b..3104b54a0e 100644 --- a/stark/miden-lifted-stark/README.md +++ b/stark/miden-lifted-stark/README.md @@ -21,9 +21,9 @@ miden-lifted-stark ← this crate └── miden-lifted-air ← AIR traits (aux columns, periodic columns) ``` -The system supports **multiple traces of different power-of-two heights**. -Shorter traces are virtually lifted to the maximum height via LMCS upsampling, -so the PCS and verifier operate on a single uniform view. +The system supports **multiple traces of different power-of-two heights of at +least 2 rows**. Shorter traces are virtually lifted to the maximum height via +LMCS upsampling, so the PCS and verifier operate on a single uniform view. ## Notation @@ -32,7 +32,7 @@ so the PCS and verifier operate on a single uniform view. - `r_j = N / n_j`: lift ratio (a power of two). - `H`: two-adic subgroup of size `N` with generator `omega_H`. - `g`: multiplicative coset shift (`F::GENERATOR` by convention). -- `D`: constraint degree blowup (here fixed at `D = 4`). +- `D`: quotient-domain blowup, derived per AIR from its constraint degree; the batch value is the max over AIRs. - `gJ`: quotient-domain coset (size `N * D`). - `gK`: PCS/LDE coset (size `N * B`, where `B` is the FRI blowup). - `z`: global out-of-domain point sampled once. @@ -48,43 +48,52 @@ view. Informally, an AIR is "liftable" if transition constraints do not rely on the wrap-around row (last -> first) unless that behavior is explicitly constrained. -See `docs/lifting.md` for a deeper discussion and sufficient conditions. +See the "Mathematical background" in `src/prover/README.md` and +`src/verifier/README.md` for a deeper discussion and sufficient conditions. ## Protocol Summary -### Prover (`prove_multi`) +### Prover (`prove`) -1. **Commit main traces** — LDE each trace on its lifted coset, bit-reverse +1. **Validate and bind instance shape** — Validate runtime inputs, call + `Statement::observe`, then observe the instance count and log trace heights + in instance order. +2. **Commit main traces** — LDE each trace on its lifted coset, bit-reverse rows, build LMCS tree. Send root. -2. **Sample randomness** — Squeeze auxiliary randomness from the Fiat-Shamir +3. **Sample randomness** — Squeeze auxiliary randomness from the Fiat-Shamir channel. Build and commit auxiliary traces. -3. **Sample challenges** — `alpha` (constraint folding) and `beta` +4. **Sample challenges** — `alpha` (constraint folding) and `beta` (cross-trace accumulation). -4. **Evaluate constraints** — For each trace in ascending height order, - evaluate AIR constraints on the quotient domain using SIMD-packed - arithmetic. Produces a numerator N_j per trace (no vanishing division). -5. **Accumulate numerators** — Fold across traces: - `acc = cyclic_extend(acc) * beta + N_j`. -6. **Divide by vanishing polynomial** — One pass on the full quotient domain, - exploiting Z_H periodicity for batch inverse. +5. **Evaluate per-AIR quotients** — For each trace in ascending height order, + evaluate AIR constraints on that AIR's native quotient domain using + SIMD-packed arithmetic and divide by its trace vanishing polynomial. + Produces Q_j per trace. +6. **Accumulate quotients** — Fold across traces: + `acc = cyclic_extend(acc) * beta + Q_j`. 7. **Commit quotient** — Decompose Q into D chunks via fused iDFT + coefficient scaling + flatten + DFT pipeline. Commit via LMCS. 8. **Sample OOD point z** — Rejection-sampled to lie outside H and the LDE coset. 9. **Open via PCS** — Delegate to the internal `pcs` modules. -### Verifier (`verify_multi`) +### Verifier (`verify`) -1. **Receive commitments** — Main, auxiliary, and quotient roots from transcript. -2. **Re-derive challenges** — Same `alpha`, `beta`, `z` via Fiat-Shamir. -3. **Verify PCS openings** — At `[z, z_next]` where `z_next = z * omega_H`. -4. **Reconstruct Q(z)** — Barycentric interpolation over the D quotient +1. **Validate and bind instance shape** — Validate proof heights against the + AIRs, call `Statement::observe`, then observe the instance count and log + trace heights in instance order. +2. **Receive commitments** — Main, auxiliary, and quotient roots from transcript. +3. **Re-derive challenges** — Same `alpha`, `beta`, `z` via Fiat-Shamir. +4. **Verify PCS openings** — At `[z, z_next]` where `z_next = z * omega_H`. +5. **Reconstruct Q(z)** — Barycentric interpolation over the D quotient chunks. -5. **Evaluate constraints at OOD** — For each AIR at the lifted OOD point +6. **Evaluate constraints at OOD** — For each AIR at the lifted OOD point `y_j = z^{r_j}`: compute selectors, evaluate periodic polynomials, fold constraints with alpha, accumulate with beta. -6. **Check identity** — `accumulated == Q(z) * Z_H(z)`. -7. **Ensure transcript is fully consumed** — Canonicality enforcement. +7. **Evaluate external assertions** — Call `Statement::eval_external` + once over the global view (challenges, aux values, log heights); each + returned EF value must equal zero. +8. **Check identity** — `accumulated == Q(z) * Z_H(z)`. +9. **Ensure transcript is fully consumed** — Canonicality enforcement. ## Math Sketch @@ -109,31 +118,20 @@ constraint count ahead of time. ### Cross-Trace Accumulation -Numerators from traces of increasing height are combined: +Per-AIR quotients from traces of increasing height are combined: ``` -acc = cyclic_extend(acc) * beta + N_j +acc = cyclic_extend(acc) * beta + Q_j ``` -where `cyclic_extend` repeats the accumulator via modular indexing -(`i & (len - 1)`) to match the next trace's quotient domain size. -This works because: +where `Q_j` is the AIR's folded constraint numerator divided by its trace +vanishing polynomial on the native quotient coset `gJ_j`. If an AIR uses a +smaller quotient degree than the batch maximum, `Q_j` is first low-degree +extended along the quotient-degree axis. `cyclic_extend` then repeats the +accumulator via modular indexing (`i & (len - 1)`) to match the next trace's +quotient domain size. -``` -Z_H(x) = Z_{H^r}(x) * Phi_r(x) -``` - -so cyclic extension of a polynomial divisible by `Z_{H^r}` preserves -divisibility by `Z_H`. - -### Vanishing Division - -After accumulation, the combined numerator is divided by `Z_H(x) = x^N - 1` -once on the full quotient domain. - -On the quotient coset `gJ` (where `|J| = N * D`), the values `x^N` range over a -size-`D` subgroup, so `Z_H(x)` takes only `D` distinct values. The prover can -batch-invert those `D` values once and index them by `i mod D`. +Vanishing division is therefore per-AIR, not a final global division pass. ### Quotient Decomposition @@ -169,8 +167,9 @@ at `y_j`, and the opened trace values already correspond to `p_j(y_j)`. - **Fused quotient pipeline** — iDFT, coefficient scaling by `(omega^t)^{-k}`, flatten to base field, zero-pad, forward DFT — all in one pass, no redundant coset operations. -- **Periodic vanishing exploit** — On the quotient coset `gJ`, `Z_H(x)` takes - only `D` distinct values; batch inverse computes those once. +- **Periodic vanishing exploit** — On each AIR's quotient coset `gJ_j`, + `Z_{H_j}(x)` takes only `D_j` distinct values; batch inverse computes those + once. - **Zero-copy quotient domain** — `split_rows().bit_reverse_rows()` gives a natural-order view of committed LDE data without copying. - **Efficient periodic columns** — Only `max_period * blowup` LDE values @@ -178,53 +177,58 @@ at `y_j`, and the opened trace values already correspond to `p_j(y_j)`. - **Cyclic extension** — Cross-trace accumulation uses bitwise AND for modular indexing (power-of-two sizes). - **Parallel execution** — Rayon parallelism throughout constraint evaluation - and vanishing division (gated by `parallel` feature). + and per-AIR quotient division (gated by `parallel` feature). ## Entry Points | Item | Purpose | |------|---------| -| `prover::prove_single` | Prove a single-AIR STARK | -| `prover::prove_multi` | Prove a multi-trace STARK | -| `AirWitness` | Prover witness (trace + public values) | -| `verifier::verify_single` | Verify a single-AIR proof | -| `verifier::verify_multi` | Verify a multi-trace proof | -| `AirInstance` | Verifier instance (public values + variable-length inputs) | -| `Transcript` | Structured transcript view (alias for `proof::StarkTranscript`) | +| `prover::prove` | Prove one or more AIR instances | +| `ProverStatement` | Validated proving input: a `Statement` plus per-AIR main witness traces in instance order | +| `Statement` | A `MultiAir` plus validated per-proof caller inputs (`air_inputs`, optional `aux_inputs`) | +| `MultiAir` | Trusted statement definition: AIR instances, cross-AIR assertions, and statement observation hooks | +| `verifier::verify` | Verify a multi-trace proof | +| `MultiAir::eval_external` | Cross-AIR assertion hook: returns assertion expression values to be checked for zero (default: no assertions) | +| `Statement::aux_inputs` | Auxiliary public inputs consumed only by `eval_external` (empty unless provided) | +| `StarkProof` | Structured parse-only view of the proof; `log_trace_heights()` exposes instance-order heights and `air_order()` exposes the derived proof-order mapping | | `StarkConfig` | PCS params + LMCS + DFT configuration | -| `coset::LiftedCoset` | Domain operations: selectors, vanishing, coset shifts | +| `pcs` | Structured PCS sub-proof types (DEEP / FRI) for inspection and error matching | ## Modules | Path | Purpose | |------|---------| | `src/config.rs` | `StarkConfig` — wraps `PcsParams`, LMCS, and DFT | -| `src/coset.rs` | `LiftedCoset` — domain queries, selector computation, vanishing | +| `src/domain.rs` | `TwoAdicSubgroup`, `TwoAdicCoset`, `LiftedDomain` — the domain hierarchy; `log_quotient_degree`, `DomainError` (incl. the `log_quotient_degree ≤ log_blowup` compat bound) | | `src/selectors.rs` | `Selectors` — generic container for row selectors | -| `src/prover/mod.rs` | `prove_single`, `prove_multi` — orchestration and protocol flow | +| `src/prover/mod.rs` | `prove` — orchestration and protocol flow | | `src/prover/commit.rs` | `Committed` — LDE, bit-reverse, LMCS tree construction | | `src/prover/constraints/` | Constraint evaluation (SIMD) and layout discovery | | `src/prover/periodic.rs` | `PeriodicLde` — precomputed periodic column LDEs | -| `src/prover/quotient.rs` | Quotient construction, cyclic extension, vanishing division | -| `src/verifier/mod.rs` | `verify_single`, `verify_multi` — orchestration and identity check | +| `src/prover/quotient.rs` | Quotient upsampling, cyclic extension, and commitment | +| `src/verifier/mod.rs` | `verify` — orchestration and identity check | | `src/verifier/constraints.rs` | `ConstraintFolder` — OOD constraint evaluation, quotient reconstruction | | `src/verifier/periodic.rs` | `PeriodicPolys` — polynomial coefficients for OOD evaluation | -| `src/proof.rs` | `StarkProof`, `StarkTranscript` — proof artifact and structured transcript view | -| `src/instance.rs` | `AirInstance`, `AirWitness`, `InstanceShapes` — protocol-level instance types | +| `src/proof.rs` | `StarkProofData`, `StarkProof` — wire artifact and structured parse-only view | +| `src/order.rs` | public `ShapeError` plus the crate-internal instance↔proof ordering helper; `TraceOrder` construction validates the proof's log heights against the AIRs | +| `src/debug.rs` | `check_constraints` (row-by-row), structural assertion (`assert_prover_setup`) over `miden_lifted_air::debug::assert_multi_air_valid` | ## Conventions & Assumptions -- **AIR ordering** — The proof defines an ordering of AIR instances - (queryable via `InstanceShapes::air_order`). The caller must bind AIR - configurations and `air_order` into the Fiat-Shamir challenger. See the - prover module-level docs. -- **Power-of-two heights** — All trace heights are powers of two. +- **AIR ordering** — The proof orders AIR instances deterministically by trace + height (stable sort on `(log_trace_height, instance_index)`), materialised + internally from the heights stored on `StarkProof`; the ordering type is + crate-private. The caller must bind the AIR list into the Fiat-Shamir + challenger. See the prover module-level docs. +- **Power-of-two heights** — All trace heights are powers of two and at least 2 rows. - **Bit-reversed storage** — All evaluation matrices are in bit-reversed order. -- **Constraint degree** — Fixed at `D = 4` (`LOG_CONSTRAINT_DEGREE = 2`). - Both prover and verifier must agree on this constant. -- **Transcript ordering** — The Fiat-Shamir transcript follows a strict - observe/squeeze protocol. Prover and verifier must process commitments and - challenges in identical order. This is security-critical. +- **Quotient degree** — Derived per AIR from symbolic constraint-degree analysis + (`log_quotient_degree`); the proof uses the max over AIRs. Degree-2 AIRs are + valid and use the protocol's minimum quotient chunk count. Each AIR must + satisfy `log_quotient_degree(air) ≤ log_blowup`. +- **Transcript ordering** — `Statement::observe` absorbs statement-owned inputs; + prover and verifier then observe the instance count and log trace heights in + instance order. All later observe/squeeze steps must match exactly. - **Extension field discipline** — Main trace and preprocessed data stay in the base field. Only auxiliary columns, challenges, alpha powers, and the accumulator use the extension field. @@ -233,17 +237,18 @@ at `y_j`, and the opened trace values already correspond to `p_j(y_j)`. ## Tests -The end-to-end test suite lives in `tests/`: +The end-to-end test suite lives in `src/testing/`, behind the `testing` feature: -- **`tiny_air.rs`** — `TinyAir` exercising single-trace, multi-trace +- **`test_tiny_air.rs`** — `TinyAir` exercising single-trace, multi-trace (same and different heights), periodic columns, and malformed transcript rejection. -- **`aux_shape.rs`** — Validates that mismatched auxiliary trace dimensions - are caught. +- **`test_external_assertions.rs`** — `MultiAir::eval_external` and `aux_inputs`. +- **`test_multi_aux_alignment.rs`** — aux-trace alignment across multiple AIRs. +- **`test_per_air_degree.rs`** — per-AIR quotient degrees. Run with: ```bash -cargo test -p miden-lifted-stark +cargo test -p miden-lifted-stark --features testing ``` ## Security diff --git a/stark/miden-lifted-stark/benches/deep_quotient.rs b/stark/miden-lifted-stark/benches/deep_quotient.rs index 3f248c86e5..9abb947de3 100644 --- a/stark/miden-lifted-stark/benches/deep_quotient.rs +++ b/stark/miden-lifted-stark/benches/deep_quotient.rs @@ -17,13 +17,11 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; -use miden_lifted_stark::{ - Lmcs, LmcsTree, - testing::{ - LOG_HEIGHTS, PARALLEL_STR, PointQuotients, RELATIVE_SPECS, bit_reversed_coset_points, - configs::goldilocks_poseidon2::{Felt, QuadFelt, test_lmcs}, - generate_matrices_from_specs, total_elements, - }, +use miden_lifted_stark::testing::{ + Coset, LOG_HEIGHTS, Lmcs, LmcsTree, PARALLEL_STR, PointQuotients, RELATIVE_SPECS, + canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt, test_lmcs}, + generate_matrices_from_specs, total_elements, }; use p3_field::FieldArray; use p3_matrix::dense::RowMajorMatrix; @@ -53,7 +51,8 @@ fn bench_deep_quotient(c: &mut Criterion) { matrix_groups.iter().map(|matrices| lmcs.build_tree(matrices.clone())).collect(); // Precompute coset points (LDE domain matches max matrix height) - let coset_points = bit_reversed_coset_points::(log_lde_height); + let domain = canonical_domain::(log_lde_height, 0); + let coset_points = domain.lde_coset().bit_reversed_points(); // Get matrix references from trees (stored as BitReversedMatrixView after build_tree) let matrices_refs: Vec> = diff --git a/stark/miden-lifted-stark/benches/merkle_commit.rs b/stark/miden-lifted-stark/benches/merkle_commit.rs index 00b2c91ea7..fae9ed10c6 100644 --- a/stark/miden-lifted-stark/benches/merkle_commit.rs +++ b/stark/miden-lifted-stark/benches/merkle_commit.rs @@ -17,13 +17,10 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; -use miden_lifted_stark::{ - Lmcs, LmcsTree, - testing::{ - LOG_HEIGHTS, PARALLEL_STR, RELATIVE_SPECS, - configs::goldilocks_poseidon2::{Felt, QuadFelt, test_lmcs}, - generate_matrices_from_specs, total_elements, - }, +use miden_lifted_stark::testing::{ + LOG_HEIGHTS, Lmcs, LmcsTree, PARALLEL_STR, RELATIVE_SPECS, + configs::goldilocks_poseidon2::{Felt, QuadFelt, test_lmcs}, + generate_matrices_from_specs, total_elements, }; use p3_matrix::{bitrev::BitReversalPerm, dense::RowMajorMatrix, extension::FlatMatrixView}; use rand::{SeedableRng, rngs::SmallRng}; diff --git a/stark/miden-lifted-stark/benches/pcs.rs b/stark/miden-lifted-stark/benches/pcs.rs index 6dfe59391a..e01899efec 100644 --- a/stark/miden-lifted-stark/benches/pcs.rs +++ b/stark/miden-lifted-stark/benches/pcs.rs @@ -7,26 +7,23 @@ use std::hint::black_box; use criterion::{Criterion, Throughput, criterion_group, criterion_main}; -use miden_lifted_stark::{ - Lmcs, LmcsTree, log2_strict_u8, - testing::{ - BENCH_PCS_PARAMS, LOG_HEIGHTS, PARALLEL_STR, RELATIVE_SPECS, - configs::goldilocks_poseidon2::{Felt, QuadFelt, test_challenger, test_lmcs}, - generate_matrices_from_specs, open_with_channel, total_elements, - }, +use miden_lifted_stark::testing::{ + BENCH_PCS_PARAMS, LOG_HEIGHTS, Lmcs, LmcsTree, PARALLEL_STR, RELATIVE_SPECS, canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt, test_challenger, test_lmcs}, + generate_matrices_from_specs, open_with_channel, total_elements, }; use miden_stark_transcript::ProverTranscript; use p3_challenger::{CanObserve, FieldChallenger}; use p3_dft::{Radix2DitParallel, TwoAdicSubgroupDft}; -use p3_field::Field; use p3_matrix::{Matrix, dense::RowMajorMatrix}; fn bench_pcs(c: &mut Criterion) { let dft = Radix2DitParallel::::default(); - let shift = Felt::GENERATOR; let lmcs = test_lmcs(); for &log_lde_height in LOG_HEIGHTS { + let domain = canonical_domain::(log_lde_height, 0); + let shift = domain.lde_shift(); let max_lde_size = 1usize << log_lde_height; let group_name = format!("PCS_Open/{max_lde_size}/goldilocks/poseidon2/{PARALLEL_STR}"); let mut group = c.benchmark_group(&group_name); @@ -48,7 +45,6 @@ fn bench_pcs(c: &mut Criterion) { let tree = lmcs.build_aligned_tree(all_lde_matrices); let commitment = tree.root(); - let log_lde_height = log2_strict_u8(tree.height()); let base_challenger = test_challenger(); @@ -65,7 +61,7 @@ fn bench_pcs(c: &mut Criterion) { open_with_channel::( &BENCH_PCS_PARAMS, &lmcs, - log_lde_height, + &domain, [z1, z2], trace_trees, &mut channel, diff --git a/stark/miden-lifted-stark/benches/pcs_trace.rs b/stark/miden-lifted-stark/benches/pcs_trace.rs index b7dcad3b55..d500befcd6 100644 --- a/stark/miden-lifted-stark/benches/pcs_trace.rs +++ b/stark/miden-lifted-stark/benches/pcs_trace.rs @@ -10,18 +10,14 @@ use std::time::Instant; -use miden_lifted_stark::{ - Lmcs, LmcsTree, PcsParams, log2_strict_u8, - testing::{ - LOG_HEIGHTS, RELATIVE_SPECS, - configs::goldilocks_poseidon2::{Felt, QuadFelt, test_challenger, test_lmcs}, - generate_matrices_from_specs, open_with_channel, - }, +use miden_lifted_stark::testing::{ + LOG_HEIGHTS, Lmcs, LmcsTree, PcsParams, RELATIVE_SPECS, canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt, test_challenger, test_lmcs}, + generate_matrices_from_specs, open_with_channel, }; use miden_stark_transcript::ProverTranscript; use p3_challenger::{CanObserve, FieldChallenger}; use p3_dft::{Radix2DitParallel, TwoAdicSubgroupDft}; -use p3_field::Field; use p3_matrix::{Matrix, bitrev::BitReversibleMatrix, dense::RowMajorMatrix}; use tracing_subscriber::EnvFilter; @@ -36,7 +32,6 @@ fn main() { .init(); let dft = Radix2DitParallel::::default(); - let shift = Felt::GENERATOR; let params = PcsParams::new( 2, // log_blowup @@ -55,6 +50,10 @@ fn main() { eprintln!("=== Goldilocks lifted/arity4 log_height={log_lde_height} (n={size}) ==="); eprintln!("{}\n", "=".repeat(60)); + // LDE coset for this batch — sole source of `F::GENERATOR`. + let domain = canonical_domain::(log_lde_height, 0); + let shift = domain.lde_shift(); + let matrix_groups: Vec>> = generate_matrices_from_specs(RELATIVE_SPECS, log_lde_height); @@ -76,7 +75,6 @@ fn main() { let tree = lmcs.build_aligned_tree(all_lde_matrices); let commitment = tree.root(); - let log_lde_height = log2_strict_u8(tree.height()); let mut challenger = test_challenger(); challenger.observe(commitment); @@ -90,7 +88,7 @@ fn main() { open_with_channel::( ¶ms, &lmcs, - log_lde_height, + &domain, [z1, z2], trace_trees, &mut channel, diff --git a/stark/miden-lifted-stark/benches/per_air_degree_opt.rs b/stark/miden-lifted-stark/benches/per_air_degree_opt.rs new file mode 100644 index 0000000000..42e3526e55 --- /dev/null +++ b/stark/miden-lifted-stark/benches/per_air_degree_opt.rs @@ -0,0 +1,325 @@ +//! Illustrates the per-AIR quotient-degree optimization. +//! +//! This bench compares two synthetic configurations of the **same** code path, +//! toggled via [`OverrideLogQuotientDegree`]. It is **not** a comparison against +//! the codebase's pre-PR state; it isolates the speedup attributable to evaluating +//! a low-degree AIR on its native quotient domain rather than the global one. +//! +//! Runs `prove` with two AIRs: +//! - Core AIR: width 72, max constraint degree 9 (so `D = 8`). +//! - Chip AIR: width 72, max constraint degree 5 (so `D = 4`). +//! +//! Two variants: +//! - **Baseline**: both AIRs are forced to `constraint_degree = 9` (so `D = 8`). Chip's constraints +//! are evaluated on `chip_height * 8` points (the full target domain) rather than its natural +//! `chip_height * 4`. +//! - **Optimized**: Core forced to `constraint_degree = 9` (`D = 8`), Chip to `constraint_degree = +//! 5` (`D = 4`). Chip evaluates on its native domain (size `chip_height * 4`), divides by the +//! native vanishing polynomial, and `upsample_evals` lifts the resulting quotient to the target +//! (`chip_height * 8`). +//! +//! Run: +//! ```bash +//! RUSTFLAGS="-Ctarget-cpu=native" cargo bench -p miden-lifted-stark \ +//! --bench per_air_degree_opt --features testing,parallel +//! ``` + +use std::time::Instant; + +use miden_lifted_air::{ + AirBuilder, BaseAir, ConstraintDegrees, LiftedAir, LiftedAirBuilder, WindowAccess, +}; +use miden_lifted_stark::{ + GenericStarkConfig, ProverInstance, + testing::{ + MultiAir, PcsParams, ProverStatement, Statement, + configs::goldilocks_poseidon2::{Dft, Felt, QuadFelt, test_challenger, test_lmcs}, + }, +}; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; +use tracing_subscriber::EnvFilter; + +// ----------------------------------------------------------------------------- +// AIR +// ----------------------------------------------------------------------------- + +const WIDTH: usize = 72; + +/// Number of redundant recurrence constraints per column. Each copy is a separate +/// `assert_eq` call (so the symbolic analyzer counts them independently) but all +/// reduce to the same `next[c] == local[c]^power` identity, so any trace satisfying +/// one copy satisfies all of them. This dials up per-point constraint-evaluation +/// work without changing trace semantics or constraint degree. +const CONSTRAINTS_PER_COLUMN: usize = 20; + +/// Per-row transition kind used by the benchmark AIR. +#[derive(Clone, Copy, Debug)] +enum BenchAirKind { + Core, + Chip, +} + +impl BenchAirKind { + fn log2_hi(self) -> usize { + match self { + BenchAirKind::Core => 3, + BenchAirKind::Chip => 2, + } + } + + fn recurrence_power(self) -> u64 { + match self { + BenchAirKind::Core => 9, + BenchAirKind::Chip => 5, + } + } +} + +#[derive(Clone, Copy, Debug)] +struct BenchAir { + kind: BenchAirKind, +} + +impl BaseAir for BenchAir { + fn width(&self) -> usize { + WIDTH + } +} + +impl LiftedAir for BenchAir { + fn num_randomness(&self) -> usize { + 1 + } + + fn aux_width(&self) -> usize { + 1 + } + + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + // Trivial aux: a single constant-challenge column. + (RowMajorMatrix::new(vec![challenges[0]; main.height()], 1), vec![]) + } + + fn eval>(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.current_slice().to_vec(), main.next_slice().to_vec()); + let log2_hi = self.kind.log2_hi(); + + // Main recurrence on every column: `next[c] = local[c]^(2^log2_hi + 1)`. + // Duplicate the constraint `CONSTRAINTS_PER_COLUMN` times to densify eval work + // without changing the trace or the max constraint degree. + for _ in 0..CONSTRAINTS_PER_COLUMN { + for c in 0..WIDTH { + let x: AB::Expr = local[c].into(); + let x_hi: AB::Expr = x.clone().exp_power_of_2(log2_hi); + builder.when_transition().assert_eq(next[c].into(), x_hi * x); + } + } + + // Aux constraint: `aux_local == challenge`, trivial degree-1 identity. + let aux = builder.permutation(); + let aux_local = aux.current_slice().to_vec(); + let challenge = builder.permutation_randomness()[0]; + let aux_expr: AB::ExprEF = aux_local[0].into(); + let challenge_expr: AB::ExprEF = challenge.into(); + builder.assert_eq_ext(aux_expr, challenge_expr); + } +} + +/// Test wrapper that forces [`LiftedAir::constraint_degree`] to a chosen value +/// (quotient chunking — and hence `log_quotient_degree` — is derived from it). +/// +/// Delegates everything else to the inner AIR. Used by this bench to toggle between +/// the baseline (force every AIR to the global-max degree) and the optimized path +/// (each AIR at its natural degree). +/// +/// Overriding *higher* than the AIR actually needs is safe: the prover/verifier just +/// use a larger quotient domain than necessary. Overriding *lower* than needed would +/// produce an invalid proof. +#[derive(Clone, Copy, Debug)] +struct OverrideConstraintDegree { + inner: A, + constraint_degree: usize, +} + +/// Constraint degree that yields `log_quotient_degree == l`: +/// `log2_ceil(((1 << l) + 1) - 1) == log2_ceil(1 << l) == l`. +fn constraint_degree_for_log_qd(l: usize) -> usize { + (1 << l) + 1 +} + +impl> BaseAir for OverrideConstraintDegree { + fn width(&self) -> usize { + self.inner.width() + } + fn num_public_values(&self) -> usize { + self.inner.num_public_values() + } +} + +impl> LiftedAir for OverrideConstraintDegree { + fn periodic_columns(&self) -> Vec> { + self.inner.periodic_columns() + } + fn num_randomness(&self) -> usize { + self.inner.num_randomness() + } + fn aux_width(&self) -> usize { + self.inner.aux_width() + } + fn num_aux_values(&self) -> usize { + self.inner.num_aux_values() + } + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + air_inputs: &[Felt], + aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + self.inner.build_aux_trace(main, air_inputs, aux_inputs, challenges) + } + fn eval>(&self, builder: &mut AB) { + self.inner.eval(builder) + } + fn constraint_degree(&self) -> ConstraintDegrees + where + Self: Sized, + { + ConstraintDegrees { base: self.constraint_degree, ext: 0 } + } +} + +// ----------------------------------------------------------------------------- +// MultiAir (trivial: constant-challenge aux column for each trace) +// ----------------------------------------------------------------------------- + +struct BenchMultiAir { + airs: Vec>, +} + +impl MultiAir for BenchMultiAir { + type Air = OverrideConstraintDegree; + + fn airs(&self) -> &[Self::Air] { + &self.airs + } +} + +// ----------------------------------------------------------------------------- +// Trace generation +// ----------------------------------------------------------------------------- + +/// Generate a `WIDTH x height` trace satisfying `next[c] = local[c]^power`. +/// +/// Column `c` starts at `c + 2` (so rows are non-constant and all-distinct at t=0). +fn generate_trace(power: u64, height: usize) -> RowMajorMatrix { + let mut data = Vec::with_capacity(WIDTH * height); + // Row 0: col c = c + 2 (base field). + for c in 0..WIDTH { + data.push(Felt::from_u64((c + 2) as u64)); + } + // Rows 1..height: col c = prev^power. + for r in 1..height { + let prev_row_start = (r - 1) * WIDTH; + for c in 0..WIDTH { + let prev = data[prev_row_start + c]; + data.push(prev.exp_u64(power)); + } + } + RowMajorMatrix::new(data, WIDTH) +} + +// ----------------------------------------------------------------------------- +// Driver +// ----------------------------------------------------------------------------- + +/// Custom PCS params for this bench: `log_blowup = 3` to permit `log_qd = 3`. +fn bench_pcs_params() -> PcsParams { + PcsParams::new( + 3, // log_blowup (must be >= max log_qd = 3) + 2, // log_folding_arity (arity 4) + 2, // log_final_degree + 0, // folding_pow_bits + 0, // deep_pow_bits + 30, // num_queries + 0, // query_pow_bits + ) + .expect("valid PCS params") +} + +fn run_prove( + label: &str, + core_log_qd: usize, + chip_log_qd: usize, + core_height: usize, + chip_height: usize, +) { + let config = + GenericStarkConfig::new(bench_pcs_params(), test_lmcs(), Dft::default(), test_challenger()); + + let core_air = OverrideConstraintDegree { + inner: BenchAir { kind: BenchAirKind::Core }, + constraint_degree: constraint_degree_for_log_qd(core_log_qd), + }; + let chip_air = OverrideConstraintDegree { + inner: BenchAir { kind: BenchAirKind::Chip }, + constraint_degree: constraint_degree_for_log_qd(chip_log_qd), + }; + + let core_trace = generate_trace(BenchAirKind::Core.recurrence_power(), core_height); + let chip_trace = generate_trace(BenchAirKind::Chip.recurrence_power(), chip_height); + + let statement = + Statement::new(BenchMultiAir { airs: vec![core_air, chip_air] }, Vec::new(), Vec::new()) + .unwrap(); + let prover_statement = ProverStatement::new(statement, vec![core_trace, chip_trace]).unwrap(); + + eprintln!("\n{}", "=".repeat(70)); + eprintln!( + "=== {label}: core(h={core_height}, log_qd={core_log_qd}) chip(h={chip_height}, log_qd={chip_log_qd}) ===" + ); + eprintln!("{}\n", "=".repeat(70)); + + let start = Instant::now(); + let _output = ProverInstance::new(&config, &prover_statement, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect("prove succeeds"); + let elapsed = start.elapsed(); + + eprintln!(">>> Total prove time: {elapsed:.3?}\n"); +} + +fn main() { + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("debug")), + ) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE) + .init(); + + // Sizes chosen so eval makes a substantive fraction of prove time while keeping + // the bench under ~1 minute. core @ 2^14, chip @ 2^16. + let core_height = 1 << 14; + let chip_height = 1 << 16; + + // Baseline: force both AIRs to report log_qd = 3 (global max). Chip loses the + // optimization and evaluates on chip_height*8 points directly, no upsample. + run_prove("baseline (no upsample)", 3, 3, core_height, chip_height); + + // Optimized: Core at natural log_qd=3, Chip at natural log_qd=2. Chip evaluates + // on chip_height*4 points natively, then upsamples to chip_height*8. + run_prove("optimized (upsample fires)", 3, 2, core_height, chip_height); +} diff --git a/stark/miden-lifted-stark/benches/plonky3.rs b/stark/miden-lifted-stark/benches/plonky3.rs index 058802104f..5a4e5cfb36 100644 --- a/stark/miden-lifted-stark/benches/plonky3.rs +++ b/stark/miden-lifted-stark/benches/plonky3.rs @@ -15,24 +15,21 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; -use miden_lifted_stark::{ - LiftedCoset, Lmcs, LmcsTree, log2_strict_u8, - testing::{ - BENCH_PCS_PARAMS, LOG_HEIGHTS, PARALLEL_STR, QC_CONSTRAINT_DEGREE, QC_PCS_PARAMS, - RELATIVE_SPECS, commit_quotient, - configs::{ - goldilocks_blake3_192 as gl_blake3_192, goldilocks_keccak as gl_keccak, - goldilocks_poseidon2 as gl, - }, - generate_matrices_from_specs, open_with_channel, total_elements, +use miden_lifted_stark::testing::{ + BENCH_PCS_PARAMS, Coset, LOG_HEIGHTS, Lmcs, LmcsTree, PARALLEL_STR, QC_CONSTRAINT_DEGREE, + QC_PCS_PARAMS, RELATIVE_SPECS, canonical_domain, commit_quotient, + configs::{ + goldilocks_blake3_192 as gl_blake3_192, goldilocks_keccak as gl_keccak, + goldilocks_poseidon2 as gl, }, + generate_matrices_from_specs, log2_strict_u8, open_with_channel, total_elements, }; use miden_stark_transcript::ProverTranscript; use p3_blake3::Blake3; use p3_challenger::{CanObserve, FieldChallenger}; use p3_commit::{ExtensionMmcs, Mmcs, Pcs}; use p3_dft::{Radix2DitParallel, TwoAdicSubgroupDft}; -use p3_field::{Field, coset::TwoAdicMultiplicativeCoset}; +use p3_field::coset::TwoAdicMultiplicativeCoset; use p3_fri::{FriParameters, TwoAdicFriPcs}; use p3_keccak::KeccakF; use p3_matrix::{Matrix, dense::RowMajorMatrix}; @@ -172,9 +169,10 @@ fn bench_lmcs_vs_mmcs(c: &mut Criterion) { fn bench_pcs_open(c: &mut Criterion) { let dft = Radix2DitParallel::::default(); - let shift = gl::Felt::GENERATOR; for &log_lde_height in LOG_HEIGHTS { + let domain = canonical_domain::(log_lde_height, 0); + let shift = domain.lde_shift(); let max_lde_size = 1usize << log_lde_height; let group_name = format!("PCS_Open/{max_lde_size}/goldilocks/poseidon2/{PARALLEL_STR}"); let mut group = c.benchmark_group(&group_name); @@ -263,7 +261,6 @@ fn bench_pcs_open(c: &mut Criterion) { let tree = lmcs.build_aligned_tree(all_lde_matrices); let commitment = tree.root(); - let log_lde_height = log2_strict_u8(tree.height()); let base_challenger = gl::test_challenger(); @@ -280,7 +277,7 @@ fn bench_pcs_open(c: &mut Criterion) { open_with_channel::( &BENCH_PCS_PARAMS, &lmcs, - log_lde_height, + &domain, [z1, z2], trace_trees, &mut channel, @@ -325,13 +322,14 @@ fn bench_quotient_commit(c: &mut Criterion) { // --- Lifted --- { let config = lifted_config(); - let coset = LiftedCoset::unlifted(log_n, QC_PCS_PARAMS.log_blowup()); + let domain = canonical_domain::(log_n, QC_PCS_PARAMS.log_blowup()) + .evaluation_domain(log_d); group.bench_function(BenchmarkId::new("lifted", &label), |bench| { bench.iter(|| { let mut q_evals = random_quotient_evals(n, QC_CONSTRAINT_DEGREE, 42); q_evals.reserve(n * b - n * QC_CONSTRAINT_DEGREE); - let committed = commit_quotient(&config, q_evals, &coset); + let committed = commit_quotient(&config, q_evals, &domain); black_box(committed) }); }); @@ -340,8 +338,10 @@ fn bench_quotient_commit(c: &mut Criterion) { // --- Plonky3 PCS --- { let pcs = workspace_pcs(QC_PCS_PARAMS.log_blowup() as usize, 0, 1, 1); + // Quotient evaluation domain (order N·D, sharing the LDE shift). + let q_domain = canonical_domain::(log_n, log_d).evaluation_domain(log_d); let quotient_domain = - TwoAdicMultiplicativeCoset::new(gl::Felt::GENERATOR, (log_n + log_d) as usize) + TwoAdicMultiplicativeCoset::new(q_domain.shift(), q_domain.log_size() as usize) .unwrap(); group.bench_function(BenchmarkId::new("plonky3_pcs", &label), |bench| { diff --git a/stark/miden-lifted-stark/benches/quotient_commit.rs b/stark/miden-lifted-stark/benches/quotient_commit.rs index 8b113a3509..6c28cf02bb 100644 --- a/stark/miden-lifted-stark/benches/quotient_commit.rs +++ b/stark/miden-lifted-stark/benches/quotient_commit.rs @@ -11,9 +11,10 @@ use std::hint::black_box; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use miden_lifted_stark::{ - GenericStarkConfig, LiftedCoset, + GenericStarkConfig, testing::{ - QC_CONSTRAINT_DEGREE, QC_PCS_PARAMS, commit_quotient, configs::goldilocks_poseidon2 as gl, + QC_CONSTRAINT_DEGREE, QC_PCS_PARAMS, canonical_domain, commit_quotient, + configs::goldilocks_poseidon2 as gl, log2_strict_u8, }, }; use p3_dft::Radix2DitParallel; @@ -38,13 +39,15 @@ fn bench_quotient_commit(c: &mut Criterion) { let b = 1usize << QC_PCS_PARAMS.log_blowup(); let label = format!("N=2^{log_n}"); - let coset = LiftedCoset::unlifted(log_n, QC_PCS_PARAMS.log_blowup()); + let log_d = log2_strict_u8(QC_CONSTRAINT_DEGREE); + let domain = canonical_domain::(log_n, QC_PCS_PARAMS.log_blowup()) + .evaluation_domain(log_d); group.bench_function(BenchmarkId::new("lifted", &label), |bench| { bench.iter(|| { let mut q_evals = random_quotient_evals(n, QC_CONSTRAINT_DEGREE, 42); q_evals.reserve(n * b - n * QC_CONSTRAINT_DEGREE); - let committed = commit_quotient(&config, q_evals, &coset); + let committed = commit_quotient(&config, q_evals, &domain); black_box(committed) }); }); diff --git a/stark/miden-lifted-stark/src/coset.rs b/stark/miden-lifted-stark/src/coset.rs deleted file mode 100644 index 2d7fb1ddb7..0000000000 --- a/stark/miden-lifted-stark/src/coset.rs +++ /dev/null @@ -1,419 +0,0 @@ -//! Lifted coset domain abstraction with selector and vanishing computation. -//! -//! This module provides [`LiftedCoset`](crate::coset::LiftedCoset), the central abstraction for -//! domain operations in lifted STARKs where traces of different heights share a common evaluation -//! domain. - -use alloc::vec::Vec; - -use miden_stark_transcript::Channel; -use p3_field::{ExtensionField, TwoAdicField, batch_multiplicative_inverse}; -use p3_maybe_rayon::prelude::*; - -use crate::{pcs::params::MAX_LOG_DOMAIN_SIZE, selectors::Selectors}; - -// ============================================================================ -// LiftedCoset -// ============================================================================ - -/// Lifted coset for polynomial evaluation. -/// -/// Represents a coset (gK)ʳ where: -/// - K is the evaluation domain of size 2^log_lde_height -/// - r = 2^log_lift_ratio is the lift factor (row repetition) -/// - The shift is gʳ where g = F::GENERATOR -/// -/// Key relationships: -/// - log_blowup = log_lde_height - log_trace_height -/// - log_lift_ratio = log_max_lde_height - log_lde_height -/// - lde_shift = gʳ = F::GENERATOR.exp_power_of_2(log_lift_ratio) -/// -/// # Invariants -/// -/// - `log_lde_height = log_trace_height + log_blowup` -/// - `log_lde_height <= log_max_lde_height` -/// - All heights are powers of two (stored as log values) -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct LiftedCoset { - /// Log₂ of the original trace height. - pub log_trace_height: u8, - /// Log₂ of this matrix's LDE height. - pub log_lde_height: u8, - /// Log₂ of the maximum LDE height in the commitment. - pub log_max_lde_height: u8, -} - -impl LiftedCoset { - /// Create a new `LiftedCoset`. - /// - /// Both `log_lde_height` and `log_max_lde_height` are derived by adding - /// `log_blowup` to the respective trace heights. - /// - /// # Panics - /// - /// Panics if `log_trace_height > log_max_trace_height` or if - /// `log_max_trace_height + log_blowup > MAX_LOG_DOMAIN_SIZE`. - #[inline] - pub fn new(log_trace_height: u8, log_blowup: u8, log_max_trace_height: u8) -> Self { - assert!( - log_trace_height <= log_max_trace_height, - "trace height cannot exceed max trace height" - ); - let log_lde_height = log_trace_height as u16 + log_blowup as u16; - let log_max_lde_height = log_max_trace_height as u16 + log_blowup as u16; - assert!( - log_max_lde_height <= MAX_LOG_DOMAIN_SIZE as u16, - "LDE height 2^{log_max_lde_height} exceeds maximum 2^{MAX_LOG_DOMAIN_SIZE}", - ); - Self { - log_trace_height, - log_lde_height: log_lde_height as u8, - log_max_lde_height: log_max_lde_height as u8, - } - } - - /// Create a `LiftedCoset` at max height (no lifting). - /// - /// Convenience for the common single-trace case where the LDE height - /// equals the max LDE height. - #[inline] - pub fn unlifted(log_trace_height: u8, log_blowup: u8) -> Self { - let log_lde_height = log_trace_height as u16 + log_blowup as u16; - assert!( - log_lde_height <= MAX_LOG_DOMAIN_SIZE as u16, - "LDE height 2^{log_lde_height} exceeds maximum 2^{MAX_LOG_DOMAIN_SIZE}", - ); - Self { - log_trace_height, - log_lde_height: log_lde_height as u8, - log_max_lde_height: log_lde_height as u8, - } - } - - // ============ Existing methods ============ - - /// Log₂ of the blowup factor for this matrix. - /// - /// Returns `log_lde_height - log_trace_height`. - #[inline] - pub fn log_blowup(&self) -> usize { - (self.log_lde_height - self.log_trace_height) as usize - } - - /// Log₂ of the lift ratio for this matrix. - /// - /// The lift ratio is how many times this matrix's rows are virtually repeated - /// to match the max LDE height: `max_lde_height / lde_height`. - /// - /// Returns `log_max_lde_height - log_lde_height`. - #[inline] - pub fn log_lift_ratio(&self) -> usize { - (self.log_max_lde_height - self.log_lde_height) as usize - } - - /// Whether this matrix is lifted (its LDE height is less than the max). - #[inline] - pub fn is_lifted(&self) -> bool { - self.log_lde_height < self.log_max_lde_height - } - - /// Compute the coset shift for this matrix's LDE domain. - /// - /// For a matrix with lift ratio `r = 2^log_lift_ratio`, the coset shift is gʳ - /// where g is the field generator. - /// - /// Why gʳ: lifting embeds a smaller-domain polynomial into the max domain by - /// composition `p_lift(X) = p(Xʳ)`. Evaluating `p_lift` on the max coset `g·K_max` - /// corresponds to evaluating `p` on the nested coset `gʳ·K`, because - /// `(g·ω)ʳ = gʳ·ωʳ` and ωʳ ranges over K when ω ranges over `K_max`. - #[inline] - pub fn lde_shift(&self) -> F { - F::GENERATOR.exp_power_of_2(self.log_lift_ratio()) - } - - /// The trace height (number of constraint rows). - #[inline] - pub fn trace_height(&self) -> usize { - 1 << self.log_trace_height as usize - } - - /// The LDE height for this matrix. - #[inline] - pub fn lde_height(&self) -> usize { - 1 << self.log_lde_height as usize - } - - /// The maximum LDE height across all matrices. - #[inline] - pub fn max_lde_height(&self) -> usize { - 1 << self.log_max_lde_height as usize - } - - /// The blowup factor for this matrix. - #[inline] - pub fn blowup(&self) -> usize { - 1 << self.log_blowup() - } - - // ============ Domain derivation ============ - - /// Derive the quotient domain coset from this LDE coset. - /// - /// For constraint evaluation, we need a coset of size `trace_height * constraint_degree`. - /// This transforms (gK)ʳ into (gJ)ʳ while preserving the lift ratio. - /// - /// # Panics - /// Panics if log_constraint_degree > log_blowup. - /// - /// The quotient domain is a strict subset of the committed LDE domain. - /// - /// If the constraint degree is `D`, the resulting quotient polynomial has degree - /// `< N * (D - 1)`, so `N * D` evaluation points suffice for commitment and for the - /// verifier's reconstruction. The PCS uses a larger blowup `B`, so the committed - /// LDE domain `gK` has `N * B` points, but constraint evaluation only needs the - /// sub-coset `gJ` of size `N * D` (with `D <= B`). - pub fn quotient_domain(&self, log_constraint_degree: u8) -> Self { - let log_blowup = self.log_lde_height - self.log_trace_height; - assert!(log_constraint_degree <= log_blowup, "constraint degree cannot exceed blowup"); - let log_max_trace_height = self.log_max_lde_height - log_blowup; - Self { - log_trace_height: self.log_trace_height, - log_lde_height: self.log_trace_height + log_constraint_degree, - log_max_lde_height: log_max_trace_height + log_constraint_degree, - } - } - - // ============ Selector computation ============ - - /// Compute selectors for evaluation over this coset in natural order. - /// - /// Returns is_first_row, is_last_row, is_transition for each point in the coset - /// (gK)ʳ. The trace domain H has size `2^log_trace_height`. - /// - /// Selectors use unnormalized Lagrange basis polynomials. The is_first_row selector - /// is `L₀(x) = Z_H(x) / (x − 1)`, which equals 0 on all of H except the first row. - /// When multiplied by a constraint C(x), it enforces C only at the first row: - /// `L₀(x)·C(x)` vanishes on H iff `C(1) = 0`. Similarly, - /// `is_last_row = Z_H(x) / (x − ω⁻¹)`. - /// - /// The is_transition selector is `(x − ω⁻¹)`, which is nonzero everywhere except the - /// last row, enforcing transition constraints on all consecutive row pairs. - /// These are "unnormalized" because we omit the constant factor 1/N that would make - /// them evaluate to exactly 1 at their target row. This is fine because both prover - /// and verifier evaluate the same unnormalized form: multiplying all boundary constraints - /// by a common nonzero constant does not affect whether the quotient is a polynomial. - pub fn selectors(&self) -> Selectors> { - let shift: F = self.lde_shift(); - let coset_size = self.lde_height(); - let log_blowup = self.log_blowup(); - - // Z_H(x) = xⁿ − 1 evaluated at coset points. - // Periodic with 2^log_blowup distinct values; expand to full coset size for zip. - let s_pow_n = shift.exp_power_of_2(self.log_trace_height as usize); - let z_h_periodic: Vec = F::two_adic_generator(log_blowup) - .shifted_powers(s_pow_n) - .take(1 << log_blowup) - .map(|x| x - F::ONE) - .collect(); - let period = z_h_periodic.len(); - - // Coset points in natural order: shift·ω_Jⁱ - let omega_j = F::two_adic_generator(self.log_lde_height as usize); - let xs: Vec = omega_j.shifted_powers(shift).collect_n(coset_size); - - let omega_h_inv = F::two_adic_generator(self.log_trace_height as usize).inverse(); - - // Unnormalized Lagrange selector: selᵢ = Z_H(xᵢ) / (xᵢ − basis_point) - // Uses modular indexing into z_h_periodic to avoid a full-size allocation. - let single_point_selector = |basis_point: F| -> Vec { - let denoms: Vec = xs.par_iter().map(|&x| x - basis_point).collect(); - let invs = batch_multiplicative_inverse(&denoms); - (0..coset_size) - .into_par_iter() - .map(|i| z_h_periodic[i % period] * invs[i]) - .collect() - }; - - Selectors { - is_first_row: single_point_selector(F::ONE), - is_last_row: single_point_selector(omega_h_inv), - is_transition: xs.into_par_iter().map(|x| x - omega_h_inv).collect(), - } - } - - /// Lifted selectors at the OOD point (verifier). - /// - /// For a selector `s(x)` defined over the original trace domain of size `n_j`, - /// lifting evaluates `s(z_lift)` where `z_lift = z^r` and - /// `r = 2^log_lift_ratio = max_n / n_j`. This maps the OOD point `z` - /// (sampled in the max-trace domain) into the per-instance trace domain. - /// - /// # Formulas (unnormalized) - /// - `is_first_row = Z_H(z_lift) / (z_lift − 1)` - /// - `is_last_row = Z_H(z_lift) / (z_lift − ω_{n_j}⁻¹)` - /// - `is_transition = z_lift − ω_{n_j}⁻¹` - /// - /// where `Z_H(z_lift) = z_lift^{n_j} − 1 = z^{max_n} − 1`. - pub fn selectors_at(&self, z: EF) -> Selectors - where - F: TwoAdicField, - EF: ExtensionField, - { - let z_lift = z.exp_power_of_2(self.log_lift_ratio()); - let vanishing = self.vanishing_at::(z_lift); - let omega_inv = F::two_adic_generator(self.log_trace_height as usize).inverse(); - - Selectors { - is_first_row: vanishing / (z_lift - F::ONE), - is_last_row: vanishing / (z_lift - omega_inv), - is_transition: z_lift - omega_inv, - } - } - - // ============ Vanishing polynomial ============ - - /// Vanishing polynomial at an out-of-domain point. - /// - /// Returns `Z_H(z) = zⁿ − 1` using `exp_power_of_2` (log-many squarings). - pub fn vanishing_at(&self, z: EF) -> EF - where - F: TwoAdicField, - EF: ExtensionField, - { - z.exp_power_of_2(self.log_trace_height as usize) - EF::ONE - } - - // ============ Domain membership ============ - - /// Check if a point is in the trace domain H. - /// - /// Returns true if `z^N == 1` where N is the trace height. - /// Points in H cause division by zero in vanishing polynomial inversion. - #[inline] - pub fn is_in_trace_domain(&self, z: EF) -> bool - where - F: TwoAdicField, - EF: ExtensionField, - { - z.exp_power_of_2(self.log_trace_height as usize) == EF::ONE - } - - /// Check if a point is in the LDE coset gK. - /// - /// Returns true if `(z/g)^|K| == 1` where g is the generator shift - /// and K is the LDE domain. Points in gK cause division by zero in DEEP quotients. - #[inline] - pub fn is_in_lde_coset(&self, z: EF) -> bool - where - F: TwoAdicField, - EF: ExtensionField, - { - let shift: F = self.lde_shift(); - let z_over_shift = z * shift.inverse(); - z_over_shift.exp_power_of_2(self.log_lde_height as usize) == EF::ONE - } - - // ============ OOD point sampling ============ - - /// Sample an OOD evaluation point from the channel that lies outside both the - /// trace-domain subgroup `H` and the LDE evaluation coset `gK`. - /// - /// Repeatedly draws `sample_algebra_element` candidates until one satisfies - /// both exclusion tests. This terminates with overwhelming probability because - /// `|H ∪ gK|` is negligible relative to the extension field size. - pub fn sample_ood_point(&self, channel: &mut impl Channel) -> EF - where - F: TwoAdicField, - EF: ExtensionField, - { - loop { - let candidate: EF = channel.sample_algebra_element(); - if !self.is_in_trace_domain::(candidate) - && !self.is_in_lde_coset::(candidate) - { - break candidate; - } - } - } -} - -// ============================================================================ -// Tests -// ============================================================================ - -#[cfg(test)] -mod tests { - use p3_field::{Field, PrimeCharacteristicRing}; - - use super::*; - use crate::testing::configs::goldilocks_poseidon2::Felt; - - #[test] - fn domain_info_basic() { - // Trace height 2^10, blowup 2^3, max trace 2^12 - let info = LiftedCoset::new(10, 3, 12); - - assert_eq!(info.log_trace_height, 10); - assert_eq!(info.log_lde_height, 13); - assert_eq!(info.log_max_lde_height, 15); - - assert_eq!(info.log_blowup(), 3); - assert_eq!(info.log_lift_ratio(), 2); - assert!(info.is_lifted()); - - assert_eq!(info.trace_height(), 1024); - assert_eq!(info.lde_height(), 8192); - assert_eq!(info.max_lde_height(), 32768); - assert_eq!(info.blowup(), 8); - } - - #[test] - fn domain_info_no_lift() { - // Matrix at max height (no lifting needed) - let info = LiftedCoset::unlifted(10, 3); - - assert_eq!(info.log_lift_ratio(), 0); - assert!(!info.is_lifted()); - } - - #[test] - fn domain_info_lde_shift() { - // Trace height 2^10, blowup 2^3, max trace 2^12 - let info = LiftedCoset::new(10, 3, 12); - let shift: Felt = info.lde_shift(); - - // shift = g^(2^2) = g^4 - let expected = Felt::GENERATOR.exp_power_of_2(2); - assert_eq!(shift, expected); - } - - #[test] - fn domain_info_no_lift_shift() { - // When not lifted, shift should be g^1 = g - let info = LiftedCoset::unlifted(10, 3); - let shift: Felt = info.lde_shift(); - - // shift = g^(2^0) = g - assert_eq!(shift, Felt::GENERATOR); - } - - #[test] - fn quotient_domain_preserves_lift_ratio_and_updates_blowup() { - // Trace height 2^10, blowup 2^3 (B=8), max trace 2^12. - let lde = LiftedCoset::new(10, 3, 12); - - // Constraint degree D = 4 (log D = 2), so quotient domain size is N*D. - let q = lde.quotient_domain(2); - - // Trace height is unchanged; the evaluation domain becomes N*D. - assert_eq!(q.log_trace_height, 10); - assert_eq!(q.log_blowup(), 2); - assert_eq!(q.log_lde_height, 12); - - // Max evaluation domain becomes N_max*D. - assert_eq!(q.log_max_lde_height, 14); - - // Lift ratio is preserved. - assert_eq!(q.log_lift_ratio(), lde.log_lift_ratio()); - } -} diff --git a/stark/miden-lifted-stark/src/debug.rs b/stark/miden-lifted-stark/src/debug.rs index 42aa358c49..739c7132d2 100644 --- a/stark/miden-lifted-stark/src/debug.rs +++ b/stark/miden-lifted-stark/src/debug.rs @@ -1,175 +1,145 @@ -//! Debug constraint checker for lifted AIRs. +//! Debug helpers for lifted AIRs. //! -//! Evaluates constraints row-by-row on concrete trace values and panics if any constraint -//! is nonzero. This avoids the full STARK pipeline (DFT, commitment, FRI) and provides -//! immediate feedback on constraint violations during development. +//! Two flavours of helpers live here: //! -//! # Usage -//! -//! ```ignore -//! use miden_lifted_stark::AirWitness; -//! -//! // Single instance -//! let witness = AirWitness::new(&trace, &public_values, &[]); -//! check_constraints(&air, &witness, &aux_builder, &challenges); -//! -//! // Multiple instances -//! check_constraints_multi( -//! &[(&air_a, witness_a, &builder_a), (&air_b, witness_b, &builder_b)], -//! &challenges, -//! ); -//! ``` +//! - **Structural assertion** ([`assert_prover_setup`]) — a panic-based check over +//! [`miden_lifted_air::debug::assert_multi_air_valid`]. Call it from tests / setup; the prover +//! and verifier hot paths trust the AIR's structural contract. +//! - **Constraint checker** ([`check_constraints`]) — evaluates AIR constraints row-by-row on +//! concrete trace values and panics on the first nonzero constraint. It derives deterministic +//! debug challenges; it does not replay the prover transcript. extern crate alloc; use alloc::vec::Vec; use miden_lifted_air::{ - AirBuilder, AuxBuilder, EmptyWindow, ExtensionBuilder, LiftedAir, PeriodicAirBuilder, - PermutationAirBuilder, RowWindow, + AirBuilder, ExtensionBuilder, LiftedAir, MultiAir, PeriodicAirBuilder, PermutationAirBuilder, + ProverStatement, RowWindow, debug::assert_multi_air_valid, }; +use p3_challenger::{CanObserve, CanSample}; use p3_field::{ExtensionField, Field}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; -use crate::instance::AirWitness; +use crate::order::TraceOrder; // ============================================================================ -// Public API +// Structural assertions (over miden_lifted_air::debug) // ============================================================================ -/// Evaluate every AIR constraint against a concrete trace and panic on failure. +/// Assert the AIR's structural contract via [`assert_multi_air_valid`]. /// -/// Convenience wrapper around [`check_constraints_multi`] for a single instance. +/// Only the *trusted* structural contract is asserted here. The AIR ↔ PCS +/// compatibility bound (`log_quotient_degree <= log_blowup`) is a validated +/// runtime input — prover and verifier surface it as +/// [`DomainError::ConstraintDegreeTooHigh`](crate::DomainError) — so it is not +/// re-checked here. /// -/// # Panics -/// -/// - If the AIR fails validation -/// - If trace dimensions don't match the AIR -/// - If challenges are insufficient -/// - If any constraint evaluates to nonzero on any row -pub fn check_constraints( - air: &A, - witness: AirWitness<'_, F>, - aux_builder: &B, - challenges: &[EF], -) where +/// The preprocessed bundle's shape (tree presence, per-trace width, per-AIR +/// height) is validated at [`ProverInstance::new`](crate::ProverInstance::new) +/// construction time, so it is not re-checked here. +pub fn assert_prover_setup(prover_statement: &ProverStatement) +where F: Field, EF: ExtensionField, - A: LiftedAir, - B: AuxBuilder, + MA: MultiAir, { - check_constraints_multi(&[(air, witness, aux_builder)], challenges); + assert_multi_air_valid::(prover_statement.statement().multi_air()); } -/// Evaluate constraints for multiple AIR instances and panic on failure. +// ============================================================================ +// Public API +// ============================================================================ + +/// Evaluate AIR constraints against concrete trace values and panic on failure. /// -/// Each instance is a tuple of `(air, witness, aux_builder)`. +/// Constraints are checked row-by-row using the trace + aux trace built by +/// [`ProverStatement`]. All AIRs see the same `air_inputs` from `statement`. /// -/// Builds the auxiliary trace for each instance and checks constraints row by row. -/// Uses shared challenges across all instances (caller samples from RNG in test code). +/// Derives auxiliary-trace challenges from the supplied challenger using only +/// statement-owned data plus the instance count and log trace heights. This is +/// a local constraint debugger: it intentionally skips protocol commitments +/// (including any preprocessed setup commitment), so sampled challenges need +/// not match a full proof transcript produced by +/// [`ProverInstance::prove`](crate::ProverInstance::prove). /// /// # Panics /// -/// - If any AIR fails validation /// - If trace dimensions don't match their AIR -/// - If challenges are insufficient /// - If any constraint evaluates to nonzero on any row -pub fn check_constraints_multi( - instances: &[(&A, AirWitness<'_, F>, &B)], - challenges: &[EF], +pub fn check_constraints( + prover_statement: &ProverStatement, + mut challenger: Ch, ) where F: Field, EF: ExtensionField, - A: LiftedAir, - B: AuxBuilder, + MA: MultiAir, + Ch: CanObserve + CanSample, { - assert!(!instances.is_empty(), "no instances provided"); - - // Sort by (trace_height, caller_index) to match InstanceShapes::from_trace_heights. - let mut perm: Vec = (0..instances.len()).collect(); - perm.sort_by_key(|&i| (instances[i].1.trace.height(), i)); - - for (i, &orig_idx) in perm.iter().enumerate() { - let &(air, ref witness, aux_builder) = &instances[orig_idx]; - - air.validate() - .unwrap_or_else(|e| panic!("AIR validation failed for instance {i}: {e}")); - - let main = witness.trace; - let height = main.height(); - - // Main trace dimensions. - assert!( - height.is_power_of_two(), - "instance {i}: trace height {height} is not a power of two" - ); - assert_eq!( - main.width, - air.width(), - "instance {i}: main trace width mismatch: expected {}, got {}", - air.width(), - main.width - ); - assert_eq!( - witness.public_values.len(), - air.num_public_values(), - "instance {i}: public values length mismatch: expected {}, got {}", - air.num_public_values(), - witness.public_values.len() - ); - assert_eq!( - witness.var_len_public_inputs.len(), - air.num_var_len_public_inputs(), - "instance {i}: var-len public inputs count mismatch: expected {}, got {}", - air.num_var_len_public_inputs(), - witness.var_len_public_inputs.len() - ); - assert!( - challenges.len() >= air.num_randomness(), - "instance {i}: not enough challenges: need {}, got {}", - air.num_randomness(), - challenges.len() - ); - - // Build auxiliary trace. + let statement = prover_statement.statement(); + let airs = statement.airs(); + let traces = prover_statement.traces(); + let air_inputs = statement.air_inputs(); + let aux_inputs = statement.aux_inputs(); + assert!(!airs.is_empty(), "no instances provided"); + assert_eq!(airs.len(), traces.len(), "airs and traces counts must match"); + + // Seed deterministic debug challenges from statement/height observations only. + // Do not observe setup/trace commitments or replay the prover transcript here. + let trace_heights: Vec = traces.iter().map(Matrix::height).collect(); + let trace_order = TraceOrder::from_trace_heights::(airs, &trace_heights) + .expect("ProverStatement::new should reject malformed heights"); + statement.observe(&mut challenger, trace_order.log_heights()); + trace_order.observe_shape::(&mut challenger); + let max_num_randomness = airs.iter().map(LiftedAir::num_randomness).max().unwrap_or(0); + let challenges: Vec = (0..max_num_randomness) + .map(|_| EF::from_basis_coefficients_fn(|_| challenger.sample())) + .collect(); + + let mut aux_traces = Vec::with_capacity(airs.len()); + let mut aux_values_per_air = Vec::with_capacity(airs.len()); + for (air, main) in airs.iter().zip(traces.iter()) { + let num_randomness = air.num_randomness(); let (aux_trace, aux_values) = - aux_builder.build_aux_trace(main, &challenges[..air.num_randomness()]); + air.build_aux_trace(main, air_inputs, aux_inputs, &challenges[..num_randomness]); + aux_traces.push(aux_trace); + aux_values_per_air.push(aux_values); + } + + // Mirror the verifier's external-assertion check: the cross-AIR + // interactions must hold for these aux values and public inputs. Each + // assertion is a concrete value, so a zero-check is exact. + let aux_views: Vec<&[EF]> = aux_values_per_air.iter().map(Vec::as_slice).collect(); + let assertions = statement + .eval_external(&challenges, &aux_views, trace_order.log_heights()) + .expect("eval_external failed during check_constraints"); + for (k, assertion) in assertions.iter().enumerate() { + assert_eq!(*assertion, EF::ZERO, "external assertion {k} is non-zero"); + } - // Auxiliary trace dimensions. + for (i, ((air, main), (aux_trace, aux_values))) in airs + .iter() + .zip(traces.iter()) + .zip(aux_traces.iter().zip(aux_values_per_air.iter())) + .enumerate() + { + // `check_builder_shape` validates row-window widths per row, but aux trace + // height is invisible to a single row window — check it here so a short aux + // trace fails cleanly rather than via an opaque row-slice panic. assert_eq!( aux_trace.height(), - height, - "instance {i}: aux trace height mismatch: expected {height}, got {}", + main.height(), + "instance {i}: aux trace height mismatch: expected {}, got {}", + main.height(), aux_trace.height() ); - assert_eq!( - aux_trace.width, - air.aux_width(), - "instance {i}: aux trace width mismatch: expected {}, got {}", - air.aux_width(), - aux_trace.width - ); - assert_eq!( - aux_values.len(), - air.num_aux_values(), - "instance {i}: aux values count mismatch: expected {}, got {}", - air.num_aux_values(), - aux_values.len() - ); - check_single_trace( - air, - main, - &aux_trace, - &aux_values, - witness.public_values, - challenges, - i, - ); + check_single_trace(air, main, aux_trace, aux_values, air_inputs, &challenges, i); } } /// Check constraints for one instance's traces row by row. +#[allow(clippy::too_many_arguments)] fn check_single_trace( air: &A, main: &RowMajorMatrix, @@ -184,6 +154,20 @@ fn check_single_trace( A: LiftedAir, { let height = main.height(); + + // Preprocessed matrix comes straight off the AIR (debug-only; this + // re-materialises `BaseAir::preprocessed_trace`). Its height must match the main + // trace; width is checked per row by `check_builder_shape`. + let preprocessed = air.preprocessed_trace(); + if let Some(preproc) = &preprocessed { + assert_eq!( + preproc.height(), + height, + "instance {instance_index}: preprocessed trace height mismatch: expected {height}, got {}", + preproc.height() + ); + } + let periodic_matrix = air.periodic_columns_matrix(); for row in 0..height { let next_row = (row + 1) % height; @@ -200,8 +184,16 @@ fn check_single_trace( let periodic_row = periodic_matrix.as_ref().map(|m| m.row_slice(row % m.height()).unwrap()); let periodic_values: &[F] = periodic_row.as_deref().unwrap_or(&[]); + // Preprocessed rows (empty window when the AIR declares none). + let preprocessed_current = preprocessed.as_ref().map(|m| m.row_slice(row).unwrap()); + let preprocessed_next = preprocessed.as_ref().map(|m| m.row_slice(next_row).unwrap()); + let mut builder = DebugConstraintBuilder { main: RowWindow::from_two_rows(&main_current, &main_next), + preprocessed: RowWindow::from_two_rows( + preprocessed_current.as_deref().unwrap_or(&[]), + preprocessed_next.as_deref().unwrap_or(&[]), + ), permutation: RowWindow::from_two_rows(&aux_current, &aux_next), randomness: &challenges[..air.num_randomness()], public_values, @@ -214,7 +206,8 @@ fn check_single_trace( row_index: row, }; - debug_assert!(air.is_valid_builder(&builder).is_ok()); + #[cfg(debug_assertions)] + miden_lifted_air::debug::check_builder_shape(air, &builder); air.eval(&mut builder); } @@ -231,6 +224,7 @@ fn check_single_trace( /// (permutation) trace, matching the actual field layout of lifted STARK traces. struct DebugConstraintBuilder<'a, F: Field, EF: ExtensionField> { main: RowWindow<'a, F>, + preprocessed: RowWindow<'a, F>, permutation: RowWindow<'a, EF>, randomness: &'a [EF], public_values: &'a [F], @@ -251,7 +245,7 @@ where type F = F; type Expr = F; type Var = F; - type PreprocessedWindow = EmptyWindow; + type PreprocessedWindow = RowWindow<'a, F>; type MainWindow = RowWindow<'a, F>; type PublicVar = F; @@ -260,7 +254,7 @@ where } fn preprocessed(&self) -> &Self::PreprocessedWindow { - EmptyWindow::empty_ref() + &self.preprocessed } fn is_first_row(&self) -> Self::Expr { diff --git a/stark/miden-lifted-stark/src/domain.rs b/stark/miden-lifted-stark/src/domain.rs new file mode 100644 index 0000000000..d3b05517df --- /dev/null +++ b/stark/miden-lifted-stark/src/domain.rs @@ -0,0 +1,1064 @@ +//! Domain hierarchy: subgroup → coset → lifted domain. +//! +//! This module hosts three concrete types representing a STARK's evaluation +//! domains, each adding one level of structure: +//! +//! 1. [`TwoAdicSubgroup`] — the bare multiplicative subgroup `H = ⟨ω⟩ ⊂ F^*` of order +//! `2^log_size`. **Single home of `F::two_adic_generator(...)` in this crate.** +//! +//! 2. [`TwoAdicCoset`] — a coset `s · H` of a [`TwoAdicSubgroup`], carrying a multiplicative +//! shift `s ∈ F`. +//! +//! 3. [`LiftedDomain`] — the STARK-protocol object: an LDE coset `g^r · K` together with a +//! smaller trace subgroup `H ⊆ K` and a lift ratio `r` relative to the max coset `g · K_max`. +//! **Single home of `F::GENERATOR` in this crate** (used in [`LiftedDomain::try_canonical`] and +//! [`LiftedDomain::try_sub_domain`] to compute the canonical LDE coset shift). +//! +//! [`TwoAdicSubgroup`] and [`TwoAdicCoset`] both implement the [`Coset`] +//! trait, which captures the shared `(log_size, shift, generator)` interface +//! and provides default bodies for `size`, `point_at`, `points`, +//! `bit_reversed_points`, `vanishing_at`, `contains`, `generator_inverse`. +//! [`LiftedDomain`] does *not* implement [`Coset`] — it composes a trace +//! subgroup and an LDE coset, and exposing a single coset interface would +//! silently pick one of two distinct vanishing polynomials. Callers say +//! `domain.trace_subgroup()` or `domain.lde_coset()` to disambiguate. + +use alloc::vec::Vec; +use core::marker::PhantomData; + +use miden_lifted_air::{LiftedAir, log2_ceil_u8}; +use miden_stark_transcript::Channel; +use p3_field::{ExtensionField, Field, TwoAdicField, batch_multiplicative_inverse}; +use p3_maybe_rayon::prelude::*; +use p3_util::reverse_slice_index_bits; +use thiserror::Error; + +use crate::selectors::Selectors; + +// ============================================================================ +// Errors +// ============================================================================ + +/// Errors from validated `LiftedDomain` construction (the `try_*` family). +/// +/// `LiftedDomain::try_canonical` and `LiftedDomain::try_sub_domain` read +/// parameters that may come from untrusted inputs (proofs, instance metadata) +/// and surface a recoverable error rather than panic. Test and benchmark +/// fixtures pick statically valid sizes and use the panicking +/// `testing::canonical_domain` helper instead. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum DomainError { + /// `log_lde_order = log_trace_height + log_blowup` exceeds the smaller of + /// `F::TWO_ADICITY` (no `2^log`-th root of unity exists) and + /// `usize::BITS - 1` (32-bit overflow guard). + #[error( + "LDE log order {log_lde_order} exceeds bound {bound} (min of F::TWO_ADICITY and usize::BITS-1)" + )] + LdeOrderTooLarge { log_lde_order: usize, bound: usize }, + /// Sub-domain construction with a trace height larger than the parent's. + #[error("sub-domain trace log size {smaller} exceeds parent {parent}")] + SubDomainTooLarge { smaller: u8, parent: u8 }, + /// No heights supplied to a multi-height constructor. + #[error("no trace heights supplied")] + EmptyHeights, + /// Heights are not in non-decreasing order. + #[error("trace heights are not in non-decreasing order")] + HeightsNotAscending, + + /// Some AIR's `log_quotient_degree` exceeds the PCS blowup, so its + /// quotient polynomial would not fit the committed LDE. The recoverable + /// twin of the invariant `EvaluationDomain::new` asserts; the prover and + /// verifier check it inline (they already compute the max degree for the + /// quotient domain). + #[error("log_quotient_degree {log_quotient} > log_blowup {log_blowup}")] + ConstraintDegreeTooHigh { log_quotient: u8, log_blowup: u8 }, +} + +// ============================================================================ +// Coset trait +// ============================================================================ + +/// Shared interface for two-adic coset-like multiplicative domains. +/// +/// A coset `s · H` is parameterised by: +/// - `log_size`: log₂ of the order `|H|`, +/// - `shift`: the multiplicative offset `s ∈ F` (`F::ONE` for a plain subgroup), +/// - `generator`: a primitive `2^log_size`-th root of unity. +/// +/// Implemented by `TwoAdicSubgroup` (with `shift = F::ONE`) and +/// `TwoAdicCoset`. The default bodies for `point_at`, `points`, +/// `bit_reversed_points`, `vanishing_at`, `contains`, `size`, and +/// `generator_inverse` are written once here in terms of the three required +/// methods — for a subgroup, the shift collapses to `F::ONE` and the formulas +/// reduce to the unshifted case. +/// +/// [`LiftedDomain`] deliberately does **not** implement this trait: it composes +/// a trace subgroup and an LDE coset, each with its own vanishing polynomial, +/// and exposing a single `Coset` interface would force one to silently win. +/// Callers reach into the parts: `domain.trace_subgroup()` or +/// `domain.lde_coset()`. +pub trait Coset: Sized { + /// Log₂ of the domain order. + fn log_size(&self) -> u8; + + /// Multiplicative shift `s` (`F::ONE` for a subgroup). + fn shift(&self) -> F; + + /// Primitive `2^log_size`-th root of unity. + fn generator(&self) -> F; + + /// Cached inverse of [`Self::shift`]. The default body recomputes it on + /// every call; implementations should override either with a free + /// constant (`F::ONE` for a subgroup) or by returning a value computed + /// once at construction time. `vanishing_at` and `contains` route through + /// this method so multiple invocations on the same coset don't redo the + /// inversion. + #[inline] + fn shift_inverse(&self) -> F { + self.shift().inverse() + } + + /// Domain order: `2^log_size`. + #[inline] + fn size(&self) -> usize { + 1 << self.log_size() as usize + } + + /// Inverse of the generator. Convenient for FRI twiddle factors and + /// last-row Lagrange denominators. + #[inline] + fn generator_inverse(&self) -> F { + self.generator().inverse() + } + + /// The `i`-th point in natural order: `s · ωⁱ`. + #[inline] + fn point_at(&self, i: u64) -> F { + self.shift() * self.generator().exp_u64(i) + } + + /// All points in natural order, length `2^log_size`. + /// + /// Single-pass iteration via `shifted_powers(self.shift())`: produces + /// `s, s·ω, s·ω², …` directly without an intermediate "unshifted then map" + /// step. + fn points(&self) -> Vec { + self.generator().shifted_powers(self.shift()).take(self.size()).collect() + } + + /// All points in bit-reversed order. + fn bit_reversed_points(&self) -> Vec { + let mut pts = self.points(); + reverse_slice_index_bits(&mut pts); + pts + } + + /// Vanishing polynomial of the domain at `z`: `(z/s)^|H| − 1`. + /// + /// For a subgroup (`s = 1`) this reduces to `z^|H| − 1`. + #[inline] + fn vanishing_at>(&self, z: EF) -> EF { + (z * self.shift_inverse()).exp_power_of_2(self.log_size() as usize) - EF::ONE + } + + /// Membership test: returns `true` iff `(z/s)^|H| == 1`. + #[inline] + fn contains>(&self, z: EF) -> bool { + (z * self.shift_inverse()).exp_power_of_2(self.log_size() as usize) == EF::ONE + } +} + +// ============================================================================ +// TwoAdicSubgroup +// ============================================================================ + +/// Multiplicative subgroup of `F^*` of order `2^log_size`. +/// +/// Implements [`Coset`] with `shift = F::ONE`. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct TwoAdicSubgroup { + log_size: u8, + _phantom: PhantomData, +} + +impl TwoAdicSubgroup { + /// Create a subgroup of order `2^log_size`. + /// + /// # Panics + /// + /// Panics if `log_size > F::TWO_ADICITY` — a primitive `2^log_size`-th root + /// of unity does not exist in `F` for that size. + #[inline] + pub fn new(log_size: u8) -> Self { + assert!( + (log_size as usize) <= F::TWO_ADICITY, + "subgroup log size {log_size} exceeds field two-adicity {}", + F::TWO_ADICITY, + ); + Self { log_size, _phantom: PhantomData } + } + + /// Smaller subgroup of order `2^(log_size − log_factor)`. + /// + /// Used by FRI to derive the next round's domain. + /// + /// # Panics + /// + /// Panics if `log_factor > self.log_size`. + #[inline] + pub fn shrink(self, log_factor: u8) -> Self { + assert!( + log_factor <= self.log_size, + "cannot shrink subgroup of log size {} by {log_factor}", + self.log_size, + ); + Self::new(self.log_size - log_factor) + } +} + +impl Coset for TwoAdicSubgroup { + #[inline] + fn log_size(&self) -> u8 { + self.log_size + } + + #[inline] + fn shift(&self) -> F { + F::ONE + } + + /// `F::ONE` is its own inverse — skip the inversion entirely. + #[inline] + fn shift_inverse(&self) -> F { + F::ONE + } + + /// **The only place in the crate that calls `F::two_adic_generator(...)` + /// directly.** All other domain-indexed two-adic roots flow from this + /// method via the [`Coset`] trait. + #[inline] + fn generator(&self) -> F { + F::two_adic_generator(self.log_size as usize) + } +} + +// ============================================================================ +// TwoAdicCoset +// ============================================================================ + +/// A coset `s · H` of a [`TwoAdicSubgroup`] `H`, carrying a multiplicative shift +/// `s ∈ F`. +/// +/// `TwoAdicSubgroup` is the special case `s = 1`. The two are kept distinct so +/// the type system can express "no shift" without runtime checks. Both +/// implement [`Coset`]. +/// +/// The shift must be non-zero. The constructor panics on zero shift because a +/// zero-shifted set is not a multiplicative coset. +#[derive(Copy, Clone, Debug)] +pub struct TwoAdicCoset { + subgroup: TwoAdicSubgroup, + shift: F, + shift_inverse: F, +} + +impl TwoAdicCoset { + /// Create a coset `shift · subgroup`. Computes and caches `shift⁻¹`. + /// + /// Panics if `shift` is zero. + #[inline] + pub fn new(subgroup: TwoAdicSubgroup, shift: F) -> Self { + assert!(shift != F::ZERO, "coset shift must be non-zero"); + Self { + subgroup, + shift, + shift_inverse: shift.inverse(), + } + } + + /// The underlying subgroup `H`. + #[inline] + pub fn subgroup(&self) -> &TwoAdicSubgroup { + &self.subgroup + } +} + +impl PartialEq for TwoAdicCoset { + /// Equality ignores the cached `shift_inverse` field — it is fully + /// determined by `shift`. + #[inline] + fn eq(&self, other: &Self) -> bool { + self.subgroup == other.subgroup && self.shift == other.shift + } +} + +impl Eq for TwoAdicCoset {} + +impl Coset for TwoAdicCoset { + #[inline] + fn log_size(&self) -> u8 { + self.subgroup.log_size() + } + + #[inline] + fn shift(&self) -> F { + self.shift + } + + /// Returns the cached inverse computed at construction time. + #[inline] + fn shift_inverse(&self) -> F { + self.shift_inverse + } + + #[inline] + fn generator(&self) -> F { + self.subgroup.generator() + } +} + +// ============================================================================ +// LiftedDomain +// ============================================================================ + +/// STARK lifted-domain object: an LDE coset `g^r·K` together with a smaller +/// trace subgroup `H ⊆ K` and a lift ratio `r` relative to the max coset +/// `g·K_max`. +/// +/// **Single home of `F::GENERATOR` in this crate** — exposed via +/// [`LiftedDomain::canonical_lde_shift`] and used internally by +/// [`LiftedDomain::try_canonical`] and [`LiftedDomain::try_sub_domain`]. +/// +/// # Invariants +/// +/// - `lde_coset.log_size() = trace_subgroup.log_size() + log_blowup` (where `log_blowup = lde − +/// trace`) +/// - `lde_coset.shift() = F::GENERATOR.exp_power_of_2(F::TWO_ADICITY − lde_coset.log_size())` — the +/// canonical, batch-independent shift for this LDE order. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct LiftedDomain { + /// Trace domain `H` — size `2^log_trace_height`. + trace_subgroup: TwoAdicSubgroup, + /// LDE evaluation coset `g^r·K` — size `2^log_lde_height`, shift `g^r`. + lde_coset: TwoAdicCoset, + /// Log₂ of the lift ratio: `log_max_lde_height − log_lde_height`. + log_lift_ratio: u8, +} + +impl LiftedDomain { + /// The canonical LDE coset shift `g^(2^(F::TWO_ADICITY − log_lde_order))` + /// for an order-`2^log_lde_order` domain. + /// + /// Equivalent to the LDE shift of `try_canonical(t, b)` for + /// `log_lde_order = t + b`, without building a [`LiftedDomain`]. Returns + /// `None` if `log_lde_order > F::TWO_ADICITY`. + /// + /// Single source of `F::GENERATOR` in this crate. + #[inline] + pub fn canonical_lde_shift(log_lde_order: u8) -> Option { + let exp = F::TWO_ADICITY.checked_sub(log_lde_order as usize)?; + Some(F::GENERATOR.exp_power_of_2(exp)) + } + + /// Create the canonical domain for `(trace, blowup)`: trace height + /// `2^log_trace_height`, LDE blowup `2^log_blowup`, `log_lift_ratio = 0`. + /// + /// The LDE coset shift is the **canonical generator power** + /// `g^(2^(F::TWO_ADICITY − log_lde_order))`, where `log_lde_order = + /// log_trace_height + log_blowup`. This shift depends only on the LDE + /// order — it is invariant of which other matrices appear in the same + /// batch — making the LDE shift of `try_canonical(t, b)` a function of + /// the field and `(t, b)` alone. + /// + /// This is the **primary constructor**. Per-batch sub-domains derive + /// from this one via [`try_sub_domain`](Self::try_sub_domain), which + /// shrinks the trace subgroup; the shift is recomputed from the new LDE + /// order (still canonical for that order). + /// + /// Validates parameters that may come from untrusted input (proofs, + /// instance metadata), returning [`DomainError::LdeOrderTooLarge`] if + /// `log_trace_height + log_blowup` exceeds the smaller of `F::TWO_ADICITY` + /// (no `2^log`-th root of unity exists) and `usize::BITS - 1` (32-bit + /// overflow guard). Fixtures with statically valid sizes use the + /// `testing::canonical_domain` helper. + #[inline] + pub fn try_canonical(log_trace_height: u8, log_blowup: u8) -> Result { + let log_lde_order = log_trace_height as usize + log_blowup as usize; + let bound = F::TWO_ADICITY.min((usize::BITS - 1) as usize); + if log_lde_order > bound { + return Err(DomainError::LdeOrderTooLarge { log_lde_order, bound }); + } + // Bound check passed → both sub-sizes fit in u8 and inside F's two-adicity. + let log_lde_height = log_lde_order as u8; + let shift = Self::canonical_lde_shift(log_lde_height) + .expect("log_lde_order ≤ F::TWO_ADICITY validated above"); + Ok(Self { + trace_subgroup: TwoAdicSubgroup::new(log_trace_height), + lde_coset: TwoAdicCoset::new(TwoAdicSubgroup::new(log_lde_height), shift), + log_lift_ratio: 0, + }) + } + + /// Derive a sub-domain with a smaller trace subgroup, sharing this + /// domain's blowup. The new domain's shift is the canonical shift for + /// its own (smaller) LDE order — independent of the parent's lift ratio. + /// The new `log_lift_ratio` grows by the trace shrink amount, recording + /// the batch context for OOD lifting. + /// + /// Validates parameters that may come from untrusted input, returning + /// [`DomainError::SubDomainTooLarge`] if `smaller_log_trace_height > + /// self.log_trace_height()`. + #[inline] + pub fn try_sub_domain(&self, smaller_log_trace_height: u8) -> Result { + let log_trace = self.log_trace_height(); + if smaller_log_trace_height > log_trace { + return Err(DomainError::SubDomainTooLarge { + smaller: smaller_log_trace_height, + parent: log_trace, + }); + } + let log_blowup = self.log_blowup(); + let log_lift_ratio_inc = log_trace - smaller_log_trace_height; + let new_log_lift_ratio = self.log_lift_ratio + log_lift_ratio_inc; + let new_log_lde = smaller_log_trace_height + log_blowup; + let shift = Self::canonical_lde_shift(new_log_lde) + .expect("new_log_lde ≤ parent log_lde_height ≤ F::TWO_ADICITY"); + Ok(Self { + trace_subgroup: TwoAdicSubgroup::new(smaller_log_trace_height), + lde_coset: TwoAdicCoset::new(TwoAdicSubgroup::new(new_log_lde), shift), + log_lift_ratio: new_log_lift_ratio, + }) + } + + // ============ Subgroup / coset accessors ============ + + /// The trace domain `H` as a `TwoAdicSubgroup`. + #[inline] + pub fn trace_subgroup(&self) -> &TwoAdicSubgroup { + &self.trace_subgroup + } + + /// The LDE evaluation coset `g^r·K` as a `TwoAdicCoset`. + #[inline] + pub fn lde_coset(&self) -> &TwoAdicCoset { + &self.lde_coset + } + + // ============ Protocol-named height / shift sugar ============ + + /// Log₂ of the original trace height. + #[inline] + pub fn log_trace_height(&self) -> u8 { + self.trace_subgroup.log_size() + } + + /// Log₂ of this matrix's LDE height. + #[inline] + pub fn log_lde_height(&self) -> u8 { + self.lde_coset.log_size() + } + + /// Log₂ of the blowup factor for this matrix: `log_lde_height − log_trace_height`. + #[inline] + pub fn log_blowup(&self) -> u8 { + self.log_lde_height() - self.log_trace_height() + } + + /// The trace height (number of constraint rows). + #[inline] + pub fn trace_height(&self) -> usize { + self.trace_subgroup.size() + } + + /// The LDE height for this matrix. + #[inline] + pub fn lde_height(&self) -> usize { + self.lde_coset.size() + } + + /// The coset shift `g^r` for this matrix's LDE domain. + #[inline] + pub fn lde_shift(&self) -> F { + self.lde_coset.shift() + } + + /// Pair this domain with a quotient degree to form an + /// `EvaluationDomain`, the value carried through the constraint / + /// quotient layer of the protocol. + /// + /// # Panics + /// Panics if `log_quotient_degree > self.log_blowup()`. + #[inline] + pub fn evaluation_domain(self, log_quotient_degree: u8) -> EvaluationDomain { + EvaluationDomain::new(self, log_quotient_degree) + } + + // ============ Selector computation ============ + + /// Unnormalized Lagrange row selectors at an OOD extension-field point `z`, + /// evaluated as if `z` lives on the (lifted) trace subgroup. + /// + /// Internally the OOD point is first lifted via `z' = z^(2^log_lift_ratio)` + /// so the selectors line up with the trace domain regardless of the + /// per-instance lift ratio. For unlifted domains (`log_lift_ratio = 0`) + /// this reduces to plain trace-subgroup selectors at `z`. + /// + /// - `is_first_row = Z_H(z') / (z' − 1)` + /// - `is_last_row = Z_H(z') / (z' − ω_H⁻¹)` + /// - `is_transition = z' − ω_H⁻¹` + /// + /// where `Z_H(z') = z'^N_H − 1` and `ω_H` is the trace subgroup generator. + /// + /// # Panics + /// + /// Panics if `z_lift ∈ {1, ω_H⁻¹}` (denominator zero in the row selectors). + /// For an OOD `z` sampled via [`sample_ood_point`](Self::sample_ood_point) + /// this is statistically impossible; callers passing arbitrary `z` must + /// avoid those two values. + pub fn selectors_at(&self, z: EF) -> Selectors + where + EF: ExtensionField, + { + let z_lift = z.exp_power_of_2(self.log_lift_ratio as usize); + let vanishing = self.trace_subgroup.vanishing_at(z_lift); + let omega_h_inv = self.trace_subgroup.generator_inverse(); + Selectors { + is_first_row: vanishing / (z_lift - F::ONE), + is_last_row: vanishing / (z_lift - omega_h_inv), + is_transition: z_lift - omega_h_inv, + } + } + + // ============ OOD point sampling ============ + + /// Sample an OOD evaluation point outside both `H` and the LDE coset. + /// + /// Repeatedly draws candidates from `channel` until one falls outside both exclusion + /// sets. Terminates with overwhelming probability because `|H ∪ gK|` is negligible + /// relative to the extension field size. + pub fn sample_ood_point(&self, channel: &mut impl Channel) -> EF + where + EF: ExtensionField, + { + loop { + let candidate: EF = channel.sample_algebra_element(); + if !self.trace_subgroup.contains(candidate) && !self.lde_coset.contains(candidate) { + break candidate; + } + } + } +} + +// ============================================================================ +// Quotient degree +// ============================================================================ + +/// Log₂ of the number of quotient chunks for `air`, clamped so the quotient +/// degree `D = 2^log_quotient_degree` is always ≥ 1. +/// +/// # Why `M − 1` chunks? +/// +/// Let N be the trace height (so trace columns are polynomials of degree < N). +/// Symbolic evaluation assigns each constraint a *degree multiple* M +/// (the AIR's combined +/// [`constraint_degree().max()`](miden_lifted_air::ConstraintDegrees::max)), +/// so the numerator polynomial C(X) has degree bounded by roughly M·(N − 1). +/// +/// In a STARK, the constraint numerator is divisible by the trace vanishing +/// polynomial `Z_H(X) = Xᴺ − 1`, so the quotient polynomial +/// `Q(X) = C(X) / Z_H(X)` has +/// +/// `deg(Q) ≤ deg(C) − N ≤ M·(N − 1) − N < (M − 1)·N`. +/// +/// We commit to Q(X) by splitting it into D chunks of degree < N; D = M − 1 +/// suffices, rounded up to a power of two. +/// +/// # Low symbolic degrees +/// +/// Quotient construction still needs at least one chunk. The prover and +/// verifier therefore clamp the derived quotient degree to `D = 1`. The air +/// crate reports raw symbolic degrees; this STARK-layer helper applies the +/// protocol clamp. +pub fn log_quotient_degree(air: &A) -> u8 +where + F: Field, + A: LiftedAir, +{ + // Maximum degree over base and extension field constraints. + let constraint_degree = air.constraint_degree().max(); + // Subtract one quotient chunk for division by the vanishing polynomial. + let quotient_chunks = constraint_degree.saturating_sub(1); + // Clamp to the protocol minimum of one quotient chunk. + let quotient_chunks = quotient_chunks.max(1); + // Return the log₂ quotient chunk count. + log2_ceil_u8(quotient_chunks) +} + +// ============================================================================ +// EvaluationDomain +// ============================================================================ + +/// The order-`2^(log_trace_height + log_quotient_degree)` coset on which the +/// quotient polynomial is evaluated, paired with its protocol context (the +/// parent [`LiftedDomain`] for OOD lifting and PCS interop). +/// +/// Implements [`Coset`] directly: `&eval_domain` *is* the evaluation coset, +/// sharing the parent's LDE shift. Use [`Coset::points`], [`Coset::shift`], +/// [`Coset::vanishing_at`] etc. for coset-level operations. +/// +/// Reach the parent context via [`lifted`](Self::lifted) for accessors that +/// belong to the LDE / trace side (`trace_subgroup`, `lde_coset`, +/// `log_blowup`, `selectors_at`, `sample_ood_point`, …). +/// +/// # Invariant +/// +/// `log_quotient_degree ≤ lifted.log_blowup()` (enforced at construction). +/// The "quotient degree" `D = 2^log_quotient_degree` is the number of chunks +/// the quotient polynomial Q is decomposed into; the value comes from +/// [`log_quotient_degree`]. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct EvaluationDomain { + domain: LiftedDomain, + log_quotient_degree: u8, +} + +impl EvaluationDomain { + /// Pair `domain` with `log_quotient_degree`. + /// + /// # Panics + /// + /// Panics if `log_quotient_degree > domain.log_blowup()`. + #[inline] + pub fn new(domain: LiftedDomain, log_quotient_degree: u8) -> Self { + let log_blowup = domain.log_blowup(); + assert!( + log_quotient_degree <= log_blowup, + "quotient log degree {log_quotient_degree} exceeds blowup {log_blowup}" + ); + Self { domain, log_quotient_degree } + } + + /// The parent [`LiftedDomain`] — used to reach trace/LDE-side accessors + /// and to hand off to PCS-layer APIs that take `&LiftedDomain`. + #[inline] + pub fn lifted(&self) -> &LiftedDomain { + &self.domain + } + + /// Log₂ of the quotient degree (number of chunks Q is decomposed into). + #[inline] + pub fn log_quotient_degree(&self) -> u8 { + self.log_quotient_degree + } + + /// Quotient degree: `2^log_quotient_degree` — the number of chunks Q is + /// decomposed into. + #[inline] + pub fn quotient_degree(&self) -> usize { + 1 << self.log_quotient_degree as usize + } + + /// Log₂ of the trace height — the trace subgroup `H`'s order. + /// Defines `Z_H(x) = x^N − 1` for vanishing and selector periodicity. + #[inline] + pub fn log_trace_height(&self) -> u8 { + self.domain.log_trace_height() + } + + /// Trace height `N = 2^log_trace_height`. + #[inline] + pub fn trace_height(&self) -> usize { + self.domain.trace_height() + } + + /// The evaluation coset's underlying subgroup (order + /// `2^(log_trace_height + log_quotient_degree)`). + #[inline] + pub fn subgroup(&self) -> TwoAdicSubgroup { + TwoAdicSubgroup::new(self.log_size()) + } + + /// Unnormalized prover-side Lagrange row selectors, evaluated at every + /// point of this coset (natural order). + /// + /// For each coset point `xᵢ = s · ωⁱ` (where `s = self.shift()` is the LDE + /// shift and `ω = self.generator()` generates this coset's size-`2^(N+D)` + /// subgroup): + /// - `is_first_row[i] = Z_H(xᵢ) / (xᵢ − 1)` + /// - `is_last_row[i] = Z_H(xᵢ) / (xᵢ − ω_H⁻¹)` + /// - `is_transition[i] = xᵢ − ω_H⁻¹` + /// + /// where `Z_H(x) = x^N_H − 1` is the trace subgroup's vanishing and `ω_H` + /// is its generator. `Z_H(xᵢ)` is periodic with `2^log_quotient_degree` + /// distinct values across the coset; we batch-invert the unique + /// denominators. + pub fn selectors(&self) -> Selectors> { + let log_trace_height = self.log_trace_height(); + let coset_size = self.size(); + let shift = self.shift(); + + // Z_H(x) = x^N_H − 1 over this coset is periodic with + // 2^log_quotient_degree distinct values. + let s_pow_n = shift.exp_power_of_2(log_trace_height as usize); + let blowup_subgroup = self.subgroup().shrink(log_trace_height); + let z_h_periodic: Vec = blowup_subgroup + .generator() + .shifted_powers(s_pow_n) + .take(1 << self.log_quotient_degree as usize) + .map(|x| x - F::ONE) + .collect(); + let period = z_h_periodic.len(); + + // Coset points in natural order. + let xs: Vec = self.generator().shifted_powers(shift).collect_n(coset_size); + let omega_h_inv = self.domain.trace_subgroup().generator_inverse(); + + // Unnormalized Lagrange selector: selᵢ = Z_H(xᵢ) / (xᵢ − basis_point). + let single_point_selector = |basis_point: F| -> Vec { + let denoms: Vec = xs.par_iter().map(|&x| x - basis_point).collect(); + let invs = batch_multiplicative_inverse(&denoms); + (0..coset_size) + .into_par_iter() + .map(|i| z_h_periodic[i % period] * invs[i]) + .collect() + }; + + Selectors { + is_first_row: single_point_selector(F::ONE), + is_last_row: single_point_selector(omega_h_inv), + is_transition: xs.into_par_iter().map(|x| x - omega_h_inv).collect(), + } + } + + /// The `D = self.quotient_degree()` distinct values of `1 / Z_H` on this + /// evaluation coset, where `Z_H(X) = X^N − 1` and `N = 2^log_trace_height`. + /// + /// `Z_H` cycles with period `D` over the `N · D` points of the coset, so + /// callers index with `evals[i & (D − 1)]` to get `1 / Z_H(xᵢ)` for the + /// `i`-th point of the domain. + pub fn inv_vanishing_evals(&self) -> Vec { + let num_distinct = self.quotient_degree(); + let s_pow_n = self.shift().exp_power_of_2(self.log_trace_height() as usize); + let omega_d = self.subgroup().shrink(self.log_trace_height()).generator(); + let z_h_evals: Vec = + omega_d.powers().take(num_distinct).map(|x| s_pow_n * x - F::ONE).collect(); + batch_multiplicative_inverse(&z_h_evals) + } + + /// Reconstruct `Q(z)` from `D` quotient chunk evaluations. + /// + /// The quotient `Q` is committed as `D = self.quotient_degree()` chunk + /// polynomials `qₜ` of degree `< N`, one per `H`-coset inside `J`: `qₜ` + /// agrees with `Q` on the coset `g · ω_Jᵗ · H`. The verifier opens all + /// `qₜ(z)` at the same OOD point `z` and recombines them into `Q(z)` here. + /// + /// The map `x → xᴺ` collapses each coset `g · ω_Jᵗ · H` to a single + /// `D`-th root of unity. Let `ωₛ = ω_Jᴺ` (a `D`-th root of unity) and + /// `u = (z / s)ᴺ` where `s = self.shift()`. Then `Q(z)` is the + /// barycentric interpolation of the values `qₜ(z)` at the points `ωₛᵗ`: + /// + /// ```text + /// wₜ = ωₛᵗ / (u − ωₛᵗ) + /// Q(z) = (Σₜ wₜ · qₜ(z)) / (Σₜ wₜ) + /// ``` + /// + /// # Panics + /// + /// Panics if `chunks.len() != self.quotient_degree()` — the verifier must + /// have unpacked exactly `D = 2^log_quotient_degree` quotient chunks + /// before calling this method. + pub fn reconstruct_quotient>(&self, z: EF, chunks: &[EF]) -> EF { + assert_eq!( + chunks.len(), + self.quotient_degree(), + "chunk count must equal quotient degree D" + ); + let omega_s = self.subgroup().shrink(self.log_trace_height()).generator(); + let u = (z * self.shift_inverse()).exp_power_of_2(self.log_trace_height() as usize); + + let mut numerator = EF::ZERO; + let mut denominator = EF::ZERO; + let mut omega_s_t = F::ONE; + + for &q_t in chunks.iter() { + let a_t = u - omega_s_t; + let w_t = a_t.inverse() * omega_s_t; + numerator += w_t * q_t; + denominator += w_t; + omega_s_t *= omega_s; + } + + numerator * denominator.inverse() + } +} + +impl Coset for EvaluationDomain { + /// `log_trace_height + log_quotient_degree` — the eval coset's order. + #[inline] + fn log_size(&self) -> u8 { + self.domain.log_trace_height() + self.log_quotient_degree + } + + /// The evaluation coset shares the parent LDE coset's shift. + #[inline] + fn shift(&self) -> F { + self.domain.lde_shift() + } + + /// Cached parent LDE shift inverse — reused by `vanishing_at` / `contains`. + #[inline] + fn shift_inverse(&self) -> F { + self.domain.lde_coset().shift_inverse() + } + + #[inline] + fn generator(&self) -> F { + // Routed through TwoAdicSubgroup so F::two_adic_generator stays + // confined to the single canonical site. + self.subgroup().generator() + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use p3_field::{Field, PrimeCharacteristicRing}; + + use super::*; + use crate::testing::{ + canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt}, + }; + + // ========== TwoAdicSubgroup ========== + + #[test] + fn subgroup_basic_dimensions() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(5); + assert_eq!(h.log_size(), 5); + assert_eq!(h.size(), 32); + } + + #[test] + fn subgroup_generator_matches_two_adic_generator() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(7); + assert_eq!(h.generator(), Felt::two_adic_generator(7)); + } + + #[test] + fn subgroup_generator_inverse_is_inverse() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(6); + assert_eq!(h.generator() * h.generator_inverse(), Felt::ONE); + } + + #[test] + fn subgroup_point_at_matches_powers() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let g = h.generator(); + for i in 0..h.size() as u64 { + assert_eq!(h.point_at(i), g.exp_u64(i)); + } + } + + #[test] + fn subgroup_points_length_and_first_two() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(3); + let pts = h.points(); + assert_eq!(pts.len(), h.size()); + assert_eq!(pts[0], Felt::ONE); + assert_eq!(pts[1], h.generator()); + } + + #[test] + fn subgroup_bit_reversed_points() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let natural = h.points(); + let br = h.bit_reversed_points(); + assert_eq!(natural[0], br[0]); + // Adjacent-negation: br[1] = ω^{n/2} = -1 + assert_eq!(br[1], -Felt::ONE); + } + + #[test] + fn subgroup_shrink_halves() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(8); + let h2 = h.shrink(1); + assert_eq!(h2.log_size(), 7); + assert_eq!(h2.generator(), h.generator() * h.generator()); + } + + #[test] + fn subgroup_vanishing_zero_in_subgroup() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + for k in 0..h.size() as u64 { + assert_eq!(h.vanishing_at(h.point_at(k)), Felt::ZERO); + } + } + + #[test] + fn subgroup_vanishing_outside() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let z = Felt::from_u32(7); + let expected = z.exp_u64(h.size() as u64) - Felt::ONE; + assert_eq!(h.vanishing_at(z), expected); + } + + #[test] + fn subgroup_contains() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(5); + for k in 0..h.size() as u64 { + assert!(h.contains(h.point_at(k))); + } + assert!(!h.contains(QuadFelt::from(Felt::from_u32(12345)))); + } + + // ========== TwoAdicCoset ========== + + #[test] + fn coset_unshifted_matches_subgroup() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let coset = TwoAdicCoset::new(h, Felt::ONE); + assert_eq!(coset.shift(), Felt::ONE); + assert_eq!(coset.points(), h.points()); + assert_eq!(coset.bit_reversed_points(), h.bit_reversed_points()); + } + + #[test] + fn coset_points_match_shift_times_subgroup() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(5); + let shift = Felt::from_u32(7); + let coset = TwoAdicCoset::new(h, shift); + + let expected: Vec = h.points().into_iter().map(|p| shift * p).collect(); + assert_eq!(coset.points(), expected); + + for i in 0..coset.size() as u64 { + assert_eq!(coset.point_at(i), shift * h.point_at(i)); + } + } + + #[test] + fn coset_bit_reversed_points_explicit() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let shift = Felt::from_u32(11); + let coset = TwoAdicCoset::new(h, shift); + let expected: Vec = h.bit_reversed_points().into_iter().map(|p| shift * p).collect(); + assert_eq!(coset.bit_reversed_points(), expected); + } + + #[test] + fn coset_vanishing_zero_at_coset_points() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let coset = TwoAdicCoset::new(h, Felt::from_u32(13)); + for k in 0..coset.size() as u64 { + assert_eq!(coset.vanishing_at(coset.point_at(k)), Felt::ZERO); + } + } + + #[test] + fn coset_vanishing_outside() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(4); + let shift = Felt::from_u32(13); + let coset = TwoAdicCoset::new(h, shift); + let z = Felt::from_u32(999); + let expected = (z * shift.inverse()).exp_u64(coset.size() as u64) - Felt::ONE; + assert_eq!(coset.vanishing_at(z), expected); + } + + #[test] + fn coset_contains_its_own_points() { + let h: TwoAdicSubgroup = TwoAdicSubgroup::new(5); + let coset = TwoAdicCoset::new(h, Felt::from_u32(31)); + for k in 0..coset.size() as u64 { + assert!(coset.contains(coset.point_at(k))); + } + assert!(!coset.contains(QuadFelt::from(Felt::from_u32(54321)))); + } + + // ========== LiftedDomain ========== + + #[test] + fn domain_canonical_is_unlifted() { + // Canonical of trace 2^10, blowup 2^3 — no lifting. + let info: LiftedDomain = canonical_domain(10, 3); + assert_eq!(info.log_trace_height(), 10); + assert_eq!(info.log_lde_height(), 13); + assert_eq!(info.log_blowup(), 3); + // Field-relative shift: g^(2^(TWO_ADICITY − log_lde_order)). + let expected = Felt::GENERATOR.exp_power_of_2(Felt::TWO_ADICITY - 13); + assert_eq!(info.lde_shift(), expected); + } + + #[test] + fn canonical_lde_shift_matches_domain_shift() { + let from_static = LiftedDomain::::canonical_lde_shift(13).unwrap(); + let from_domain = canonical_domain::(10, 3).lde_shift(); + assert_eq!(from_static, from_domain); + } + + #[test] + fn sub_domain_lifts_relative_to_canonical() { + // Canonical: max trace 2^12, blowup 2^3 — sub-domain at trace 2^10 → lift_ratio 2. + let parent: LiftedDomain = canonical_domain(12, 3); + let sub = parent.try_sub_domain(10).expect("sub-domain trace height out of range"); + + assert_eq!(sub.log_trace_height(), 10); + assert_eq!(sub.log_lde_height(), 13); + assert_eq!(sub.log_blowup(), 3); + assert_eq!(sub.trace_height(), 1024); + assert_eq!(sub.lde_height(), 8192); + + // Shift is canonical for the sub-domain's own LDE order (13), not derived + // from the parent's lift ratio. Crucially, equal to canonical_domain(10, 3).lde_shift(). + let expected_shift = Felt::GENERATOR.exp_power_of_2(Felt::TWO_ADICITY - 13); + assert_eq!(sub.lde_shift(), expected_shift); + assert_eq!(sub.lde_shift(), canonical_domain::(10, 3).lde_shift()); + } + + #[test] + fn sub_domain_at_same_trace_is_identity() { + let tallest: LiftedDomain = canonical_domain(10, 3); + let same = tallest.try_sub_domain(10).expect("sub-domain trace height out of range"); + assert_eq!(same.lde_shift(), tallest.lde_shift()); + assert_eq!(same.log_trace_height(), tallest.log_trace_height()); + assert_eq!(same.log_lde_height(), tallest.log_lde_height()); + } + + #[test] + fn lde_coset_point_at_matches_shift_times_omega() { + let info: LiftedDomain = canonical_domain::(5, 2) + .try_sub_domain(4) + .expect("sub-domain trace height out of range"); + let shift = info.lde_shift(); + let omega = info.lde_coset().generator(); + for i in 0..4 { + assert_eq!(info.lde_coset().point_at(i as u64), shift * omega.exp_u64(i as u64)); + } + } + + // ========== EvaluationDomain ========== + + #[test] + fn evaluation_domain_is_a_coset_sharing_parent_shift() { + // Sub-domain with lift_ratio = 2. + let eval: EvaluationDomain = canonical_domain::(12, 3) + .try_sub_domain(10) + .expect("sub-domain trace height out of range") + .evaluation_domain(2); + // Order N · D = 2^(10 + 2) = 2^12. + assert_eq!(eval.log_size(), 12); + assert_eq!(eval.size(), 1 << 12); + // Shift is borrowed from the parent LDE coset (literally a sub-coset). + assert_eq!(eval.shift(), eval.lifted().lde_shift()); + } + + #[test] + fn evaluation_domain_quotient_degree() { + let eval: EvaluationDomain = canonical_domain::(8, 2).evaluation_domain(2); + assert_eq!(eval.log_quotient_degree(), 2); + assert_eq!(eval.quotient_degree(), 4); + } +} diff --git a/stark/miden-lifted-stark/src/instance.rs b/stark/miden-lifted-stark/src/instance.rs deleted file mode 100644 index fb64d46671..0000000000 --- a/stark/miden-lifted-stark/src/instance.rs +++ /dev/null @@ -1,342 +0,0 @@ -//! Protocol-level instance types for the lifted STARK prover and verifier. -//! -//! - [`AirInstance`]: Verifier instance — public values + variable-length inputs -//! - [`AirWitness`]: Prover witness — trace + public values -//! - [`InstanceShapes`]: Per-instance trace heights carried on [`StarkProof`](crate::StarkProof) - -extern crate alloc; - -use alloc::{vec, vec::Vec}; - -use miden_lifted_air::{AirStructureError, LiftedAir, VarLenPublicInputs, log2_strict_u8}; -use p3_challenger::CanObserve; -use p3_field::{Field, PrimeCharacteristicRing, TwoAdicField}; -use p3_matrix::{Matrix, dense::RowMajorMatrix}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -// ============================================================================ -// Instance data -// ============================================================================ - -/// Verifier instance: public values and variable-length inputs. -/// -/// Both the prover and verifier carry `var_len_public_inputs`. The verifier uses -/// them in [`LiftedAir::reduced_aux_values`] for the cross-AIR identity check. -/// -/// Log trace heights are not part of the instance — they are carried on the -/// [`StarkProof`](crate::StarkProof) as [`InstanceShapes`] and absorbed into -/// the Fiat-Shamir state. -#[derive(Clone, Copy, Debug)] -pub struct AirInstance<'a, F> { - /// Public values for this AIR. - pub public_values: &'a [F], - /// Reducible inputs for the cross-AIR identity check. Empty slice if no buses. - pub var_len_public_inputs: VarLenPublicInputs<'a, F>, -} - -/// Prover witness: trace matrix, public values, and variable-length public inputs. -/// -/// Validates on construction that the trace height is a power of two. -/// -/// **Commitment:** callers **must** bind both `public_values` and -/// `var_len_public_inputs` to the Fiat-Shamir challenger state before proving. -#[derive(Clone, Copy, Debug)] -pub struct AirWitness<'a, F> { - /// Main trace matrix. - pub trace: &'a RowMajorMatrix, - /// Public values for this AIR. - pub public_values: &'a [F], - /// Variable-length public inputs (reducible inputs for bus identity checks). - pub var_len_public_inputs: VarLenPublicInputs<'a, F>, -} - -impl<'a, F> AirWitness<'a, F> { - /// Create a new prover witness with validation. - /// - /// # Panics - /// - /// - If `trace.height()` is not a power of two - pub fn new( - trace: &'a RowMajorMatrix, - public_values: &'a [F], - var_len_public_inputs: VarLenPublicInputs<'a, F>, - ) -> Self - where - F: Field, - { - assert!( - trace.height().is_power_of_two(), - "trace height must be power of two, got {}", - trace.height() - ); - Self { - trace, - public_values, - var_len_public_inputs, - } - } - - /// Convert to a verifier instance (drops the trace). - pub fn to_instance(&self) -> AirInstance<'a, F> { - AirInstance { - public_values: self.public_values, - var_len_public_inputs: self.var_len_public_inputs, - } - } -} - -// ============================================================================ -// Shape metadata -// ============================================================================ - -/// Per-instance shape metadata carried on [`StarkProof`](crate::StarkProof). -/// -/// Stores log₂ trace heights (absorbed into the Fiat-Shamir challenger) -/// and the AIR ordering (not absorbed — see [`air_order`](Self::air_order)). -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct InstanceShapes { - // `pub(crate)` so in-crate tests can construct malformed shapes to - // exercise the verifier-path validation in `validate_inputs`. External - // callers go through `InstanceShapes::from_trace_heights`. - pub(crate) log_trace_heights: Vec, - /// The AIR ordering: `air_order[j]` is the caller's original index of - /// the instance at position `j` in the proof's ordering. - pub(crate) air_order: Vec, -} - -impl InstanceShapes { - /// Construct from raw trace heights (must be powers of two). - /// - /// Determines the proof's AIR ordering by sorting instances by - /// `(log_trace_height, caller_index)`. The resulting - /// [`air_order`](Self::air_order) maps each position in the proof's - /// ordering back to the caller's original index. - pub fn from_trace_heights(trace_heights: Vec) -> Result { - let log_heights: Vec = trace_heights - .iter() - .map(|&h| { - if !h.is_power_of_two() { - return Err(InstanceValidationError::InvalidTraceHeight { height: h }); - } - Ok(log2_strict_u8(h)) - }) - .collect::>()?; - - // Sort by (log_height, caller_index) for a canonical ordering. - let mut perm: Vec = (0..log_heights.len()).collect(); - perm.sort_by_key(|&i| (log_heights[i], i)); - - let sorted_log_heights: Vec = perm.iter().map(|&i| log_heights[i]).collect(); - let air_order: Vec = perm.iter().map(|&i| i as u32).collect(); - - Ok(Self { - log_trace_heights: sorted_log_heights, - air_order, - }) - } - - /// Log₂ of the trace height for each instance, in the proof's AIR - /// ordering. - pub fn log_trace_heights(&self) -> &[u8] { - &self.log_trace_heights - } - - /// The AIR ordering used by the proof: `air_order()[j]` is the caller's - /// original index of the instance at position `j` in the proof's - /// ordering. Not absorbed into the Fiat-Shamir transcript. - pub fn air_order(&self) -> &[u32] { - &self.air_order - } - - pub(crate) fn len(&self) -> usize { - self.log_trace_heights.len() - } - - /// Reorder `data` from the caller's natural order to the proof's AIR - /// ordering. Returns a `Vec` where position `j` holds - /// `data[air_order[j]]`. - /// - /// Validates that `air_order` is a valid permutation before applying it. - /// Returns an error if lengths mismatch or if `air_order` is malformed. - pub(crate) fn reorder(&self, mut data: Vec) -> Result, InstanceValidationError> { - let n = data.len(); - validate_air_order(&self.air_order, n)?; - let mut placed = vec![false; n]; - for start in 0..n { - if placed[start] { - continue; - } - let mut j = start; - loop { - let src = self.air_order[j] as usize; - placed[j] = true; - if src == start { - break; - } - data.swap(j, src); - j = src; - } - } - Ok(data) - } - - pub fn size_in_bytes(&self) -> usize { - size_of_val(self.log_trace_heights.as_slice()) + size_of_val(self.air_order.as_slice()) - } - - /// Absorb the log trace heights into a Fiat-Shamir challenger as one - /// base field element per `log_h`. The `air_order` values are **not** - /// absorbed. - pub(crate) fn observe_heights(&self, challenger: &mut C) - where - F: Field + PrimeCharacteristicRing, - C: CanObserve, - { - for &h in &self.log_trace_heights { - challenger.observe(F::from_u8(h)); - } - } -} - -// ============================================================================ -// Validation -// ============================================================================ - -/// Errors from validating instance- and protocol-level inputs. -#[derive(Debug, Error)] -pub enum InstanceValidationError { - #[error(transparent)] - AirStructure(#[from] AirStructureError), - #[error("no instances provided")] - Empty, - #[error("trace height {height} is not a power of two")] - InvalidTraceHeight { height: usize }, - #[error("trace width mismatch: expected {expected}, got {actual}")] - WidthMismatch { expected: usize, actual: usize }, - #[error("public values length mismatch: expected {expected}, got {actual}")] - PublicValuesMismatch { expected: usize, actual: usize }, - #[error("var-len public inputs count mismatch: expected {expected}, got {actual}")] - VarLenPublicInputsMismatch { expected: usize, actual: usize }, - #[error("trace height {trace_height} is less than max periodic column length {max_period}")] - TraceHeightBelowPeriod { trace_height: usize, max_period: usize }, - #[error( - "instance count {instances} does not match log trace heights count {log_trace_heights}" - )] - HeightCountMismatch { - instances: usize, - log_trace_heights: usize, - }, - #[error("LDE domain log-size {log_h} + {log_blowup} exceeds field two-adicity {two_adicity}")] - LdeDomainExceedsTwoAdicity { - log_h: u8, - log_blowup: u8, - two_adicity: usize, - }, - #[error("air_order length {air_order} does not match instance count {instances}")] - AirOrderLengthMismatch { instances: usize, air_order: usize }, - #[error("invalid air_order permutation for {count} instances")] - InvalidAirOrder { count: usize }, - #[error("log trace heights are not in ascending order")] - HeightsNotAscending, -} - -/// Cross-check instances against their shapes and return the log of the -/// maximum trace height. -/// -/// Instances and shapes must already be in the proof's AIR ordering. -/// -/// Checks: -/// - shape count matches instance count -/// - each `log_h + log_blowup` fits in both `F::TWO_ADICITY` and `usize::BITS - 1` (guards -/// downstream `two_adic_generator` and `1usize << log_lde_height` against wire-format shapes; the -/// `usize` bound only bites on 32-bit targets) -/// - each AIR is structurally valid ([`LiftedAir::validate`]) -/// - each instance's public values / var-len inputs match its AIR -/// - max height ≥ 2 (needed for the 2-row transition window) -/// - each trace height covers the AIR's longest periodic column -pub(crate) fn validate_inputs( - instances: &[(&A, AirInstance<'_, F>)], - shapes: &InstanceShapes, - log_blowup: u8, -) -> Result -where - F: TwoAdicField, - A: LiftedAir, -{ - if instances.len() != shapes.len() { - return Err(InstanceValidationError::HeightCountMismatch { - instances: instances.len(), - log_trace_heights: shapes.len(), - }); - } - // Upper bound on `log_h + log_blowup`: the two-adic generator must exist, - // and `1usize << log_lde_height` must not overflow on this target. - let max_log_lde_height = F::TWO_ADICITY.min((usize::BITS - 1) as usize); - let mut log_prev: u8 = 0; - for ((air, inst), &log_h) in instances.iter().zip(shapes.log_trace_heights()) { - if log_h as usize + log_blowup as usize > max_log_lde_height { - return Err(InstanceValidationError::LdeDomainExceedsTwoAdicity { - log_h, - log_blowup, - two_adicity: F::TWO_ADICITY, - }); - } - air.validate()?; - let expected_pv = air.num_public_values(); - if inst.public_values.len() != expected_pv { - return Err(InstanceValidationError::PublicValuesMismatch { - expected: expected_pv, - actual: inst.public_values.len(), - }); - } - let expected_vl = air.num_var_len_public_inputs(); - if inst.var_len_public_inputs.len() != expected_vl { - return Err(InstanceValidationError::VarLenPublicInputsMismatch { - expected: expected_vl, - actual: inst.var_len_public_inputs.len(), - }); - } - if log_h < log_prev { - return Err(InstanceValidationError::HeightsNotAscending); - } - let trace_height = 1usize << log_h as usize; - let max_period = air.periodic_columns().iter().map(Vec::len).max().unwrap_or(0); - if trace_height < max_period { - return Err(InstanceValidationError::TraceHeightBelowPeriod { - trace_height, - max_period, - }); - } - log_prev = log_h; - } - // `log_prev == 0` catches both "no instances" and "all traces are - // height 1" — both break the 2-row transition window. - if log_prev == 0 { - return Err(InstanceValidationError::Empty); - } - Ok(log_prev) -} - -/// Validate that `air_order` is a valid permutation of `0..n`. -/// -/// Called on the verifier side where `air_order` comes from an untrusted proof. -pub(crate) fn validate_air_order( - air_order: &[u32], - n: usize, -) -> Result<(), InstanceValidationError> { - if air_order.len() != n { - return Err(InstanceValidationError::AirOrderLengthMismatch { - instances: n, - air_order: air_order.len(), - }); - } - let mut seen = vec![false; n]; - for &idx in air_order { - let Some(slot @ false) = seen.get_mut(idx as usize) else { - return Err(InstanceValidationError::InvalidAirOrder { count: n }); - }; - *slot = true; - } - Ok(()) -} diff --git a/stark/miden-lifted-stark/src/lib.rs b/stark/miden-lifted-stark/src/lib.rs index 3e872c41f5..72da1c473e 100644 --- a/stark/miden-lifted-stark/src/lib.rs +++ b/stark/miden-lifted-stark/src/lib.rs @@ -7,40 +7,53 @@ //! //! The lifted STARK has three trust domains: //! -//! 1. **AIR = trusted** — [`air::LiftedAir`] implementations are correct application code. It is -//! the AIR implementer's responsibility to satisfy the contract below. -//! [`air::LiftedAir::validate`] checks the statically-verifiable subset. +//! 1. **AIR = trusted** — [`air::LiftedAir`] / [`air::MultiAir`] implementations are application +//! code. Structural mistakes, including an empty AIR collection, may panic; check them in +//! tests/setup with [`miden_lifted_air::debug::assert_multi_air_valid`] or +//! [`debug::assert_prover_setup`]. //! -//! 2. **Instance = validated** — The prover validates that its witness matches the AIR spec. The -//! verifier validates instance metadata. Both return structured errors. +//! 2. **Runtime inputs = validated** — [`Statement::new`](air::Statement::new), +//! [`ProverStatement::new`](air::ProverStatement::new), internal trace-order reconstruction, and +//! domain checks validate caller/proof shape data and return typed [`ProverError`] / +//! [`VerifierError`] values. //! //! 3. **Proof = untrusted** — Transcript data is verified cryptographically (PCS errors, constraint //! mismatch, etc.). //! -//! ## Validated properties +//! ## Validated at runtime //! -//! These are checked by [`air::LiftedAir::validate`] and by the internal -//! instance validator in [`instance`], and enforced by both prover and -//! verifier before proceeding: +//! Checked before cryptographic work begins: [`Statement::new`](air::Statement::new) / +//! [`ProverStatement::new`](air::ProverStatement::new) for caller inputs and prover trace shape, +//! internal trace-order reconstruction for proof/caller heights, and domain checks for +//! `log_quotient_degree(air) ≤ log_blowup`: //! -//! - **No preprocessed trace** — the lifted protocol does not support them. -//! - **Positive aux width** — every AIR must have an auxiliary trace. -//! - **Periodic columns** — each has positive, power-of-two length ≤ trace height. -//! - **Constraint degree** — `log_quotient_degree() ≤ log_blowup`. -//! - **Instance dimensions** — trace width, public values length, var-len public inputs count, and -//! trace height (power of two) all match the AIR specification. +//! - **Shape well-formedness** — ≤ 256 instances, with the maximum LDE order `log_trace_height + +//! log_blowup` bounded by the field's two-adicity and the host's `usize` width. +//! - **Compat** — `log_quotient_degree(air) ≤ log_blowup`, per AIR. +//! - **Per-AIR instance dimensions** — public values length matches `num_public_values()`, trace +//! height is at least 2 rows, trace height ≥ max periodic column length, trace width matches +//! `width()` (prover-only), raw height is a power of two (prover-only), `aux_inputs.len() ≤ +//! max_aux_inputs`. //! -//! ## Unchecked trust assumptions +//! ## Trusted AIR contracts //! -//! These cannot be verified statically and are the AIR implementer's responsibility: +//! These are AIR implementer responsibilities. Run [`debug::assert_prover_setup`] (or its +//! components) from your test harness to catch structural mistakes early: //! -//! 1. **Window size** — Only transition window size 2. -//! 2. **Deterministic constraints** — `eval()` emits the same number and types of constraints +//! 1. **AIR structural contract** — non-empty AIR collection, shared public-value count, positive +//! aux width, power-of-two periodic column lengths, and matching override helpers. Checked by +//! [`miden_lifted_air::debug::assert_multi_air_valid`]. These are not typed `Statement::new` +//! errors. +//! 2. **Window size** — only transition window size 2. +//! 3. **Deterministic constraints** — `eval()` emits the same number and types of constraints //! regardless of builder implementation. -//! 3. **Consistent aux builder** — `AuxBuilder::build_aux_trace` returns width = `aux_width()`, -//! height = main trace height, and exactly `num_aux_values()` values. (The prover asserts these -//! at runtime as a defense-in-depth sanity check.) -//! 4. **Sound `reduced_aux_values`** — Returns correct bus contributions for valid inputs. +//! 4. **[`LiftedAir::build_aux_trace`](air::LiftedAir::build_aux_trace) output** — per AIR, an aux +//! trace of width `aux_width()`, height matching the main trace, and exactly `num_aux_values()` +//! aux values. A malformed output is caught by the prover (LDE/commit panic) or by verification, +//! since the verifier re-derives these shapes from the AIR contract. +//! 5. **Sound [`Statement::eval_external`](air::Statement::eval_external)** — Returns cross-AIR +//! assertion values that are zero iff the proof's cross-AIR interactions are well-formed for the +//! given aux values and public inputs. #![no_std] @@ -51,98 +64,33 @@ extern crate alloc; // ============================================================================ mod config; -mod coset; pub mod debug; -pub mod instance; +pub(crate) mod domain; pub mod lmcs; -mod pcs; +mod order; +pub mod pcs; +mod preprocessed; pub mod proof; pub mod prover; mod selectors; +pub(crate) mod util; pub mod verifier; pub use config::{GenericStarkConfig, StarkConfig}; -pub use coset::LiftedCoset; -pub use debug::{check_constraints, check_constraints_multi}; -pub use instance::{AirInstance, AirWitness, InstanceShapes, InstanceValidationError}; -pub use lmcs::{ - Lmcs, LmcsError, LmcsTree, OpenedRows, - bitrev::{BitReversibleMatrix, materialize_bitrev}, - config::LmcsConfig, - hiding_config::HidingLmcsConfig, - lifted_tree::LiftedMerkleTree, - merkle_witness::MerkleWitness, - node_id::NodeId, - proof::{ - BatchProof as LmcsBatchProof, BatchProofView as LmcsBatchProofView, - LeafOpening as LmcsLeafOpening, Proof as LmcsProof, - }, - row_list::RowList, - tree_indices::{MissingSiblingsIter, TreeIndices}, - utils::log2_strict_u8, -}; -pub use pcs::{ - deep::{ - proof::{DeepTranscript, OpenedValues as PcsOpenedValues}, - verifier::DeepError, - }, - fri::{ - proof::{FriRoundTranscript, FriTranscript}, - verifier::FriError, - }, - params::{PcsParams, PcsParamsError}, - proof::PcsTranscript, - verifier::PcsError, -}; -pub use proof::{StarkDigest, StarkOutput, StarkProof, StarkTranscript}; -pub use prover::{ProverError, prove_multi, prove_single}; -pub use verifier::{VerifierError, verify_multi, verify_single}; - -/// Backward-compatible PCS namespace. -/// -/// Older consumers accessed DEEP/FRI/PCS types through `miden_lifted_stark::fri`. -/// The current implementation organizes them under an internal `pcs` module, so this -/// public facade preserves the earlier module path. -pub mod fri { - pub use crate::{ - DeepError, DeepTranscript, FriError, FriRoundTranscript, FriTranscript, PcsError, - PcsOpenedValues, PcsParams, PcsParamsError, PcsTranscript, - }; - - pub mod deep { - pub use crate::{DeepError, DeepTranscript, PcsOpenedValues}; - - pub mod proof { - pub use crate::{DeepTranscript, PcsOpenedValues}; - } - - pub mod verifier { - pub use crate::DeepError; - } - } - - pub mod params { - pub use crate::{PcsParams, PcsParamsError}; - } - - pub mod proof { - pub use crate::PcsTranscript; - } - - pub mod round_proof { - pub use crate::{FriRoundTranscript, FriTranscript}; - } - - pub mod verifier { - pub use crate::{FriError, PcsError}; - } -} +pub use debug::check_constraints; +// `domain` and `order` are internal modules, but these error types surface through the public +// `ProverError` / `VerifierError`, so they need a public path of their own. +pub use domain::DomainError; +pub use order::ShapeError; +pub use preprocessed::{Preprocessed, PreprocessedValidationError}; +pub use prover::{ProverError, ProverInstance}; +pub use verifier::{VerifierError, VerifierInstance}; // ============================================================================ // Namespaced re-exports from upstream crates // ============================================================================ -/// AIR traits, instance/witness types, and upstream `p3-air` re-exports. +/// AIR traits, statement/proving-input types, and upstream `p3-air` re-exports. /// /// This module re-exports items from [`miden_lifted_air`], which in turn /// re-exports `p3-air` types. Consumers should never need to depend on `p3-air` @@ -153,23 +101,23 @@ pub mod air { Air, AirBuilder, AirBuilderWithContext, - // Lifted AIR types - AirStructureError, - AuxBuilder, BaseAir, - EmptyWindow, + ConstraintDegrees, ExtensionBuilder, FilteredAirBuilder, + // Lifted AIR types + InstanceError, LiftedAir, LiftedAirBuilder, + MultiAir, PeriodicAirBuilder, PermutationAirBuilder, - ReducedAuxValues, + ProverStatement, ReductionError, RowWindow, - TracePart, - VarLenPublicInputs, + Statement, WindowAccess, + debug, log2_strict_u8, }; @@ -178,22 +126,12 @@ pub mod air { pub use miden_lifted_air::symbolic::*; } - /// Auxiliary trace types (builder, cross-AIR identity checking). - pub mod auxiliary { - pub use miden_lifted_air::auxiliary::*; - } - /// AIR constraint utility functions from upstream p3-air. pub mod utils { pub use miden_lifted_air::utils::*; } } -/// Fiat-Shamir transcript channels and data types. -pub mod transcript { - pub use miden_stark_transcript::{TranscriptChallenger, TranscriptData, TranscriptError}; -} - /// Stateful hasher primitives for LMCS construction. pub mod hasher { pub use miden_stateful_hasher::{ diff --git a/stark/miden-lifted-stark/src/lmcs/config.rs b/stark/miden-lifted-stark/src/lmcs/config.rs index f636428d4d..0824c7c435 100644 --- a/stark/miden-lifted-stark/src/lmcs/config.rs +++ b/stark/miden-lifted-stark/src/lmcs/config.rs @@ -9,14 +9,16 @@ use p3_field::PackedValue; use p3_matrix::Matrix; use p3_symmetric::{Hash, PseudoCompressionFunction}; -use crate::lmcs::{ - Lmcs, LmcsError, OpenedRows, - bitrev::BitReversibleMatrix, - lifted_tree::LiftedMerkleTree, - merkle_witness::MerkleWitness, - proof::{BatchProof, LeafOpening}, - row_list::RowList, - tree_indices::TreeIndices, +use crate::{ + lmcs::{ + Lmcs, LmcsError, OpenedRows, + lifted_tree::LiftedMerkleTree, + merkle_witness::MerkleWitness, + proof::{BatchProof, LeafOpening}, + row_list::RowList, + tree_indices::TreeIndices, + }, + util::bitrev::BitReversibleMatrix, }; /// LMCS configuration holding cryptographic primitives (sponge + compression). @@ -24,19 +26,19 @@ use crate::lmcs::{ /// This implementation defines the transcript hint layout used by /// [`LmcsTree::prove_batch`](crate::lmcs::LmcsTree::prove_batch) and consumed by /// `open_batch` and [`Lmcs::read_batch_proof`]: -/// - For each *distinct* query index (in caller order, skipping duplicates): one row per matrix (in -/// leaf order), then `SALT_ELEMS` field elements of salt. -/// - After all indices: missing sibling hashes, level-by-level, left-to-right, bottom-to-top. +/// - For each distinct tree index (sorted ascending): one row per committed matrix (in committed +/// matrix order), then `SALT_ELEMS` field elements of salt. +/// - After all leaves: missing sibling hashes, level-by-level, left-to-right, bottom-to-top. /// /// Hints are not observed into the Fiat-Shamir challenger. /// -/// `open_batch` expects `widths` and `log_max_height` to match the committed tree, +/// `open_batch` expects `widths` and `indices.depth()` to match the committed tree, /// rejects empty `indices`, and ignores extra hint data. Widths must match the /// committed row lengths (including any alignment padding if `build_aligned_tree` /// was used). Duplicate indices are coalesced in the returned openings. /// [`read_batch_proof`](crate::lmcs::Lmcs::read_batch_proof) parses -/// the same hint stream, hashes leaves, and reconstructs per-index authentication paths -/// without verifying against a commitment. Empty indices yield an empty map, and +/// the same hint stream, hashes leaves, and reconstructs leaf authentication paths without +/// verifying against a commitment. Empty indices are accepted by `read_batch_proof`; /// out-of-range indices return `InvalidProof`. /// /// Padding note: @@ -156,22 +158,20 @@ where Hash::from(self.compress.compress([left_digest, right_digest])) } - /// Verify a batch opening from transcript hints. + /// Verify an exact batch opening from transcript hints. /// /// Security notes: - /// - `widths` and `log_max_height` must describe the committed tree; they are not checked. + /// - `widths` and `indices.depth()` must describe the committed tree; they are not checked. /// - `widths` must match the committed row lengths (including any alignment padding if /// `build_aligned_tree` was used); LMCS does not enforce that padded values are zero. /// Verifiers cannot distinguish zero padding from arbitrary values unless they check the /// opened rows or constrain them elsewhere. /// - Empty `indices` returns `InvalidProof`. - /// - Duplicate indices are coalesced in the returned map (unique keys only). - /// - Out-of-range indices (>= 2^log_max_height) return `InvalidProof`. /// - Missing siblings or malformed hints return `InvalidProof`. /// - Extra hints are ignored and left unread. /// - Returns `RootMismatch` only after a well-formed proof yields a different root. /// - /// Leaf openings are read in **sorted tree index order** (ascending, deduplicated). + /// Leaf openings are read in sorted tree index order (ascending, deduplicated). fn open_batch( &self, commitment: &Self::Commitment, @@ -186,15 +186,15 @@ where return Err(LmcsError::InvalidProof); } - // 1. Read openings and hash each into a leaf hash. - let mut opened_rows: BTreeMap> = BTreeMap::new(); + // 1. Read one opening per unique leaf and hash it. + let mut leaf_rows: BTreeMap> = BTreeMap::new(); let mut leaf_hashes: Vec<(usize, Self::Commitment)> = Vec::with_capacity(indices.len()); - for &index in indices.iter() { + for &leaf in indices.iter() { let opening = LeafOpening::<_, SALT_ELEMS>::read_from_channel(widths.to_vec(), channel)?; - leaf_hashes.push((index, opening.leaf_hash(self))); - opened_rows.insert(index, opening.rows); + leaf_hashes.push((leaf, opening.leaf_hash(self))); + leaf_rows.insert(leaf, opening.rows); } // 2. Recompute root by streaming siblings directly from the channel. @@ -210,10 +210,10 @@ where return Err(LmcsError::RootMismatch); } - Ok(opened_rows) + Ok(leaf_rows) } - /// Parse batch hints into per-index opening proofs. + /// Parse batch hints into per-leaf opening proofs. /// /// Reads openings, hashes leaves, builds a pruned tree, and extracts /// authentication paths. Salt is stored as `Vec` in the output. @@ -233,11 +233,11 @@ where let mut openings = BTreeMap::new(); let mut leaf_hashes: Vec<(usize, Self::Commitment)> = Vec::with_capacity(indices.len()); - for &index in indices.iter() { + for &leaf in indices.iter() { let opening = LeafOpening::<_, SALT_ELEMS>::read_from_channel(widths.to_vec(), channel)?; - leaf_hashes.push((index, opening.leaf_hash(self))); - openings.insert(index, opening); + leaf_hashes.push((leaf, opening.leaf_hash(self))); + openings.insert(leaf, opening); } // 2. Build PrunedTree from leaf hashes + channel siblings. @@ -263,15 +263,13 @@ where mod tests { use alloc::vec; + use miden_lifted_air::log2_strict_u8; use miden_stark_transcript::{ProverTranscript, TranscriptData, VerifierTranscript}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; use super::*; - use crate::{ - lmcs::{LmcsTree, utils::log2_strict_u8}, - testing::configs::goldilocks_poseidon2 as gl, - }; + use crate::{lmcs::LmcsTree, testing::configs::goldilocks_poseidon2 as gl}; fn small_matrix(height: usize, width: usize, seed: u64) -> RowMajorMatrix { let values = (0..height * width).map(|i| gl::Felt::from_u64(seed + i as u64)).collect(); @@ -343,7 +341,7 @@ mod tests { let mut verifier_channel = gl::verifier_channel(&transcript); let wrong_tree = lmcs.build_tree(vec![small_matrix(4, 2, 999)]); assert_eq!( - lmcs.open_batch(&wrong_tree.root(), &widths, &tree_indices_0, &mut verifier_channel,), + lmcs.open_batch(&wrong_tree.root(), &widths, &tree_indices_0, &mut verifier_channel), Err(LmcsError::RootMismatch) ); @@ -354,7 +352,7 @@ mod tests { let truncated = TranscriptData::new(fields, commitments); let mut verifier_channel = gl::verifier_channel(&truncated); assert_eq!( - lmcs.open_batch(&commitment, &widths, &tree_indices_0, &mut verifier_channel,), + lmcs.open_batch(&commitment, &widths, &tree_indices_0, &mut verifier_channel), Err(LmcsError::TranscriptError( miden_stark_transcript::TranscriptError::NoMoreCommitments )) @@ -370,6 +368,78 @@ mod tests { ); } + #[test] + fn virtual_lifted_indices_fold_to_committed_leaves() { + let lmcs = gl::test_lmcs(); + let tree = lmcs.build_tree(vec![small_matrix(4, 2, 0)]); + let widths = tree.aligned_widths(); + let commitment = tree.root(); + let tree_log_height = log2_strict_u8(tree.height()); + let query_depth = tree_log_height + 1; + let indices = TreeIndices::new([0usize, 4, 5, 7], query_depth).unwrap(); + + let make_transcript = || { + let mut prover_channel = gl::prover_channel(); + tree.prove_lifted_batch(&indices, &mut prover_channel); + prover_channel.finalize() + }; + + let (prover_digest, transcript) = make_transcript(); + let mut verifier_channel = gl::verifier_channel(&transcript); + let opened = lmcs + .open_lifted_batch( + &commitment, + &widths, + &indices, + tree_log_height, + &mut verifier_channel, + ) + .unwrap(); + + assert_eq!(opened[&0], tree.aligned_rows(0)); + assert_eq!(opened[&4], tree.aligned_rows(0)); + assert_eq!(opened[&5], tree.aligned_rows(1)); + assert_eq!(opened[&7], tree.aligned_rows(3)); + let verifier_digest = + verifier_channel.finalize().expect("transcript should finalize cleanly"); + assert_eq!(prover_digest, verifier_digest); + + let (_, transcript) = make_transcript(); + let mut verifier_channel = gl::verifier_channel(&transcript); + let batch = lmcs + .read_lifted_batch_proof(&widths, &indices, tree_log_height, &mut verifier_channel) + .unwrap(); + assert_eq!(batch.openings.len(), 3); + assert!(batch.openings.contains_key(&0)); + assert!(batch.openings.contains_key(&1)); + assert!(batch.openings.contains_key(&3)); + + let invalid_tree_log_height = query_depth + 1; + let (_, transcript) = gl::prover_channel().finalize(); + let mut verifier_channel = gl::verifier_channel(&transcript); + assert_eq!( + lmcs.open_lifted_batch( + &commitment, + &widths, + &indices, + invalid_tree_log_height, + &mut verifier_channel, + ), + Err(LmcsError::InvalidProof) + ); + + let mut verifier_channel = gl::verifier_channel(&transcript); + assert!(matches!( + lmcs.read_lifted_batch_proof( + &widths, + &indices, + invalid_tree_log_height, + &mut verifier_channel + ), + Err(LmcsError::InvalidProof) + )); + } + /// Reproduces the "root mismatch" bug when using Goldilocks + Blake3 (byte-based hash). /// /// The lifted STARK only tests with field-based Poseidon2, never with byte-based hashes. diff --git a/stark/miden-lifted-stark/src/lmcs/hiding_config.rs b/stark/miden-lifted-stark/src/lmcs/hiding_config.rs index 26ed69dc78..67e0afaf72 100644 --- a/stark/miden-lifted-stark/src/lmcs/hiding_config.rs +++ b/stark/miden-lifted-stark/src/lmcs/hiding_config.rs @@ -13,9 +13,12 @@ use rand::{ distr::{Distribution, StandardUniform}, }; -use crate::lmcs::{ - Lmcs, LmcsError, OpenedRows, bitrev::BitReversibleMatrix, config::LmcsConfig, - lifted_tree::LiftedMerkleTree, proof::BatchProof, tree_indices::TreeIndices, +use crate::{ + lmcs::{ + Lmcs, LmcsError, OpenedRows, config::LmcsConfig, lifted_tree::LiftedMerkleTree, + proof::BatchProof, tree_indices::TreeIndices, + }, + util::bitrev::BitReversibleMatrix, }; /// Configuration for hiding LMCS with random salt. diff --git a/stark/miden-lifted-stark/src/lmcs/lifted_tree.rs b/stark/miden-lifted-stark/src/lmcs/lifted_tree.rs index 268a10aa73..c45aa80310 100644 --- a/stark/miden-lifted-stark/src/lmcs/lifted_tree.rs +++ b/stark/miden-lifted-stark/src/lmcs/lifted_tree.rs @@ -10,9 +10,9 @@ use p3_symmetric::{Hash, PseudoCompressionFunction}; use p3_util::{log2_strict_usize, reverse_bits_len}; use tracing::info_span; -use crate::lmcs::{ - LmcsTree, bitrev::BitReversibleMatrix, proof::LeafOpening, row_list::RowList, - tree_indices::TreeIndices, +use crate::{ + lmcs::{LmcsTree, proof::LeafOpening, row_list::RowList, tree_indices::TreeIndices}, + util::bitrev::BitReversibleMatrix, }; /// A uniform binary Merkle tree whose leaves are constructed from matrices with power-of-two @@ -61,22 +61,25 @@ use crate::lmcs::{ /// /// ## Transcript Hints /// -/// `prove_batch` streams transcript hints in the format expected by +/// `prove_batch` streams exact transcript hints in the format expected by /// [`Lmcs::open_batch`](crate::lmcs::Lmcs::open_batch): -/// - For each unique query index **in sorted tree index order** (ascending, deduplicated): one row -/// per matrix (in leaf order), then `SALT_ELEMS` field elements of salt. +/// - For each unique tree index **in sorted tree index order** (ascending, deduplicated): one row +/// per committed matrix (in committed matrix order), then `SALT_ELEMS` field elements of salt. /// - Each row is padded with explicit zeros to the LMCS alignment. This allows verifiers to absorb /// fixed-size chunks without special-casing the final partial chunk; padding is not enforced to /// be zero. /// - After all indices: missing sibling hashes, level-by-level, left-to-right, bottom-to-top. /// +/// Use [`LmcsTree::prove_lifted_batch`] to open query indices from a larger domain against a +/// shorter committed tree. +/// /// Hints are not observed into the Fiat-Shamir challenger. /// /// This generally shouldn't be used directly. If you're using a Merkle tree as an MMCS, /// see the MMCS wrapper types. #[derive(Debug)] pub struct LiftedMerkleTree { - /// All leaf matrices in insertion order. + /// All committed matrices in insertion order. /// /// Matrices must be sorted by height (shortest to tallest) and all heights must be /// powers of two. Each matrix's rows are absorbed into sponge states that are @@ -145,17 +148,24 @@ where self.leaves.iter().map(Matrix::width).collect() } - /// Prove a batch opening and stream it into a transcript channel. + /// Prove an exact batch opening and stream it into a transcript channel. /// - /// Panics if any index is out of range. Rows are padded to `alignment` and those - /// padding values are not validated by verification; callers that require zero - /// padding must check the opened rows explicitly. + /// Panics if `indices.depth()` is not this tree's depth or any index is out of range. Rows are + /// padded to `alignment` and those padding values are not validated by verification; callers + /// that require zero padding must check the opened rows explicitly. /// /// Leaf openings are written in **sorted tree index order** (ascending, deduplicated). fn prove_batch(&self, indices: &TreeIndices, channel: &mut Ch) where Ch: ProverChannel>, { + let tree_log_height = log2_strict_usize(self.height()) as u8; + assert_eq!( + indices.depth(), + tree_log_height, + "exact batch indices must be in the committed tree's index space", + ); + // Stream leaf openings in sorted tree index order. for &index in indices.iter() { let opening = LeafOpening { @@ -536,10 +546,11 @@ mod tests { use super::*; use crate::{ - lmcs::{tests::build_leaves_single, utils::aligned_len}, + lmcs::tests::build_leaves_single, testing::configs::goldilocks_poseidon2::{ self as gl, DIGEST, Felt, PackedFelt, RATE, Sponge, }, + util::align::aligned_len, }; /// Common matrix group scenarios for testing lifting with varying heights. diff --git a/stark/miden-lifted-stark/src/lmcs/merkle_witness.rs b/stark/miden-lifted-stark/src/lmcs/merkle_witness.rs index 666acf421c..4d81e98363 100644 --- a/stark/miden-lifted-stark/src/lmcs/merkle_witness.rs +++ b/stark/miden-lifted-stark/src/lmcs/merkle_witness.rs @@ -35,7 +35,7 @@ impl MerkleWitness { /// /// `fetch_sibling` is called with the [`NodeId`] of each missing sibling, /// level-by-level, left-to-right, bottom-to-top, matching transcript order. - pub fn build( + pub(super) fn build( leaves: impl IntoIterator, tree_depth: usize, mut fetch_sibling: impl FnMut(NodeId) -> Result, @@ -95,6 +95,10 @@ impl MerkleWitness { /// Authentication path for a leaf index (sibling hashes, bottom-to-top). pub fn path(&self, index: usize) -> Option> { + if index >= (1usize << self.tree_depth) { + return None; + } + let mut path = Vec::with_capacity(self.tree_depth); let mut id = NodeId::new(self.tree_depth, index); for _ in 0..self.tree_depth { @@ -158,5 +162,6 @@ mod tests { // path: sibling hashes from leaf to root assert_eq!(tree.path(0).unwrap(), vec![2, 7]); assert_eq!(tree.path(3).unwrap(), vec![3, 3]); + assert!(tree.path(4).is_none()); } } diff --git a/stark/miden-lifted-stark/src/lmcs/mod.rs b/stark/miden-lifted-stark/src/lmcs/mod.rs index 422ea7c9c2..3b5c1b1c08 100644 --- a/stark/miden-lifted-stark/src/lmcs/mod.rs +++ b/stark/miden-lifted-stark/src/lmcs/mod.rs @@ -12,7 +12,7 @@ //! - [`Lmcs`]: Trait for LMCS configurations, providing type-erased access to commitment //! operations. //! - [`LmcsTree`]: Trait for built LMCS trees, providing opening operations. -//! - [`lifted_tree::LiftedMerkleTree`]: The underlying Merkle tree data structure. +//! - `lifted_tree::LiftedMerkleTree`: The underlying Merkle tree data structure. //! - [`proof::Proof`]: Single-opening proof with rows, optional salt, and authentication path. //! - [`proof::BatchProof`]: Batch opening data with Merkle witness for path extraction. //! @@ -69,23 +69,20 @@ //! This equivalence follows from the bit-reversal identity: for `r = N/n = 2^k`, //! `bitrev_N(i) mod n = bitrev_n(i >> k)`. -pub mod bitrev; pub mod config; pub mod hiding_config; -pub mod lifted_tree; +pub(crate) mod lifted_tree; pub mod merkle_witness; -pub mod node_id; +pub(crate) mod node_id; pub mod proof; pub mod row_list; -pub mod tree_indices; -pub mod utils; +pub(crate) mod tree_indices; #[cfg(test)] mod tests; use alloc::{collections::BTreeMap, vec::Vec}; -use bitrev::BitReversibleMatrix; use miden_stark_transcript::{ProverChannel, TranscriptError, VerifierChannel}; use p3_matrix::Matrix; use proof::BatchProofView; @@ -93,11 +90,13 @@ use row_list::RowList; use thiserror::Error; use tree_indices::TreeIndices; +use crate::util::{align::aligned_len, bitrev::BitReversibleMatrix}; + // ============================================================================ // Type Aliases // ============================================================================ -/// Opened rows keyed by leaf index, returned by [`Lmcs::open_batch`]. +/// Opened rows keyed by tree or query index, returned by LMCS opening APIs. pub type OpenedRows = BTreeMap>; // ============================================================================ @@ -114,13 +113,14 @@ pub trait Lmcs: Clone { type Commitment: Clone + Eq; /// Tree type (prover data), parameterized by stored matrix type. type Tree>: LmcsTree; - /// Batch witness type returned by [`read_batch_proof`](Self::read_batch_proof). + /// Batch witness type returned by [`read_batch_proof`](Self::read_batch_proof) and + /// [`read_lifted_batch_proof`](Self::read_lifted_batch_proof). type BatchProof: BatchProofView; /// Build a tree from domain-ordered matrices with no transcript padding (alignment = 1). /// /// The LMCS extracts the inner bit-reversed matrices via - /// [`BitReversibleMatrix::bit_reverse_rows`] and stores them. The tree is indexed + /// `BitReversibleMatrix::bit_reverse_rows` and stores them. The tree is indexed /// by domain order; [`LmcsTree::leaves`] returns the stored bit-reversed matrices. /// /// This affects only transcript hint formatting; the commitment root is unchanged. @@ -148,20 +148,22 @@ pub trait Lmcs: Clone { /// Compress two hashes into their parent (2-to-1 compression). fn compress(&self, left: Self::Commitment, right: Self::Commitment) -> Self::Commitment; - /// Open a batch proof by reading hint data from a transcript channel. + /// Open an exact batch proof by reading hint data from a transcript channel. /// /// The hint format is implementation-defined; callers must use the matching - /// `LmcsTree::prove_batch` implementation to produce compatible hints. + /// exact [`LmcsTree::prove_batch`] implementation to produce compatible hints. /// `widths` must match the committed tree (including any alignment padding - /// if `build_aligned_tree` was used). + /// if `build_aligned_tree` was used), and `indices` must already be in the + /// committed tree's own index space. /// /// # Preconditions - /// - `indices` must be non-empty and have depth matching `log₂(tree height)`. + /// - `indices` must be non-empty. + /// - `indices.depth()` must be the committed tree depth. /// /// # Postconditions - /// On success, the returned map contains exactly one entry per unique index. - /// Each entry's `RowList` has one row per width in `widths`, with that - /// row's length matching the corresponding width. + /// On success, the returned map is keyed by the exact tree indices — one + /// entry per unique index. Each entry's `RowList` has one row per width + /// in `widths`, with that row's length matching the corresponding width. fn open_batch( &self, commitment: &Self::Commitment, @@ -172,11 +174,37 @@ pub trait Lmcs: Clone { where Ch: VerifierChannel; - /// Parse a batch opening from transcript hints without verification. + /// Open a virtually lifted batch proof. + /// + /// `query_indices` live in the query domain. They are projected to + /// `tree_log_height`, opened with [`Self::open_batch`], then expanded back to the + /// original query indices. The returned map is keyed by the original query + /// indices, so callers can reduce all commitment groups uniformly. + /// + /// Returns [`LmcsError::InvalidProof`] if `tree_log_height > query_indices.depth()`. + fn open_lifted_batch( + &self, + commitment: &Self::Commitment, + widths: &[usize], + query_indices: &TreeIndices, + tree_log_height: u8, + channel: &mut Ch, + ) -> Result, LmcsError> + where + Ch: VerifierChannel, + { + let leaf_indices = query_indices.fold_to_depth(tree_log_height)?; + let rows_by_leaf = self.open_batch(commitment, widths, &leaf_indices, channel)?; + query_indices.expand_leaf_values(tree_log_height, &rows_by_leaf) + } + + /// Parse an exact batch opening from transcript hints without verification. /// /// Reads leaf openings and sibling hashes from the channel, hashes leaves, /// and reconstructs the Merkle witness. Does not verify against a commitment; - /// validation happens in [`open_batch`](Lmcs::open_batch). + /// validation happens in [`open_batch`](Lmcs::open_batch). The returned + /// witness and openings are keyed by the exact tree indices, because Merkle + /// paths are defined for leaves of the actual tree. /// /// Use [`merkle_witness::MerkleWitness::path`] on the returned witness to extract /// authentication paths. @@ -189,6 +217,27 @@ pub trait Lmcs: Clone { where Ch: VerifierChannel; + /// Parse a virtually lifted batch opening from transcript hints. + /// + /// `query_indices` are projected to `tree_log_height`, then parsed with + /// [`Self::read_batch_proof`]. The returned proof is the same `BatchProof` + /// type as the exact parser and remains keyed by projected tree indices. + /// + /// Returns [`LmcsError::InvalidProof`] if `tree_log_height > query_indices.depth()`. + fn read_lifted_batch_proof( + &self, + widths: &[usize], + query_indices: &TreeIndices, + tree_log_height: u8, + channel: &mut Ch, + ) -> Result + where + Ch: VerifierChannel, + { + let leaf_indices = query_indices.fold_to_depth(tree_log_height)?; + self.read_batch_proof(widths, &leaf_indices, channel) + } + /// Get the alignment used by `build_aligned_tree`. /// /// This is the hasher's rate, used to pad rows when streaming hints. @@ -227,19 +276,35 @@ pub trait LmcsTree { /// Get aligned widths for each committed matrix (padded to alignment). fn aligned_widths(&self) -> Vec { let alignment = self.alignment(); - self.widths().into_iter().map(|w| utils::aligned_len(w, alignment)).collect() + self.widths().into_iter().map(|w| aligned_len(w, alignment)).collect() } - /// Prove a batch opening and stream it into a transcript channel. + /// Prove an exact batch opening and stream it into a transcript channel. /// /// The hint format is implementation-defined and must be consumed by the - /// corresponding `Lmcs::open_batch` implementation. Rows are padded to the - /// tree's alignment before being written to the channel. + /// corresponding exact `Lmcs::open_batch` implementation. Rows are padded to + /// the tree's alignment before being written to the channel. `indices` must + /// already be in this tree's own index space. /// /// Leaf openings are written in **sorted tree index order** (ascending, deduplicated). fn prove_batch(&self, indices: &TreeIndices, channel: &mut Ch) where Ch: ProverChannel; + + /// Prove a virtually lifted batch opening. + /// + /// Projects `query_indices` to this tree's depth and then delegates to exact + /// [`Self::prove_batch`]. + fn prove_lifted_batch(&self, query_indices: &TreeIndices, channel: &mut Ch) + where + Ch: ProverChannel, + { + let tree_log_height = miden_lifted_air::log2_strict_u8(self.height()); + let leaf_indices = query_indices + .fold_to_depth(tree_log_height) + .expect("query index depth must be at least the committed tree depth"); + self.prove_batch(&leaf_indices, channel); + } } // ============================================================================ diff --git a/stark/miden-lifted-stark/src/lmcs/node_id.rs b/stark/miden-lifted-stark/src/lmcs/node_id.rs index 35331d31d1..d40dfdec13 100644 --- a/stark/miden-lifted-stark/src/lmcs/node_id.rs +++ b/stark/miden-lifted-stark/src/lmcs/node_id.rs @@ -12,7 +12,7 @@ /// - `depth()` = `ilog2(id)` /// - `position()` = `id − 2^depth` #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct NodeId(usize); +pub(super) struct NodeId(usize); impl NodeId { /// Create a node ID from a (depth, position) pair. @@ -26,25 +26,25 @@ impl NodeId { /// The depth in the tree (0 = root). #[inline] - pub const fn depth(&self) -> usize { + pub const fn depth(self) -> usize { self.0.ilog2() as usize } /// The position within the depth level. #[inline] - pub const fn position(&self) -> usize { + pub const fn position(self) -> usize { self.0 - (1 << self.depth()) } /// The sibling node (same depth, position XOR 1). #[inline] - pub const fn sibling(&self) -> Self { + pub const fn sibling(self) -> Self { Self(self.0 ^ 1) } /// The parent node (depth − 1, position >> 1). #[inline] - pub const fn parent(&self) -> Self { + pub const fn parent(self) -> Self { Self(self.0 >> 1) } } diff --git a/stark/miden-lifted-stark/src/lmcs/proof.rs b/stark/miden-lifted-stark/src/lmcs/proof.rs index d542b130e7..527d04433b 100644 --- a/stark/miden-lifted-stark/src/lmcs/proof.rs +++ b/stark/miden-lifted-stark/src/lmcs/proof.rs @@ -1,7 +1,7 @@ //! LMCS proof structures. //! //! - [`Proof`]: Single-opening proof with rows, optional salt, and authentication path. -//! - [`BatchProof`]: Batch opening data with per-index rows/salt and a [`MerkleWitness`]. +//! - [`BatchProof`]: Batch opening data with per-leaf rows/salt and a [`MerkleWitness`]. //! //! Use [`Lmcs::read_batch_proof`] to parse transcript hints //! into a [`BatchProof`] without verifying against a commitment. @@ -29,8 +29,8 @@ pub struct Proof { /// Batch opening data parsed from transcript hints without verification. /// -/// Bundles opened leaf data (rows + salt) per index with the reconstructed -/// [`MerkleWitness`] for authentication path queries. +/// Bundles opened leaf data (rows + salt) with the reconstructed [`MerkleWitness`] for +/// authentication path queries. Indices here are leaf indices after any virtual-lift folding. pub struct BatchProof { /// Opened leaf data keyed by leaf index. pub openings: BTreeMap>, @@ -40,7 +40,7 @@ pub struct BatchProof { /// Accessor trait for batch proof data. /// -/// Provides read access to individual openings, authentication paths, and query indices. +/// Provides read access to individual openings, authentication paths, and leaf indices. /// This allows consumers (e.g. the Miden VM recursive verifier) to work with batch proofs /// through the opaque `Lmcs::BatchProof` associated type. pub trait BatchProofView { @@ -55,7 +55,7 @@ pub trait BatchProofView { /// Get the authentication path (bottom-to-top sibling hashes) for a given leaf index. fn path(&self, index: usize) -> Option>; - /// Iterate over the unique query indices (in sorted order). + /// Iterate over the unique leaf indices (in sorted order). fn indices(&self) -> impl Iterator + '_; } @@ -140,14 +140,13 @@ impl LeafOpening { #[cfg(test)] mod tests { + use miden_lifted_air::log2_strict_u8; use p3_matrix::dense::RowMajorMatrix; use rand::{SeedableRng, rngs::SmallRng}; use super::*; use crate::{ - lmcs::{ - LmcsTree, tests::roundtrip_open_batch, tree_indices::TreeIndices, utils::log2_strict_u8, - }, + lmcs::{LmcsTree, tests::roundtrip_open_batch, tree_indices::TreeIndices}, testing::configs::goldilocks_poseidon2 as gl, }; diff --git a/stark/miden-lifted-stark/src/lmcs/row_list.rs b/stark/miden-lifted-stark/src/lmcs/row_list.rs index 93f72d8064..1315826b43 100644 --- a/stark/miden-lifted-stark/src/lmcs/row_list.rs +++ b/stark/miden-lifted-stark/src/lmcs/row_list.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; -use crate::lmcs::utils::aligned_len; +use crate::util::align::aligned_len; /// Flat storage of variable-width rows. /// diff --git a/stark/miden-lifted-stark/src/lmcs/tests.rs b/stark/miden-lifted-stark/src/lmcs/tests.rs index b19d71014e..3aacb6281f 100644 --- a/stark/miden-lifted-stark/src/lmcs/tests.rs +++ b/stark/miden-lifted-stark/src/lmcs/tests.rs @@ -8,17 +8,17 @@ use gl::{ }; use hiding_config::HidingLmcsConfig; use lifted_tree::LiftedMerkleTree; +use miden_lifted_air::log2_strict_u8; use miden_stateful_hasher::{Alignable, StatefulHasher}; use p3_field::PrimeCharacteristicRing; use p3_matrix::dense::RowMajorMatrix; use rand::{RngExt, SeedableRng, rngs::SmallRng}; -use utils::{aligned_len, log2_strict_u8}; use super::*; // ============================================================================ // Test Helpers and Re-exports // ============================================================================ -use crate::testing::configs::goldilocks_poseidon2 as gl; +use crate::{testing::configs::goldilocks_poseidon2 as gl, util::align::aligned_len}; type OpenedRows = BTreeMap>; diff --git a/stark/miden-lifted-stark/src/lmcs/tree_indices.rs b/stark/miden-lifted-stark/src/lmcs/tree_indices.rs index 907a9f6b17..786b5830a6 100644 --- a/stark/miden-lifted-stark/src/lmcs/tree_indices.rs +++ b/stark/miden-lifted-stark/src/lmcs/tree_indices.rs @@ -1,19 +1,20 @@ -//! Validated Merkle tree leaf indices and missing sibling iteration. +//! Validated Merkle tree indices and missing sibling iteration. //! -//! [`TreeIndices`] bundles a sorted, deduplicated set of leaf indices with -//! the tree depth, enforcing the invariant that every index is in `0..2^depth`. +//! [`TreeIndices`] bundles a sorted, deduplicated set of domain indices with +//! their source depth, enforcing the invariant that every index is in `0..2^depth`. //! //! [`MissingSiblingsIter`] walks the tree upward from a set of leaf positions //! and yields the sibling nodes absent from the set — exactly the nodes whose //! hashes must be provided to reconstruct the root. -use alloc::vec::Vec; +use alloc::{collections::BTreeMap, vec::Vec}; use crate::lmcs::{LmcsError, node_id::NodeId}; -/// A validated set of Merkle tree leaf indices at a given depth. +/// A validated set of Merkle tree indices at a given source depth. /// -/// Invariants (enforced by [`new`](Self::new) and [`shrink_depth`](Self::shrink_depth)): +/// Invariants (enforced by [`new`](Self::new), [`fold_to_depth`](Self::fold_to_depth), +/// and [`shrink_depth`](Self::shrink_depth)): /// - `indices` is sorted ascending with no duplicates. /// - Every index satisfies `index < 2^depth`. /// @@ -64,30 +65,77 @@ impl TreeIndices { } /// Iterator over sibling nodes absent from this leaf set, bottom-to-top. - pub fn missing_siblings(&self) -> MissingSiblingsIter { + pub(super) fn missing_siblings(&self) -> MissingSiblingsIter { MissingSiblingsIter::new(&self.indices, self.depth) } - /// Map domain indices to folded domain indices `shift` levels down, in place. + /// Fold these source-domain indices onto a committed tree at `target_depth`. /// - /// In natural (domain) order, folding by `2^shift` maps each index to its - /// low `(depth - shift)` bits: `index & ((1 << (depth - shift)) - 1)`. - /// The depth is reduced and duplicates (from indices in the same coset) - /// are removed. + /// In natural (domain) order, a query index opens the leaf selected by its low + /// `target_depth` bits: `query & ((1 << target_depth) - 1)`. Returns the + /// deduplicated leaf indices, leaving `self` unchanged. Use + /// [`expand_leaf_values`](Self::expand_leaf_values) to re-key leaf-keyed + /// results back to the original query indices. /// - /// Unlike the bit-reversed right-shift, masking can reorder indices, - /// so `sort_unstable()` is needed before `dedup()`. + /// Returns `InvalidProof` if `target_depth` is above the current source depth. + pub fn fold_to_depth(&self, target_depth: u8) -> Result { + let mut leaf_indices = self.clone(); + leaf_indices.fold_in_place(target_depth)?; + Ok(leaf_indices) + } + + /// Re-key leaf-keyed values back to the original query indices. /// - /// Shrinking a root-only set (depth 0) has no effect. - pub fn shrink_depth(&mut self, shift: u8) { - let new_depth = self.depth.saturating_sub(shift); - let mask = (1usize << new_depth as usize) - 1; + /// Each query index maps to the leaf selected by its low `target_depth` bits; + /// `target_depth` must match the depth passed to the [`fold_to_depth`](Self::fold_to_depth) + /// that produced `leaf_values`. Returns `InvalidProof` if a required leaf is absent. + pub fn expand_leaf_values( + &self, + target_depth: u8, + leaf_values: &BTreeMap, + ) -> Result, LmcsError> { + let leaf_mask = (1usize << target_depth as usize) - 1; + self.indices + .iter() + .map(|&query| { + let value = leaf_values.get(&(query & leaf_mask)).ok_or(LmcsError::InvalidProof)?; + Ok((query, value.clone())) + }) + .collect() + } + + /// Map domain indices to folded domain indices at `target_depth`, in place. + /// + /// In natural (domain) order, folding maps each index to its low `target_depth` + /// bits: `index & ((1 << target_depth) - 1)`. The depth is reduced and duplicates + /// from indices in the same coset are removed. Returns `InvalidProof` if + /// `target_depth` is above the current source depth. + fn fold_in_place(&mut self, target_depth: u8) -> Result<(), LmcsError> { + if target_depth > self.depth { + return Err(LmcsError::InvalidProof); + } + + let mask = (1usize << target_depth as usize) - 1; for idx in &mut self.indices { *idx &= mask; } self.indices.sort_unstable(); self.indices.dedup(); - self.depth = new_depth; + self.depth = target_depth; + Ok(()) + } + + /// Map domain indices to folded domain indices `shift` levels down, in place. + /// + /// Shifts at or beyond the current depth collapse every index to depth 0, + /// where the tree has a single root/leaf at index 0 (for example, the + /// commitment tree of a one-row matrix). Use + /// [`fold_to_depth`](Self::fold_to_depth) when invalid target depths + /// should be rejected instead of saturated. + pub fn shrink_depth(&mut self, shift: u8) { + let target_depth = self.depth.saturating_sub(shift); + self.fold_in_place(target_depth) + .expect("target depth is derived from current depth"); } } @@ -122,7 +170,7 @@ impl TreeIndices { /// /// The gap never closes because each pair of siblings produces at most one /// parent, so `next_len` grows slower than `current.start` advances. -pub struct MissingSiblingsIter { +pub(super) struct MissingSiblingsIter { /// Shared buffer: `nodes[current]` are unprocessed nodes in the current /// layer; `nodes[..next_len]` accumulates their parents for the next layer. nodes: Vec, @@ -214,6 +262,29 @@ mod tests { assert!(TreeIndices::new([0, 4], 2).is_err()); } + #[test] + fn fold_to_depth_then_expand_leaf_values() { + // Query indices at depth 3 fold to their low 2 bits: 0,4→0; 5→1; 7→3. + let ti = TreeIndices::new([0, 4, 5, 7], 3).unwrap(); + let leaves = ti.fold_to_depth(2).unwrap(); + assert_eq!(leaves.iter().copied().collect::>(), [0, 1, 3]); + assert_eq!(leaves.depth(), 2); + + // Folding leaves the source untouched and rejects a depth above it. + assert_eq!(ti.iter().copied().collect::>(), [0, 4, 5, 7]); + assert_eq!(ti.depth(), 3); + assert!(ti.fold_to_depth(4).is_err()); + + // Each query index reads its low-2-bit leaf: 0,4→'a'; 5→'b'; 7→'c'. + let leaf_values: BTreeMap = + [(0, 'a'), (1, 'b'), (3, 'c')].into_iter().collect(); + let expanded = ti.expand_leaf_values(2, &leaf_values).unwrap(); + assert_eq!( + expanded.into_iter().collect::>(), + vec![(0, 'a'), (4, 'a'), (5, 'b'), (7, 'c')] + ); + } + #[test] fn shrink_depth() { // Domain indices [4,5,6,7] at depth 3: low-bit mask with new_depth=1 → mask=1. diff --git a/stark/miden-lifted-stark/src/lmcs/utils.rs b/stark/miden-lifted-stark/src/lmcs/utils.rs deleted file mode 100644 index 8ef58fc142..0000000000 --- a/stark/miden-lifted-stark/src/lmcs/utils.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! Utility functions for LMCS operations. - -use alloc::vec::Vec; -use core::array; - -use p3_field::PackedValue; -use p3_util::log2_strict_usize; - -/// Strict log₂ returning `u8`. -/// -/// Panics if `n` is not a power of two. -#[inline] -pub fn log2_strict_u8(n: usize) -> u8 { - log2_strict_usize(n) as u8 -} - -/// Extension trait for `PackedValue` providing columnar pack/unpack operations. -/// -/// These methods perform transpose operations on packed data, useful for -/// SIMD-parallelized Merkle tree construction. -pub trait PackedValueExt: PackedValue { - /// Pack columns from `WIDTH` rows of scalar values. - /// - /// Given `WIDTH` rows of `N` scalar values, extract each column and pack it - /// into a single packed value. This performs a transpose operation. - #[inline] - #[must_use] - fn pack_columns(rows: &[[Self::Value; N]]) -> [Self; N] { - assert_eq!(rows.len(), Self::WIDTH); - array::from_fn(|col| Self::from_fn(|lane| rows[lane][col])) - } -} - -// Blanket implementation for all PackedValue types -impl PackedValueExt for T {} - -/// Compute the aligned length for `len` given an alignment. -#[inline] -pub const fn aligned_len(len: usize, alignment: usize) -> usize { - if alignment <= 1 { - len - } else { - len.next_multiple_of(alignment) - } -} - -/// Align each width in place, returning the same `Vec`. -pub fn aligned_widths(mut widths: Vec, alignment: usize) -> Vec { - for w in &mut widths { - *w = aligned_len(*w, alignment); - } - widths -} diff --git a/stark/miden-lifted-stark/src/order.rs b/stark/miden-lifted-stark/src/order.rs new file mode 100644 index 0000000000..008595c37d --- /dev/null +++ b/stark/miden-lifted-stark/src/order.rs @@ -0,0 +1,408 @@ +//! Internal instance↔proof ordering helper: the crate-internal [`TraceOrder`] +//! plus the public [`ShapeError`]. +//! +//! The air crate's [`MultiAir`](miden_lifted_air::MultiAir) trait is order-agnostic — every list it +//! exposes is in **instance order** (the position returned by +//! [`MultiAir::airs`](miden_lifted_air::MultiAir::airs)). [`TraceOrder`] carries the permutation +//! between instance order and the proof's wire-format **proof order** (a deterministic stable sort +//! of the per-AIR heights), and validates the proof-supplied log heights against the AIRs at +//! construction. + +extern crate alloc; + +use alloc::vec::Vec; + +use miden_lifted_air::{LiftedAir, log2_strict_u8}; +use p3_challenger::CanObserve; +use p3_field::Field; +use thiserror::Error; + +// ============================================================================ +// TraceOrder +// ============================================================================ + +/// The permutation between **instance order** (AIR positions from +/// [`MultiAir::airs`](miden_lifted_air::MultiAir::airs)) and **proof order**, the wire-format +/// ordering used inside the prover/verifier. +/// +/// Proof order is the stable sort of instance indices by `(log_trace_height, +/// instance_index)`. Both sides recompute it from the heights, so the proof commits +/// to heights only. Use [`Self::to_proof_order`] / [`Self::to_instance_order`] (or +/// [`Self::reorder_to_proof_in_place`]) to move data between the two views. +#[derive(Clone, Debug)] +pub(crate) struct TraceOrder { + /// Log trace heights in instance order. + log_heights: Vec, + /// `instance_indices[j]` = instance index at proof position `j`. Length + /// matches `log_heights`. + instance_indices: Vec, +} + +impl TraceOrder { + /// Build from raw (non-log) trace heights in instance order, validated + /// against `airs`. + /// + /// Validates that every height is a power of two and at least 2, that the + /// log-height fits in `u8` and within the host's `usize` width, that the + /// number of instances fits in `u8`, and (via [`Self::from_log_heights`]) + /// that the heights match the AIRs. + pub(crate) fn from_trace_heights( + airs: &[A], + trace_heights: &[usize], + ) -> Result + where + F: Field, + A: LiftedAir, + { + if trace_heights.is_empty() { + return Err(ShapeError::Empty); + } + if trace_heights.len() > u8::MAX as usize + 1 { + return Err(ShapeError::TooManyInstances { count: trace_heights.len() }); + } + let log_heights: Vec = trace_heights + .iter() + .map(|&h| { + if !h.is_power_of_two() { + return Err(ShapeError::InvalidTraceHeight { height: h }); + } + Ok(log2_strict_u8(h)) + }) + .collect::>()?; + Self::from_log_heights::(airs, log_heights) + } + + /// Build from instance-order log trace heights, validated against `airs`. + /// + /// Used on the verifier side, where heights are read straight off the + /// (untrusted) proof as `u8`s. Power-of-two-ness is automatic (heights are + /// stored as log₂). Checks: non-emptiness, at least 2 rows per trace, + /// host-`usize` bound, the u8 instance-count limit, `airs.len()` matches the + /// height count, and per AIR `(1 << log_h) >= air.max_periodic_length()`. + /// Holding a `TraceOrder` thus guarantees the proof's heights are feasible + /// for the AIRs. + pub(crate) fn from_log_heights( + airs: &[A], + log_heights: Vec, + ) -> Result + where + F: Field, + A: LiftedAir, + { + if log_heights.is_empty() { + return Err(ShapeError::Empty); + } + if log_heights.len() > u8::MAX as usize + 1 { + return Err(ShapeError::TooManyInstances { count: log_heights.len() }); + } + let max_log = (usize::BITS - 1) as u8; + for (idx, &h) in log_heights.iter().enumerate() { + if h == 0 { + return Err(ShapeError::TraceHeightTooSmall { air: idx }); + } + if h > max_log { + return Err(ShapeError::LogTraceHeightTooLarge { log_h: h, max: max_log }); + } + } + if airs.len() != log_heights.len() { + return Err(ShapeError::TraceCountMismatch { + airs: airs.len(), + heights: log_heights.len(), + }); + } + for (idx, (air, &log_h)) in airs.iter().zip(log_heights.iter()).enumerate() { + let trace_height = 1usize << log_h as usize; + let max_period = air.max_periodic_length(); + if trace_height < max_period { + return Err(ShapeError::TraceHeightBelowPeriod { + air: idx, + trace_height, + max_period, + }); + } + } + let n = log_heights.len(); + // `0..n as u8` would wrap to an empty range at the boundary n == 256 + // (`256 as u8 == 0`), which the `TooManyInstances` guard above permits. + let mut instance_indices: Vec = (0..n).map(|i| i as u8).collect(); + instance_indices.sort_by_key(|&i| (log_heights[i as usize], i)); + Ok(Self { log_heights, instance_indices }) + } + + /// Number of AIR instances. + pub(crate) fn len(&self) -> usize { + self.log_heights.len() + } + + /// Log trace heights in instance order. Matches + /// [`MultiAir::airs`](miden_lifted_air::MultiAir::airs). + pub(crate) fn log_heights(&self) -> &[u8] { + &self.log_heights + } + + /// Instance indices in proof order: `instance_indices()[j]` is the + /// instance index of the AIR at proof position `j`. + pub(crate) fn instance_indices(&self) -> &[u8] { + &self.instance_indices + } + + /// Bind protocol-owned instance shape into Fiat-Shamir. + /// + /// The instance count is observed first, followed by log trace heights in + /// instance order. Proof order is derived deterministically from these + /// heights, so it is not observed separately. + pub(crate) fn observe_shape(&self, challenger: &mut C) + where + F: Field, + C: CanObserve, + { + challenger.observe(F::from_usize(self.len())); + for &log_h in self.log_heights() { + challenger.observe(F::from_u8(log_h)); + } + } + + /// Log trace heights in proof order (ascending by construction). + pub(crate) fn log_heights_proof(&self) -> Vec { + self.instance_indices.iter().map(|&i| self.log_heights[i as usize]).collect() + } + + /// The largest log trace height (= last entry of [`Self::log_heights_proof`]). + pub(crate) fn max_log_height(&self) -> u8 { + // `instance_indices` is non-empty (constructor rejects empty input). + let last = *self.instance_indices.last().expect("TraceOrder is non-empty"); + self.log_heights[last as usize] + } + + /// Reorder instance-order data to proof order, cloning. + /// + /// Returns a `Vec` of length [`Self::len`] where position `j` holds + /// `instance_data[instance_indices()[j]]`. + pub(crate) fn to_proof_order(&self, instance_data: &[T]) -> Vec { + debug_assert_eq!(instance_data.len(), self.len()); + self.instance_indices + .iter() + .map(|&i| instance_data[i as usize].clone()) + .collect() + } + + /// Permute `data` in place from instance order to proof order. + /// + /// After the call, `data[j] == data_original[instance_indices()[j]]`. + /// Avoids the clone in [`Self::to_proof_order`] for owned data like + /// `RowMajorMatrix`. + pub(crate) fn reorder_to_proof_in_place(&self, data: &mut [T]) { + assert_eq!(data.len(), self.len()); + let n = self.len(); + let perm = &self.instance_indices; + let mut visited = alloc::vec![false; n]; + // Cycle decomposition: each cycle of the permutation is rotated in + // place via swaps along the cycle. + for start in 0..n { + if visited[start] { + continue; + } + visited[start] = true; + let mut current = start; + loop { + let next = perm[current] as usize; + if next == start { + break; + } + data.swap(current, next); + visited[next] = true; + current = next; + } + } + } + + /// Reorder proof-order data back to instance order, cloning. + /// + /// Returns a `Vec` of length [`Self::len`] where position `i` holds the + /// element at the proof position whose instance index is `i`. + pub(crate) fn to_instance_order(&self, proof_data: &[T]) -> Vec { + debug_assert_eq!(proof_data.len(), self.len()); + let n = self.len(); + let mut out: Vec> = (0..n).map(|_| None).collect(); + for (j, &i) in self.instance_indices.iter().enumerate() { + out[i as usize] = Some(proof_data[j].clone()); + } + out.into_iter() + .map(|o| o.expect("instance_indices is a permutation of 0..n")) + .collect() + } + + /// AIR instance index backing each committed preprocessed trace. + /// + /// The preprocessed commitment contains one committed LDE trace per AIR with + /// [`preprocessed_width`](miden_lifted_air::LiftedAir::preprocessed_width) + /// `> 0`, in proof order (the LMCS height-monotone committed-trace order). The result + /// length is the number of preprocessed AIRs, which is `<= len()`. + pub(crate) fn preprocessed_air_for_trace_index(&self, airs: &[A]) -> Vec + where + F: Field, + A: LiftedAir, + { + self.instance_indices + .iter() + .copied() + .filter(|&i| airs[i as usize].preprocessed_width() > 0) + .collect() + } + + /// Preprocessed trace index for each AIR, or `None` when the AIR declares no + /// preprocessed columns. Length is [`Self::len`]; the inverse of + /// [`Self::preprocessed_air_for_trace_index`]. + pub(crate) fn preprocessed_trace_index_for_air( + &self, + airs: &[A], + ) -> Vec> + where + F: Field, + A: LiftedAir, + { + let air_for_preprocessed_trace = self.preprocessed_air_for_trace_index::(airs); + let mut v = alloc::vec![None; airs.len()]; + for (preprocessed_trace_idx, &air_idx) in air_for_preprocessed_trace.iter().enumerate() { + v[air_idx as usize] = Some(preprocessed_trace_idx); + } + v + } +} + +// ============================================================================ +// Errors +// ============================================================================ + +/// Errors from parsing or validating proof shape metadata (the +/// caller-order `&[u8]` of log trace heights carried on the proof). +#[derive(Debug, Error)] +pub enum ShapeError { + #[error("no instances provided")] + Empty, + #[error("trace height {height} is not a power of two")] + InvalidTraceHeight { height: usize }, + #[error("AIR {air}: trace height must be at least 2 rows")] + TraceHeightTooSmall { air: usize }, + #[error("log trace height {log_h} exceeds {max} (would overflow usize on this target)")] + LogTraceHeightTooLarge { log_h: u8, max: u8 }, + #[error("more than 256 instances ({count}) — exceeds the u8 caller-index limit")] + TooManyInstances { count: usize }, + #[error("airs().len() = {airs} does not match log trace heights length {heights}")] + TraceCountMismatch { airs: usize, heights: usize }, + #[error( + "AIR {air}: trace height = {trace_height} is less than max periodic column \ + length {max_period}" + )] + TraceHeightBelowPeriod { + air: usize, + trace_height: usize, + max_period: usize, + }, +} + +#[cfg(test)] +mod tests { + use alloc::{vec, vec::Vec}; + + use miden_lifted_air::{BaseAir, LiftedAirBuilder}; + use p3_goldilocks::Goldilocks; + use p3_matrix::dense::RowMajorMatrix; + + use super::*; + + type TF = Goldilocks; + + /// Minimal AIR with no periodic columns, for the ordering tests (which only + /// exercise the height permutation, not the periodic-feasibility check). + #[derive(Clone)] + struct OrderTestAir; + + impl BaseAir for OrderTestAir { + fn width(&self) -> usize { + 1 + } + } + + impl LiftedAir for OrderTestAir { + fn num_randomness(&self) -> usize { + 0 + } + fn aux_width(&self) -> usize { + 1 + } + fn num_aux_values(&self) -> usize { + 0 + } + fn build_aux_trace( + &self, + _main: &RowMajorMatrix, + _air_inputs: &[TF], + _aux_inputs: &[TF], + _challenges: &[TF], + ) -> (RowMajorMatrix, Vec) { + // Unused: these tests only exercise the height permutation. + (RowMajorMatrix::new(Vec::new(), 1), Vec::new()) + } + fn eval>(&self, _builder: &mut AB) {} + } + + fn airs(n: usize) -> Vec { + vec![OrderTestAir; n] + } + + #[test] + fn trace_order_canonical_ordering() { + // Instance order: heights [8, 2, 8, 4]. Sort by (log_h, idx) → + // [1 (log=1), 3 (log=2), 0 (log=3), 2 (log=3)]. + let order = TraceOrder::from_trace_heights::(&airs(4), &[8, 2, 8, 4]).unwrap(); + assert_eq!(order.instance_indices(), &[1, 3, 0, 2]); + assert_eq!(order.log_heights(), &[3, 1, 3, 2]); + assert_eq!(order.log_heights_proof(), vec![1, 2, 3, 3]); + assert_eq!(order.max_log_height(), 3); + } + + #[test] + fn trace_order_roundtrip() { + let order = TraceOrder::from_trace_heights::(&airs(4), &[8, 2, 8, 4]).unwrap(); + let instance_data = vec!["a", "b", "c", "d"]; + let proof_data = order.to_proof_order(&instance_data); + assert_eq!(proof_data, vec!["b", "d", "a", "c"]); + let back = order.to_instance_order(&proof_data); + assert_eq!(back, instance_data); + } + + #[test] + fn trace_order_reorder_in_place_matches_clone() { + // Mix of singletons and longer cycles in the permutation, plus + // ties broken by instance index. + let cases: &[&[usize]] = &[ + &[8, 2, 8, 4], + &[2, 4, 8, 16], // already sorted (identity permutation) + &[16, 8, 4, 2], // reverse-sorted + &[4, 4, 4, 4], // all equal (identity by tiebreak) + &[8, 2, 4, 16, 4], // mixed + ]; + for &heights in cases { + let order = + TraceOrder::from_trace_heights::(&airs(heights.len()), heights).unwrap(); + let instance_data: Vec = (0..heights.len()).collect(); + let expected = order.to_proof_order(&instance_data); + let mut data = instance_data; + order.reorder_to_proof_in_place(&mut data); + assert_eq!(data, expected, "in-place mismatch for heights {heights:?}"); + } + } + + #[test] + fn trace_order_accepts_max_instances() { + // The boundary n == 256 (= u8::MAX + 1) is the largest accepted count; + // index construction must not wrap `256 as u8` to an empty range. + let n = 256; + let order = TraceOrder::from_log_heights::(&airs(n), vec![1; n]).unwrap(); + assert_eq!(order.instance_indices().len(), n); + let mut seen = order.instance_indices().to_vec(); + seen.sort_unstable(); + assert!(seen.iter().copied().eq(0..=u8::MAX), "indices must be a permutation of 0..=255"); + } +} diff --git a/stark/miden-lifted-stark/src/pcs/deep/interpolate.rs b/stark/miden-lifted-stark/src/pcs/deep/interpolate.rs index 38aa1005cb..93c8761c15 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/interpolate.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/interpolate.rs @@ -205,7 +205,7 @@ mod tests { use alloc::vec; use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; - use p3_field::{Field, PrimeCharacteristicRing}; + use p3_field::PrimeCharacteristicRing; use p3_interpolation::{interpolate_coset, interpolate_coset_with_precomputation}; use p3_matrix::{bitrev::BitReversibleMatrix, dense::RowMajorMatrix}; use p3_util::reverse_slice_index_bits; @@ -213,8 +213,11 @@ mod tests { use super::*; use crate::{ - pcs::utils::bit_reversed_coset_points, - testing::configs::goldilocks_poseidon2::{Felt, QuadFelt}, + domain::Coset, + testing::{ + canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt}, + }, }; /// Verify `batch_eval_lifted` matches `interpolate_coset` for various lift factors. @@ -229,10 +232,11 @@ mod tests { let log_blowup = 2; let log_n = 8; // Full LDE domain size = 256 let n = 1 << log_n; - let shift = Felt::GENERATOR; + let domain = canonical_domain::(log_n, 0); + let shift = domain.lde_shift(); // Coset points in bit-reversed order for our barycentric evaluation - let coset_points_br = bit_reversed_coset_points::(log_n); + let coset_points_br = domain.lde_coset().bit_reversed_points(); // Random out-of-domain evaluation point let z: QuadFelt = rng.sample(StandardUniform); @@ -305,10 +309,11 @@ mod tests { let log_blowup = 2; let log_n = 8; let n = 1 << log_n; - let shift = Felt::GENERATOR; + let domain = canonical_domain::(log_n, 0); + let shift = domain.lde_shift(); // Coset points in both orderings - let coset_points_br = bit_reversed_coset_points::(log_n); + let coset_points_br = domain.lde_coset().bit_reversed_points(); let mut coset_points_std = coset_points_br.clone(); reverse_slice_index_bits(&mut coset_points_std); // Convert to standard order @@ -371,10 +376,11 @@ mod tests { let log_blowup = 2; let log_n = 8; let n = 1 << log_n; - let shift = Felt::GENERATOR; + let domain = canonical_domain::(log_n, 0); + let shift = domain.lde_shift(); // Coset points in bit-reversed order - let coset_points_br = bit_reversed_coset_points::(log_n); + let coset_points_br = domain.lde_coset().bit_reversed_points(); // Two random out-of-domain evaluation points let z1: QuadFelt = rng.sample(StandardUniform); @@ -454,9 +460,10 @@ mod tests { let log_blowup = 2; let log_n = 8; let n = 1 << log_n; - let shift = Felt::GENERATOR; + let domain = canonical_domain::(log_n, 0); + let shift = domain.lde_shift(); - let coset_points_br = bit_reversed_coset_points::(log_n); + let coset_points_br = domain.lde_coset().bit_reversed_points(); let z1: QuadFelt = rng.sample(StandardUniform); let z2: QuadFelt = rng.sample(StandardUniform); diff --git a/stark/miden-lifted-stark/src/pcs/deep/mod.rs b/stark/miden-lifted-stark/src/pcs/deep/mod.rs index 4eb8551bd1..da0a7ad462 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/mod.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/mod.rs @@ -49,10 +49,10 @@ //! - Prover LDE evaluation (`prover::DeepPoly::from_trees` via explicit dot-product with reversed //! negated coefficients — see comments there) -pub mod interpolate; -pub mod proof; -pub mod prover; -pub mod verifier; +pub(crate) mod interpolate; +pub(crate) mod proof; +pub(crate) mod prover; +pub(crate) mod verifier; use alloc::vec::Vec; @@ -66,7 +66,7 @@ use proof::OpenedValues; /// Controls proof-of-work grinding for DEEP challenge sampling. /// Column alignment is handled at the LMCS layer and by padding evaluations. #[derive(Clone, Copy, Debug)] -pub struct DeepParams { +pub(crate) struct DeepParams { /// Grinding bits before DEEP challenge sampling. pub(crate) deep_pow_bits: usize, } @@ -76,7 +76,7 @@ pub struct DeepParams { /// The prover sends one flat slice per evaluation point containing all matrices' /// column values concatenated. This function splits by widths and reshapes into /// per-group, per-matrix `RowMajorMatrix` with `num_eval_points` rows each. -pub fn read_eval_matrices( +fn read_eval_matrices( group_widths: &[&[usize]], num_eval_points: usize, channel: &mut Ch, diff --git a/stark/miden-lifted-stark/src/pcs/deep/proof.rs b/stark/miden-lifted-stark/src/pcs/deep/proof.rs index 1f16bf7583..890059042c 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/proof.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/proof.rs @@ -1,4 +1,4 @@ -//! DEEP transcript data structures. +//! DEEP structured proof types — parsed view of the DEEP sub-transcript. use alloc::vec::Vec; @@ -14,7 +14,7 @@ use crate::pcs::deep::{DeepParams, read_eval_matrices}; /// where `g` is the commitment group index and `m` the matrix index within that group. pub type OpenedValues = Vec>>; -/// Structured transcript view for the DEEP interaction. +/// Structured view of the DEEP sub-proof. /// /// This records the prover's PoW witness and the two challenges sampled /// from the Fiat-Shamir transcript after observing evaluations. @@ -22,7 +22,7 @@ pub type OpenedValues = Vec>>; /// `evals[g][m]` is a `RowMajorMatrix` with `num_eval_points` rows for /// commitment group `g`, matrix `m`. Widths include alignment padding (matching /// the committed rows). -pub struct DeepTranscript> { +pub struct DeepProof> { /// `evals[g][m]` is a `RowMajorMatrix` with `num_eval_points` rows. pub evals: OpenedValues, /// Proof-of-work witness sampled before DEEP challenges. @@ -33,19 +33,19 @@ pub struct DeepTranscript> { pub challenge_points: EF, } -impl DeepTranscript +impl DeepProof where F: TwoAdicField, EF: ExtensionField, { - /// Parse DEEP transcript data from a verifier channel. + /// Parse a [`DeepProof`] from a verifier channel. /// /// Reads OOD evaluations, verifies the PoW witness, and samples batching /// challenges. Does not verify the DEEP quotient itself; that validation /// happens in the DEEP verifier. Commitment widths must match the /// committed rows (including any alignment padding). - pub fn from_verifier_channel( - params: &DeepParams, + pub(in crate::pcs) fn read_from_channel( + params: DeepParams, commitments: &[(::Commitment, Vec)], num_eval_points: usize, channel: &mut Ch, diff --git a/stark/miden-lifted-stark/src/pcs/deep/prover.rs b/stark/miden-lifted-stark/src/pcs/deep/prover.rs index 710cbd51d9..72258c027a 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/prover.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/prover.rs @@ -10,15 +10,10 @@ use p3_maybe_rayon::prelude::*; use tracing::info_span; use crate::{ - lmcs::{ - Lmcs, LmcsTree, - row_list::RowList, - utils::{aligned_widths, log2_strict_u8}, - }, - pcs::{ - deep::{DeepParams, interpolate::PointQuotients}, - utils::{PackedFieldExtensionExt, bit_reversed_coset_points, horner}, - }, + domain::{Coset, LiftedDomain}, + lmcs::{Lmcs, LmcsTree, row_list::RowList}, + pcs::deep::{DeepParams, interpolate::PointQuotients}, + util::{align::aligned_widths, horner::horner, packing::PackedFieldExtensionExt}, }; /// The DEEP quotient `Q(X)` evaluated over the LDE domain. @@ -47,16 +42,20 @@ pub struct DeepPoly { impl DeepPoly { /// Construct `Q(X)` by evaluating trace trees at the opening points. /// - /// This computes the LDE coset points from the trace tree height, evaluates the committed + /// This computes the LDE coset points from `domain`, evaluates the committed /// matrices at `eval_points`, and then calls [`Self::from_evals`]. /// - /// Preconditions: `eval_points` must be distinct and lie outside the trace subgroup `H` - /// and LDE evaluation coset `gK`. The outer protocol is expected to enforce this. + /// Preconditions: + /// - `eval_points` must be distinct and lie outside the trace subgroup `H` and LDE evaluation + /// coset `gK`. The outer protocol is expected to enforce this. + /// - Every trace tree height must be a power of two `≤ domain.lde_height()`. Shorter trees are + /// virtually lifted to the max domain during batched evaluation. + /// - At least one trace tree must have height `domain.lde_height()`. pub fn from_trees( params: DeepParams, + domain: &LiftedDomain, trace_trees: &[&L::Tree], eval_points: [EF; N], - log_blowup: u8, channel: &mut Ch, ) -> Self where @@ -66,14 +65,21 @@ impl DeepPoly { M: Matrix, Ch: ProverChannel, { - let lde_height = trace_trees.first().expect("at least one trace tree required").height(); + assert!(!trace_trees.is_empty(), "at least one trace tree required"); + let lde_height = domain.lde_height(); assert!( - trace_trees.iter().all(|tree| tree.height() == lde_height), - "mixed trace tree heights are not supported" + trace_trees + .iter() + .all(|tree| tree.height().is_power_of_two() && tree.height() <= lde_height), + "tree heights must be powers of two ≤ the max LDE height" + ); + assert!( + trace_trees.iter().any(|tree| tree.height() == lde_height), + "at least one tree must fill the max LDE height" ); - let log_lde_height = log2_strict_u8(lde_height); - let coset_points = bit_reversed_coset_points::(log_lde_height); + let coset_points = domain.lde_coset().bit_reversed_points(); + let log_blowup = domain.log_blowup(); let matrices_groups: Vec> = trace_trees.iter().map(|tree| tree.leaves().iter().collect()).collect(); @@ -131,9 +137,8 @@ impl DeepPoly { "mixed trace tree alignments are not supported" ); - // Collect the LDE matrices from each committed tree, grouped by commitment. - // matrices_groups[group_idx][matrix_idx] is a reference to the LDE matrix - // whose rows are bit-reversed coset evaluations at height `lde_height`. + // Collect LDE matrices grouped by commitment. Each matrix is at its + // committed tree height; shorter groups are lifted when combined/opened. let matrices_groups: Vec> = trace_trees.iter().map(|tree| tree.leaves().iter().collect()).collect(); @@ -207,23 +212,52 @@ impl DeepPoly { // full-domain allocation and improving cache locality. let deep_evals = info_span!("DEEP reduce + assemble").in_scope(|| { + // Fast path: when every matrix has the same height (the common single-AIR / + // uniform-height case for `prove_*`), accumulate every group directly into one + // shared buffer. + let uniform_height = + matrices_groups.iter().flat_map(|g| g.iter()).all(|m| m.height() == n); + let mut neg_column_coeffs_iter = neg_column_coeffs.iter(); - let mut neg_f_reduced = zip(matrices_groups.iter(), &group_sizes) - .map(|(matrices_group, &size)| { + let mut neg_f_reduced = if uniform_height { + let mut acc = EF::zero_vec(n); + let mut packed_coeffs: Vec = Vec::new(); + for (matrices_group, &size) in zip(matrices_groups.iter(), &group_sizes) { let group_coeffs: Vec<&Vec> = neg_column_coeffs_iter.by_ref().take(size).collect(); - accumulate_matrices(matrices_group, &group_coeffs) - }) - .reduce(|mut acc, next| { - debug_assert_eq!(acc.len(), next.len()); - acc.par_chunks_mut(w).zip(next.par_chunks(w)).for_each( - |(acc_chunk, next_chunk)| { - EF::add_slices(acc_chunk, next_chunk); - }, - ); - acc - }) - .unwrap_or_else(|| EF::zero_vec(n)); + for (&matrix, coeffs) in zip(matrices_group, group_coeffs.iter()) { + let active_coeffs = &coeffs[..matrix.width()]; + packed_coeffs.clear(); + packed_coeffs.extend(active_coeffs.chunks(w).map(|chunk| { + if chunk.len() == w { + EF::ExtensionPacking::from_ext_slice(chunk) + } else { + let mut padded = EF::zero_vec(w); + padded[..chunk.len()].copy_from_slice(chunk); + EF::ExtensionPacking::from_ext_slice(&padded) + } + })); + matrix + .rowwise_packed_dot_product::(&packed_coeffs) + .zip(acc.par_iter_mut()) + .for_each(|(dot_result, acc_val)| { + *acc_val += dot_result; + }); + } + } + acc + } else { + zip(matrices_groups.iter(), &group_sizes) + .map(|(matrices_group, &size)| { + let group_coeffs: Vec<&Vec> = + neg_column_coeffs_iter.by_ref().take(size).collect(); + accumulate_matrices(matrices_group, &group_coeffs) + }) + // Combine groups of different heights; `add_lifted` lifts the shorter + // buffer onto the taller one before adding. + .reduce(|a, b| add_lifted(w, a, b)) + .unwrap_or_else(|| EF::zero_vec(n)) + }; // Pre-compute βʲ for all N points let point_coeffs: [EF; N] = @@ -300,7 +334,11 @@ fn accumulate_matrices, M: Matrix, C: AsRef<[ let n = matrices.last().unwrap().height(); let mut acc = EF::zero_vec(n); - let mut scratch = EF::zero_vec(n); + // When all matrices in this group share a single height, + // it is a tad more efficient to skip preallocating the buffer. + let mut scratch: Vec = Vec::new(); + let w = F::Packing::WIDTH; + let mut packed_coeffs: Vec = Vec::new(); let mut active_height = matrices.first().unwrap().height(); @@ -317,6 +355,9 @@ fn accumulate_matrices, M: Matrix, C: AsRef<[ // Upsample: [a, b] → [a, a, b, b] when height doubles if height > active_height { + if scratch.len() < height { + scratch.resize(height, EF::ZERO); + } let scaling_factor = height / active_height; scratch[..height] .par_chunks_mut(scaling_factor) @@ -327,21 +368,18 @@ fn accumulate_matrices, M: Matrix, C: AsRef<[ // SIMD path using horizontal packing. // Slice to matrix width to avoid packing alignment-padding coefficients. - let w = F::Packing::WIDTH; let active_coeffs = &coeffs[..matrix.width()]; - let packed_coeffs: Vec = active_coeffs - .chunks(w) - .map(|chunk| { - if chunk.len() == w { - EF::ExtensionPacking::from_ext_slice(chunk) - } else { - // Pad with zeros for the last chunk - let mut padded = EF::zero_vec(w); - padded[..chunk.len()].copy_from_slice(chunk); - EF::ExtensionPacking::from_ext_slice(&padded) - } - }) - .collect(); + packed_coeffs.clear(); + packed_coeffs.extend(active_coeffs.chunks(w).map(|chunk| { + if chunk.len() == w { + EF::ExtensionPacking::from_ext_slice(chunk) + } else { + // Pad with zeros for the last chunk + let mut padded = EF::zero_vec(w); + padded[..chunk.len()].copy_from_slice(chunk); + EF::ExtensionPacking::from_ext_slice(&padded) + } + })); matrix .rowwise_packed_dot_product::(&packed_coeffs) @@ -356,6 +394,34 @@ fn accumulate_matrices, M: Matrix, C: AsRef<[ acc } +/// Sum two DEEP reduced-eval buffers whose power-of-two heights may differ, lifting +/// the shorter onto the taller by bit-reversed nearest-neighbor repetition +/// (`long[j*r + k] += short[j]`). Returns the taller buffer, reused in place — no fresh +/// allocation. `w` is the base-field packing width for the equal-height SIMD add. +/// +/// `reduce` folds groups left-to-right, so a short group (e.g. a setup-fixed +/// preprocessed tree, opened first) is lifted into the first full-height group it meets. +fn add_lifted(w: usize, a: Vec, b: Vec) -> Vec { + let (mut long, short) = if a.len() >= b.len() { (a, b) } else { (b, a) }; + debug_assert!( + !short.is_empty() && long.len() % short.len() == 0, + "DEEP group heights must be nested powers of two" + ); + let r = long.len() / short.len(); + if r == 1 { + long.par_chunks_mut(w) + .zip(short.par_chunks(w)) + .for_each(|(x, y)| EF::add_slices(x, y)); + } else { + // `short[j]` covers the `r` contiguous slots `[j*r, (j+1)*r)` of the taller + // buffer — the bit-reversed nearest-neighbor lift, fused with the add. + long.par_chunks_mut(r) + .zip(short.par_iter()) + .for_each(|(chunk, &v)| chunk.iter_mut().for_each(|x| *x += v)); + } + long +} + #[cfg(test)] mod tests { use alloc::vec; diff --git a/stark/miden-lifted-stark/src/pcs/deep/tests.rs b/stark/miden-lifted-stark/src/pcs/deep/tests.rs index c7250c1df2..321f3b2080 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/tests.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/tests.rs @@ -2,17 +2,23 @@ use alloc::vec; -use proof::DeepTranscript; +use p3_util::reverse_bits_len; +use proof::DeepProof; use prover::DeepPoly; use rand::{RngExt, SeedableRng, distr::StandardUniform, prelude::SmallRng}; use verifier::DeepOracle; use super::*; use crate::{ + domain::LiftedDomain, lmcs::{Lmcs, LmcsTree, tree_indices::TreeIndices}, - testing::configs::goldilocks_poseidon2::{ - Felt, Lmcs as BaseLmcs, QuadFelt, prover_channel_with_commitment, test_lmcs, - verifier_channel_with_commitment, + pcs::verifier::CommitmentGroup, + testing::{ + canonical_domain, + configs::goldilocks_poseidon2::{ + Felt, Lmcs as BaseLmcs, QuadFelt, prover_channel_with_commitment, test_lmcs, + verifier_channel_with_commitment, + }, }, }; @@ -26,6 +32,7 @@ fn deep_quotient_end_to_end() { let log_blowup: u8 = 2; let log_lde_height: u8 = 10; let lde_height = 1 << log_lde_height as usize; + let max_domain: LiftedDomain = canonical_domain(log_lde_height - log_blowup, log_blowup); let params = DeepParams { deep_pow_bits: 1 }; // Two random opening points @@ -53,9 +60,9 @@ fn deep_quotient_end_to_end() { let trace_trees: &[&_] = &[&tree]; let deep_poly = DeepPoly::from_trees::( params, + &max_domain, trace_trees, [z1, z2], - log_blowup, &mut prover_channel, ); // Sample domain indices. The LMCS tree is indexed by domain order. @@ -66,12 +73,16 @@ fn deep_quotient_end_to_end() { let (prover_digest, transcript) = prover_channel.finalize(); // Create commitments slice for multi-commitment API (single commitment in this case) - let commitments = vec![(commitment, widths)]; + let commitments = vec![CommitmentGroup { + root: commitment, + widths, + log_height: log_lde_height, + }]; // Step 4: Verifier constructs DeepOracle with same transcript state let mut verifier_channel = verifier_channel_with_commitment(&transcript, &commitment); let (deep_oracle, _evals) = - DeepOracle::new(params, &[z1, z2], commitments, log_lde_height, &mut verifier_channel) + DeepOracle::new(params, &[z1, z2], commitments, &max_domain, &mut verifier_channel) .expect("DeepOracle construction should succeed"); // Step 5: Verify at multiple query tree indices (proofs are read from transcript) @@ -82,7 +93,7 @@ fn deep_quotient_end_to_end() { for &tree_idx in tree_indices.iter() { // Prover's deep_evals are in bit-reversed order internally: // deep_evals[bitrev(d)] = Q(g·ω^d). For domain index d, access bitrev(d). - let bitrev_idx = p3_util::reverse_bits_len(tree_idx, log_lde_height as usize); + let bitrev_idx = reverse_bits_len(tree_idx, log_lde_height as usize); let prover_eval = deep_poly.deep_evals[bitrev_idx]; let verifier_eval = verifier_evals[&tree_idx]; assert_eq!( @@ -94,14 +105,14 @@ fn deep_quotient_end_to_end() { let verifier_digest = verifier_channel.finalize().expect("transcript should finalize cleanly"); assert_eq!(prover_digest, verifier_digest); - // Re-parse DeepTranscript (DEEP phase only) from a fresh channel. + // Re-parse DeepProof (DEEP phase only) from a fresh channel. let reparse_commitments = vec![(commitment, tree.aligned_widths())]; let mut reparse_channel = verifier_channel_with_commitment(&transcript, &commitment); - DeepTranscript::::from_verifier_channel( - ¶ms, + DeepProof::::read_from_channel( + params, &reparse_commitments, 2, // num_eval_points &mut reparse_channel, ) - .expect("DeepTranscript re-parse should succeed"); + .expect("DeepProof re-parse should succeed"); } diff --git a/stark/miden-lifted-stark/src/pcs/deep/verifier.rs b/stark/miden-lifted-stark/src/pcs/deep/verifier.rs index c0b265fc36..b328688764 100644 --- a/stark/miden-lifted-stark/src/pcs/deep/verifier.rs +++ b/stark/miden-lifted-stark/src/pcs/deep/verifier.rs @@ -7,11 +7,13 @@ use p3_matrix::Matrix; use thiserror::Error; use crate::{ + domain::{Coset, LiftedDomain}, lmcs::{Lmcs, LmcsError, tree_indices::TreeIndices}, pcs::{ deep::{DeepParams, proof::OpenedValues, read_eval_matrices}, - utils::horner_acc, + verifier::CommitmentGroup, }, + util::horner::horner_acc, }; /// Verifier's view of the DEEP quotient as a point-query oracle. @@ -31,18 +33,19 @@ use crate::{ /// - reduces the opened row to `f_red(X)` using Horner with the same `α`, /// - reconstructs `Q(X)` and returns it to the FRI verifier. /// -/// Lifting is transparent at this layer: the prover commits to lifted codewords, so -/// every opened column is interpreted as a polynomial over the same max domain. -pub struct DeepOracle, L: Lmcs> { - /// Trace commitments with their widths (one per trace tree). +/// Full-height groups live on the max domain. Shorter groups are interpreted by +/// folding query indices to their committed depth. +pub(in crate::pcs) struct DeepOracle, L: Lmcs> { + /// Committed groups (root + widths + tree depth), one per tree. /// /// Widths must match the committed rows (including any alignment padding if - /// `build_aligned_tree` was used). - commitments: Vec<(L::Commitment, Vec)>, + /// `build_aligned_tree` was used). A group's `log_height` may be below the + /// query depth — a virtually-lifted, setup-fixed preprocessed tree. + commitments: Vec>, - /// Log₂ of the LDE domain height (tree has 2^log_lde_height leaves). - /// Verifier expects all commitments to be lifted to this same LDE height. - log_lde_height: u8, + /// Max LDE coset; query indices are sampled at `domain.log_lde_height()` and + /// folded down to each group's committed depth when shorter. + domain: LiftedDomain, /// Reduced openings: pairs of `(zⱼ, f_reduced(zⱼ))` from the prover's claims. reduced_openings: Vec<(EF, EF)>, @@ -59,28 +62,25 @@ impl, L: Lmcs> DeepOracle` with one row per evaluation point. pub fn new( params: DeepParams, eval_points: &[EF], - commitments: Vec<(L::Commitment, Vec)>, - log_lde_height: u8, + commitments: Vec>, + domain: &LiftedDomain, channel: &mut Ch, ) -> Result<(Self, OpenedValues), DeepError> where Ch: VerifierChannel, { - let group_widths: Vec<&[usize]> = commitments.iter().map(|(_, gw)| gw.as_slice()).collect(); + let group_widths: Vec<&[usize]> = commitments.iter().map(|g| g.widths.as_slice()).collect(); let evals = read_eval_matrices::(&group_widths, eval_points.len(), channel)?; // 1. Check grinding witness @@ -109,7 +109,7 @@ impl, L: Lmcs> DeepOracle, L: Lmcs> DeepOracle = tree_indices.iter().map(|&idx| (idx, EF::ZERO)).collect(); - for (group_idx, (commit, widths)) in self.commitments.iter().enumerate() { + for (group_idx, group) in self.commitments.iter().enumerate() { + // `open_lifted_batch` returns rows keyed by the original query indices, even when + // this group is committed at a shorter depth, so the reduction is uniform. let opened_rows = lmcs - .open_batch(commit, widths, tree_indices, channel) + .open_lifted_batch( + &group.root, + &group.widths, + tree_indices, + group.log_height, + channel, + ) .map_err(|source| DeepError::LmcsError { source, tree: group_idx })?; // Reduce opened rows via Horner: f_reduced(X) = Σᵢ αᵂ⁻¹⁻ⁱ · fᵢ(X). @@ -162,9 +170,6 @@ impl, L: Lmcs> DeepOracle, L: Lmcs> DeepOracle usize { + pub const fn arity(self) -> usize { 1 << self.log_arity as usize } #[inline] - pub const fn log_arity(&self) -> u8 { + pub const fn log_arity(self) -> u8 { self.log_arity } @@ -190,6 +190,7 @@ impl FriFold { pub mod tests { use alloc::vec::Vec; + use p3_dft::{NaiveDft, Radix2DFTSmallBatch, TwoAdicSubgroupDft}; use p3_field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use p3_matrix::dense::RowMajorMatrix; use p3_util::reverse_slice_index_bits; @@ -201,11 +202,12 @@ pub mod tests { use super::*; use crate::{ - pcs::utils::horner, + domain::{Coset, TwoAdicSubgroup}, testing::{ configs::goldilocks_poseidon2::{Felt, QuadFelt}, params::{FRI_FOLD_ARITY_2, FRI_FOLD_ARITY_4, FRI_FOLD_ARITY_8}, }, + util::horner::horner, }; // Type alias for tests using packed fields @@ -216,8 +218,6 @@ pub mod tests { /// Generates a random polynomial, computes evaluations on a coset using NaiveDft, /// then verifies fold_evals correctly recovers f(β). fn test_fold_evals_naive_dft(fold: FriFold) { - use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; - let mut rng = SmallRng::seed_from_u64(42); let arity = fold.arity(); @@ -258,15 +258,13 @@ pub mod tests { let rng = &mut SmallRng::seed_from_u64(1); let beta: Ext = rng.sample(StandardUniform); let arity = fold.arity(); - let log_arity = fold.log_arity() as usize; + let log_arity = fold.log_arity(); // Random polynomial of degree arity - 1 let poly: Vec = (0..arity).map(|_| rng.sample(StandardUniform)).collect(); - // Compute roots of unity in bit-reversed order for this arity - let mut roots: Vec = - Base::two_adic_generator(log_arity).powers().take(arity).collect(); - reverse_slice_index_bits(&mut roots); + // Roots of unity in bit-reversed order for this arity. + let roots: Vec = TwoAdicSubgroup::::new(log_arity).bit_reversed_points(); let s: Base = rng.sample(StandardUniform); let s_inv = s.inverse(); @@ -324,13 +322,13 @@ pub mod tests { fn test_folding_preserves_low_degree(fold: FriFold) { let rng = &mut SmallRng::seed_from_u64(42); let arity = fold.arity(); - let log_arity = fold.log_arity() as usize; + let log_arity: u8 = fold.log_arity(); - let log_blowup = 2; - let log_poly_degree = 4; // degree 16 polynomial - let poly_degree = 1 << log_poly_degree; - let log_lde_size = log_poly_degree + log_blowup; - let lde_size = 1 << log_lde_size; + let log_blowup: u8 = 2; + let log_poly_degree: u8 = 4; // degree 16 polynomial + let poly_degree = 1usize << log_poly_degree; + let log_lde_size: u8 = log_poly_degree + log_blowup; + let lde_size = 1usize << log_lde_size; // Generate random low-degree polynomial let coeffs: Vec = (0..poly_degree).map(|_| rng.sample(StandardUniform)).collect(); @@ -338,14 +336,14 @@ pub mod tests { // Compute LDE in bit-reversed order let mut full_coeffs = coeffs; full_coeffs.resize(lde_size, QuadFelt::ZERO); - let dft = p3_dft::Radix2DFTSmallBatch::::default(); - let mut evals = p3_dft::TwoAdicSubgroupDft::dft_algebra(&dft, full_coeffs); + let dft = Radix2DFTSmallBatch::::default(); + let mut evals = dft.dft_algebra(full_coeffs); reverse_slice_index_bits(&mut evals); // Compute s_invs let log_num_cosets = log_lde_size - log_arity; - let num_cosets = 1 << log_num_cosets; - let g_inv = Felt::two_adic_generator(log_lde_size).inverse(); + let num_cosets = 1usize << log_num_cosets; + let g_inv = TwoAdicSubgroup::::new(log_lde_size).generator_inverse(); let mut s_invs: Vec = g_inv.powers().take(num_cosets).collect(); reverse_slice_index_bits(&mut s_invs); @@ -357,7 +355,7 @@ pub mod tests { // IDFT the result to get coefficients let mut folded_for_idft = folded; reverse_slice_index_bits(&mut folded_for_idft); - let folded_coeffs = p3_dft::TwoAdicSubgroupDft::idft_algebra(&dft, folded_for_idft); + let folded_coeffs = dft.idft_algebra(folded_for_idft); // Check that all coefficients beyond degree/arity are zero let expected_degree = poly_degree / arity; diff --git a/stark/miden-lifted-stark/src/pcs/fri/mod.rs b/stark/miden-lifted-stark/src/pcs/fri/mod.rs index b89373ac8f..174919ad86 100644 --- a/stark/miden-lifted-stark/src/pcs/fri/mod.rs +++ b/stark/miden-lifted-stark/src/pcs/fri/mod.rs @@ -5,36 +5,50 @@ //! //! ## Domain Convention //! -//! This FRI implementation treats inputs as evaluations over the unshifted two-adic subgroup. -//! If the PCS evaluates over a coset `gK`, the shift is absorbed into the polynomial: -//! `Q'(X) = Q(g·X)`. The low-degree test is run on `Q'` using subgroup points. +//! This FRI implementation treats inputs as evaluations over the unshifted two-adic +//! subgroup. If the PCS evaluates over a coset `gK`, the shift is absorbed into +//! the polynomial: `Q'(X) = Q(g·X)`. The low-degree test is run on `Q'` using +//! subgroup points. +//! +//! ## Type vocabulary +//! +//! FRI takes its initial domain as a [`LiftedDomain`](crate::domain::LiftedDomain) — +//! the protocol-level LDE object carrying both the LDE subgroup size and the blowup +//! ratio relative to the trace. The unshifted subgroup view is reached via +//! `domain.lde_coset().subgroup()`. Each fold round shrinks the working subgroup by the +//! folding arity, derived via [`TwoAdicSubgroup::shrink`](crate::domain::TwoAdicSubgroup::shrink) +//! (or per-round generator squaring inside the round loop, which is equivalent and avoids +//! re-querying `F::two_adic_generator`). Internal `arity`-th roots of unity used by the fold +//! operations come from `TwoAdicSubgroup::::new(log_arity).generator()`. Routing every +//! two-adic root and every multiplicative coset shift through these encapsulation types keeps +//! `F::two_adic_generator` and `F::GENERATOR` confined to their single canonical sites. -pub mod fold; -pub mod proof; -pub mod prover; -pub mod verifier; +pub(crate) mod fold; +pub(crate) mod proof; +pub(crate) mod prover; +pub(crate) mod verifier; use fold::FriFold; +use p3_field::TwoAdicField; + +use crate::domain::LiftedDomain; /// FRI protocol parameters. /// /// Controls the trade-off between proof size, prover time, and verifier time. /// -/// Higher `log_blowup` increases soundness per query (fewer queries needed) but introduces -/// larger Merkle trees — increasing both proof size (longer authentication paths) and prover time -/// (LDE over a larger domain). Higher arity reduces the number of FRI rounds (fewer Merkle -/// tree commitments) but increases per-query proof size (each opening reveals `arity` -/// siblings). `log_final_degree` reduces the number of rounds and therefore the number of -/// Merkle commitments; if too large, the final polynomial's coefficients dominate the proof -/// size. +/// Higher arity reduces the number of FRI rounds (fewer Merkle tree commitments) but increases +/// per-query proof size (each opening reveals `arity` siblings). `log_final_degree` reduces the +/// number of rounds and therefore the number of Merkle commitments; if too large, the final +/// polynomial's coefficients dominate the proof size. +/// +/// The LDE blowup factor is **not** stored here — it is a structural property of the codeword +/// being tested and is read from the [`LiftedDomain`](crate::domain::LiftedDomain) passed to +/// [`num_rounds`](Self::num_rounds), [`final_poly_degree`](Self::final_poly_degree), +/// [`FriPolys::new`](prover::FriPolys::new), and +/// [`FriOracle::new`](verifier::FriOracle::new). #[derive(Clone, Copy, Debug)] -pub struct FriParams { - /// Log₂ of the blowup factor (LDE domain size / polynomial degree). - /// - /// Higher values increase soundness but also proof size and prover time. - /// Typical values: 2-4 (blowup factors of 4-16). - pub(crate) log_blowup: u8, - +pub(crate) struct FriParams { /// The FRI folding strategy. /// /// Determines the folding arity (2, 4, or 8). @@ -52,45 +66,52 @@ pub struct FriParams { } impl FriParams { - /// Compute the number of folding rounds for a given initial evaluation domain size. + /// Compute the number of folding rounds for an LDE codeword evaluated on `domain`. /// /// Each round reduces the domain by `2^log_folding_factor`. We fold until the domain /// size reaches `2^(log_final_degree + log_blowup)`, at which point the polynomial /// degree is at most `2^log_final_degree`. /// /// Uses `div_ceil` to round up, ensuring we always reach the target degree even if - /// the domain size doesn't divide evenly by the folding factor. + /// the domain size doesn't divide evenly by the folding factor. `PcsParams::new` + /// rejects parameter sets whose target final domain is too small to be reachable by + /// fixed-arity folding for all valid domains. #[inline] - pub fn num_rounds(&self, log_domain_size: u8) -> usize { - // Final domain size = final_degree × blowup = 2^(log_final_degree + log_blowup). - // Safety: PcsParams::new() validates this sum does not exceed MAX_LOG_DOMAIN_SIZE. - debug_assert!( - (self.log_final_degree as u16 + self.log_blowup as u16) - <= crate::pcs::params::MAX_LOG_DOMAIN_SIZE as u16, - "log_final_degree + log_blowup overflows; construct FriParams via PcsParams::new()", - ); - let log_max_final_size = self.log_final_degree + self.log_blowup; - // Number of times we need to divide by 2^log_folding_factor - log_domain_size - .saturating_sub(log_max_final_size) - .div_ceil(self.fold.log_arity()) as usize + pub fn num_rounds(&self, domain: &LiftedDomain) -> usize { + // Maximum domain size needed to accommodate a degree-`2^log_final_degree` polynomial + // after folding `num_rounds` times. + let log_max_final_size = u16::from(self.log_final_degree) + u16::from(domain.log_blowup()); + // Number of domain squarings required to reach a domain of size at most + // `log_max_final_size`. `saturating_sub` covers the degenerate "LDE already at or + // below target" case. + let num_steps = u16::from(domain.log_lde_height()).saturating_sub(log_max_final_size); + // Divide the number of steps by the folding factor to get the number of rounds. + // Round up so the final domain is ≤ `2^log_max_final_size` even when the + // folding factor doesn't divide `num_steps` evenly. The last round may + // overshoot, leaving the actual final degree strictly below the bound — + // see [`final_poly_degree`](Self::final_poly_degree). + num_steps.div_ceil(u16::from(self.fold.log_arity())) as usize } - /// Compute the final polynomial degree after folding. + /// Compute the final polynomial degree after folding the codeword evaluated on `domain`. /// - /// After `num_rounds` folding rounds, the domain shrinks from `2^log_domain_size` - /// to `2^(log_domain_size - num_rounds × log_folding_factor)`. The polynomial - /// degree is then `domain_size / blowup`. + /// After `num_rounds` folding rounds, the LDE domain shrinks from + /// `2^domain.log_lde_height()` to `2^(log_lde_height − num_rounds × log_folding_factor)`. + /// The polynomial degree is then `domain_size / blowup`. /// /// Due to `div_ceil` in `num_rounds`, the actual final degree may be smaller than /// `2^log_final_degree` when the folding doesn't divide evenly. #[inline] - pub fn final_poly_degree(&self, log_domain_size: u8) -> usize { - let num_rounds = self.num_rounds(log_domain_size); - // log of final domain size after folding - let log_final_size = log_domain_size as usize - num_rounds * self.fold.log_arity() as usize; - // degree = domain_size / blowup = 2^(log_final_size - log_blowup) - 1 << log_final_size.saturating_sub(self.log_blowup as usize) + pub fn final_poly_degree(&self, domain: &LiftedDomain) -> usize { + let num_rounds = self.num_rounds(domain); + // log of the final domain size: starting LDE shrunk by num_rounds folds of + // factor 2^log_arity. + let log_final_domain_size = (domain.log_lde_height() as usize) + .saturating_sub(num_rounds * self.fold.log_arity() as usize); + let log_final_poly_degree = + log_final_domain_size.saturating_sub(domain.log_blowup() as usize); + // Poly degree = final domain size / blowup. + 1 << log_final_poly_degree } } diff --git a/stark/miden-lifted-stark/src/pcs/fri/proof.rs b/stark/miden-lifted-stark/src/pcs/fri/proof.rs index 64c1849d6b..1b3cb2de28 100644 --- a/stark/miden-lifted-stark/src/pcs/fri/proof.rs +++ b/stark/miden-lifted-stark/src/pcs/fri/proof.rs @@ -1,14 +1,14 @@ -//! FRI transcript data structures. +//! FRI structured proof types — parsed view of the FRI sub-transcript. use alloc::vec::Vec; use miden_stark_transcript::{TranscriptError, VerifierChannel}; use p3_field::{ExtensionField, TwoAdicField}; -use crate::pcs::fri::FriParams; +use crate::{domain::LiftedDomain, pcs::fri::FriParams}; -/// Structured transcript view for a single FRI folding round. -pub struct FriRoundTranscript { +/// Structured view of a single FRI folding round. +pub struct FriRoundProof { /// Commitment to the folded evaluation matrix for this round. pub commitment: Commitment, /// Proof-of-work witness sampled before `beta`. @@ -17,35 +17,35 @@ pub struct FriRoundTranscript { pub beta: EF, } -/// Structured transcript view for the full FRI interaction. -pub struct FriTranscript { +/// Structured view of the full FRI sub-proof. +pub struct FriProof { /// Per-round commitments and challenges. - pub rounds: Vec>, + pub rounds: Vec>, /// Coefficients of the final low-degree polynomial in descending degree order /// `[cₙ, ..., c₁, c₀]`, ready for direct Horner evaluation. pub final_poly: Vec, } -impl FriTranscript +impl FriProof where F: TwoAdicField, EF: ExtensionField, Commitment: Clone, { - /// Parse a FRI transcript from a verifier channel. + /// Parse a [`FriProof`] from a verifier channel. /// /// Reads commitments, verifies PoW witnesses, samples challenges, and /// reads the final polynomial. Does not verify low-degree claims; /// that validation happens in `FriOracle::test_low_degree`. - pub fn from_verifier_channel( + pub(in crate::pcs) fn read_from_channel( params: &FriParams, - log_domain_size: u8, + domain: &LiftedDomain, channel: &mut Ch, ) -> Result where Ch: VerifierChannel, { - let num_rounds = params.num_rounds(log_domain_size); + let num_rounds = params.num_rounds(domain); let mut rounds = Vec::with_capacity(num_rounds); for _ in 0..num_rounds { @@ -54,10 +54,10 @@ where let pow_witness = channel.grind(params.folding_pow_bits)?; let beta: EF = channel.sample_algebra_element(); - rounds.push(FriRoundTranscript { commitment, pow_witness, beta }); + rounds.push(FriRoundProof { commitment, pow_witness, beta }); } - let final_degree = params.final_poly_degree(log_domain_size); + let final_degree = params.final_poly_degree(domain); let final_poly = channel.receive_algebra_slice(final_degree)?; Ok(Self { rounds, final_poly }) diff --git a/stark/miden-lifted-stark/src/pcs/fri/prover.rs b/stark/miden-lifted-stark/src/pcs/fri/prover.rs index 76c5bc3b31..08cfbfa917 100644 --- a/stark/miden-lifted-stark/src/pcs/fri/prover.rs +++ b/stark/miden-lifted-stark/src/pcs/fri/prover.rs @@ -9,7 +9,8 @@ use p3_util::reverse_slice_index_bits; use tracing::{debug_span, info_span}; use crate::{ - lmcs::{Lmcs, LmcsTree, tree_indices::TreeIndices, utils::log2_strict_u8}, + domain::{Coset, LiftedDomain}, + lmcs::{Lmcs, LmcsTree, tree_indices::TreeIndices}, pcs::fri::FriParams, }; @@ -89,19 +90,26 @@ where /// evaluations and sampling a random challenge `beta` — until the degree is small enough to /// send the polynomial directly. The query phase then spot-checks that each fold was /// performed correctly. - pub fn new(params: &FriParams, lmcs: &L, evals: Vec, channel: &mut Ch) -> Self + pub fn new( + params: &FriParams, + lmcs: &L, + domain: &LiftedDomain, + evals: Vec, + channel: &mut Ch, + ) -> Self where Ch: ProverChannel, { let log_arity = params.fold.log_arity(); let arity = params.fold.arity(); + let log_blowup = domain.log_blowup(); + let mut subgroup = *domain.lde_coset().subgroup(); let mut folded_trees = Vec::new(); - let mut domain_size = evals.len(); - let log_domain_size = log2_strict_u8(domain_size); - let final_poly_degree = params.final_poly_degree(log_domain_size); - let final_domain_size = final_poly_degree << params.log_blowup; + debug_assert_eq!(evals.len(), subgroup.size(), "evals length must match subgroup size"); + let final_poly_degree = params.final_poly_degree(domain); + let final_domain_size = final_poly_degree << log_blowup; // ───────────────────────────────────────────────────────────────────────── // Precompute s_inv for all cosets @@ -122,20 +130,28 @@ where // and g has order 2^log_domain_size. // // We generate sequential powers of g_inv and bit-reverse to get s_inv values - // in the correct order for each row. - let g_inv = F::two_adic_generator(log_domain_size as usize).inverse(); - let mut s_invs: Vec = g_inv.powers().take(domain_size >> log_arity as usize).collect(); + // in the correct order for each row. Per-round we shrink `s_invs` by selecting + // every `arity`-th element and raising to the `arity`-th power, which is + // equivalent to using the next subgroup's `generator_inverse` without paying + // for an inversion. + let mut s_invs: Vec = subgroup + .generator_inverse() + .powers() + .take(subgroup.size() >> log_arity as usize) + .collect(); reverse_slice_index_bits(&mut s_invs); let mut folded_evals = evals; - while domain_size > final_domain_size { + while subgroup.size() > final_domain_size { let round = folded_trees.len(); + let domain_size = subgroup.size(); + let folded_subgroup = subgroup.shrink(log_arity); + // ───────────────────────────────────────────────────────────────────── // Reshape into matrix and wrap with FlatMatrixView for commitment // ───────────────────────────────────────────────────────────────────── - // domain_size evaluations → matrix with folded_domain_size rows × arity columns. + // domain_size evaluations → matrix with folded_subgroup.size() rows × arity columns. // FlatMatrixView presents the EF matrix as F matrix without copying. - let folded_domain_size = domain_size >> log_arity as usize; let matrix = RowMajorMatrix::new(folded_evals, arity); let flat_view = FlatMatrixView::new(matrix); // FRI round commitments use `build_tree` (unaligned) rather than @@ -172,7 +188,7 @@ where folded_trees.push(tree); // Output of folding becomes the input domain for the next round. - domain_size = folded_domain_size; + subgroup = folded_subgroup; // ───────────────────────────────────────────────────────────────────── // Update s⁻¹ for next round @@ -187,7 +203,7 @@ where // = s_inv[k·arity]^arity // // So we select every `arity`-th element and raise to power `arity`. - let next_folded_size = domain_size >> log_arity as usize; + let next_folded_size = subgroup.size() >> log_arity as usize; s_invs = (0..next_folded_size) .into_par_iter() .map(|k| s_invs[k * arity].exp_power_of_2(log_arity as usize)) diff --git a/stark/miden-lifted-stark/src/pcs/fri/tests.rs b/stark/miden-lifted-stark/src/pcs/fri/tests.rs index a6ea6a173c..ad6bdbf178 100644 --- a/stark/miden-lifted-stark/src/pcs/fri/tests.rs +++ b/stark/miden-lifted-stark/src/pcs/fri/tests.rs @@ -2,20 +2,24 @@ use alloc::{collections::BTreeMap, vec, vec::Vec}; +use miden_lifted_air::log2_strict_u8; use miden_stark_transcript::VerifierTranscript; use p3_challenger::CanObserve; use p3_dft::{Radix2DFTSmallBatch, TwoAdicSubgroupDft}; use p3_field::PrimeCharacteristicRing; use p3_matrix::{Matrix, bitrev::BitReversibleMatrix, dense::RowMajorMatrix}; -use proof::FriTranscript; +use p3_util::{log2_strict_usize, reverse_bits_len}; +use proof::FriProof; use prover::FriPolys; use rand::{RngExt, SeedableRng, distr::StandardUniform, prelude::SmallRng}; use verifier::{FriError, FriOracle}; use super::*; use crate::{ - lmcs::{tree_indices::TreeIndices, utils::log2_strict_u8}, + domain::LiftedDomain, + lmcs::tree_indices::TreeIndices, testing::{ + canonical_domain, configs::goldilocks_poseidon2::{ Challenger, Felt, Lmcs as BaseLmcs, QuadFelt, TestDigest, TestTranscriptData, prover_channel, random_lde_matrix, test_challenger, test_lmcs, verifier_channel, @@ -91,11 +95,11 @@ fn build_initial_evals( evals: &[QuadFelt], tree_indices: &TreeIndices, ) -> BTreeMap { - let log_n = log2_strict_u8(evals.len()) as usize; + let log_n = log2_strict_usize(evals.len()); tree_indices .iter() .map(|&domain_idx| { - let bitrev_idx = p3_util::reverse_bits_len(domain_idx, log_n); + let bitrev_idx = reverse_bits_len(domain_idx, log_n); (domain_idx, evals[bitrev_idx]) }) .collect() @@ -104,11 +108,13 @@ fn build_initial_evals( fn prove_queries( params: &FriParams, lmcs: &BaseLmcs, + domain: &LiftedDomain, evals: Vec, tree_indices: TreeIndices, ) -> (TestDigest, TestTranscriptData) { let mut prover_channel = prover_channel(); - let fri_polys = FriPolys::::new(params, lmcs, evals, &mut prover_channel); + let fri_polys = + FriPolys::::new(params, lmcs, domain, evals, &mut prover_channel); fri_polys.prove_queries(params, tree_indices, &mut prover_channel); prover_channel.finalize() } @@ -117,7 +123,7 @@ fn verify_queries( params: &FriParams, lmcs: &BaseLmcs, transcript: &TestTranscriptData, - lde_size: usize, + domain: &LiftedDomain, initial_evals: &BTreeMap, tree_indices: TreeIndices, challenger: Option, @@ -126,8 +132,7 @@ fn verify_queries( Some(challenger) => VerifierTranscript::from_data(challenger, transcript), None => verifier_channel(transcript), }; - let log_domain_size = log2_strict_u8(lde_size); - let oracle = FriOracle::new(params, log_domain_size, &mut channel)?; + let oracle = FriOracle::new(params, domain, &mut channel)?; oracle.test_low_degree(lmcs, params, initial_evals.clone(), tree_indices, &mut channel)?; let digest = channel.finalize().expect("transcript should finalize cleanly"); Ok(digest) @@ -138,7 +143,6 @@ fn run_roundtrip_case(case: &FriRoundtripCase, seed: u64) -> Result<(), FriError let lmcs = test_lmcs(); let params = FriParams { - log_blowup: case.log_blowup, fold: case.fold, log_final_degree: case.log_final_degree, folding_pow_bits: case.folding_pow_bits, @@ -154,25 +158,23 @@ fn run_roundtrip_case(case: &FriRoundtripCase, seed: u64) -> Result<(), FriError .values; let lde_size = evals.len(); let log_domain_size = log2_strict_u8(lde_size); + let domain = canonical_domain::(case.log_poly_degree, case.log_blowup); // Sample domain indices (no bit-reversal needed — tree is in domain order) let tree_indices = TreeIndices::new(sample_indices(&mut rng, lde_size, case.num_queries), log_domain_size) .expect("indices are in range"); let initial_evals = build_initial_evals(&evals, &tree_indices); - let (prover_digest, transcript) = prove_queries(¶ms, &lmcs, evals, tree_indices.clone()); + let (prover_digest, transcript) = + prove_queries(¶ms, &lmcs, &domain, evals, tree_indices.clone()); let verifier_digest = - verify_queries(¶ms, &lmcs, &transcript, lde_size, &initial_evals, tree_indices, None)?; + verify_queries(¶ms, &lmcs, &transcript, &domain, &initial_evals, tree_indices, None)?; assert_eq!(prover_digest, verifier_digest); - // Re-parse FriTranscript (commit phase only) from a fresh channel. + // Re-parse FriProof (commit phase only) from a fresh channel. let mut reparse_channel = verifier_channel(&transcript); - FriTranscript::::from_verifier_channel( - ¶ms, - log_domain_size, - &mut reparse_channel, - ) - .expect("FriTranscript re-parse should succeed"); + FriProof::::read_from_channel(¶ms, &domain, &mut reparse_channel) + .expect("FriProof re-parse should succeed"); Ok(()) } @@ -198,7 +200,6 @@ fn test_fri_verify_wrong_eval() { let log_final_degree: u8 = 2; let params = FriParams { - log_blowup, fold: FRI_FOLD_ARITY_2, log_final_degree, folding_pow_bits: 1, @@ -208,6 +209,7 @@ fn test_fri_verify_wrong_eval() { random_lde_matrix::(&mut rng, log_poly_degree, log_blowup, 1, Felt::ONE).values; let lde_size = evals.len(); let log_domain_size = log2_strict_u8(lde_size); + let domain = canonical_domain::(log_poly_degree, log_blowup); let tree_indices = TreeIndices::new(sample_indices(&mut rng, lde_size, 2), log_domain_size) .expect("indices are in range"); let mut initial_evals = build_initial_evals(&evals, &tree_indices); @@ -221,9 +223,10 @@ fn test_fri_verify_wrong_eval() { } initial_evals.insert(first_idx, wrong_eval); - let (_prover_digest, transcript) = prove_queries(¶ms, &lmcs, evals, tree_indices.clone()); + let (_prover_digest, transcript) = + prove_queries(¶ms, &lmcs, &domain, evals, tree_indices.clone()); let result = - verify_queries(¶ms, &lmcs, &transcript, lde_size, &initial_evals, tree_indices, None); + verify_queries(¶ms, &lmcs, &transcript, &domain, &initial_evals, tree_indices, None); assert!( matches!(result, Err(FriError::EvaluationMismatch { .. })), @@ -246,7 +249,6 @@ fn test_fri_verify_wrong_beta() { let log_final_degree: u8 = 2; let params = FriParams { - log_blowup, fold: FRI_FOLD_ARITY_2, log_final_degree, folding_pow_bits: 0, // No grinding to simplify test @@ -259,16 +261,19 @@ fn test_fri_verify_wrong_beta() { random_lde_matrix::(&mut rng, log_poly_degree, log_blowup, 1, Felt::ONE).values; let lde_size = evals1.len(); let log_domain_size = log2_strict_u8(lde_size); + let domain = canonical_domain::(log_poly_degree, log_blowup); // Prover 1: generate FRI transcript (grinds per-round internally). let tree_indices = TreeIndices::new(sample_indices(&mut rng, lde_size, 2), log_domain_size) .expect("indices are in range"); let initial_evals = build_initial_evals(&evals1, &tree_indices); - let (_prover_digest, transcript) = prove_queries(¶ms, &lmcs, evals1, tree_indices.clone()); + let (_prover_digest, transcript) = + prove_queries(¶ms, &lmcs, &domain, evals1, tree_indices.clone()); // Prover 2: generate different transcript (different commitments = different betas). let mut prover2_channel = prover_channel(); - let _ = FriPolys::::new(¶ms, &lmcs, evals2, &mut prover2_channel); + let _ = + FriPolys::::new(¶ms, &lmcs, &domain, evals2, &mut prover2_channel); let (_, prover2_transcript) = prover2_channel.finalize(); let other_commitment = prover2_transcript .commitments() @@ -283,7 +288,7 @@ fn test_fri_verify_wrong_beta() { ¶ms, &lmcs, &transcript, - lde_size, + &domain, &initial_evals, tree_indices, Some(wrong_challenger), @@ -310,7 +315,6 @@ fn test_fri_zero_rounds_final_poly_only() { let log_final_degree: u8 = log_poly_degree; // final degree >= domain size => zero rounds let params = FriParams { - log_blowup, fold: FRI_FOLD_ARITY_2, log_final_degree, folding_pow_bits: 0, @@ -320,14 +324,16 @@ fn test_fri_zero_rounds_final_poly_only() { random_lde_matrix::(&mut rng, log_poly_degree, log_blowup, 1, Felt::ONE).values; let lde_size = evals.len(); let log_domain_size = log2_strict_u8(lde_size); + let domain = canonical_domain::(log_poly_degree, log_blowup); let tree_indices = TreeIndices::new(sample_indices(&mut rng, lde_size, 2), log_domain_size) .expect("indices are in range"); let initial_evals = build_initial_evals(&evals, &tree_indices); - let (prover_digest, transcript) = prove_queries(¶ms, &lmcs, evals, tree_indices.clone()); + let (prover_digest, transcript) = + prove_queries(¶ms, &lmcs, &domain, evals, tree_indices.clone()); let mut channel = verifier_channel(&transcript); - let fri_transcript: FriTranscript = - FriTranscript::from_verifier_channel(¶ms, log_domain_size, &mut channel) + let fri_transcript: FriProof = + FriProof::read_from_channel(¶ms, &domain, &mut channel) .expect("transcript parsing should succeed"); assert!(fri_transcript.rounds.is_empty(), "expected zero folding rounds"); @@ -338,7 +344,7 @@ fn test_fri_zero_rounds_final_poly_only() { ); let verifier_digest = - verify_queries(¶ms, &lmcs, &transcript, lde_size, &initial_evals, tree_indices, None) + verify_queries(¶ms, &lmcs, &transcript, &domain, &initial_evals, tree_indices, None) .expect("zero-round FRI should verify"); assert_eq!(prover_digest, verifier_digest); } @@ -355,7 +361,6 @@ fn test_final_polynomial_correctness() { let log_final_degree: u8 = 3; let params = FriParams { - log_blowup, fold: FRI_FOLD_ARITY_2, log_final_degree, folding_pow_bits: 0, // No grinding for this test @@ -379,15 +384,16 @@ fn test_final_polynomial_correctness() { let lde = dft.coset_lde_algebra_batch(evals_h, log_blowup as usize, Felt::ONE); let evals = lde.bit_reverse_rows().to_row_major_matrix().values; - let log_domain_size = log_poly_degree + log_blowup; + let domain = canonical_domain::(log_poly_degree, log_blowup); let mut prover_channel = prover_channel(); - let _fri_polys = FriPolys::::new(¶ms, &lmcs, evals, &mut prover_channel); + let _fri_polys = + FriPolys::::new(¶ms, &lmcs, &domain, evals, &mut prover_channel); let (_, transcript) = prover_channel.finalize(); let mut v_channel = verifier_channel(&transcript); - let fri_transcript: FriTranscript = - FriTranscript::from_verifier_channel(¶ms, log_domain_size, &mut v_channel) + let fri_transcript: FriProof = + FriProof::read_from_channel(¶ms, &domain, &mut v_channel) .expect("transcript parsing should succeed"); assert_eq!( diff --git a/stark/miden-lifted-stark/src/pcs/fri/verifier.rs b/stark/miden-lifted-stark/src/pcs/fri/verifier.rs index fadf0346b3..23a3046ba0 100644 --- a/stark/miden-lifted-stark/src/pcs/fri/verifier.rs +++ b/stark/miden-lifted-stark/src/pcs/fri/verifier.rs @@ -24,8 +24,10 @@ use p3_util::reverse_bits_len; use thiserror::Error; use crate::{ + domain::{Coset, LiftedDomain, TwoAdicSubgroup}, lmcs::{Lmcs, LmcsError, tree_indices::TreeIndices}, - pcs::{fri::FriParams, utils::horner}, + pcs::fri::FriParams, + util::horner::horner, }; /// FRI low-degree test oracle. @@ -38,14 +40,14 @@ use crate::{ /// /// Uses a single base-field LMCS. Opened base field values are reconstructed /// to extension field for folding verification. -pub struct FriOracle +pub(in crate::pcs) struct FriOracle where F: TwoAdicField, EF: ExtensionField, L: Lmcs, { - /// Log₂ of the initial domain size. - log_domain_size: u8, + /// Initial round's domain (the LDE evaluation subgroup). + subgroup: TwoAdicSubgroup, /// Per-round commitment and folding challenge. rounds: Vec>, /// Coefficients of the final low-degree polynomial in descending degree order @@ -67,13 +69,14 @@ where /// Create oracle by reading from a verifier channel. pub fn new( params: &FriParams, - log_domain_size: u8, + domain: &LiftedDomain, channel: &mut Ch, ) -> Result where Ch: VerifierChannel, { - let num_rounds = params.num_rounds(log_domain_size); + let subgroup = *domain.lde_coset().subgroup(); + let num_rounds = params.num_rounds(domain); let mut rounds = Vec::with_capacity(num_rounds); for _ in 0..num_rounds { @@ -85,10 +88,10 @@ where rounds.push(FriRoundOracle { commitment, beta }); } - let final_degree = params.final_poly_degree(log_domain_size); + let final_degree = params.final_poly_degree(domain); let final_poly = channel.receive_algebra_slice(final_degree)?; - Ok(Self { log_domain_size, rounds, final_poly }) + Ok(Self { subgroup, rounds, final_poly }) } /// Test low-degree proximity by reading openings from a verifier channel. @@ -119,15 +122,21 @@ where let base_width = arity * EF::DIMENSION; let widths = [base_width]; - let mut log_domain_size = self.log_domain_size; - let mut g_inv = F::two_adic_generator(log_domain_size as usize).inverse(); + // Per-round state: the subgroup carries the working domain size as a typed + // value, and `g_inv` is raised to the power `2^log_arity` in lockstep + // (matching `subgroup.shrink(log_arity)`) to avoid re-inverting each round. + let mut subgroup = self.subgroup; + let mut g_inv = subgroup.generator_inverse(); for (round_idx, round) in self.rounds.iter().enumerate() { - let log_folded_domain_size = log_domain_size - log_arity; + let folded_subgroup = subgroup.shrink(log_arity); + let folded_size = folded_subgroup.size(); + let log_folded_domain_size = folded_subgroup.log_size(); // Shrink indices by log_arity to get this round's row indices. tree_indices.shrink_depth(log_arity); + // FRI round trees are full-height at the (already shrunk) round depth. let opened_rows = lmcs .open_batch(&round.commitment, &widths, &tree_indices, channel) .map_err(|source| FriError::LmcsError { source, round: round_idx })?; @@ -146,7 +155,6 @@ where // // 3. The prover cannot provide different row data for the same row_idx. LMCS opens each // row exactly once via `opened_rows[&row_idx]`. - let folded_size = 1usize << log_folded_domain_size; evals = evals .into_iter() .map(|(idx, eval)| { @@ -182,19 +190,19 @@ where }) .collect::>()?; - log_domain_size = log_folded_domain_size; + subgroup = folded_subgroup; g_inv = g_inv.exp_power_of_2(log_arity as usize); } // After all folding rounds, the polynomial has been reduced to degree < final_degree. // The prover sent this final polynomial's coefficients; we evaluate it at each - // folded query point on the final domain and check consistency with the folded - // values. This closes the FRI proximity argument: if the original codeword was - // far from low-degree, at least one query fails with high probability. + // folded query point on the final-round subgroup and check consistency with the + // folded values. This closes the FRI proximity argument: if the original codeword + // was far from low-degree, at least one query fails with high probability. // // `final_poly` is in descending degree order [cₙ, ..., c₁, c₀], which is // the native order for Horner evaluation. - let generator = F::two_adic_generator(log_domain_size as usize); + let generator = subgroup.generator(); for (idx, eval) in evals { // Domain index directly gives the exponent (no bit-reversal needed). let x = generator.exp_u64(idx as u64); diff --git a/stark/miden-lifted-stark/src/pcs/mod.rs b/stark/miden-lifted-stark/src/pcs/mod.rs index 006b6187e5..33cb59a8be 100644 --- a/stark/miden-lifted-stark/src/pcs/mod.rs +++ b/stark/miden-lifted-stark/src/pcs/mod.rs @@ -7,11 +7,11 @@ //! //! This module provides: //! -//! - **[`deep`]**: DEEP (Domain Extension for Eliminating Pretenders) quotient construction for -//! batching polynomial evaluation claims into a single low-degree polynomial. +//! - **`deep`** (internal): DEEP (Domain Extension for Eliminating Pretenders) quotient +//! construction for batching polynomial evaluation claims into a single low-degree polynomial. //! -//! - **[`fri`]**: FRI (Fast Reed-Solomon IOP) protocol for low-degree testing, with configurable -//! folding arities and final polynomial degree. +//! - **`fri`** (internal): FRI (Fast Reed-Solomon IOP) protocol for low-degree testing, with +//! configurable folding arities and final polynomial degree. //! //! - **PCS API (module root)**: complete PCS implementation combining DEEP quotient and FRI via //! `prover::open_with_channel` and `verifier::verify`, plus `PcsParams`. @@ -25,18 +25,25 @@ //! caller's (or AIR's) responsibility. (FRI openings still ignore the padded tail because //! FRI expects a fixed single-column width.) -/// DEEP quotient construction for batched polynomial evaluation. -pub mod deep; - -/// FRI protocol for low-degree testing. -pub mod fri; - -pub mod params; -pub mod proof; -pub mod prover; -pub mod verifier; - -pub mod utils; +pub(crate) mod deep; +pub(crate) mod fri; +pub(crate) mod params; +pub(crate) mod proof; +pub(crate) mod prover; +pub(crate) mod verifier; #[cfg(test)] mod tests; + +// Structured proof types and errors needed for inspection / error pattern matching. +pub use deep::{ + proof::{DeepProof, OpenedValues}, + verifier::DeepError, +}; +pub use fri::{ + proof::{FriProof, FriRoundProof}, + verifier::FriError, +}; +pub use params::{PcsParams, PcsParamsError}; +pub use proof::PcsProof; +pub use verifier::PcsError; diff --git a/stark/miden-lifted-stark/src/pcs/params.rs b/stark/miden-lifted-stark/src/pcs/params.rs index 361f8f9a85..03e0c6199d 100644 --- a/stark/miden-lifted-stark/src/pcs/params.rs +++ b/stark/miden-lifted-stark/src/pcs/params.rs @@ -7,9 +7,6 @@ use crate::pcs::{ fri::{FriParams, fold::FriFold}, }; -/// Maximum log₂ of any domain size. Domains cannot exceed 2⁶⁴ elements. -pub const MAX_LOG_DOMAIN_SIZE: u8 = 64; - /// Errors from invalid PCS parameter combinations. #[derive(Clone, Debug, Error)] pub enum PcsParamsError { @@ -17,10 +14,17 @@ pub enum PcsParamsError { InvalidFoldingArity(u8), #[error("log_blowup must be > 0")] ZeroBlowup, - #[error("log_final_degree ({log_final_degree}) + log_blowup ({log_blowup}) exceeds 64")] - FinalDomainTooLarge { log_final_degree: u8, log_blowup: u8 }, #[error("num_queries must be > 0")] ZeroQueries, + #[error( + "log_final_degree + log_blowup must be at least log_folding_arity - 1 \ + (got {log_final_degree} + {log_blowup} < {min_target})" + )] + FinalDegreeUnreachable { + log_final_degree: u8, + log_blowup: u8, + min_target: u8, + }, } /// Complete PCS parameters combining DEEP and FRI parameters. @@ -29,6 +33,11 @@ pub enum PcsParamsError { /// Internal sub-parameters are accessible to crate-internal code only. #[derive(Clone, Copy, Debug)] pub struct PcsParams { + /// Log₂ of the LDE blowup factor (LDE domain size / trace size). + /// + /// Higher values increase soundness per query but also proof size and prover time + /// (LDE over a larger domain). Typical values: 2-4 (blowup factors of 4-16). + pub(crate) log_blowup: u8, /// DEEP quotient parameters. pub(crate) deep: DeepParams, /// FRI protocol parameters. @@ -46,8 +55,13 @@ impl PcsParams { /// /// - [`PcsParamsError::InvalidFoldingArity`] if `log_folding_arity` is not 1, 2, or 3. /// - [`PcsParamsError::ZeroBlowup`] if `log_blowup` is 0. - /// - [`PcsParamsError::FinalDomainTooLarge`] if `log_final_degree + log_blowup > 64`. /// - [`PcsParamsError::ZeroQueries`] if `num_queries` is 0. + /// - [`PcsParamsError::FinalDegreeUnreachable`] if the final target domain is too small to be + /// reachable by fixed-arity FRI folding for all valid domains. + /// + /// Field-relative bound checking (`log_final_degree + log_blowup ≤ F::TWO_ADICITY`) + /// is deferred to `TwoAdicSubgroup::new` at the point a + /// concrete domain is constructed; `PcsParams` itself is field-agnostic. pub fn new( log_blowup: u8, log_folding_arity: u8, @@ -62,20 +76,23 @@ impl PcsParams { if log_blowup == 0 { return Err(PcsParamsError::ZeroBlowup); } - if log_final_degree as u16 + log_blowup as u16 > MAX_LOG_DOMAIN_SIZE as u16 { - return Err(PcsParamsError::FinalDomainTooLarge { log_final_degree, log_blowup }); - } if num_queries == 0 { return Err(PcsParamsError::ZeroQueries); } + + let min_target = fold.log_arity() - 1; + if log_final_degree.saturating_add(log_blowup) < min_target { + return Err(PcsParamsError::FinalDegreeUnreachable { + log_final_degree, + log_blowup, + min_target, + }); + } + Ok(Self { + log_blowup, deep: DeepParams { deep_pow_bits }, - fri: FriParams { - log_blowup, - fold, - log_final_degree, - folding_pow_bits, - }, + fri: FriParams { fold, log_final_degree, folding_pow_bits }, num_queries, query_pow_bits, }) @@ -84,7 +101,7 @@ impl PcsParams { /// Log₂ of the blowup factor. #[inline] pub fn log_blowup(&self) -> u8 { - self.fri.log_blowup + self.log_blowup } /// Number of query repetitions. @@ -123,3 +140,30 @@ impl PcsParams { self.fri.log_final_degree } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_final_target_too_small_for_fixed_arity_folding() { + let err = PcsParams::new(1, 3, 0, 0, 0, 1, 0).unwrap_err(); + assert!(matches!( + err, + PcsParamsError::FinalDegreeUnreachable { + log_final_degree: 0, + log_blowup: 1, + min_target: 2, + } + )); + } + + #[test] + fn accepts_minimum_universally_reachable_final_target() { + let params = PcsParams::new(1, 3, 1, 0, 0, 1, 0) + .expect("final target log size equals log_folding_arity - 1"); + assert_eq!(params.log_blowup(), 1); + assert_eq!(params.log_folding_arity(), 3); + assert_eq!(params.log_final_degree(), 1); + } +} diff --git a/stark/miden-lifted-stark/src/pcs/proof.rs b/stark/miden-lifted-stark/src/pcs/proof.rs index 0c1fef4029..3e8e8eafe8 100644 --- a/stark/miden-lifted-stark/src/pcs/proof.rs +++ b/stark/miden-lifted-stark/src/pcs/proof.rs @@ -1,4 +1,4 @@ -//! PCS transcript data structures. +//! PCS structured proof types — parsed view of the PCS sub-transcript. use alloc::vec::Vec; @@ -6,74 +6,83 @@ use miden_stark_transcript::{TranscriptError, VerifierChannel}; use p3_field::{ExtensionField, Field, TwoAdicField}; use crate::{ - lmcs::{Lmcs, tree_indices::TreeIndices}, - pcs::{deep::proof::DeepTranscript, fri::proof::FriTranscript, params::PcsParams}, + domain::LiftedDomain, + lmcs::{Lmcs, LmcsError, tree_indices::TreeIndices}, + pcs::{ + deep::proof::DeepProof, fri::proof::FriProof, params::PcsParams, verifier::CommitmentGroup, + }, }; -/// Structured transcript view for the full PCS interaction. +/// Structured view of the full PCS sub-proof. /// /// Captures observed transcript data plus parsed LMCS batch openings for inspection. -pub struct PcsTranscript +pub struct PcsProof where L: Lmcs, L::F: Field, EF: ExtensionField, { - /// DEEP transcript data (evals, PoW witness, challenges). - pub deep_transcript: DeepTranscript, - /// FRI transcript data (round commitments/challenges, final polynomial). - pub fri_transcript: FriTranscript, + /// DEEP sub-proof (evals, PoW witness, challenges). + pub deep_proof: DeepProof, + /// FRI sub-proof (round commitments/challenges, final polynomial). + pub fri_proof: FriProof, /// Proof-of-work witness for query sampling. pub query_pow_witness: L::F, - /// Query indices in sampling order (domain indices, may contain duplicates). + /// Query indices in sampling order (original domain indices, may contain duplicates). pub query_indices: Vec, - /// Batch witness per trace tree (leaf data + Merkle witness). + /// Batch witness per committed group (leaf data + Merkle witness). + /// + /// The order matches the commitment groups passed to `read_from_channel`; each + /// witness is keyed by leaf index after folding `query_indices` to that + /// group's depth. pub deep_witnesses: Vec, /// Batch witness per FRI round (leaf data + Merkle witness). pub fri_witnesses: Vec, } -impl PcsTranscript +impl PcsProof where L: Lmcs, L::F: TwoAdicField, EF: ExtensionField, { - /// Parse a PCS transcript from a verifier channel without validation. + /// Parse a [`PcsProof`] from a verifier channel without validation. /// - /// Composes [`DeepTranscript`], [`FriTranscript`], and per-query LMCS batch proofs. + /// Composes [`DeepProof`], [`FriProof`], and per-query LMCS batch proofs. /// Does not verify any claims; validation happens in - /// [`verify_multi`](crate::verify_multi). - /// Commitment widths must match the committed rows (including any alignment padding), - /// and all commitments are expected to be lifted to the same `log_lde_height`. - /// - /// `log_lde_height` is the log₂ of the LDE evaluation domain height (i.e. the height of - /// the committed LDE matrices). When a trace degree is known, it is typically - /// `log_trace_height + params.fri.log_blowup` (plus any extension used by the caller). - pub fn from_verifier_channel( + /// [`verify`](crate::VerifierInstance::verify). + /// Commitment widths must match the committed rows (including any alignment padding). + /// Each commitment carries its tree depth; groups below the query domain are parsed + /// using folded leaf indices. + pub(crate) fn read_from_channel( params: &PcsParams, lmcs: &L, - commitments: &[(L::Commitment, Vec)], - log_lde_height: u8, + commitments: &[CommitmentGroup], + domain: &LiftedDomain, eval_points: [EF; N], channel: &mut Ch, ) -> Result where + L::F: TwoAdicField, Ch: VerifierChannel, { + let log_lde_height = domain.log_lde_height(); if commitments.is_empty() { return Err(TranscriptError::NoMoreFields); } - let deep_transcript = DeepTranscript::from_verifier_channel::( - ¶ms.deep, - commitments, + let deep_commitments: Vec<_> = commitments + .iter() + .map(|group| (group.root.clone(), group.widths.clone())) + .collect(); + let deep_proof = DeepProof::read_from_channel::( + params.deep, + &deep_commitments, eval_points.len(), channel, )?; - let fri_transcript = - FriTranscript::from_verifier_channel(¶ms.fri, log_lde_height, channel)?; + let fri_proof = FriProof::read_from_channel(¶ms.fri, domain, channel)?; let query_pow_witness = channel.grind(params.query_pow_bits())?; @@ -86,9 +95,15 @@ where let deep_witnesses: Vec<_> = commitments .iter() - .map(|(_commitment, widths)| { - lmcs.read_batch_proof(widths, &tree_indices, channel).map_err(|e| match e { - crate::lmcs::LmcsError::TranscriptError(te) => te, + .map(|group| { + lmcs.read_lifted_batch_proof( + &group.widths, + &tree_indices, + group.log_height, + channel, + ) + .map_err(|e| match e { + LmcsError::TranscriptError(te) => te, _ => TranscriptError::NoMoreFields, }) }) @@ -96,7 +111,7 @@ where let log_arity = params.fri.fold.log_arity(); let arity = params.fri.fold.arity(); - let num_rounds = params.fri.num_rounds(log_lde_height); + let num_rounds = params.fri.num_rounds(domain); let mut fri_witnesses = Vec::with_capacity(num_rounds); let mut round_indices = tree_indices; @@ -107,7 +122,7 @@ where let round_widths = [base_width]; let batch = lmcs.read_batch_proof(&round_widths, &round_indices, channel).map_err( |e| match e { - crate::lmcs::LmcsError::TranscriptError(te) => te, + LmcsError::TranscriptError(te) => te, _ => TranscriptError::NoMoreFields, }, )?; @@ -115,8 +130,8 @@ where } Ok(Self { - deep_transcript, - fri_transcript, + deep_proof, + fri_proof, query_pow_witness, query_indices, deep_witnesses, diff --git a/stark/miden-lifted-stark/src/pcs/prover.rs b/stark/miden-lifted-stark/src/pcs/prover.rs index 6faf70de80..7ceb8659c2 100644 --- a/stark/miden-lifted-stark/src/pcs/prover.rs +++ b/stark/miden-lifted-stark/src/pcs/prover.rs @@ -8,6 +8,7 @@ use p3_matrix::Matrix; use tracing::{info_span, instrument}; use crate::{ + domain::LiftedDomain, lmcs::{Lmcs, LmcsTree, tree_indices::TreeIndices}, pcs::{deep::prover::DeepPoly, fri::prover::FriPolys, params::PcsParams}, }; @@ -18,14 +19,14 @@ use crate::{ /// - `eval_points` must lie outside both the trace-domain subgroup `H` and the LDE evaluation coset /// `gK` used by the PCS. If a point lies in either set, denominators `(zⱼ − X)` in the DEEP /// quotient become zero for some domain element, making the quotient undefined. -/// - All trace trees must be built at the same LDE height `2^log_lde_height`. Multiple LDE heights -/// are not supported yet and will panic. +/// - Every trace tree's height must be a power of two `≤ domain.lde_height()`. A tree shorter than +/// the max (e.g. a setup-fixed preprocessed tree) is virtually lifted — the query phase folds the +/// sampled indices down to each tree's own depth. +/// - At least one trace tree must have height `domain.lde_height()`, so the DEEP quotient is +/// constructed over the full max domain. /// -/// `log_lde_height` is the log₂ of the LDE evaluation domain height (i.e. the height of -/// the committed LDE matrices). When a trace degree is known, it is typically -/// `log_trace_height + params.fri.log_blowup` (plus any extension used by the caller). -/// In that common case, the trace subgroup `H` has size `2^(log_lde_height - -/// params.fri.log_blowup)`, while the LDE coset `gK` has size `2^log_lde_height`. +/// `domain` is the max LDE coset the trace trees were committed on; `domain.log_lde_height` +/// equals `log_trace_height + domain.log_blowup()` for the tallest trace. /// /// Alignment is derived from the trace trees to pad DEEP evaluations consistently. /// Trace trees must be built with `build_aligned_tree` to match this padding. @@ -33,7 +34,7 @@ use crate::{ pub fn open_with_channel( params: &PcsParams, lmcs: &L, - log_lde_height: u8, + domain: &LiftedDomain, eval_points: [EF; N], trace_trees: &[&L::Tree], channel: &mut Ch, @@ -46,25 +47,15 @@ pub fn open_with_channel( { const { assert!(N > 0, "at least one evaluation point required") }; - // Determine LDE domain size from the supplied LDE height. - // For now, all trace trees must share this height; mixed LDE heights are not supported yet. - assert!(!trace_trees.is_empty(), "at least one trace tree required"); - let expected_height = 1 << log_lde_height as usize; - assert!( - trace_trees.iter().all(|tree| tree.height() == expected_height), - "mixed LDE heights are not supported yet", - ); + // Trees shorter than the max (e.g. a setup-fixed preprocessed tree) are virtually + // lifted: `prove_lifted_batch` projects sampled query indices down to each tree's + // own depth. `DeepPoly::from_trees` validates the tree-height preconditions below. + let log_lde_height = domain.log_lde_height(); // ───────────────────────────────────────────────────────────────────────── // Construct DEEP quotient (observes evals, grinds, samples alpha and beta) // ───────────────────────────────────────────────────────────────────────── let deep_poly = info_span!("DEEP quotient").in_scope(|| { - DeepPoly::from_trees::( - params.deep, - trace_trees, - eval_points, - params.fri.log_blowup, - channel, - ) + DeepPoly::from_trees::(params.deep, domain, trace_trees, eval_points, channel) }); // ───────────────────────────────────────────────────────────────────────── @@ -72,8 +63,9 @@ pub fn open_with_channel( // ───────────────────────────────────────────────────────────────────────── // The deep_poly contains evaluations on the LDE domain (size 2^log_lde_height). // FRI will prove that this polynomial is low-degree. - let fri_polys = info_span!("FRI commit phase") - .in_scope(|| FriPolys::::new(¶ms.fri, lmcs, deep_poly.deep_evals, channel)); + let fri_polys = info_span!("FRI commit phase").in_scope(|| { + FriPolys::::new(¶ms.fri, lmcs, domain, deep_poly.deep_evals, channel) + }); // ───────────────────────────────────────────────────────────────────────── // Grind for query sampling @@ -98,7 +90,7 @@ pub fn open_with_channel( // Open input trees at all query indices at once (one proof per tree) info_span!("open input trees", n_trees = trace_trees.len()).in_scope(|| { for tree in trace_trees { - tree.prove_batch(&tree_indices, channel); + tree.prove_lifted_batch(&tree_indices, channel); } }); diff --git a/stark/miden-lifted-stark/src/pcs/tests.rs b/stark/miden-lifted-stark/src/pcs/tests.rs index c0f3bc430e..c742fac7de 100644 --- a/stark/miden-lifted-stark/src/pcs/tests.rs +++ b/stark/miden-lifted-stark/src/pcs/tests.rs @@ -2,25 +2,27 @@ use alloc::{vec, vec::Vec}; +use miden_lifted_air::log2_strict_u8; use miden_stark_transcript::{ProverTranscript, VerifierTranscript}; use p3_challenger::CanObserve; -use p3_field::Field; use p3_matrix::{Matrix, bitrev::BitReversibleMatrix, dense::RowMajorMatrix}; use params::PcsParams; -use proof::PcsTranscript; +use proof::PcsProof; use prover::open_with_channel; use rand::{RngExt, SeedableRng, distr::StandardUniform, prelude::SmallRng}; -use verifier::{PcsError, verify_aligned}; +use verifier::{CommitmentGroup, PcsError, verify_aligned}; use super::*; use crate::{ - lmcs::{ - Lmcs, LmcsTree, - utils::{aligned_widths, log2_strict_u8}, - }, - testing::configs::goldilocks_poseidon2::{ - self as gl, Felt, Lmcs as BaseLmcs, QuadFelt, TestTree, random_lde_matrix, test_lmcs, + domain::LiftedDomain, + lmcs::{Lmcs, LmcsTree}, + testing::{ + canonical_domain, + configs::goldilocks_poseidon2::{ + self as gl, Felt, Lmcs as BaseLmcs, QuadFelt, TestTree, random_lde_matrix, test_lmcs, + }, }, + util::align::aligned_widths, }; fn test_params() -> PcsParams { @@ -48,9 +50,19 @@ fn run_pcs_case(params: &PcsParams, trees: Vec, seed: u64) -> Result<( let lde_height = trees[0].leaves().last().map(Matrix::height).unwrap_or(0); let log_lde_height = log2_strict_u8(lde_height); + let log_blowup = params.log_blowup; + let max_domain: LiftedDomain = canonical_domain(log_lde_height - log_blowup, log_blowup); let eval_points: [QuadFelt; 2] = [rng.sample(StandardUniform), rng.sample(StandardUniform)]; let commitments: Vec<_> = trees.iter().map(|t| (t.root(), t.widths())).collect(); + let commitment_groups: Vec<_> = trees + .iter() + .map(|t| CommitmentGroup { + root: t.root(), + widths: t.widths(), + log_height: log2_strict_u8(t.height()), + }) + .collect(); let trace_trees: Vec<&_> = trees.iter().collect(); // Prover: observe all commitments before opening. @@ -63,7 +75,7 @@ fn run_pcs_case(params: &PcsParams, trees: Vec, seed: u64) -> Result<( open_with_channel::( params, &lmcs, - log_lde_height, + &max_domain, eval_points, &trace_trees, &mut prover_channel, @@ -80,8 +92,8 @@ fn run_pcs_case(params: &PcsParams, trees: Vec, seed: u64) -> Result<( let result = verify_aligned::( params, &lmcs, - &commitments, - log_lde_height, + &commitment_groups, + &max_domain, eval_points, &mut verifier_channel, ); @@ -91,11 +103,15 @@ fn run_pcs_case(params: &PcsParams, trees: Vec, seed: u64) -> Result<( verifier_channel.finalize().expect("transcript should finalize cleanly"); assert_eq!(prover_digest, verifier_digest); - // Re-parse PcsTranscript from a fresh channel and verify digest agreement. + // Re-parse PcsProof from a fresh channel and verify digest agreement. let alignment = lmcs.alignment(); - let aligned_commitments: Vec<_> = commitments + let aligned_commitments: Vec<_> = commitment_groups .iter() - .map(|(c, widths)| (*c, aligned_widths(widths.clone(), alignment))) + .map(|group| CommitmentGroup { + root: group.root, + widths: aligned_widths(group.widths.clone(), alignment), + log_height: group.log_height, + }) .collect(); let mut challenger = gl::test_challenger(); @@ -104,15 +120,15 @@ fn run_pcs_case(params: &PcsParams, trees: Vec, seed: u64) -> Result<( } let mut reparse_channel = VerifierTranscript::from_data(challenger, &transcript); - PcsTranscript::::from_verifier_channel::<_, 2>( + PcsProof::::read_from_channel::<_, 2>( params, &lmcs, &aligned_commitments, - log_lde_height, + &max_domain, eval_points, &mut reparse_channel, ) - .expect("PcsTranscript re-parse should succeed"); + .expect("PcsProof re-parse should succeed"); let reparse_digest = reparse_channel .finalize() @@ -127,24 +143,29 @@ fn test_pcs_cases() { let lmcs = test_lmcs(); let params = test_params(); + let log_blowup = params.log_blowup; + // Pass the LDE shift through `LiftedDomain` (the only sanctioned access path). + let lde_shift = LiftedDomain::::canonical_lde_shift(6 + log_blowup) + .expect("test parameters in range"); + // Case 1: single matrix, single tree. let rng = &mut SmallRng::seed_from_u64(42); - let matrix = random_lde_matrix(rng, 6, params.fri.log_blowup, 3, Felt::GENERATOR); + let matrix = random_lde_matrix(rng, 6, log_blowup, 3, lde_shift); let tree = lmcs.build_aligned_tree(vec![matrix.bit_reverse_rows()]); run_pcs_case(¶ms, vec![tree], 100).expect("single-tree roundtrip"); // Case 2: two separate trees with different column counts. let rng = &mut SmallRng::seed_from_u64(24); - let mat_a = random_lde_matrix(rng, 6, params.fri.log_blowup, 2, Felt::GENERATOR); - let mat_b = random_lde_matrix(rng, 6, params.fri.log_blowup, 4, Felt::GENERATOR); + let mat_a = random_lde_matrix(rng, 6, log_blowup, 2, lde_shift); + let mat_b = random_lde_matrix(rng, 6, log_blowup, 4, lde_shift); let tree_a = lmcs.build_aligned_tree(vec![mat_a.bit_reverse_rows()]); let tree_b = lmcs.build_aligned_tree(vec![mat_b.bit_reverse_rows()]); run_pcs_case(¶ms, vec![tree_a, tree_b], 200).expect("multi-tree roundtrip"); // Case 3: mixed heights in one commitment group (LMCS upsampling). let rng = &mut SmallRng::seed_from_u64(99); - let short = random_lde_matrix(rng, 4, params.fri.log_blowup, 2, Felt::GENERATOR); - let tall = random_lde_matrix(rng, 6, params.fri.log_blowup, 3, Felt::GENERATOR); + let short = random_lde_matrix(rng, 4, log_blowup, 2, lde_shift); + let tall = random_lde_matrix(rng, 6, log_blowup, 3, lde_shift); let tree = lmcs.build_aligned_tree(vec![short.bit_reverse_rows(), tall.bit_reverse_rows()]); run_pcs_case(¶ms, vec![tree], 300).expect("mixed-height roundtrip"); diff --git a/stark/miden-lifted-stark/src/pcs/utils.rs b/stark/miden-lifted-stark/src/pcs/utils.rs deleted file mode 100644 index 8bfef3ae0f..0000000000 --- a/stark/miden-lifted-stark/src/pcs/utils.rs +++ /dev/null @@ -1,106 +0,0 @@ -use alloc::vec::Vec; -use core::{ - array, - ops::{Add, Mul}, -}; - -use p3_field::{ - ExtensionField, Field, PackedFieldExtension, PackedValue, TwoAdicField, - coset::TwoAdicMultiplicativeCoset, -}; -use p3_util::reverse_slice_index_bits; - -// ============================================================================ -// Extension trait for PackedFieldExtension methods not in upstream -// ============================================================================ - -/// Horner fold with an explicit accumulator. -/// -/// Computes `acc·xⁿ + v₀·xⁿ⁻¹ + v₁·xⁿ⁻² + ... + vₙ₋₁·x⁰` where n = len(vals). -/// Equivalently: `((acc·x + v₀)·x + v₁)·x + ... + vₙ₋₁`. -/// The first element gets the highest power of `x`. -/// -/// For polynomial evaluation `p(x) = Σᵢ cᵢ·xⁱ`, pass coefficients in -/// descending degree order `[cₙ, ..., c₁, c₀]`. -#[inline] -pub fn horner_acc(acc: Acc, x: X, vals: I) -> Acc -where - I: IntoIterator, - Acc: Mul + Add, - X: Clone, -{ - vals.into_iter().fold(acc, |acc, val| acc * x.clone() + val) -} - -/// Horner fold starting from zero. -/// -/// See [`horner_acc`] for the evaluation convention. -#[inline] -pub fn horner(x: X, vals: I) -> Acc -where - I: IntoIterator, - Acc: Default + Mul + Add, - X: Clone, -{ - horner_acc(Acc::default(), x, vals) -} - -/// Extension trait adding `pack_ext_columns` and `to_ext_slice` methods. -/// -/// These methods enable efficient SIMD operations on arrays of extension field elements -/// by providing column-wise packing and unpacking utilities. -pub trait PackedFieldExtensionExt< - BaseField: Field, - ExtField: ExtensionField, ->: PackedFieldExtension -{ - /// Pack N columns from WIDTH rows into N packed extension field elements. - /// - /// Input: `rows[lane][col]` - WIDTH rows, each with N extension field elements. - /// Output: `result[col]` - N packed values, where each packs WIDTH lanes. - fn pack_ext_columns(rows: &[[ExtField; N]]) -> [Self; N] { - let width = BaseField::Packing::WIDTH; - debug_assert_eq!(rows.len(), width); - array::from_fn(|col| { - let col_elems: Vec = (0..width).map(|lane| rows[lane][col]).collect(); - Self::from_ext_slice(&col_elems) - }) - } - - /// Extract all lanes to an output slice. - fn to_ext_slice(&self, out: &mut [ExtField]) { - let width = BaseField::Packing::WIDTH; - for (lane, slot) in out.iter_mut().enumerate().take(width) { - *slot = self.extract(lane); - } - } -} - -impl< - BaseField: Field, - ExtField: ExtensionField, - P: PackedFieldExtension, -> PackedFieldExtensionExt for P -{ -} - -/// Coset points `gK` in bit-reversed order. -/// -/// Note: the coset shift `g` is fixed to `F::GENERATOR` by convention in this PCS. -/// -/// Bit-reversal gives two properties essential for lifting: -/// - **Adjacent negation**: `gK[2i+1] = -gK[2i]`, so both square to the same value -/// - **Squaring gives prefix**: `(gK[2i])² = (gK)²[i]` — the even-indexed elements, when squared, -/// form the half-size sub-coset. Generalizes to r-th powers. -/// -/// Together these enable iterative weight folding in barycentric evaluation. -/// -/// # Panics -/// Panics if the two-adic coset construction fails (e.g., `log_n` exceeds the field's -/// two-adicity), since this unwraps `TwoAdicMultiplicativeCoset::new`. -pub fn bit_reversed_coset_points(log_n: u8) -> Vec { - let coset = TwoAdicMultiplicativeCoset::new(F::GENERATOR, log_n as usize).unwrap(); - let mut pts: Vec = coset.iter().collect(); - reverse_slice_index_bits(&mut pts); - pts -} diff --git a/stark/miden-lifted-stark/src/pcs/verifier.rs b/stark/miden-lifted-stark/src/pcs/verifier.rs index 51fc4b9114..3fbd3f57fc 100644 --- a/stark/miden-lifted-stark/src/pcs/verifier.rs +++ b/stark/miden-lifted-stark/src/pcs/verifier.rs @@ -21,7 +21,8 @@ use p3_matrix::{Matrix, horizontally_truncated::HorizontallyTruncated}; use thiserror::Error; use crate::{ - lmcs::{Lmcs, tree_indices::TreeIndices, utils::aligned_widths}, + domain::LiftedDomain, + lmcs::{Lmcs, tree_indices::TreeIndices}, pcs::{ deep::{ proof::OpenedValues, @@ -30,8 +31,23 @@ use crate::{ fri::verifier::{FriError, FriOracle}, params::PcsParams, }, + util::align::aligned_widths, }; +/// A committed group to open: its root, per-matrix (unpadded) widths, and tree depth. +/// +/// `log_height` is `log₂` of the committed tree's leaf count. It is `≤ +/// domain.log_lde_height()`: a tree committed below the query domain (e.g. a +/// setup-fixed preprocessed tree) is virtually lifted with +/// [`Lmcs::open_lifted_batch`](crate::lmcs::Lmcs::open_lifted_batch). +/// Full-height groups set it to `domain.log_lde_height()`. +#[derive(Clone, Debug)] +pub(crate) struct CommitmentGroup { + pub root: C, + pub widths: Vec, + pub log_height: u8, +} + /// Verify polynomial evaluation claims against commitments. /// /// Commitment widths must match the committed rows (including any alignment padding @@ -47,16 +63,17 @@ use crate::{ /// # Preconditions /// - `eval_points` must lie outside both the trace-domain subgroup `H` and the LDE evaluation coset /// `gK`. Otherwise denominators `(zⱼ − X)` in the DEEP quotient become zero, making it undefined. -/// - All commitments must be lifted to the same LDE height `2^log_lde_height`. +/// - Each group's `log_height` must be `≤ domain.log_lde_height()`; shorter groups are virtually +/// lifted (their query indices fold down to the committed depth). /// /// # Returns /// `opened[group][matrix]` as a `RowMajorMatrix` with `N` rows /// (one per evaluation point), using the same widths that were passed in. -pub fn verify( +pub(crate) fn verify( params: &PcsParams, lmcs: &L, - commitments: &[(L::Commitment, Vec)], - log_lde_height: u8, + commitments: &[CommitmentGroup], + domain: &LiftedDomain, eval_points: [EF; N], channel: &mut Ch, ) -> Result, PcsError> @@ -72,17 +89,19 @@ where return Err(PcsError::NoCommitments); } + let log_lde_height = domain.log_lde_height(); + // Construct verifier's DEEP oracle (observes evals, checks PoW, samples α/β) let (deep_oracle, evals) = DeepOracle::::new( params.deep, &eval_points, commitments.to_vec(), - log_lde_height, + domain, channel, )?; // Create FRI oracle (observes commitments + final poly, checks per-round PoW) - let fri_oracle = FriOracle::new(¶ms.fri, log_lde_height, channel)?; + let fri_oracle = FriOracle::new(¶ms.fri, domain, channel)?; // Check query PoW witness and sample query indices channel.grind(params.query_pow_bits())?; @@ -93,8 +112,7 @@ where let tree_indices = TreeIndices::new(sampled_indices_iter, log_lde_height) .expect("sampled indices are in range"); - // Verify DEEP openings for all queries at once - // tree_indices are bit-reversed positions; deep_evals is keyed by tree index + // Verify DEEP openings for all sampled domain indices at once. let deep_evals = deep_oracle.open_batch(lmcs, &tree_indices, channel)?; // Test low-degree proximity for all queries at once @@ -109,11 +127,11 @@ where /// 1. Aligns widths to `lmcs.alignment()` /// 2. Calls [`verify`] with aligned widths /// 3. Truncates returned evals back to original widths -pub fn verify_aligned( +pub(crate) fn verify_aligned( params: &PcsParams, lmcs: &L, - commitments: &[(L::Commitment, Vec)], - log_lde_height: u8, + commitments: &[CommitmentGroup], + domain: &LiftedDomain, eval_points: [EF; N], channel: &mut Ch, ) -> Result, PcsError> @@ -126,19 +144,23 @@ where let alignment = lmcs.alignment(); let aligned_commitments: Vec<_> = commitments .iter() - .map(|(c, widths)| (c.clone(), aligned_widths(widths.clone(), alignment))) + .map(|g| CommitmentGroup { + root: g.root.clone(), + widths: aligned_widths(g.widths.clone(), alignment), + log_height: g.log_height, + }) .collect(); - let evals = verify(params, lmcs, &aligned_commitments, log_lde_height, eval_points, channel)?; + let evals = verify(params, lmcs, &aligned_commitments, domain, eval_points, channel)?; // Truncate each matrix back to original widths, removing alignment padding. let truncated = evals .into_iter() .zip(commitments) - .map(|(group, (_, orig_widths))| { + .map(|(group, g)| { group .into_iter() - .zip(orig_widths) + .zip(&g.widths) .map(|(mat, &orig_w)| { HorizontallyTruncated::new(mat, orig_w) .expect("original width must not exceed aligned width") diff --git a/stark/miden-lifted-stark/src/preprocessed.rs b/stark/miden-lifted-stark/src/preprocessed.rs new file mode 100644 index 0000000000..2cb17b3860 --- /dev/null +++ b/stark/miden-lifted-stark/src/preprocessed.rs @@ -0,0 +1,346 @@ +//! Preprocessed data: the fixed per-AIR matrices and their committed LDE tree. +//! +//! Preprocessed columns are *fixed circuit data* (lookup tables, selectors) +//! declared by the AIR via [`BaseAir::preprocessed_trace`] and committed once +//! at setup. The prover holds the cached raw matrices plus their LDE tree (the +//! [`Preprocessed`] bundle, built once and borrowed across proofs); the +//! verifier holds only the commitment (a root hash, trusted like the AIR list +//! itself). +//! +//! [`Preprocessed::build`] caches the by-value [`BaseAir::preprocessed_trace`] +//! evals and builds the aligned LDE tree using the supplied STARK config. The +//! resulting bundle is tied to that config's PCS blowup and LMCS alignment. +//! `validate_preprocessed` checks a bundle against a prover statement and +//! config; it runs at [`ProverInstance::new`](crate::ProverInstance::new) +//! construction time, so the prover never re-checks the shape. + +use alloc::vec::Vec; + +use miden_lifted_air::{BaseAir, LiftedAir, MultiAir, ProverStatement, Statement, log2_strict_u8}; +use p3_dft::TwoAdicSubgroupDft; +use p3_field::{ExtensionField, TwoAdicField}; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; +use thiserror::Error; +use tracing::info_span; + +use crate::{ + StarkConfig, + domain::LiftedDomain, + lmcs::{Lmcs, LmcsTree}, + order::TraceOrder, + prover::commit::Committed, + util::bitrev::materialize_bitrev, +}; + +// ============================================================================ +// Preprocessed +// ============================================================================ + +/// Fixed per-AIR preprocessed data: the cached raw matrices plus their +/// committed LDE tree. +/// +/// `traces[i]` is `Some` exactly when AIR `i` declares preprocessed columns; +/// the LDE tree commits one LDE trace per such AIR, in proof order. Built once +/// at setup via [`Preprocessed::build`] and borrowed across proofs. +/// +/// Parameterized over the LMCS `L` rather than a full [`StarkConfig`] so the +/// value can be borrowed by prover instances with the same commitment type. The +/// bundle is still tied to the PCS blowup and LMCS alignment used at build time; +/// `validate_preprocessed` checks those config-dependent dimensions. +pub struct Preprocessed +where + F: TwoAdicField, + L: Lmcs, +{ + /// Per-AIR raw preprocessed matrices in instance order; `None` where the + /// AIR declares none. The cached [`BaseAir::preprocessed_trace`] evals — + /// `preprocessed_trace` re-allocates on every call, so they are computed + /// once here and retained for validation and `check_constraints`. + traces: Vec>>, + /// Committed LDE tree, one committed LDE trace per preprocessed AIR. + committed: Committed, L>, +} + +impl Preprocessed +where + F: TwoAdicField, + L: Lmcs, +{ + /// Build the preprocessed bundle from a statement's AIRs, or `None` when no + /// AIR declares preprocessed columns. + /// + /// Calls [`BaseAir::preprocessed_trace`] once per AIR (caching the by-value + /// result), then LDEs the declared matrices — sorted height-ascending, + /// tiebroken by AIR index, the committed trace order both sides reproduce — + /// and builds the aligned tree. + /// + /// # Panics + /// + /// Panics if a declared preprocessed matrix has non-power-of-two height or + /// its LDE order exceeds the field's two-adicity — programmer errors at + /// setup, not untrusted input. + pub fn build(statement: &Statement, config: &C) -> Option + where + EF: ExtensionField, + MA: MultiAir, + C: StarkConfig, + { + let traces: Vec>> = + statement.airs().iter().map(BaseAir::preprocessed_trace).collect(); + if traces.iter().all(Option::is_none) { + return None; + } + + // Committed trace order: preprocessed AIRs sorted by `(height, air_idx)`. This must + // match the trace↔AIR mapping the prover/verifier reconstruct via + // `TraceOrder::preprocessed_air_for_trace_index` (preprocessed AIRs in proof + // order, i.e. sorted by `(main_trace_height, air_idx)`). The two + // coincide because `validate_preprocessed` rejects any bundle whose + // preprocessed height differs from the main trace height — so sorting by + // the preprocessed matrix height here yields the same order. `build` + // sees only the AIR list (fixed circuit data), not the witness traces, + // so it cannot call `TraceOrder` directly. + let mut pairs: Vec<(usize, &RowMajorMatrix)> = traces + .iter() + .enumerate() + .filter_map(|(i, t)| t.as_ref().map(|m| (i, m))) + .collect(); + pairs.sort_by_key(|(air_idx, m)| (m.height(), *air_idx)); + + let log_blowup = config.pcs().log_blowup(); + let ldes: Vec<_> = pairs + .into_iter() + .map(|(air_idx, trace)| { + let height = trace.height(); + assert!( + height.is_power_of_two(), + "preprocessed matrix for AIR {air_idx} has non-power-of-two height {height}", + ); + let log_h = log2_strict_u8(height); + let coset_shift = LiftedDomain::::canonical_lde_shift(log_h + log_blowup) + .expect("preprocessed LDE order exceeds field two-adicity"); + let width = trace.width(); + info_span!("preprocessed LDE", air = air_idx, log_height = log_h, width).in_scope( + || { + let lde = config.dft().coset_lde_batch( + trace.clone(), + log_blowup.into(), + coset_shift, + ); + materialize_bitrev(lde) + }, + ) + }) + .collect(); + + Some(Self { + traces, + committed: Committed::new(config.lmcs().build_aligned_tree(ldes)), + }) + } + + /// Commitment (Merkle root) of the preprocessed LDE tree — handed to the + /// verifier via [`VerifierInstance::new`](crate::VerifierInstance::new). + pub fn commitment(&self) -> L::Commitment { + self.committed.root() + } + + /// The committed LDE tree, for opening and per-AIR quotient-domain views. + pub(crate) fn committed(&self) -> &Committed, L> { + &self.committed + } +} + +// ============================================================================ +// Validation +// ============================================================================ + +/// Validate a [`Preprocessed`] bundle against a prover statement and STARK config: +/// per-AIR presence, raw matrix shape, committed LDE shape, and LMCS alignment. +/// +/// Called by [`ProverInstance::new`](crate::ProverInstance::new) only when both +/// the AIRs declare preprocessed columns and a bundle is supplied; aggregate +/// presence parity is checked separately by the constructor. +pub(crate) fn validate_preprocessed( + config: &SC, + prover_statement: &ProverStatement, + preprocessed: &Preprocessed, +) -> Result<(), PreprocessedValidationError> +where + F: TwoAdicField, + EF: ExtensionField, + MA: MultiAir, + SC: StarkConfig, +{ + let airs = prover_statement.statement().airs(); + let main_traces = prover_statement.traces(); + let log_blowup = config.pcs().log_blowup(); + + if preprocessed.traces.len() != airs.len() { + return Err(PreprocessedValidationError::RawTraceCountMismatch { + expected: airs.len(), + actual: preprocessed.traces.len(), + }); + } + + let expected_alignment = config.lmcs().alignment(); + let actual_alignment = preprocessed.committed.tree().alignment(); + if actual_alignment != expected_alignment { + return Err(PreprocessedValidationError::AlignmentMismatch { + expected: expected_alignment, + actual: actual_alignment, + }); + } + + // Reconstruct the trace↔AIR mapping the prover/verifier use. + // Heights are already validated by `ProverStatement::new`, so this cannot fail. + let heights: Vec = main_traces.iter().map(Matrix::height).collect(); + let trace_order = TraceOrder::from_trace_heights::(airs, &heights) + .expect("ProverStatement guarantees valid trace shapes"); + let preprocessed_trace_to_air = trace_order.preprocessed_air_for_trace_index::(airs); + let air_to_preprocessed_trace = trace_order.preprocessed_trace_index_for_air::(airs); + + // Raw cached matrices must line up with AIR declarations in instance order. + for (air_idx, (air, raw_preprocessed)) in airs.iter().zip(&preprocessed.traces).enumerate() { + let expected_presence = air.preprocessed_width() > 0; + let actual_presence = raw_preprocessed.is_some(); + if actual_presence != expected_presence { + return Err(PreprocessedValidationError::TracePresenceMismatch { + air: air_idx, + expected: expected_presence, + actual: actual_presence, + }); + } + + let Some(raw_preprocessed) = raw_preprocessed else { + continue; + }; + + let trace = air_to_preprocessed_trace[air_idx] + .expect("presence validation guarantees a declared preprocessed AIR"); + let expected_width = air.preprocessed_width(); + let actual_width = raw_preprocessed.width(); + if actual_width != expected_width { + return Err(PreprocessedValidationError::WidthMismatch { + trace, + air: air_idx, + expected: expected_width, + actual: actual_width, + }); + } + + let main = main_traces[air_idx].height(); + if raw_preprocessed.height() != main { + return Err(PreprocessedValidationError::HeightMismatch { + air: air_idx, + main, + preprocessed: raw_preprocessed.height(), + }); + } + } + + let committed_traces = preprocessed.committed.tree().leaves(); + if committed_traces.len() != preprocessed_trace_to_air.len() { + return Err(PreprocessedValidationError::TreeLengthMismatch { + expected: preprocessed_trace_to_air.len(), + actual: committed_traces.len(), + }); + } + + // Validate the committed leaves directly against the AIRs and config, not by + // trusting them to be the LDE of `traces`. A `Preprocessed` need not have come + // from `build` (e.g. a deserialized bundle), so its raw and committed halves are + // checked independently: width against the declared AIR, and LDE height against + // this proving config's blowup applied to the main trace height (catching a + // bundle built under a different blowup). + for (preprocessed_trace_idx, &air_idx_u8) in preprocessed_trace_to_air.iter().enumerate() { + let air_idx = air_idx_u8 as usize; + let expected_width = airs[air_idx].preprocessed_width(); + let committed_trace = &committed_traces[preprocessed_trace_idx]; + let actual_width = committed_trace.width(); + if actual_width != expected_width { + return Err(PreprocessedValidationError::WidthMismatch { + trace: preprocessed_trace_idx, + air: air_idx, + expected: expected_width, + actual: actual_width, + }); + } + + let main_height = main_traces[air_idx].height(); + let expected_lde_height = main_height.checked_shl(u32::from(log_blowup)).ok_or( + PreprocessedValidationError::LdeHeightOverflow { + air: air_idx, + main: main_height, + log_blowup, + }, + )?; + let actual_lde_height = committed_trace.height(); + if actual_lde_height != expected_lde_height { + return Err(PreprocessedValidationError::LdeHeightMismatch { + trace: preprocessed_trace_idx, + air: air_idx, + log_blowup, + expected: expected_lde_height, + actual: actual_lde_height, + }); + } + } + + Ok(()) +} + +/// Errors from constructing a stark-layer instance: preprocessed presence +/// parity and (prover side) the bundle's shape against the AIR declarations. +#[derive(Debug, Error)] +pub enum PreprocessedValidationError { + #[error( + "preprocessed setup presence mismatch: AIRs declare preprocessed columns = {expected}, setup supplied = {actual}" + )] + PresenceMismatch { expected: bool, actual: bool }, + #[error("raw preprocessed trace count {actual} does not match AIR count {expected}")] + RawTraceCountMismatch { expected: usize, actual: usize }, + #[error( + "AIR {air}: preprocessed trace presence mismatch: AIR declares preprocessed columns = {expected}, raw trace supplied = {actual}" + )] + TracePresenceMismatch { air: usize, expected: bool, actual: bool }, + #[error( + "preprocessed setup alignment mismatch: config expects {expected}, setup uses {actual}" + )] + AlignmentMismatch { expected: usize, actual: usize }, + #[error( + "preprocessed trace {trace} (AIR {air}) width mismatch: AIR declares {expected}, setup has {actual}" + )] + WidthMismatch { + trace: usize, + air: usize, + expected: usize, + actual: usize, + }, + #[error( + "preprocessed trace count {actual} does not match the preprocessed-AIR count {expected}" + )] + TreeLengthMismatch { expected: usize, actual: usize }, + #[error( + "AIR {air}: preprocessed matrix height ({preprocessed}) does not match main trace height ({main})" + )] + HeightMismatch { + air: usize, + main: usize, + preprocessed: usize, + }, + #[error( + "AIR {air}: main trace height {main} overflows usize when shifted by log_blowup {log_blowup}" + )] + LdeHeightOverflow { air: usize, main: usize, log_blowup: u8 }, + #[error( + "preprocessed trace {trace} (AIR {air}) LDE height mismatch for log_blowup {log_blowup}: expected {expected}, setup has {actual}" + )] + LdeHeightMismatch { + trace: usize, + air: usize, + log_blowup: u8, + expected: usize, + actual: usize, + }, +} diff --git a/stark/miden-lifted-stark/src/proof.rs b/stark/miden-lifted-stark/src/proof.rs index c25d20cfa1..3ebd990cf8 100644 --- a/stark/miden-lifted-stark/src/proof.rs +++ b/stark/miden-lifted-stark/src/proof.rs @@ -1,54 +1,68 @@ //! STARK proof types and structured transcript. //! //! This module defines the proof artifact types shared by prover and verifier: -//! - [`StarkProof`]: raw transcript data (field elements and commitments) +//! - [`StarkProofData`]: raw transcript data (field elements and commitments) //! - [`StarkDigest`]: binding digest committing to the entire interaction //! - [`StarkOutput`]: combined prover output (proof + digest) -//! - [`StarkTranscript`]: structured parse-only view of the full protocol interaction +//! - [`StarkProof`]: structured parse-only view of the full protocol interaction //! -//! [`StarkTranscript`] has a [`from_proof`](StarkTranscript::from_proof) constructor -//! that parses it from proof data and a challenger, following the same pattern as -//! [`PcsTranscript`] alongside the PCS verifier. +//! [`StarkProof`] has a [`from_data`](StarkProof::from_data) constructor +//! that parses it from a verifier instance, proof data, and a challenger, following +//! the same pattern as [`PcsProof`] alongside the PCS verifier. After parsing, +//! custom verifiers can use the structured proof plus the same +//! [`VerifierInstance`] without replaying the challenger. extern crate alloc; use alloc::{vec, vec::Vec}; -use miden_lifted_air::LiftedAir; -use miden_stark_transcript::{Channel, TranscriptData, VerifierChannel, VerifierTranscript}; -use p3_challenger::CanFinalizeDigest; +use miden_lifted_air::{BaseAir, LiftedAir, MultiAir}; +use miden_stark_transcript::{Channel, VerifierChannel, VerifierTranscript}; +// Re-exported here so the crate root does not need a dedicated `transcript` module: callers +// reach these via `proof::` alongside the proof types they parameterize. +pub use miden_stark_transcript::{TranscriptChallenger, TranscriptData, TranscriptError}; +use p3_challenger::{CanFinalizeDigest, CanObserve}; use p3_field::{ExtensionField, Field, TwoAdicField}; use serde::{Deserialize, Serialize}; use crate::{ StarkConfig, - coset::LiftedCoset, - instance::{AirInstance, InstanceShapes, validate_air_order, validate_inputs}, - lmcs::{Lmcs, utils::aligned_len}, - pcs::proof::PcsTranscript, - verifier::VerifierError, + domain::{Coset, DomainError, LiftedDomain, log_quotient_degree}, + lmcs::Lmcs, + order::TraceOrder, + pcs::{proof::PcsProof, verifier::CommitmentGroup}, + util::align::aligned_len, + verifier::{VerifierError, VerifierInstance}, }; /// Commitment type alias for convenience. type Commitment = <>::Lmcs as Lmcs>::Commitment; -/// STARK proof: per-instance shape metadata plus raw transcript data. +/// STARK proof: per-AIR log trace heights (instance order) plus the raw +/// transcript data. /// -/// Fields are opaque. The accessors below expose wire-format summaries -/// (trace count, transcript sizes). Read per-instance log trace heights by -/// parsing via [`StarkTranscript::from_proof`], which validates the shape -/// metadata and binds it into the Fiat-Shamir challenger — -/// [`verify_multi`](crate::verifier::verify_multi) runs the same validation. +/// The proof's AIR ordering is *not* stored. Both sides reconstruct it with a +/// stable sort by `(log_trace_height, instance_index)`, where `instance_index` +/// is the AIR's position in [`miden_lifted_air::Statement::airs`]. The derived +/// order then drives commitment grouping, aux-value layout, quotient accumulation, and PCS +/// opening widths. The proof therefore commits to the instance-order heights, +/// not an explicit ordering permutation. +/// +/// The heights themselves are not exposed as a direct accessor: parse the +/// proof through [`StarkProof::from_data`] and read them via +/// [`StarkProof::log_trace_heights`]. // Bounds target `Commitment` directly; `SC` itself isn't `Serialize`/`Debug`. #[derive(Clone, Serialize, Deserialize)] #[serde(bound(serialize = "TranscriptData>: Serialize"))] #[serde(bound(deserialize = "TranscriptData>: Deserialize<'de>"))] -pub struct StarkProof, SC: StarkConfig> { - pub(crate) instance_shapes: InstanceShapes, +pub struct StarkProofData, SC: StarkConfig> { + /// Per-AIR log₂ trace heights, in instance order. Matches + /// [`miden_lifted_air::Statement::airs`] position-for-position. + pub(crate) log_trace_heights: Vec, pub(crate) transcript: TranscriptData>, } -impl core::fmt::Debug for StarkProof +impl core::fmt::Debug for StarkProofData where F: TwoAdicField + core::fmt::Debug, EF: ExtensionField, @@ -56,31 +70,22 @@ where Commitment: core::fmt::Debug, { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("StarkProof") - .field("instance_shapes", &self.instance_shapes) + f.debug_struct("StarkProofData") + .field("log_trace_heights", &self.log_trace_heights) .field("transcript", &self.transcript) .finish() } } -impl StarkProof +impl StarkProofData where F: TwoAdicField, EF: ExtensionField, SC: StarkConfig, { - /// The AIR ordering used by the proof: `air_order()[j]` is the caller's - /// original index of the instance at position `j`. - /// - /// Read this before building the Fiat-Shamir challenger so you can bind - /// AIR configurations and the ordering — see the prover module-level docs. - pub fn air_order(&self) -> &[u32] { - self.instance_shapes.air_order() - } - /// Number of traces (instances) the proof was produced for. pub fn num_traces(&self) -> usize { - self.instance_shapes.log_trace_heights.len() + self.log_trace_heights.len() } /// Number of base-field elements in the transcript. @@ -95,7 +100,7 @@ where /// Total byte size of the proof. pub fn size_in_bytes(&self) -> usize { - self.instance_shapes.size_in_bytes() + self.transcript.size_in_bytes() + self.log_trace_heights.len() + self.transcript.size_in_bytes() } } @@ -105,13 +110,13 @@ where pub type StarkDigest = <>::Challenger as CanFinalizeDigest>::Digest; -/// Output of [`crate::prover::prove_single`] / [`crate::prover::prove_multi`]: the proof data and -/// its transcript digest. +/// Output of [`ProverInstance::prove`](crate::ProverInstance::prove): the proof +/// data and its transcript digest. pub struct StarkOutput, SC: StarkConfig> { /// Transcript digest committing to the entire prover–verifier interaction. pub digest: StarkDigest, /// Proof data consumed by the verifier. - pub proof: StarkProof, + pub proof: StarkProofData, } impl core::fmt::Debug for StarkOutput @@ -132,31 +137,36 @@ where /// Structured transcript view for the full lifted STARK protocol. /// -/// Captures instance shape metadata, commitments, sampled challenges, the OOD -/// evaluation point, and the PCS sub-transcript. Constructed via -/// [`from_proof`](Self::from_proof), which mirrors steps 0–9 of -/// [`verify_multi`](crate::verifier::verify_multi) but skips the constraint -/// check. -pub struct StarkTranscript +/// Captures the proof-carried commitments, sampled challenges, OOD point, parsed +/// PCS sub-transcript, and LMCS hint data needed to implement verification logic +/// without replaying the Fiat-Shamir challenger. Verification context that is not +/// proof data — the STARK config, statement, and optional preprocessed setup +/// commitment — remains on [`VerifierInstance`]. +/// +/// Constructed via [`from_data`](Self::from_data), which mirrors the transcript +/// steps of [`VerifierInstance::verify`](crate::VerifierInstance::verify), +/// finalizes the challenger once, and returns the binding digest alongside this +/// parsed view. Custom verifiers in external crates should use this type together +/// with the same verifier instance to avoid re-parsing raw transcript data. +pub struct StarkProof where L: Lmcs, L::F: Field, EF: ExtensionField, { - /// Per-instance shape metadata. Validated and observed into the challenger - /// by [`from_proof`](Self::from_proof). - pub instance_shapes: InstanceShapes, - /// Throwaway challenge squeezed right after observing the instance shapes, - /// used to clear the challenger's absorb buffer so that later sampled - /// challenges depend on the full shape metadata regardless of sponge state. - pub instance_challenge: EF, + /// AIR ordering reconstructed from the proof's log trace heights. + /// Validated and observed into the challenger by + /// [`from_data`](Self::from_data). Read its data through + /// [`log_trace_heights`](Self::log_trace_heights) and + /// [`air_order`](Self::air_order). + pub(crate) trace_order: TraceOrder, /// Main trace commitment. pub main_commit: L::Commitment, /// Randomness sampled for auxiliary traces. pub randomness: Vec, /// Auxiliary trace commitment. pub aux_commit: L::Commitment, - /// Aux values per AIR instance, observed into the transcript after the aux commitment. + /// Aux values per AIR instance, in the proof's AIR ordering. pub all_aux_values: Vec>, /// Constraint folding challenge alpha. pub alpha: EF, @@ -166,73 +176,118 @@ where pub quotient_commit: L::Commitment, /// Out-of-domain evaluation point z. pub z: EF, - /// PCS sub-transcript (DEEP evals, FRI rounds, query openings). - pub pcs_transcript: PcsTranscript, + /// PCS sub-proof (DEEP evals, FRI rounds, query openings). + pub pcs_proof: PcsProof, } -impl StarkTranscript +impl StarkProof where L: Lmcs, L::F: TwoAdicField, EF: ExtensionField, { - /// Parse a STARK transcript from proof data and a challenger. + /// Per-AIR log₂ trace heights in instance order (matches + /// [`miden_lifted_air::Statement::airs`] position-for-position). + pub fn log_trace_heights(&self) -> &[u8] { + self.trace_order.log_heights() + } + + /// The proof's AIR ordering: position `j` holds the instance index of the + /// AIR at proof position `j`. Derived deterministically from the heights. + pub fn air_order(&self) -> Vec { + self.trace_order.instance_indices().to_vec() + } + + /// Parse a STARK transcript from a verifier instance, proof data, and a challenger. /// - /// Mirrors steps 0–9 of [`verify_multi`](crate::verifier::verify_multi): - /// 0. Validate instance shapes, then observe log trace heights into the challenger and squeeze - /// a throwaway `instance_challenge` to clear the absorb buffer - /// 1. Receive main trace commitment - /// 2. Sample randomness for auxiliary traces - /// 3. Receive auxiliary trace commitment - /// 4. Receive aux values (per AIR instance) - /// 5. Sample constraint folding alpha and accumulation beta - /// 6. Receive quotient commitment - /// 7. Sample OOD point z - /// 8. Build commitment widths for PCS - /// 9. Parse PCS sub-transcript via [`PcsTranscript::from_verifier_channel`] + /// Mirrors the transcript-facing steps of + /// [`VerifierInstance::verify`](crate::VerifierInstance::verify): + /// 0. Reconstruct and validate AIR ordering from the proof heights, observe the optional + /// preprocessed commitment, then absorb statement-owned inputs and trace-shape data. + /// 1. Receive the main trace commitment. + /// 2. Sample randomness for auxiliary traces. + /// 3. Receive the auxiliary trace commitment. + /// 4. Receive auxiliary values, per AIR instance in proof order. + /// 5. Sample the constraint folding challenge `alpha` and accumulation challenge `beta`. + /// 6. Receive the quotient commitment. + /// 7. Sample the OOD point `z`. + /// 8. Build PCS commitment groups, including per-group tree depths for any virtually lifted + /// preprocessed tree. + /// 9. Parse the PCS sub-proof via `PcsProof::read_from_channel`. /// /// Does **not** verify constraints or check the quotient identity. /// Finalizes the transcript and returns the digest alongside the parsed view. #[allow(clippy::type_complexity)] - pub fn from_proof( - config: &SC, - instances: &[(&A, AirInstance<'_, L::F>)], - proof: &StarkProof, + pub fn from_data( + instance: &VerifierInstance<'_, L::F, EF, MA, SC>, + data: &StarkProofData, mut challenger: SC::Challenger, ) -> Result<(Self, StarkDigest), VerifierError> where - A: LiftedAir, + MA: MultiAir, SC: StarkConfig, { - validate_air_order(proof.instance_shapes.air_order(), instances.len())?; - let instances = proof.instance_shapes.reorder(instances.to_vec())?; + let config = instance.config(); + let statement = instance.statement(); + let preprocessed_commitment = instance.preprocessed_commitment(); + + // Shape well-formedness and per-AIR periodic-height feasibility, both + // against the (untrusted) proof heights — catches malicious + // `log_h` > usize::BITS and infeasible heights before any later use. + let trace_order = TraceOrder::from_log_heights::( + statement.airs(), + data.log_trace_heights.clone(), + )?; + + let preprocessed_trace_to_air = if preprocessed_commitment.is_some() { + trace_order.preprocessed_air_for_trace_index::(statement.airs()) + } else { + Vec::new() + }; + + let air_refs: Vec<&MA::Air> = statement.airs().iter().collect(); + let proof_ordered_airs = trace_order.to_proof_order(&air_refs); let log_blowup = config.pcs().log_blowup(); - let log_max_trace_height = validate_inputs(&instances, &proof.instance_shapes, log_blowup)?; - proof.instance_shapes.observe_heights::(&mut challenger); + let log_max_trace_height = trace_order.max_log_height(); + let max_lde_domain = LiftedDomain::::try_canonical(log_max_trace_height, log_blowup)?; - let mut channel = VerifierTranscript::from_data(challenger, &proof.transcript); + // The preprocessed commitment is a trusted statement input and must be bound + // before the rest of the instance data, matching the prover/verifier transcript. + if let Some(commitment) = preprocessed_commitment { + challenger.observe(commitment.clone()); + } - // Clear the challenger's absorb buffer after observing instance shapes. - // Mirrors `prove_multi` / `verify_multi`. - let instance_challenge: EF = channel.sample_algebra_element::(); + // `Statement::observe` absorbs statement-owned inputs. The protocol then + // binds the proof's instance count and log trace heights in instance order. + statement.observe(&mut challenger, trace_order.log_heights()); + trace_order.observe_shape::(&mut challenger); - let alignment = config.lmcs().alignment(); + let mut channel = VerifierTranscript::from_data(challenger, &data.transcript); - // Infer constraint degree from symbolic AIR analysis (max across all AIRs) - let constraint_degree = - instances.iter().map(|(air, _)| air.constraint_degree()).max().unwrap_or(2); - let log_lde_height = log_max_trace_height + log_blowup; + let alignment = config.lmcs().alignment(); - // Max LDE coset (for the largest trace, no lifting) - let max_lde_coset = LiftedCoset::unlifted(log_max_trace_height, log_blowup); + // Infer quotient degree from symbolic AIR analysis (max across all AIRs) + let max_log_quotient_degree = proof_ordered_airs + .iter() + .map(|&air| log_quotient_degree::(air)) + .max() + .expect("TraceOrder construction rejects empty AIR sets"); + if max_log_quotient_degree > log_blowup { + return Err(DomainError::ConstraintDegreeTooHigh { + log_quotient: max_log_quotient_degree, + log_blowup, + } + .into()); + } + let quotient_degree = 1usize << max_log_quotient_degree as usize; // 1. Receive main trace commitment let main_commit = channel.receive_commitment()?.clone(); // 2. Sample randomness for aux traces let max_num_randomness = - instances.iter().map(|(air, _)| air.num_randomness()).max().unwrap_or(0); + proof_ordered_airs.iter().map(|air| air.num_randomness()).max().unwrap_or(0); let randomness: Vec = (0..max_num_randomness) .map(|_| channel.sample_algebra_element::()) @@ -242,9 +297,9 @@ where let aux_commit = channel.receive_commitment()?.clone(); // 4. Receive aux values from the transcript (one EF element per aux value, per instance). - let all_aux_values: Vec> = instances + let all_aux_values: Vec> = proof_ordered_airs .iter() - .map(|(air, _)| { + .map(|air| { let count = air.num_aux_values(); (0..count) .map(|_| channel.receive_algebra_element::()) @@ -260,38 +315,69 @@ where let quotient_commit = channel.receive_commitment()?.clone(); // 7. Sample OOD point (outside max trace domain H and max LDE coset gK) - let z: EF = max_lde_coset.sample_ood_point(&mut channel); - let h = L::F::two_adic_generator(log_max_trace_height.into()); + let z: EF = max_lde_domain.sample_ood_point(&mut channel); + let h = max_lde_domain.trace_subgroup().generator(); let z_next = z * h; - // 8. Build commitment widths for PCS. + // 8. Build commitment groups for PCS. // - // The LMCS commits to rows padded to `alignment` boundary, so DEEP evals and - // batch openings are stored at aligned widths in the transcript. We must use - // aligned widths here to parse the transcript correctly. - // (The verifier's `verify_aligned` does the same alignment internally, then - // truncates the returned evals back to original widths for constraint checking.) - let main_widths: Vec = - instances.iter().map(|(air, _)| aligned_len(air.width(), alignment)).collect(); - let quotient_width = aligned_len(constraint_degree * EF::DIMENSION, alignment); - - let aux_widths: Vec = instances + // The structured parser consumes the transcript directly, so widths here are + // the aligned widths stored in LMCS committed matrices. Each group also carries its tree + // depth for virtually lifted preprocessed traces. + let main_widths: Vec = proof_ordered_airs .iter() - .map(|(air, _)| aligned_len(air.aux_width() * EF::DIMENSION, alignment)) + .map(|air| aligned_len(air.width(), alignment)) .collect(); + let quotient_width = aligned_len(quotient_degree * EF::DIMENSION, alignment); - let commitments = vec![ - (main_commit.clone(), main_widths), - (aux_commit.clone(), aux_widths), - (quotient_commit.clone(), vec![quotient_width]), - ]; + let aux_widths: Vec = proof_ordered_airs + .iter() + .map(|air| aligned_len(air.aux_width() * EF::DIMENSION, alignment)) + .collect(); - // 9. Parse PCS sub-transcript - let pcs_transcript = PcsTranscript::from_verifier_channel::<_, 2>( + let full_log_height = log_max_trace_height + log_blowup; + let mut commitments = Vec::with_capacity(4); + if let Some(commitment) = preprocessed_commitment { + let preprocessed_widths: Vec = preprocessed_trace_to_air + .iter() + .map(|&air_idx| { + aligned_len(statement.airs()[air_idx as usize].preprocessed_width(), alignment) + }) + .collect(); + let preprocessed_log_height = preprocessed_trace_to_air + .iter() + .map(|&air_idx| trace_order.log_heights()[air_idx as usize]) + .max() + .expect("preprocessed commitment implies at least one preprocessed AIR") + + log_blowup; + commitments.push(CommitmentGroup { + root: commitment.clone(), + widths: preprocessed_widths, + log_height: preprocessed_log_height, + }); + } + commitments.push(CommitmentGroup { + root: main_commit.clone(), + widths: main_widths, + log_height: full_log_height, + }); + commitments.push(CommitmentGroup { + root: aux_commit.clone(), + widths: aux_widths, + log_height: full_log_height, + }); + commitments.push(CommitmentGroup { + root: quotient_commit.clone(), + widths: vec![quotient_width], + log_height: full_log_height, + }); + + // 9. Parse PCS sub-proof + let pcs_proof = PcsProof::read_from_channel::<_, 2>( config.pcs(), config.lmcs(), &commitments, - log_lde_height, + &max_lde_domain, [z, z_next], &mut channel, )?; @@ -301,8 +387,7 @@ where Ok(( Self { - instance_shapes: proof.instance_shapes.clone(), - instance_challenge, + trace_order, main_commit, randomness, aux_commit, @@ -311,7 +396,7 @@ where beta, quotient_commit, z, - pcs_transcript, + pcs_proof, }, digest, )) diff --git a/stark/miden-lifted-stark/src/prover/README.md b/stark/miden-lifted-stark/src/prover/README.md index 6096c6e90a..14d06da83c 100644 --- a/stark/miden-lifted-stark/src/prover/README.md +++ b/stark/miden-lifted-stark/src/prover/README.md @@ -2,7 +2,7 @@ End-to-end proving for the lifted STARK protocol using LMCS commitments and the lifted FRI PCS. Supports multiple traces of different power-of-two -heights via virtual lifting. +heights of at least 2 rows via virtual lifting. Protocol-level overview lives in `miden-lifted-stark/README.md`. @@ -10,38 +10,55 @@ Protocol-level overview lives in `miden-lifted-stark/README.md`. | Item | Purpose | |------|---------| -| `prove_single` | Prove a single-AIR STARK | -| `prove_multi` | Prove a multi-trace STARK | -| `AirWitness` | Bundle a trace with its public values | +| `prove` | Prove one or more AIR instances | +| `ProverStatement` | Validated proving input: a `Statement` plus per-AIR main witness traces in instance order | +| `Statement` | A `MultiAir` plus the per-proof inputs (`air_inputs`, optional `aux_inputs`) | +| `MultiAir` | Trusted statement definition: AIR instances, cross-AIR assertions, and a Fiat-Shamir `observe` hook | ```text -prove_single(config, air, trace, public_values, var_len_public_inputs, aux_builder, challenger) -prove_multi(config, &[(air, witness, aux_builder), ...], challenger) +prove(config, &prover_statement, challenger) ``` +A `MultiAir` impl exposes its AIRs via `type Air` + `fn airs() -> &[Self::Air]` +and optionally overrides `max_aux_inputs()`, `eval_external(...)`, and +`observe(challenger, ...)` (defaults: zero `aux_inputs` budget, no cross-AIR assertions, +and framed observation of `air_inputs.len()`, `air_inputs`, `max_aux_inputs()`, +`aux_inputs.len()`, then `aux_inputs`; the protocol observes instance count and +`log_heights` after that hook). Each AIR builds its +own auxiliary trace via `LiftedAir::build_aux_trace(main, air_inputs, aux_inputs, +challenges)`. A `Statement` wraps a `MultiAir` with the +`air_inputs` shared by every AIR and the optional `aux_inputs`; `Statement::new` +validates the inputs against the AIRs. A `ProverStatement` wraps a `Statement` +with `traces()` (per-AIR main witness traces in instance order); `ProverStatement::new` +validates the trace shape. The same `MultiAir` drives both proving and +verification — the verifier takes the `Statement`, the prover the `ProverStatement`. + The proof is written into the provided transcript channel. This crate does not prescribe the *initial* challenger state used for Fiat-Shamir. ## Fiat-Shamir / transcript binding -The caller must bind protocol parameters, public values, variable-length -public inputs, AIR configurations, and `air_order` into the challenger -before calling `prove_multi`. See the Rust module-level docs for the full contract -and code examples. +The caller must bind protocol parameters and AIR configurations into the +challenger before calling `prove`. The wire-format AIR ordering is derived +deterministically from the trace heights (no explicit `air_order` to bind). +The statement's `air_inputs` and `aux_inputs` are absorbed by +`Statement::observe` using the `MultiAir::observe` framing; the protocol then +observes the instance count and log trace heights in instance order. See the +Rust module-level docs for the full contract and code examples. ## Protocol flow +0. Absorb caller-supplied inputs via `Statement::observe`, then absorb the instance count and per-instance log trace heights into the challenger. 1. Validate trace dimensions against AIR definition. 2. Commit main trace LDE on nested coset (bit-reversed), observe commitment. 3. Sample aux randomness, build aux trace, commit aux LDE. 4. Sample constraint folding challenge `alpha` and cross-trace accumulator `beta`. 5. Build periodic LDEs for periodic columns. -6. Compute folded constraint numerators on each trace's quotient domain `gJ`. -7. Lift and beta-accumulate numerators onto the max quotient domain. -8. Divide by the max vanishing polynomial to obtain Q(gJ). -9. Commit quotient chunks via fused iDFT + scaling + DFT pipeline. -10. Sample OOD point `z` (rejection-sampled outside trace domain), derive `z_next`. -11. Open via PCS at `[z, z_next]` for main, aux, and quotient trees. +6. Compute each AIR's quotient on its native quotient coset. +7. Lift and accumulate per-AIR quotients onto the max quotient domain. +8. Commit quotient chunks via fused iDFT + scaling + DFT pipeline. +9. Sample OOD point `z` (rejection-sampled outside trace domain), derive `z_next`. +10. Open via PCS at `[z, z_next]` for main, aux, and quotient trees. ## Mathematical background @@ -53,9 +70,13 @@ changes with **lifting** — how the prover avoids work on the largest Let: -- $N = 2^n$ be the **maximum** trace height across all traces in the proof. -- $D = 2^d$ be the **constraint degree** (quotient-domain blowup). -- $B = 2^b$ be the **PCS/FRI blowup** (commitment-domain blowup), with $D \le B$. +- $N$ be the **maximum** trace height across all AIRs in the proof; $n_j$ is AIR + $j$'s trace height and $r_j = N / n_j$. +- $D_j$ be AIR $j$'s **quotient degree factor**, derived from its + constraint-degree bound; its native quotient evaluation domain has size + $n_j D_j$. Let $D_{\max} = \max_j D_j$. +- $B$ be the **PCS/FRI blowup** used for commitment domains, with + $D_{\max} \le B$. - $g$ be the fixed multiplicative shift (`F::GENERATOR`). Define two-adic subgroups: @@ -63,7 +84,7 @@ Define two-adic subgroups: $$ H = \langle \omega_H \rangle,\ |H| = N \qquad -J = \langle \omega_J \rangle,\ |J| = N\,D +J = \langle \omega_J \rangle,\ |J| = N\,D_{\max} \qquad K = \langle \omega_K \rangle,\ |K| = N\,B $$ @@ -71,23 +92,19 @@ $$ with the usual relationships: $$ -\omega_H = \omega_J^D = \omega_K^B +\omega_H = \omega_J^{D_{\max}} = \omega_K^B \qquad -H = J^D = K^B +H = J^{D_{\max}} = K^B \qquad -J = K^{B/D}\ \text{(when } D \le B\text{)} +J = K^{B/D_{\max}}. $$ -We work over shifted cosets $gH, gJ, gK$. +We work over shifted cosets $gH, gJ, gK$. The **global quotient coset** is +$gJ$ of size $N D_{\max}$. ### Mixed heights via lifting -Suppose trace $T_j$ has height - -$$ -n_j = N / r_j -\qquad\text{where } r_j = 2^{\ell_j} \text{ is a power of two.} -$$ +Trace $T_j$ has height $n_j = N / r_j$ with $r_j$ a power of two. Intuitively, **lifting** makes $T_j$ look like a height-$N$ trace by stacking $r_j$ copies of it. Algebraically, if $t_j(X)$ is the degree-$`) /// - `L`: LMCS configuration type -/// -/// # Usage -/// -/// ```ignore -/// let committed = commit_traces(config, traces); -/// let root = committed.root(); -/// let view = committed.evals_on_quotient_domain(0, constraint_degree); -/// ``` -/// -/// Storing the blowup also avoids re-deriving `trace_height = lde_height / blowup` for each -/// matrix, which is needed for quotient-domain views and lifting shifts. pub struct Committed where F: TwoAdicField, @@ -60,8 +49,6 @@ where { /// The underlying LMCS tree. tree: L::Tree, - /// Log₂ of the blowup factor used during LDE. - log_blowup: u8, } impl Committed @@ -70,15 +57,10 @@ where L: Lmcs, M: Matrix, { - /// Create a new `Committed` wrapper. - /// - /// # Arguments - /// - /// - `tree`: The LMCS tree containing committed LDE matrices - /// - `log_blowup`: Log₂ of the blowup factor used during LDE + /// Create a new `Committed` wrapper around an LMCS tree. #[inline] - pub fn new(tree: L::Tree, log_blowup: u8) -> Self { - Self { tree, log_blowup } + pub fn new(tree: L::Tree) -> Self { + Self { tree } } /// Get the commitment root. @@ -92,28 +74,6 @@ where pub fn tree(&self) -> &L::Tree { &self.tree } - - /// Get log₂ of the maximum LDE height across all matrices. - /// - /// This is the height of the tree (the largest matrix height). - #[inline] - fn log_max_lde_height(&self) -> u8 { - log2_strict_u8(self.tree.height()) - } - - /// Returns the [`LiftedCoset`] the `m`-th matrix was committed on. - /// - /// # Panics - /// - /// Panics if `m >= num_matrices()`. - fn lifted_coset(&self, m: usize) -> LiftedCoset { - let matrix = &self.tree.leaves()[m]; - let log_lde_height = log2_strict_u8(matrix.height()); - let log_trace_height = log_lde_height - self.log_blowup; - let log_max_trace_height = self.log_max_lde_height() - self.log_blowup; - - LiftedCoset::new(log_trace_height, self.log_blowup, log_max_trace_height) - } } impl Committed, L> @@ -123,7 +83,8 @@ where { /// Return a zero-copy view of matrix `m` on the quotient evaluation domain. /// - /// This returns evaluations over the quotient coset `gJ ⊆ gK`. + /// This returns evaluations over the quotient coset `gJ ⊆ gK` for matrix `m`, + /// sized to the per-matrix trace height times `eval_domain.quotient_degree()`. /// /// The tree commits to LDE evaluations on `gK` (size `N·B`). The `RowMajorMatrix` /// stores bit-reversed evaluations; `gJ` appears as the first `N·D` rows, so this is @@ -135,10 +96,14 @@ where pub fn evals_on_quotient_domain( &self, m: usize, - constraint_degree: usize, + eval_domain: &EvaluationDomain, ) -> BitReversedMatrixView> { - let quotient_height = self.lifted_coset(m).trace_height() * constraint_degree; - self.tree.leaves()[m].split_rows(quotient_height).0.bit_reverse_rows() + debug_assert_eq!( + eval_domain.lifted().lde_height(), + self.tree.leaves()[m].height(), + "eval_domain LDE height must match matrix m's tree height", + ); + self.tree.leaves()[m].split_rows(eval_domain.size()).0.bit_reverse_rows() } } @@ -161,20 +126,23 @@ where /// /// # Arguments /// - `config`: STARK configuration containing PCS params, LMCS, and DFT -/// - `traces`: Trace matrices sorted by height (ascending) +/// - `domains`: One pre-validated [`LiftedDomain`] per trace, in the same order as `traces`. The +/// last entry is the batch's max-trace domain (heights are sorted ascending). +/// - `traces`: Trace matrices, in the same order as `domains`. Each must have height matching its +/// paired `domains[i].trace_height()`. /// /// # Panics -/// - If `traces` is empty -/// - If trace heights are not powers of two -/// - If traces are not sorted by height in ascending order +/// - If `domains` and `traces` have different lengths +/// - If any trace's height doesn't match its paired domain's `trace_height()` /// /// Lifting note: for a trace of height `n` embedded into a max height `n_max`, let /// `r = n_max / n`. The commitment should behave as if it contains evaluations of the /// lifted polynomial `f_lift(X) = f(Xʳ)` on the max LDE coset. This is achieved by /// evaluating the original trace on a *nested* coset with shift gʳ: the map /// `(g·ω)ʳ = gʳ·ωʳ` sends the max domain down to the smaller one. -pub fn commit_traces( +pub(super) fn commit_traces( config: &SC, + domains: &[LiftedDomain], traces: Vec>, ) -> Committed, SC::Lmcs> where @@ -182,37 +150,25 @@ where EF: ExtensionField, SC: StarkConfig, { + assert_eq!(domains.len(), traces.len(), "domains and traces must have matching lengths"); assert!(!traces.is_empty(), "at least one trace required"); - assert!( - traces.windows(2).all(|w| w[0].height() <= w[1].height()), - "traces must be sorted by height in ascending order" - ); - let log_blowup = config.pcs().log_blowup(); - // Find max trace height - let max_trace_height = traces.last().unwrap().height(); - let log_max_trace_height = log2_strict_u8(max_trace_height); - let ldes: Vec<_> = traces .into_iter() + .zip(domains) .enumerate() - .map(|(idx, trace)| { - let trace_height = trace.height(); + .map(|(idx, (trace, domain))| { let width = trace.width(); - - // Validate height is power of two - assert!( - trace_height.is_power_of_two(), - "trace height must be power of two (index {idx})" + assert_eq!( + trace.height(), + domain.trace_height(), + "trace {idx} height does not match its domain", ); - let log_trace_height = log2_strict_u8(trace_height); - - // Use LiftedCoset to compute the coset shift - let coset = LiftedCoset::new(log_trace_height, log_blowup, log_max_trace_height); - let coset_shift = coset.lde_shift::(); + let log_trace_height = domain.log_trace_height(); + let coset_shift = domain.lde_shift(); info_span!("LDE", trace = idx, log_height = log_trace_height, width).in_scope(|| { let lde = config.dft().coset_lde_batch(trace, log_blowup.into(), coset_shift); @@ -223,7 +179,7 @@ where // Build aligned LMCS tree and wrap in Committed let tree = config.lmcs().build_aligned_tree(ldes); - Committed::new(tree, log_blowup) + Committed::new(tree) } // ============================================================================ diff --git a/stark/miden-lifted-stark/src/prover/constraints/folder.rs b/stark/miden-lifted-stark/src/prover/constraints/folder.rs index e65d7c7e63..c3307c75b6 100644 --- a/stark/miden-lifted-stark/src/prover/constraints/folder.rs +++ b/stark/miden-lifted-stark/src/prover/constraints/folder.rs @@ -8,7 +8,7 @@ use alloc::vec::Vec; use core::marker::PhantomData; use miden_lifted_air::{ - AirBuilder, EmptyWindow, ExtensionBuilder, PeriodicAirBuilder, PermutationAirBuilder, RowWindow, + AirBuilder, ExtensionBuilder, PeriodicAirBuilder, PermutationAirBuilder, RowWindow, }; use p3_field::{ Algebra, BasedVectorSpace, ExtensionField, Field, PackedField, PrimeCharacteristicRing, @@ -83,7 +83,7 @@ fn batched_base_linear_combination(coeffs: &[P::Scalar], values: /// - `EF`: Extension field scalar /// - `P`: Packed base field (with `P::Scalar = F`) /// - `PE`: Packed extension field (must implement appropriate algebra traits) -pub struct ProverConstraintFolder<'a, F, EF, P, PE> +pub(super) struct ProverConstraintFolder<'a, F, EF, P, PE> where F: Field, EF: ExtensionField, @@ -92,6 +92,9 @@ where { /// Main trace two-row window (packed base field) pub main: RowWindow<'a, P>, + /// Preprocessed trace two-row window (packed base field); empty when the + /// AIR declares no preprocessed columns. + pub preprocessed: RowWindow<'a, P>, /// Aux/permutation trace two-row window (packed extension field) pub aux: RowWindow<'a, PE>, /// Randomness for aux trace (packed extension field) @@ -169,7 +172,7 @@ where type F = F; type Expr = P; type Var = P; - type PreprocessedWindow = EmptyWindow

; + type PreprocessedWindow = RowWindow<'a, P>; type MainWindow = RowWindow<'a, P>; type PublicVar = F; @@ -178,8 +181,9 @@ where self.main } + #[inline] fn preprocessed(&self) -> &Self::PreprocessedWindow { - EmptyWindow::empty_ref() + &self.preprocessed } #[inline] diff --git a/stark/miden-lifted-stark/src/prover/constraints/layout.rs b/stark/miden-lifted-stark/src/prover/constraints/layout.rs index 43df5c4569..9a30d0098c 100644 --- a/stark/miden-lifted-stark/src/prover/constraints/layout.rs +++ b/stark/miden-lifted-stark/src/prover/constraints/layout.rs @@ -6,8 +6,8 @@ use alloc::{vec, vec::Vec}; use miden_lifted_air::{ - AirBuilder, EmptyWindow, ExtensionBuilder, LiftedAir, PeriodicAirBuilder, - PermutationAirBuilder, + AirBuilder, ExtensionBuilder, LiftedAir, PeriodicAirBuilder, PermutationAirBuilder, + WindowAccess, symbolic::{AirLayout, ConstraintLayout}, }; use p3_field::{ExtensionField, Field}; @@ -24,14 +24,15 @@ use tracing::instrument; /// for all variables. This discovers which constraints are base-field vs extension-field /// without building symbolic expression trees — only the emission order matters. #[instrument(name = "compute constraint layout", skip_all, level = "debug")] -pub fn get_constraint_layout(air: &A) -> ConstraintLayout +pub(crate) fn get_constraint_layout(air: &A) -> ConstraintLayout where F: Field, EF: ExtensionField, A: LiftedAir, { let mut builder = ConstraintLayoutBuilder::::new(air.air_layout()); - debug_assert!(air.is_valid_builder(&builder).is_ok()); + #[cfg(debug_assertions)] + miden_lifted_air::debug::check_builder_shape(air, &builder); air.eval(&mut builder); builder.into_layout() } @@ -42,11 +43,33 @@ where /// tracking, no `Arc` allocations. Builds a [`ConstraintLayout`] directly by recording /// which `assert_*` method is called for each constraint. /// -/// Uses `RowMajorMatrix` as `MainWindow` because the builder owns its trace data. -/// `RowWindow` cannot be used here — it borrows, but the associated type can't -/// capture the `&self` lifetime from `main()`. +/// Uses owned windows because the builder owns its trace data. `RowWindow` cannot be used here — +/// it borrows, but the associated type can't capture the `&self` lifetime from `main()`. +#[derive(Clone)] +struct OwnedRowWindow { + values: Vec, + width: usize, +} + +impl OwnedRowWindow { + fn zeros(width: usize) -> Self { + Self { values: vec![F::ZERO; 2 * width], width } + } +} + +impl WindowAccess for OwnedRowWindow { + fn current_slice(&self) -> &[F] { + &self.values[..self.width] + } + + fn next_slice(&self) -> &[F] { + &self.values[self.width..] + } +} + struct ConstraintLayoutBuilder { main: RowMajorMatrix, + preprocessed: OwnedRowWindow, public_values: Vec, periodic_values: Vec, permutation: RowMajorMatrix, @@ -59,16 +82,17 @@ struct ConstraintLayoutBuilder { impl ConstraintLayoutBuilder { fn new(layout: AirLayout) -> Self { let AirLayout { + preprocessed_width, main_width, num_public_values, permutation_width, num_permutation_challenges, num_permutation_values, num_periodic_columns, - .. } = layout; Self { main: RowMajorMatrix::new(vec![F::ZERO; 2 * main_width], main_width), + preprocessed: OwnedRowWindow::zeros(preprocessed_width), public_values: vec![F::ZERO; num_public_values], periodic_values: vec![F::ZERO; num_periodic_columns], permutation: RowMajorMatrix::new( @@ -91,7 +115,7 @@ impl AirBuilder for ConstraintLayoutBuilder { type F = F; type Expr = F; type Var = F; - type PreprocessedWindow = EmptyWindow; + type PreprocessedWindow = OwnedRowWindow; type MainWindow = RowMajorMatrix; type PublicVar = F; @@ -100,7 +124,7 @@ impl AirBuilder for ConstraintLayoutBuilder { } fn preprocessed(&self) -> &Self::PreprocessedWindow { - EmptyWindow::empty_ref() + &self.preprocessed } fn is_first_row(&self) -> Self::Expr { diff --git a/stark/miden-lifted-stark/src/prover/constraints/mod.rs b/stark/miden-lifted-stark/src/prover/constraints/mod.rs index a3818b6fdf..91ac745b70 100644 --- a/stark/miden-lifted-stark/src/prover/constraints/mod.rs +++ b/stark/miden-lifted-stark/src/prover/constraints/mod.rs @@ -20,9 +20,12 @@ use p3_field::{ use p3_matrix::{Matrix, bitrev::BitReversedMatrixView, dense::RowMajorMatrixView}; #[cfg(feature = "parallel")] use p3_maybe_rayon::prelude::*; -use packed_row_bitrev::RowMajorMatrixBitrevPackedExt; +use packed_row_bitrev::collect_vertically_packed_row_pair_bitrev_into; -use crate::{coset::LiftedCoset, prover::periodic::PeriodicLde}; +use crate::{ + domain::{Coset, EvaluationDomain}, + prover::periodic::PeriodicLde, +}; /// Row-blocks (`i_start = r * packing_width`) processed per rayon task. const ROW_BLOCKS_PER_PARALLEL_TASK: usize = 32; @@ -33,19 +36,31 @@ type PackedVal = ::Packing; /// Type alias for packed extension field from EF. type PackedExt = >::ExtensionPacking; -/// Evaluate constraints on the quotient domain, adding results into `output`. +/// Evaluate an AIR's constraints on its native quotient coset and write the +/// per-AIR quotient evaluations (constraint numerator divided by `Z_{H_j}`) +/// into `output`. +/// +/// `coset` is the AIR's native quotient evaluation coset `gJ_j` of size `n_j * D_j`, +/// where `n_j` is the AIR's trace height and `D_j = 2^log_quotient_degree` is its +/// per-AIR constraint-degree bound. For each point on `gJ_j` we evaluate every +/// constraint, fold with powers of `alpha`, multiply by the precomputed `1 / Z_H` +/// value, and write the result: +/// +/// `output[i] = folded_constraints(x_i) / Z_{H_j}(x_i)`. /// -/// Here `gJ` is the quotient evaluation coset of size `N * D`, the subset of the -/// committed LDE coset `gK` (size `N * B`) that contains just enough points to -/// evaluate the quotient point-wise. For each point on `gJ`, we evaluate all AIR -/// constraints, fold them with powers of `alpha`, and add the resulting numerator value: +/// `inv_z_h` is a length-`D_j` slice: `Z_{H_j}(x)` takes only `D_j` distinct +/// values over `gJ_j` by periodicity, so batch-inverting them once suffices +/// (use [`crate::domain::EvaluationDomain::inv_vanishing_evals`]). Fusing the +/// divide into the write loop saves a second pass over the `n_j · D_j`-point +/// output buffer. /// -/// `output[i] += folded_constraints(xᵢ)`. +/// `output` must be a fresh zero-initialized buffer of length `n_j * D_j`; each +/// point is written once. Upsampling to the batch-wide target and beta-accumulation +/// into the shared quotient accumulator happen in the caller. /// -/// The caller is responsible for preparing `output` before calling this function -/// (e.g. cyclically extending and scaling by beta for multi-trace accumulation). -/// Trace views must be [`BitReversedMatrixView`] over dense row-major storage (as returned by -/// [`crate::prover::commit::Committed::evals_on_quotient_domain`]), in natural order on gJ. +/// Trace views must be [`BitReversedMatrixView`] over dense row-major storage (as +/// returned by [`crate::prover::commit::Committed::evals_on_quotient_domain`]), in +/// natural order on `gJ_j`. /// /// Uses SIMD-packed parallel iteration via rayon for optimal performance: /// - Processes `WIDTH` points simultaneously using packed field types @@ -60,23 +75,25 @@ type PackedExt = >::ExtensionPacking; /// collapses them into one numerator polynomial while preserving soundness (a non-zero /// constraint survives with high probability). /// -/// Why we only evaluate on `gJ`: `gJ` (size `N * D`) is a subset of the committed LDE -/// coset `gK` (size `N * B`). For `B >= D`, these `N * D` points are sufficient for -/// the quotient-degree bounds used by the protocol; division by the vanishing polynomial -/// happens later. +/// Why we evaluate on the native coset: the quotient `Q_j = C_j / Z_{H_j}` has degree +/// `< n_j * D_j` by construction, so `n_j * D_j` evaluation points suffice to determine +/// it. The committed LDE coset (size `n_j * B`, with `B >= D_j`) contains `gJ_j` as a +/// subset, so the truncated view the caller passes in is zero-copy. #[allow(clippy::too_many_arguments)] -pub fn evaluate_constraints_into( +pub(super) fn evaluate_constraints_into( output: &mut [EF], air: &A, main_on_gj: &BitReversedMatrixView>, + preprocessed_on_gj: Option<&BitReversedMatrixView>>, aux_on_gj: &BitReversedMatrixView>, - coset: &LiftedCoset, + eval_domain: &EvaluationDomain, alpha: EF, randomness: &[EF], public_values: &[F], periodic_lde: &PeriodicLde, layout: &ConstraintLayout, permutation_values: &[EF], + inv_z_h: &[F], ) where F: TwoAdicField, EF: ExtensionField, @@ -86,15 +103,18 @@ pub fn evaluate_constraints_into( type P = PackedVal; type PE = PackedExt; - let gj_height = coset.lde_height(); + let quotient_degree = eval_domain.quotient_degree(); + let gj_height = eval_domain.size(); assert_eq!(output.len(), gj_height); - let constraint_degree = coset.blowup(); let width = P::::WIDTH; assert_eq!(gj_height % width, 0, "quotient height must be divisible by packing width"); + assert_eq!(inv_z_h.len(), quotient_degree, "inv_z_h length must equal D_j"); + // Bitmask for `i % inv_z_h.len()`; len is `2^log_blowup` by construction. + let inv_z_h_mask: usize = inv_z_h.len() - 1; - // Precompute selectors via coset method - let sels = coset.selectors::(); + // Precompute selectors over the quotient evaluation coset. + let sels = eval_domain.selectors(); // ─── Decompose alpha powers by constraint layout ─── let aux_ef_width = air.aux_width(); @@ -105,6 +125,12 @@ pub fn evaluate_constraints_into( // Main trace width let main_width = main_on_gj.width(); + // Preprocessed trace view, constructed only when the AIR declares one. + let preproc_trace_view = preprocessed_on_gj.map(|m| { + let w = m.width(); + RowMajorMatrixView::new(m.inner.values, w) + }); + let preprocessed_width = preproc_trace_view.as_ref().map_or(0, Matrix::width); // Pack randomness for aux trace let packed_randomness: Vec> = randomness.iter().copied().map(Into::into).collect(); @@ -122,6 +148,7 @@ pub fn evaluate_constraints_into( let points_per_task = width * ROW_BLOCKS_PER_PARALLEL_TASK; let eval_big_slice = |main_buf: &mut Vec>, + preproc_buf: &mut Vec>, aux_base_buf: &mut Vec>, aux_pe_buf: &mut Vec>, g: usize, @@ -134,17 +161,36 @@ pub fn evaluate_constraints_into( let selectors = sels.packed_at::>(i_start); // Get main trace as packed row pair (stays in base field) - main_trace_view.collect_vertically_packed_row_pair_bitrev_into( + collect_vertically_packed_row_pair_bitrev_into::>( + &main_trace_view, i_start, - constraint_degree, + quotient_degree, main_buf, ); let main_mat = RowMajorMatrixView::new(main_buf.as_slice(), main_width); + // Get preprocessed trace as packed row pair (when present). For AIRs + // without preprocessed columns, the window is empty and the AIR must + // not call `builder.preprocessed()`. + let preprocessed = if let Some(view) = preproc_trace_view.as_ref() { + collect_vertically_packed_row_pair_bitrev_into::>( + view, + i_start, + quotient_degree, + preproc_buf, + ); + let m = RowMajorMatrixView::new(preproc_buf.as_slice(), preprocessed_width); + RowWindow::from_view(&m) + } else { + let empty: &[P] = &[]; + RowWindow::from_two_rows(empty, empty) + }; + // Get aux trace as packed row pair and convert to packed extension field - aux_trace_view.collect_vertically_packed_row_pair_bitrev_into( + collect_vertically_packed_row_pair_bitrev_into::>( + &aux_trace_view, i_start, - constraint_degree, + quotient_degree, aux_base_buf, ); @@ -166,6 +212,7 @@ pub fn evaluate_constraints_into( let mut folder: ProverConstraintFolder<'_, F, EF, P, PE> = ProverConstraintFolder { main: RowWindow::from_view(&main_mat), + preprocessed, aux: RowWindow::from_view(&aux_mat), packed_randomness: &packed_randomness, public_values, @@ -182,32 +229,50 @@ pub fn evaluate_constraints_into( }; #[cfg(debug_assertions)] - air.is_valid_builder(&folder).expect("builder dimensions must match AIR"); + miden_lifted_air::debug::check_builder_shape(air, &folder); air.eval(&mut folder); let folded = folder.finalize_constraints(); - // Unpack folded result and add scalars directly into the output chunk. - for (slot, val) in chunk.iter_mut().zip(PE::::to_ext_iter([folded])) { - *slot += val; + // Unpack the folded result, multiply by 1/Z_H (modular indexing since Z_H + // takes only D_j distinct values on gJ_j), and write into the output chunk. + for (k, (slot, val)) in + chunk.iter_mut().zip(PE::::to_ext_iter([folded])).enumerate() + { + *slot = val * inv_z_h[(i_start + k) & inv_z_h_mask]; } } }; #[cfg(feature = "parallel")] output.par_chunks_mut(points_per_task).enumerate().for_each_init( - || (Vec::>::new(), Vec::>::new(), Vec::>::new()), - |(main_buf, aux_base_buf, aux_pe_buf), (g, big_slice)| { - eval_big_slice(main_buf, aux_base_buf, aux_pe_buf, g, big_slice); + || { + ( + Vec::>::new(), + Vec::>::new(), + Vec::>::new(), + Vec::>::new(), + ) + }, + |(main_buf, preproc_buf, aux_base_buf, aux_pe_buf), (g, big_slice)| { + eval_big_slice(main_buf, preproc_buf, aux_base_buf, aux_pe_buf, g, big_slice); }, ); #[cfg(not(feature = "parallel"))] { let mut main_buf = Vec::>::new(); + let mut preproc_buf = Vec::>::new(); let mut aux_base_buf = Vec::>::new(); let mut aux_pe_buf = Vec::>::new(); output.chunks_mut(points_per_task).enumerate().for_each(|(g, big_slice)| { - eval_big_slice(&mut main_buf, &mut aux_base_buf, &mut aux_pe_buf, g, big_slice); + eval_big_slice( + &mut main_buf, + &mut preproc_buf, + &mut aux_base_buf, + &mut aux_pe_buf, + g, + big_slice, + ); }); } } diff --git a/stark/miden-lifted-stark/src/prover/constraints/packed_row_bitrev.rs b/stark/miden-lifted-stark/src/prover/constraints/packed_row_bitrev.rs index 20bda82d2d..eab271ff9b 100644 --- a/stark/miden-lifted-stark/src/prover/constraints/packed_row_bitrev.rs +++ b/stark/miden-lifted-stark/src/prover/constraints/packed_row_bitrev.rs @@ -8,87 +8,38 @@ use p3_field::PackedValue; use p3_matrix::dense::RowMajorMatrixView; use p3_util::{log2_strict_usize, reverse_bits_len}; -/// Collect logical vertically packed rows from bit-reversed row-major storage into a reusable -/// buffer. -pub trait RowMajorMatrixBitrevPackedExt { - /// One logical row block starting at logical row index `i_start` (a multiple of `P::WIDTH`). - #[expect(dead_code)] - fn collect_vertically_packed_row_bitrev_into>( - &self, - i_start: usize, - out: &mut Vec

, - ); - - /// Two logical row blocks: rows `i_start` and `i_start + step` (mod height), packed like - /// `Matrix::vertically_packed_row_pair`. - fn collect_vertically_packed_row_pair_bitrev_into>( - &self, - i_start: usize, - step: usize, - out: &mut Vec

, - ); -} - -impl<'a, F: Copy> RowMajorMatrixBitrevPackedExt for RowMajorMatrixView<'a, F> { - fn collect_vertically_packed_row_bitrev_into>( - &self, - i_start: usize, - out: &mut Vec

, - ) { - let values = self.values; - let width = self.width; - let height = values.len() / width; - let log_h = log2_strict_usize(height); - debug_assert_eq!(1usize << log_h, height); - - const MAX_WIDTH: usize = 16; - const { - debug_assert!(P::WIDTH <= MAX_WIDTH); - } - - let mut cur_off = [0usize; MAX_WIDTH]; - for (lane_idx, lane) in cur_off.iter_mut().enumerate().take(P::WIDTH) { - *lane = reverse_bits_len((i_start + lane_idx) % height, log_h) * width; - } - - out.clear(); - out.reserve(width); - for c in 0..width { - out.push(P::from_fn(|lane| values[cur_off[lane] + c])); - } +/// Two logical row blocks: rows `i_start` and `i_start + step` (mod height), packed like +/// `Matrix::vertically_packed_row_pair`. +pub(super) fn collect_vertically_packed_row_pair_bitrev_into>( + matrix: &RowMajorMatrixView<'_, F>, + i_start: usize, + step: usize, + out: &mut Vec

, +) { + let values = matrix.values; + let width = matrix.width; + let height = values.len() / width; + let log_h = log2_strict_usize(height); + debug_assert_eq!(1usize << log_h, height); + + const MAX_WIDTH: usize = 16; + const { + debug_assert!(P::WIDTH <= MAX_WIDTH); } - fn collect_vertically_packed_row_pair_bitrev_into>( - &self, - i_start: usize, - step: usize, - out: &mut Vec

, - ) { - let values = self.values; - let width = self.width; - let height = values.len() / width; - let log_h = log2_strict_usize(height); - debug_assert_eq!(1usize << log_h, height); - - const MAX_WIDTH: usize = 16; - const { - debug_assert!(P::WIDTH <= MAX_WIDTH); - } - - let mut cur_off = [0usize; MAX_WIDTH]; - let mut nxt_off = [0usize; MAX_WIDTH]; - for lane in 0..P::WIDTH { - cur_off[lane] = reverse_bits_len((i_start + lane) % height, log_h) * width; - nxt_off[lane] = reverse_bits_len((i_start + step + lane) % height, log_h) * width; - } + let mut cur_off = [0usize; MAX_WIDTH]; + let mut nxt_off = [0usize; MAX_WIDTH]; + for lane in 0..P::WIDTH { + cur_off[lane] = reverse_bits_len((i_start + lane) % height, log_h) * width; + nxt_off[lane] = reverse_bits_len((i_start + step + lane) % height, log_h) * width; + } - out.clear(); - out.reserve(2 * width); - for c in 0..width { - out.push(P::from_fn(|lane| values[cur_off[lane] + c])); - } - for c in 0..width { - out.push(P::from_fn(|lane| values[nxt_off[lane] + c])); - } + out.clear(); + out.reserve(2 * width); + for c in 0..width { + out.push(P::from_fn(|lane| values[cur_off[lane] + c])); + } + for c in 0..width { + out.push(P::from_fn(|lane| values[nxt_off[lane] + c])); } } diff --git a/stark/miden-lifted-stark/src/prover/mod.rs b/stark/miden-lifted-stark/src/prover/mod.rs index 57ffb65118..a6af3f9424 100644 --- a/stark/miden-lifted-stark/src/prover/mod.rs +++ b/stark/miden-lifted-stark/src/prover/mod.rs @@ -1,102 +1,87 @@ //! Lifted STARK prover. //! //! This module provides: -//! - [`prove_single`]: Prove a single AIR instance. -//! - [`prove_multi`]: Prove multiple AIR instances with traces of different heights. +//! - [`ProverInstance::prove`](crate::ProverInstance::prove): Prove one or more AIR instances with +//! traces of (possibly) different heights. //! -//! These functions write the proof into a [`miden_stark_transcript::ProverChannel`] -//! (commitments, grinding witnesses, and openings). +//! [`ProverInstance::prove`](crate::ProverInstance::prove) writes the proof into a +//! [`miden_stark_transcript::ProverChannel`] (commitments, grinding witnesses, +//! and openings). //! //! # Fiat-Shamir / transcript binding (initial challenger state) //! //! This crate does **not** prescribe the *initial* transcript state. The caller -//! must bind the full statement into the Fiat-Shamir challenger before calling -//! [`prove_multi`]. Both prover and verifier must produce identical challenger -//! states. Concretely, the caller **MUST** observe: +//! must bind protocol and AIR configuration data before calling +//! [`ProverInstance::prove`](crate::ProverInstance::prove). Both prover and verifier must produce +//! identical challenger states. Concretely, the caller **MUST** observe: //! //! 1. **Protocol parameters** — e.g. the STARK configuration, blowup factor, and any //! application-level domain separator. //! -//! 2. **Public values and variable-length inputs** — `public_values` and `var_len_public_inputs` -//! for every instance. Without this, Fiat-Shamir challenges are independent of the statement. +//! 2. **AIR configurations** — The framework does not commit to the [`MultiAir::airs`] list. The +//! caller MUST bind every AIR configuration into the challenger before calling +//! [`ProverInstance::prove`](crate::ProverInstance::prove) / +//! [`VerifierInstance::verify`](crate::VerifierInstance::verify). The AIR ordering on the wire +//! is derived deterministically from the trace heights (stable sort on `(log_trace_height, +//! instance_index)`), so callers do not need to commit to it separately as long as they commit +//! to the AIR list and trace heights match. //! -//! 3. **AIR configurations and `air_order`** — The proof defines an ordering of AIR instances -//! (`air_order()[j]` is the caller's original index at proof position `j`), queryable via -//! [`InstanceShapes::air_order`]. The ordering is deterministic: instances are sorted by -//! `(log_trace_height, caller_index)`. Neither the AIR configurations nor `air_order` are -//! absorbed into the transcript, so the caller must bind both into the challenger. How this is -//! done is up to the caller — see the examples below. The prover can precompute `air_order` via -//! [`InstanceShapes::from_trace_heights`]; the verifier reads it from the proof. +//! The statement's `air_inputs` and `aux_inputs` are absorbed automatically by +//! [`Statement::observe`](crate::air::Statement::observe), followed by protocol-level +//! absorption of the instance count and each AIR's log trace height in instance +//! order. Callers do not bind these themselves. //! //! ## Recommended pattern //! //! Pre-seed the challenger so statement data stays out of the proof: //! //! ```ignore -//! // --- Bind statement into Fiat-Shamir --- +//! // --- Bind protocol parameters + AIR configurations into Fiat-Shamir --- //! let mut ch = Challenger::new(perm.clone()); //! ch.observe_slice(&b"MY_APP_V1".map(|b| F::from_u8(b))); // domain separator //! ch.observe(F::from_u8(config.pcs().log_blowup())); // protocol parameters -//! // ... observe remaining protocol parameters ... -//! ch.observe_slice(&public_values); -//! for vl in &var_len_public_inputs { -//! ch.observe_slice(vl); -//! } -//! // For multi-AIR: bind AIR configurations and air_order (see below). +//! // ... bind AIR configurations + air ordering (see below) ... //! //! // --- Prove --- -//! let output = prove_multi(&config, &instances, ch)?; +//! let prover_instance = ProverInstance::new(&config, &prover_statement, None)?; +//! let output = prover_instance.prove(ch)?; //! -//! // --- Verify (identical binding) --- +//! // --- Verify (identical binding + the same statement) --- //! let mut ch = Challenger::new(perm); //! ch.observe_slice(&b"MY_APP_V1".map(|b| F::from_u8(b))); //! ch.observe(F::from_u8(config.pcs().log_blowup())); -//! // ... observe remaining protocol parameters ... -//! ch.observe_slice(&public_values); -//! for vl in &var_len_public_inputs { -//! ch.observe_slice(vl); -//! } -//! let verifier_digest = verify_multi(&config, &verifier_instances, &output.proof, ch)?; +//! let verifier_instance = VerifierInstance::new(&config, prover_instance.statement(), None)?; +//! let verifier_digest = verifier_instance.verify(&output.proof, ch)?; //! assert_eq!(output.digest, verifier_digest); //! ``` //! -//! ## Multi-AIR binding examples +//! ## Multi-AIR binding example //! //! ```text -//! // Prover: precompute air_order before building the challenger. -//! let shapes = InstanceShapes::from_trace_heights(trace_heights)?; -//! let air_order = shapes.air_order(); -//! -//! // Verifier: read air_order from the proof. -//! let air_order = proof.air_order(); -//! -//! // Option A: reorder AIRs to proof order and commit — the ordering is -//! // implicit in the commitment. -//! let ordered_airs: Vec<_> = air_order.iter().map(|&idx| &airs[idx as usize]).collect(); -//! let circuit = Circuit::from_airs(&ordered_airs); -//! challenger.observe(circuit.commitment()); -//! -//! // Option B: commit to AIRs in their natural order, then observe -//! // air_order to bind the ordering explicitly. -//! for air in &airs { +//! // Commit to AIRs in instance order — the proof's wire-format ordering is +//! // derived from the heights inside the framework, so binding the instance +//! // order is enough. +//! for air in statement.airs() { //! challenger.observe(air.commitment()); //! } -//! challenger.observe_slice(air_order); //! ``` extern crate alloc; -pub mod commit; -pub mod constraints; -pub mod periodic; -pub mod quotient; +pub(crate) mod commit; +pub(crate) mod constraints; +pub(crate) mod periodic; +pub(crate) mod quotient; -use alloc::{vec, vec::Vec}; +use alloc::vec::Vec; use commit::commit_traces; use constraints::{evaluate_constraints_into, layout::get_constraint_layout}; -use miden_lifted_air::{AuxBuilder, LiftedAir, VarLenPublicInputs, log2_strict_u8}; +use miden_lifted_air::{ + InstanceError, LiftedAir, MultiAir, ProverStatement, ReductionError, Statement, +}; use miden_stark_transcript::{Channel, ProverChannel, ProverTranscript}; +use p3_challenger::CanObserve; use p3_field::{BasedVectorSpace, ExtensionField, TwoAdicField}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; use periodic::PeriodicLde; @@ -105,180 +90,308 @@ use tracing::{info_span, instrument}; use crate::{ StarkConfig, - coset::LiftedCoset, - instance::{AirWitness, InstanceShapes, InstanceValidationError, validate_inputs}, + domain::{Coset, DomainError, LiftedDomain, log_quotient_degree}, + lmcs::Lmcs, + order::TraceOrder, pcs::prover::open_with_channel, - proof::{StarkOutput, StarkProof}, + preprocessed::{Preprocessed, PreprocessedValidationError, validate_preprocessed}, + proof::{StarkOutput, StarkProofData}, }; -/// Errors that can occur during proving. -#[derive(Debug, Error)] -pub enum ProverError { - #[error("instance validation failed: {0}")] - Instance(#[from] InstanceValidationError), - #[error( - "constraint degree exceeds blowup: \ - log_quotient_degree {log_quotient_degree} > log_blowup {log_blowup}" - )] - ConstraintDegreeTooHigh { log_quotient_degree: u8, log_blowup: u8 }, -} +// ============================================================================ +// ProverInstance +// ============================================================================ -/// Prove a single AIR. -/// -/// The caller's challenger must already be bound to the full statement -/// (protocol parameters, AIR configuration, public values, and -/// variable-length inputs) — see the module-level docs. -/// -/// This is a convenience wrapper around [`prove_multi`] for the single-AIR case. +/// Prover-side bundle: a [`StarkConfig`], a borrowed [`ProverStatement`], and +/// the optional borrowed [`Preprocessed`] data. /// -/// # Returns -/// `Ok(StarkOutput { digest, proof })` on success, or a `ProverError` if validation fails. -pub fn prove_single( - config: &SC, - air: &A, - trace: &RowMajorMatrix, - public_values: &[F], - var_len_public_inputs: VarLenPublicInputs<'_, F>, - aux_builder: &B, - challenger: SC::Challenger, -) -> Result, ProverError> +/// Construction validates preprocessed presence parity (and, when present, the +/// bundle's shape against the AIRs and STARK config), so holding a +/// `ProverInstance` is a guarantee its preprocessed shape is consistent; +/// proving never re-checks it. +pub struct ProverInstance<'a, F, EF, MA, SC> where F: TwoAdicField, EF: ExtensionField, + MA: MultiAir, SC: StarkConfig, - A: LiftedAir, - B: AuxBuilder, { - let witness = AirWitness::new(trace, public_values, var_len_public_inputs); - prove_multi(config, &[(air, witness, aux_builder)], challenger) + config: &'a SC, + prover_statement: &'a ProverStatement, + preprocessed: Option<&'a Preprocessed>, } -/// Prove multiple AIRs with traces of different heights. +impl<'a, F, EF, MA, SC> ProverInstance<'a, F, EF, MA, SC> +where + F: TwoAdicField, + EF: ExtensionField, + MA: MultiAir, + SC: StarkConfig, +{ + /// Bundle a config + prover statement with an optional preprocessed bundle. + /// + /// `preprocessed` must be `Some` exactly when some AIR declares preprocessed + /// columns; otherwise this errors with + /// [`PreprocessedValidationError::PresenceMismatch`]. When both hold, the + /// bundle's raw and committed shapes are validated against the AIRs and + /// config. + pub fn new( + config: &'a SC, + prover_statement: &'a ProverStatement, + preprocessed: Option<&'a Preprocessed>, + ) -> Result { + let expected = + prover_statement.statement().airs().iter().any(|a| a.preprocessed_width() > 0); + let actual = preprocessed.is_some(); + if expected != actual { + return Err(PreprocessedValidationError::PresenceMismatch { expected, actual }); + } + if let Some(p) = preprocessed { + validate_preprocessed(config, prover_statement, p)?; + } + Ok(Self { config, prover_statement, preprocessed }) + } + + /// Prove this instance. + pub fn prove(&self, challenger: SC::Challenger) -> Result, ProverError> { + prove(self, challenger) + } + + /// Borrow the STARK configuration. + pub fn config(&self) -> &SC { + self.config + } + + /// Borrow the wrapped air-crate prover statement. + pub fn prover_statement(&self) -> &ProverStatement { + self.prover_statement + } + + /// Borrow the verifier-side statement (the AIRs + public inputs). + pub fn statement(&self) -> &Statement { + self.prover_statement.statement() + } + + /// Commitment to the preprocessed tree, for the verifier's + /// [`VerifierInstance::new`](crate::VerifierInstance::new); `None` when there + /// is none. + pub fn preprocessed_commitment(&self) -> Option<::Commitment> { + self.preprocessed.map(Preprocessed::commitment) + } + + /// Borrow the preprocessed bundle, if any. + pub(crate) fn preprocessed(&self) -> Option<&Preprocessed> { + self.preprocessed + } +} + +/// Prove a [`ProverInstance`]. +/// +/// The caller's challenger must already be bound to protocol parameters and +/// AIR configurations — see the module-level docs. The statement's `air_inputs` +/// and `aux_inputs` are absorbed internally via +/// [`Statement::observe`](crate::air::Statement::observe); both prover and verifier +/// must carry the same statement. +/// +/// # Trust contract +/// +/// `prove` validates ONLY untrusted runtime inputs and returns typed +/// [`ProverError`] on failure. AIR structural correctness is the +/// implementer's contract — call [`crate::debug::assert_prover_setup`] +/// from your test harness to enforce it in debug builds. /// -/// The caller's challenger must already be bound to the full statement -/// (protocol parameters, AIR configurations, AIR ordering, and public -/// inputs — both fixed and variable-length) — see the module-level docs. +/// ## Validated +/// - per AIR: `air.num_public_values() == statement.air_inputs().len()` +/// - `statement.aux_inputs().len() <= multi_air.max_aux_inputs()` +/// - `prover_statement.traces().len() == statement.airs().len()` and `<= u8::MAX + 1` +/// - per AIR: `trace.width() == air.width()` +/// - per AIR: `trace.height().is_power_of_two()` +/// - per AIR: `trace.height() >= max periodic column length` +/// - per AIR: `log_quotient_degree(air) <= config.pcs().log_blowup()` +/// - LDE domain fits the field's two-adicity (via `LiftedDomain::try_canonical`) +/// - Preprocessed bundle presence, shape, and config-dependent LDE dimensions via `ProverInstance` +/// construction +/// +/// ## Trusted (NOT validated) +/// - AIR structural shape (positive `aux_width`, power-of-two periodic columns, window size 2) +/// - [`crate::air::ProverStatement::build_aux_traces`] output dimensions — a malformed output is +/// caught by the LDE/commit (panic) or by verification, since the verifier re-derives these +/// shapes. /// /// # Arguments -/// - `config`: STARK configuration (PCS params, LMCS, DFT) -/// - `instances`: Pairs of (AIR, witness, aux_builder) -/// - `challenger`: Fiat-Shamir challenger (heights are observed before use) +/// - `instance`: the config, the validated prover statement (AIRs, shared `air_inputs`, and per-AIR +/// traces, all in instance order), and the optional preprocessed bundle +/// - `challenger`: Fiat-Shamir challenger pre-bound to protocol parameters and AIR configurations /// /// # Returns -/// `Ok(StarkOutput { digest, proof })` on success, or a `ProverError` if validation fails. +/// `Ok(StarkOutput { digest, proof })`, or a [`ProverError`] if validation fails. #[instrument(name = "prove", skip_all)] -pub fn prove_multi( - config: &SC, - instances: &[(&A, AirWitness<'_, F>, &B)], +pub(crate) fn prove( + instance: &ProverInstance<'_, F, EF, MA, SC>, mut challenger: SC::Challenger, ) -> Result, ProverError> where F: TwoAdicField, EF: ExtensionField, SC: StarkConfig, - A: LiftedAir, - B: AuxBuilder, + MA: MultiAir, { - let trace_heights: Vec = instances.iter().map(|(_, w, _)| w.trace.height()).collect(); - let instance_shapes = InstanceShapes::from_trace_heights(trace_heights)?; - - // Reorder instances to the proof's AIR ordering. - let instances = instance_shapes.reorder(instances.to_vec())?; - - let verifier_instances: Vec<_> = - instances.iter().map(|(air, w, _)| (*air, w.to_instance())).collect(); + // --- Trust boundary (see doc-block above). ------------------------------- + let config = instance.config(); + let prover_statement = instance.prover_statement(); + let preprocessed = instance.preprocessed(); + let statement = prover_statement.statement(); + let airs = statement.airs(); + let air_inputs = statement.air_inputs(); + let traces = prover_statement.traces(); + let trace_heights: Vec = traces.iter().map(Matrix::height).collect(); + let trace_order = TraceOrder::from_trace_heights::(airs, &trace_heights) + .expect("ProverStatement::new should reject malformed heights"); + + // Map each AIR to its preprocessed trace index. Only used when a preprocessed + // tree is present. + let air_to_preprocessed_trace = if preprocessed.is_some() { + trace_order.preprocessed_trace_index_for_air::(airs) + } else { + Vec::new() + }; + + // Borrow each AIR and trace, then reorder both into ascending-height (proof) + // order. AIRs are passed as `&MA::Air` (the existing constraint code expects + // a reference); traces likewise via `&RowMajorMatrix`. + let air_refs: Vec<&MA::Air> = airs.iter().collect(); + let trace_refs: Vec<&RowMajorMatrix> = traces.iter().collect(); + let proof_ordered_airs = trace_order.to_proof_order(&air_refs); + let proof_ordered_traces = trace_order.to_proof_order(&trace_refs); + let proof_ordered: Vec<_> = proof_ordered_airs + .iter() + .copied() + .zip(proof_ordered_traces.iter().copied()) + .collect(); let log_blowup = config.pcs().log_blowup(); + let log_max_trace_height = trace_order.max_log_height(); + let max_lde_domain = LiftedDomain::::try_canonical(log_max_trace_height, log_blowup)?; + let instance_domains: Vec<_> = trace_order + .log_heights_proof() + .iter() + .map(|&log_h| max_lde_domain.try_sub_domain(log_h)) + .collect::>()?; - // Validate AIR structure, instance dimensions, heights, and trace widths. - let log_max_trace_height = validate_inputs(&verifier_instances, &instance_shapes, log_blowup)?; - for &(air, w, _) in &instances { - if w.trace.width() != air.width() { - return Err(InstanceValidationError::WidthMismatch { - expected: air.width(), - actual: w.trace.width(), - } - .into()); - } + // Observe the preprocessed commitment first (when present); it is part of + // the statement and binds Fiat-Shamir before any other instance data. + if let Some(preprocessed) = preprocessed { + challenger.observe(preprocessed.commitment()); } - // Observe shape metadata before creating the transcript. - instance_shapes.observe_heights::(&mut challenger); + // `Statement::observe` absorbs statement-owned inputs. The protocol then + // binds the instance count and each log trace height in instance order. + statement.observe(&mut challenger, trace_order.log_heights()); + trace_order.observe_shape::(&mut challenger); let mut channel = ProverTranscript::new(challenger); - // Clear the challenger's absorb buffer after observing instance shapes by - // squeezing a throwaway extension element. This guarantees later sampled - // challenges depend on all prior inputs regardless of sponge state. - let _instance_challenge: EF = channel.sample_algebra_element::(); - - // Infer constraint degree from symbolic AIR analysis (max across all AIRs) - let log_constraint_degree = - instances.iter().map(|(air, ..)| air.log_quotient_degree()).max().unwrap_or(1) as u8; - - if log_constraint_degree > log_blowup { - return Err(ProverError::ConstraintDegreeTooHigh { - log_quotient_degree: log_constraint_degree, + // Infer per-AIR quotient degrees from symbolic analysis (per-AIR optimization). + let log_quotient_degrees: Vec = proof_ordered + .iter() + .map(|&(air, _)| log_quotient_degree::(air)) + .collect(); + let log_quotient_degree = log_quotient_degrees + .iter() + .copied() + .max() + .expect("TraceOrder construction rejects empty AIR sets"); + if log_quotient_degree > log_blowup { + return Err(DomainError::ConstraintDegreeTooHigh { + log_quotient: log_quotient_degree, log_blowup, - }); + } + .into()); } - let log_lde_height = log_max_trace_height + log_blowup; + // Pair the max LDE domain with the quotient degree. The `EvaluationDomain` + // now flows through the constraint and quotient layers. Per-instance variants + // are pre-built so the constraint loop just indexes into a `Vec`. + let max_eval_domain = max_lde_domain.evaluation_domain(log_quotient_degree); + let instance_eval_domains: Vec<_> = instance_domains + .iter() + .map(|d| d.evaluation_domain(log_quotient_degree)) + .collect(); - // Max LDE coset (for the largest trace, no lifting) - let max_lde_coset = LiftedCoset::unlifted(log_max_trace_height, log_blowup); - let max_quotient_coset = max_lde_coset.quotient_domain(log_constraint_degree); - let max_quotient_height = max_quotient_coset.lde_height(); + // Quotient evaluation coset: a sub-coset of the LDE coset where Q is evaluated + // before being decomposed into D chunks and committed on the LDE coset itself. + let max_quotient_height = max_eval_domain.size(); // 1. Commit all main traces (trace order — ascending height). // // Clone with blowup × capacity so the DFT resize doesn't reallocate. let blowup = 1 << log_blowup as usize; - let main_traces: Vec<_> = instances + let main_traces: Vec<_> = proof_ordered .iter() - .map(|(_, w, _)| { - let src = &w.trace.values; + .map(|&(_, trace)| { + let src = &trace.values; let mut values = Vec::with_capacity(src.len() * blowup); values.extend_from_slice(src); - RowMajorMatrix::new(values, w.trace.width()) + RowMajorMatrix::new(values, trace.width()) }) .collect(); - let main_committed = - info_span!("commit to main traces").in_scope(|| commit_traces(config, main_traces)); + let main_committed = info_span!("commit to main traces") + .in_scope(|| commit_traces(config, &instance_domains, main_traces)); channel.send_commitment(main_committed.root()); - // 2. Sample randomness and build aux traces for all AIRs + // 2. Sample randomness, build aux traces, and commit them let max_num_randomness = - instances.iter().map(|(air, ..)| air.num_randomness()).max().unwrap_or(0); + proof_ordered.iter().map(|&(air, _)| air.num_randomness()).max().unwrap_or(0); let randomness: Vec = (0..max_num_randomness) .map(|_| channel.sample_algebra_element::()) .collect(); - // Build aux traces via AuxBuilder - let (aux_traces_ef, all_aux_values): (Vec>, Vec>) = - info_span!("build aux traces").in_scope(|| { - let mut traces = Vec::with_capacity(instances.len()); - let mut values = Vec::with_capacity(instances.len()); - for (air, w, aux_builder) in &instances { - let num_rand = air.num_randomness(); - let (aux, aux_vals) = aux_builder.build_aux_trace(w.trace, &randomness[..num_rand]); - - assert_eq!(aux.width(), air.aux_width(), "aux trace width mismatch"); - assert_eq!( - aux_vals.len(), - air.num_aux_values(), - "aux values length mismatch: build_aux_trace returned {} values, \ - but num_aux_values() is {}", - aux_vals.len(), - air.num_aux_values() - ); - assert_eq!(aux.height(), w.trace.height()); - traces.push(aux); - values.push(aux_vals); - } - (traces, values) - }); + // Build aux traces in instance order. The output shapes are trusted (see + // trust contract above); a malformed output is caught downstream by the + // LDE/commit or by verification. + let (mut aux_traces_ef, mut all_aux_values) = info_span!("build aux traces").in_scope(|| { + let mut aux_traces = Vec::with_capacity(airs.len()); + let mut aux_values = Vec::with_capacity(airs.len()); + for (air, main) in airs.iter().zip(traces.iter()) { + let num_randomness = air.num_randomness(); + debug_assert!( + randomness.len() >= num_randomness, + "AIR requested more aux randomness than the shared challenge pool contains", + ); + let (trace, values) = air.build_aux_trace( + main, + air_inputs, + statement.aux_inputs(), + &randomness[..num_randomness], + ); + debug_assert_eq!(trace.height(), main.height(), "aux trace height mismatch"); + debug_assert_eq!(trace.width(), air.aux_width(), "aux trace width mismatch"); + debug_assert_eq!(values.len(), air.num_aux_values(), "aux values length mismatch"); + aux_traces.push(trace); + aux_values.push(values); + } + (aux_traces, aux_values) + }); + + // Mirror the verifier's external assertion evaluation while aux values are + // still in instance order. This is cheap and catches malformed statements + // early; it could become a debug assertion if proving needs to skip this + // verifier-side sanity check. + let aux_views: Vec<&[EF]> = all_aux_values.iter().map(Vec::as_slice).collect(); + let assertions = statement + .eval_external(&randomness, &aux_views, trace_order.log_heights()) + .map_err(ProverError::Reduction)?; + for (k, assertion) in assertions.iter().enumerate() { + if *assertion != EF::ZERO { + return Err(ProverError::ExternalAssertionFailed { assertion: k }); + } + } + + // External assertions are defined in instance-order terms; now reorder aux + // traces and aux values to proof order for commitment and the prover loop. + trace_order.reorder_to_proof_in_place(&mut aux_traces_ef); + trace_order.reorder_to_proof_in_place(&mut all_aux_values); // Flatten EF -> F and commit aux traces let aux_traces: Vec> = aux_traces_ef @@ -290,8 +403,8 @@ where }) .collect(); - let aux_committed = - info_span!("commit to aux traces").in_scope(|| commit_traces(config, aux_traces)); + let aux_committed = info_span!("commit to aux traces") + .in_scope(|| commit_traces(config, &instance_domains, aux_traces)); channel.send_commitment(aux_committed.root()); // Observe aux values into the transcript (binds to Fiat-Shamir state). @@ -302,109 +415,148 @@ where } } - // 4. Sample constraint folding alpha and accumulation beta + // 3. Sample constraint folding alpha and accumulation beta let alpha: EF = channel.sample_algebra_element::(); let beta: EF = channel.sample_algebra_element::(); - // 5. Evaluate constraints and accumulate with beta folding. + // 4. Evaluate constraints and accumulate quotient evaluations with beta folding. // - // Single accumulator, processed in trace order (ascending height): - // 1. Cyclically extend accumulator to the next quotient height - // 2. Multiply every element by beta (Horner) - // 3. Add constraint evaluations in-place: acc[i] += eval(i) + // Per AIR (ascending height): + // 1. Evaluate Q_j = (alpha-folded constraints) / Z_{H_j} on the native quotient domain + // (divide fused into the eval write). + // 2. If D_j < D_max, upsample Q_j to the per-trace target domain. + // 3. Cyclically extend the accumulator and Horner-fold: acc <- acc * beta + Q_j. // // Pre-allocate with LDE capacity so commit_quotient's resize doesn't reallocate. - let constraint_degree = 1 << log_constraint_degree as usize; let mut accumulator: Vec = Vec::with_capacity(max_quotient_height * blowup); - // Pre-compute constraint layouts for each AIR (base/ext index mapping) - let layouts: Vec<_> = instances + // Pre-compute per-AIR constraint layouts. + let layouts: Vec<_> = proof_ordered .iter() - .map(|(air, ..)| get_constraint_layout::(*air)) + .map(|&(air, _)| get_constraint_layout::(air)) .collect(); info_span!("evaluate constraints").in_scope(|| { - for (i, (air, w, _)) in instances.iter().enumerate() { - let trace_height = w.trace.height(); - let log_trace_height = log2_strict_u8(trace_height); - - // Create LiftedCoset for this trace (may be lifted relative to max) - let this_lde_coset = - LiftedCoset::new(log_trace_height, log_blowup, log_max_trace_height); - let this_quotient_coset = this_lde_coset.quotient_domain(log_constraint_degree); - let this_quotient_height = this_quotient_coset.lde_height(); - - // Truncate the committed LDE to the quotient evaluation domain gJ (size N·D). - // Since B ≥ D, the committed LDE on gK (size N·B) contains gJ as a prefix in + for (i, &(air, _)) in proof_ordered.iter().enumerate() { + let this_log_quotient_degree = log_quotient_degrees[i]; + let this_quotient_degree = 1usize << this_log_quotient_degree; + + // Per-AIR native quotient evaluation domain `gJ_j` (size n_j · D_j, + // before upsampling to n_j · D_max). + let this_quotient_eval_domain = + instance_domains[i].evaluation_domain(this_log_quotient_degree); + // Target after upsample to D_max (size n_j · D_max). + let this_target_quotient_height = instance_eval_domains[i].size(); + + // Truncate the committed LDE to the AIR's native quotient evaluation domain gJ_j. + // Since B >= D_j, the committed LDE on gK (size N*B) contains gJ_j as a prefix in // bit-reversed storage, so this is a zero-copy view. - let main_on_gj = main_committed.evals_on_quotient_domain(i, constraint_degree); - let aux_on_gj = aux_committed.evals_on_quotient_domain(i, constraint_degree); + let main_on_gj = main_committed.evals_on_quotient_domain(i, &this_quotient_eval_domain); + let aux_on_gj = aux_committed.evals_on_quotient_domain(i, &this_quotient_eval_domain); + + // Preprocessed view: fetched only when this AIR declares preprocessed + // columns. Resolve the committed trace via the inverse mapping + // `air_to_preprocessed_trace[instance_idx]`; it shares the max-LDE coset with + // the main trace, so `evals_on_quotient_domain` truncates it to this + // AIR's quotient domain the same way. + let preproc_on_gj = preprocessed.and_then(|p| { + let instance_idx = trace_order.instance_indices()[i] as usize; + air_to_preprocessed_trace[instance_idx].map(|preprocessed_trace_idx| { + p.committed().evals_on_quotient_domain( + preprocessed_trace_idx, + &this_quotient_eval_domain, + ) + }) + }); - // Build periodic LDE for this trace via coset method let periodic_lde = - PeriodicLde::build(&this_quotient_coset, air.periodic_columns_matrix()); + PeriodicLde::build(&this_quotient_eval_domain, air.periodic_columns_matrix()); + + let mut quotient_evals = EF::zero_vec(this_quotient_eval_domain.size()); + let aux_values_i = &all_aux_values[i]; + let inv_z_h = this_quotient_eval_domain.inv_vanishing_evals(); - // Cyclically extend accumulator to this quotient height and scale by beta. - // On the first iteration the accumulator is empty, so this is a no-op - // and evaluate_constraints_into writes into a zero-filled buffer. tracing::debug_span!( - "cyclic_extend", - acc_len = accumulator.len(), - target = this_quotient_height + "eval_instance", + instance = i, + native_height = this_quotient_eval_domain.size(), + target_height = this_target_quotient_height, + native_degree = this_quotient_degree, + target_degree = 1 << log_quotient_degree as usize, ) .in_scope(|| { - quotient::cyclic_extend_and_scale(&mut accumulator, this_quotient_height, beta); + evaluate_constraints_into::( + &mut quotient_evals, + air, + &main_on_gj, + preproc_on_gj.as_ref(), + &aux_on_gj, + &this_quotient_eval_domain, + alpha, + &randomness[..air.num_randomness()], + air_inputs, + &periodic_lde, + &layouts[i], + aux_values_i, + &inv_z_h, + ); }); - let aux_values_i = &all_aux_values[i]; + if this_log_quotient_degree < log_quotient_degree { + let added_bits = (log_quotient_degree - this_log_quotient_degree) as usize; + quotient_evals = tracing::debug_span!( + "upsample_quotient", + instance = i, + from = this_quotient_eval_domain.size(), + to = this_target_quotient_height, + ) + .in_scope(|| { + quotient::upsample_evals::(config.dft(), quotient_evals, added_bits) + }); + } - // Add constraint evaluations in-place: accumulator[i] += eval(i) - info_span!("eval_instance", instance = i, height = this_quotient_height).in_scope( - || { - evaluate_constraints_into::( - &mut accumulator, - *air, - &main_on_gj, - &aux_on_gj, - &this_quotient_coset, - alpha, - &randomness[..air.num_randomness()], - w.public_values, - &periodic_lde, - &layouts[i], - aux_values_i, - ); - }, - ); + debug_assert_eq!(quotient_evals.len(), this_target_quotient_height); + + // Cyclically extend the running accumulator to the per-AIR target height and + // Horner-fold this AIR's contribution in: acc <- acc * beta + Q_j. + tracing::debug_span!( + "cyclic_extend_and_accumulate", + acc_len = accumulator.len(), + target = this_target_quotient_height + ) + .in_scope(|| { + quotient::cyclic_extend_and_accumulate(&mut accumulator, quotient_evals, beta); + }); } }); - // Verify we have the expected size (max quotient domain) - assert_eq!(accumulator.len(), max_quotient_height); - - // 6. Divide by vanishing polynomial once on full gJ (in-place) - tracing::debug_span!("divide_by_vanishing", height = max_quotient_height).in_scope(|| { - quotient::divide_by_vanishing_in_place::(&mut accumulator, &max_quotient_coset); - }); + debug_assert_eq!(accumulator.len(), max_quotient_height); - // 7. Commit quotient + // 5. Commit quotient. let quotient_committed = info_span!("commit to quotient poly chunks") - .in_scope(|| quotient::commit_quotient(config, accumulator, &max_lde_coset)); + .in_scope(|| quotient::commit_quotient(config, accumulator, &max_eval_domain)); channel.send_commitment(quotient_committed.root()); - // 8. Sample OOD point (outside H and gK) - let z: EF = max_lde_coset.sample_ood_point(&mut channel); - let h = F::two_adic_generator(log_max_trace_height.into()); + // 6. Sample OOD point (outside H and gK) + let z: EF = max_lde_domain.sample_ood_point(&mut channel); + let h = max_lde_domain.trace_subgroup().generator(); let z_next = z * h; - // 9. Open via PCS - let trees = vec![main_committed.tree(), aux_committed.tree(), quotient_committed.tree()]; + // 7. Open via PCS. The prover and verifier use the same group order: + // `[preprocessed?, main, aux, quotient]`. + let mut trees = Vec::with_capacity(4); + if let Some(p) = preprocessed { + trees.push(p.committed().tree()); + } + trees.push(main_committed.tree()); + trees.push(aux_committed.tree()); + trees.push(quotient_committed.tree()); info_span!("open").in_scope(|| { open_with_channel::, _, 2>( config.pcs(), config.lmcs(), - log_lde_height, + &max_lde_domain, [z, z_next], &trees, &mut channel, @@ -412,6 +564,27 @@ where }); let (digest, transcript) = channel.finalize(); - let proof = StarkProof { instance_shapes, transcript }; + let proof = StarkProofData { + log_trace_heights: trace_order.log_heights().to_vec(), + transcript, + }; Ok(StarkOutput { digest, proof }) } + +/// Errors from proving — runtime validation failures of caller-supplied data. +/// The AIR's structural contract is trusted (see the crate-level trust model). +#[derive(Debug, Error)] +pub enum ProverError { + #[error(transparent)] + Instance(#[from] InstanceError), + #[error(transparent)] + Domain(#[from] DomainError), + #[error("external assertion evaluation failed: {0}")] + Reduction(ReductionError), + #[error("external assertion {assertion} is non-zero")] + ExternalAssertionFailed { + /// Index into the assertions vector returned by + /// [`Statement::eval_external`](miden_lifted_air::Statement::eval_external). + assertion: usize, + }, +} diff --git a/stark/miden-lifted-stark/src/prover/periodic.rs b/stark/miden-lifted-stark/src/prover/periodic.rs index d6ae065ff3..bd93f4e4d8 100644 --- a/stark/miden-lifted-stark/src/prover/periodic.rs +++ b/stark/miden-lifted-stark/src/prover/periodic.rs @@ -12,7 +12,7 @@ use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; use p3_field::{PackedValue, TwoAdicField}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; -use crate::coset::LiftedCoset; +use crate::domain::{Coset, EvaluationDomain}; /// Prover-side periodic LDE values for constraint evaluation. /// @@ -26,7 +26,7 @@ use crate::coset::LiftedCoset; /// we repeat each column up to `max_period` and LDE-extend once; columns with smaller periods /// are accessed via modular indexing. #[derive(Clone, Debug)] -pub struct PeriodicLde { +pub(super) struct PeriodicLde { /// LDE values in natural order (height = max_period * blowup). /// `None` when there are no periodic columns. ldes: Option>, @@ -40,13 +40,13 @@ impl PeriodicLde { /// Uses NaiveDft since periodic column periods are typically small. /// /// # Arguments - /// - `coset`: The lifted coset providing domain information + /// - `domain`: The evaluation domain (trace + LDE + constraint degree) /// - `repeated_matrix`: Periodic columns extended to a common height (max period), or `None` if /// there are no periodic columns /// /// # Panics /// Panics if the matrix height exceeds the trace height or is not a power of two. - pub fn build(coset: &LiftedCoset, repeated_matrix: Option>) -> Self { + pub fn build(domain: &EvaluationDomain, repeated_matrix: Option>) -> Self { let Some(repeated_matrix) = repeated_matrix else { return Self { ldes: None }; }; @@ -54,19 +54,19 @@ impl PeriodicLde { let max_period = repeated_matrix.height(); let log_max_period = log2_strict_u8(max_period); assert!( - coset.log_trace_height >= log_max_period, + domain.log_trace_height() >= log_max_period, "periodic column period ({max_period}) exceeds trace height ({})", - 1 << coset.log_trace_height as usize, + 1 << domain.log_trace_height() as usize, ); - let log_blowup = coset.log_blowup(); + let log_blowup = domain.log_quotient_degree() as usize; // Compute the coset shift for the max-period subgroup. // // Periodic polynomials are naturally defined on a subgroup of order `max_period`. - // We derive the corresponding coset shift by taking the lifted coset shift - // gʳ and mapping from trace height down to `max_period` via a power-of-two ratio. - let log_ratio = coset.log_trace_height - log_max_period; - let period_shift: F = coset.lde_shift::().exp_power_of_2(log_ratio as usize); + // The quotient evaluation coset shares the LDE coset's shift; map it down + // from trace height to `max_period` via a power-of-two ratio. + let log_ratio = domain.log_trace_height() - log_max_period; + let period_shift: F = domain.shift().exp_power_of_2(log_ratio as usize); // Compute LDE using NaiveDft (periods are small) let ldes = NaiveDft @@ -105,10 +105,13 @@ mod tests { use alloc::{vec, vec::Vec}; use p3_dft::TwoAdicSubgroupDft; - use p3_field::{Field, PackedValue, PrimeCharacteristicRing}; + use p3_field::{PackedValue, PrimeCharacteristicRing}; use super::*; - use crate::testing::configs::goldilocks_poseidon2 as gl; + use crate::{ + domain::LiftedDomain, + testing::{canonical_domain, configs::goldilocks_poseidon2 as gl}, + }; /// Verify that periodic LDE values match the full LDE computation. fn assert_periodic_lde_matches_full( @@ -119,8 +122,10 @@ mod tests { let trace_height = 1 << log_trace_height as usize; let lde_height = trace_height << log_blowup as usize; - // Create a coset at max height (no lifting) - let coset = LiftedCoset::unlifted(log_trace_height, log_blowup); + // Create an evaluation domain at max height (no lifting), with constraint + // degree = log_blowup (the max-degree case for this test). + let lifted: LiftedDomain = canonical_domain(log_trace_height, log_blowup); + let domain = lifted.evaluation_domain(log_blowup); // Build the repeated matrix (same logic as periodic_columns_matrix) let max_period = columns.iter().map(Vec::len).max().unwrap(); @@ -133,16 +138,18 @@ mod tests { } let repeated_matrix = RowMajorMatrix::new(values, num_cols); - let periodic_lde = PeriodicLde::build(&coset, Some(repeated_matrix)); + let periodic_lde = PeriodicLde::build(&domain, Some(repeated_matrix)); - // Compute expected LDE for each column via full expansion (natural order) + // Compute expected LDE for each column via full expansion (natural order). + // The PeriodicLde uses `domain.lde_shift()` internally — match it here. + let expected_shift = domain.shift(); let expected: Vec> = columns .iter() .map(|col| { let full: Vec = (0..trace_height).map(|i| col[i % col.len()]).collect(); let matrix = RowMajorMatrix::new(full, 1); NaiveDft - .coset_lde_batch(matrix, log_blowup.into(), gl::Felt::GENERATOR) + .coset_lde_batch(matrix, log_blowup.into(), expected_shift) .to_row_major_matrix() .values }) diff --git a/stark/miden-lifted-stark/src/prover/quotient.rs b/stark/miden-lifted-stark/src/prover/quotient.rs index e994b294d6..09e1961613 100644 --- a/stark/miden-lifted-stark/src/prover/quotient.rs +++ b/stark/miden-lifted-stark/src/prover/quotient.rs @@ -1,103 +1,113 @@ -//! Quotient polynomial helpers: accumulation, vanishing division, decomposition. +//! Quotient polynomial helpers used by the prover's per-AIR pipeline. //! -//! The prover orchestrates the quotient pipeline (loop over instances, accumulate, -//! divide, commit). This module provides the building blocks: -//! -//! - [`cyclic_extend_and_scale`]: Horner-style beta scaling + cyclic extension -//! - [`divide_by_vanishing_in_place`]: Divide by Z_H on the quotient evaluation domain -//! - [`commit_quotient`]: Decompose Q(gJ) into chunks and commit on gK +//! - [`upsample_evals`]: Low-degree extend coset evaluations onto a larger two-adic coset (the +//! constraint-degree axis: `D_j -> D_max`, same polynomial, denser evaluations). +//! - [`cyclic_extend_and_accumulate`]: Lift the running accumulator along the trace-height axis +//! (`n_j -> N`) by cyclic repetition, and Horner-fold a new AIR's contribution in via beta. +//! - [`commit_quotient`]: Decompose Q(gJ) into chunks and commit on gK. use alloc::{format, vec, vec::Vec}; use p3_dft::TwoAdicSubgroupDft; use p3_field::{ - BasedVectorSpace, ExtensionField, Field, TwoAdicField, batch_multiplicative_inverse, + BasedVectorSpace, ExtensionField, Field, TwoAdicField, par_add_scaled_slice_in_place, + par_scale_slice_in_place, }; use p3_matrix::dense::RowMajorMatrix; use p3_maybe_rayon::prelude::*; -use p3_util::log2_strict_usize; use tracing::info_span; use crate::{ StarkConfig, - coset::LiftedCoset, - lmcs::{Lmcs, bitrev::materialize_bitrev}, + domain::{Coset, EvaluationDomain}, + lmcs::Lmcs, prover::commit::Committed, + util::bitrev::materialize_bitrev, }; // ============================================================================ -// Accumulation +// Domain lifting and accumulation // ============================================================================ -/// Cyclically extend the accumulator to `target_len` and scale every element by `β`. +/// Low-degree extend coset evaluations onto a larger two-adic coset. /// -/// On the first call (empty accumulator) this simply zero-fills to `target_len`. -/// On subsequent calls it scales the existing buffer by `β` (Horner folding) -/// then doubles via `extend_from_within` until it reaches `target_len`. +/// Treats `evals` as evaluations of a polynomial `p` on a coset `g*H` of size +/// `evals.len()`, and returns evaluations of the same `p` on the coset `g*K` +/// of size `evals.len() << added_bits` (same shift `g`, larger two-adic +/// subgroup). /// -/// Both `accumulator.len()` and `target_len` must be powers of two, and -/// `target_len ≥ accumulator.len()`. +/// # Precondition /// -/// Cyclic extension is valid because H_small is a subgroup of H_big, so -/// evaluations repeat cyclically. The β scaling implements Horner folding for -/// multi-trace accumulation: `acc = acc·β + Nⱼ`. -pub fn cyclic_extend_and_scale(accumulator: &mut Vec, target_len: usize, beta: EF) { - if accumulator.is_empty() { - accumulator.resize(target_len, EF::ZERO); - } else { - // Horner: scale the smaller buffer by beta before upsampling - accumulator.par_iter_mut().for_each(|v| *v *= beta); - // Cyclic extension by repeated doubling (all sizes are powers of 2) - while accumulator.len() < target_len { - accumulator.extend_from_within(..); - } +/// `deg(p) < evals.len()`. If the input evaluations are of a polynomial whose +/// actual degree is `>= evals.len()`, this function silently returns evaluations +/// of a different polynomial (the unique degree-`< evals.len()` interpolant of +/// the input). The caller is responsible for ensuring the degree bound. +pub(crate) fn upsample_evals(dft: &DFT, evals: Vec, added_bits: usize) -> Vec +where + F: TwoAdicField, + EF: ExtensionField, + DFT: TwoAdicSubgroupDft, +{ + if added_bits == 0 { + return evals; } -} -// ============================================================================ -// Vanishing division -// ============================================================================ + dft.lde_algebra_batch(RowMajorMatrix::new_col(evals), added_bits).values +} -/// Divide quotient numerator by vanishing polynomial in-place (natural order). +/// Fold a new AIR's quotient contribution into the cross-AIR accumulator via one +/// Horner step, after lifting the prior accumulator onto the larger coset: /// -/// Replaces each `numerator[i]` with `numerator[i] / Z_H(xᵢ)` where -/// `Z_H(X) = Xᴺ − 1` and `N` is the trace height. +/// ```text +/// acc <- lift_r(acc) * beta + contribution, r = contribution.len() / acc.len() +/// ``` /// -/// This uses a periodicity trick: on the quotient evaluation coset `gJ` of size `N·D`, -/// the values `Z_H(x)` take only `D` distinct values, so we can batch-invert those `D` -/// values once and reuse them by modular indexing. +/// On return, `accumulator.len() == contribution.len()`. `contribution.len()` and +/// (when non-empty) `accumulator.len()` must be powers of two, with +/// `contribution.len() >= accumulator.len()`. /// -/// Note that here `coset.log_blowup()` is `log2(D)` because `coset` is the *quotient* -/// domain (blowup = constraint degree), not the PCS/FRI blowup `B`. -pub fn divide_by_vanishing_in_place(numerator: &mut [EF], coset: &LiftedCoset) -where - F: TwoAdicField, - EF: ExtensionField, -{ - // D = constraint degree. On the quotient coset, log_blowup() = log₂(D). - let log_blowup = coset.log_blowup(); - let num_distinct = 1 << log_blowup; - - // The D distinct values of Z_H on gJ: - // Z_H(g·ω_Jⁱ) = sᴺ·ω_Dⁱ − 1 where - // - s is the coset shift - // - ω_D is a D-th root of unity. - let shift: F = coset.lde_shift(); - let s_pow_n = shift.exp_power_of_2(coset.log_trace_height as usize); - let z_h_evals: Vec = F::two_adic_generator(log_blowup) - .powers() - .take(num_distinct) - .map(|x| s_pow_n * x - F::ONE) - .collect(); - - let inv_van = batch_multiplicative_inverse(&z_h_evals); - - // Parallel division using modular indexing for periodicity. - // Z_H has only num_distinct unique values on gJ; power-of-2 size - // lets us use bitmask: i & (num_distinct - 1) == i % num_distinct. - numerator.par_iter_mut().enumerate().for_each(|(i, n)| { - *n *= inv_van[i & (num_distinct - 1)]; - }); +/// # Why +/// +/// - **Scale before extend.** Horner-multiplying the smaller buffer is strictly less work than the +/// lifted one, and the lifted values are determined entirely by the smaller buffer via cyclic +/// repetition. +/// - **Cyclic repetition = polynomial lift on two-adic cosets.** In natural-order coset +/// evaluations, `extended[i] = original[i mod n_old]` realises the composition `P(X) -> P(X^r)`: +/// iterating `gJ` in natural order and raising to the `r`-th power cycles through `(gJ)^r` with +/// period `|(gJ)^r|`. +pub(super) fn cyclic_extend_and_accumulate( + accumulator: &mut Vec, + contribution: Vec, + beta: EF, +) { + debug_assert!(contribution.len().is_power_of_two()); + debug_assert!(accumulator.is_empty() || accumulator.len().is_power_of_two()); + debug_assert!(contribution.len() >= accumulator.len()); + + if accumulator.is_empty() { + accumulator.extend(contribution); + return; + } + + if accumulator.len() == contribution.len() { + // No lift needed; fuse the Horner mul and add into a single packed pass by + // computing `contribution + beta * accumulator` in place on the (owned) + // contribution buffer and swapping it in. + let mut contribution = contribution; + par_add_scaled_slice_in_place(&mut contribution, accumulator, beta); + *accumulator = contribution; + return; + } + + par_scale_slice_in_place(accumulator, beta); + while accumulator.len() < contribution.len() { + accumulator.extend_from_within(..); + } + // TODO: use parallel packed addition + accumulator + .par_iter_mut() + .zip(contribution.into_par_iter()) + .for_each(|(a, c)| *a += c); } // ============================================================================ @@ -133,21 +143,19 @@ where pub fn commit_quotient( config: &SC, q_evals: Vec, - coset: &LiftedCoset, + domain: &EvaluationDomain, ) -> Committed, SC::Lmcs> where F: TwoAdicField, EF: ExtensionField, SC: StarkConfig, { - let n = coset.trace_height(); - let d = q_evals.len() / n; - let log_d = log2_strict_usize(d); - let log_blowup = config.pcs().log_blowup(); - let b = 1usize << log_blowup; + let n = domain.trace_height(); + let d = domain.quotient_degree(); + let lde_height = domain.lifted().lde_height(); - debug_assert_eq!(q_evals.len() % n, 0, "q_evals length must be divisible by N"); - debug_assert!(b >= d, "blowup B must be >= constraint degree D"); + debug_assert_eq!(q_evals.len(), n * d, "q_evals length must equal N · D"); + // D ≤ B (i.e. lde_height ≥ N · D) is enforced by `EvaluationDomain::new`. // ═══════════════════════════════════════════════════════════════════════ // Step 0: Reshape to N × D matrix @@ -172,7 +180,7 @@ where // Multiply c_hat[t, k] by (ω_Jᵗ)⁻ᵏ → a[t, k]·gᵏ. // This removes the per-coset shift ω_Jᵗ while keeping gᵏ baked in. info_span!("quotient scaling", n).in_scope(|| { - let omega_j_inv = F::two_adic_generator(coset.log_trace_height as usize + log_d).inverse(); + let omega_j_inv = domain.subgroup().generator_inverse(); // Precompute ω_J⁻ᵏ for k = 0..N with sequential multiplications let row_bases: Vec = omega_j_inv.powers().take(n).collect(); @@ -186,15 +194,15 @@ where }); // ═══════════════════════════════════════════════════════════════════════ - // Step 3: Flatten EF → F, zero-pad to N·B rows + // Step 3: Flatten EF → F, zero-pad to LDE height (N·B rows) // ═══════════════════════════════════════════════════════════════════════ // We flatten before the DFT (rather than using dft_algebra_batch) because // we need base field for commitment anyway — this skips the reconstitute. // - // Zero-padding from N to N·B rows is needed because `dft_batch` expects - // the full target-size buffer. The extra rows are zero because each qₜ has - // degree < N. We pad here (after iDFT + scaling) so those two steps work - // on the smaller N-row buffer. + // Zero-padding from N to lde_height rows is needed because `dft_batch` + // expects the full target-size buffer. The extra rows are zero because each + // qₜ has degree < N. We pad here (after iDFT + scaling) so those two steps + // work on the smaller N-row buffer. // // PERF: the full N·B-size DFT processes N·(B−1) zero rows through every // butterfly stage, costing O(N·B·log(N·B)) instead of O(N·B·log N). For @@ -214,7 +222,7 @@ where // already does internally after the iDFT phase. let base_width = d * EF::DIMENSION; let mut base_coeffs = >::flatten_to_base(coeffs.values); - base_coeffs.resize(n * b * base_width, F::ZERO); + base_coeffs.resize(lde_height * base_width, F::ZERO); let coeffs_padded = RowMajorMatrix::new(base_coeffs, base_width); // ═══════════════════════════════════════════════════════════════════════ @@ -222,7 +230,7 @@ where // ═══════════════════════════════════════════════════════════════════════ // Because gᵏ is baked into the coefficients, the plain DFT evaluates // on gK directly: entry (i, t) gives qₜ(g·ω_Kⁱ). - let quotient_matrix = info_span!("quotient DFT", dims = %format!("{}x{base_width}", n * b)) + let quotient_matrix = info_span!("quotient DFT", dims = %format!("{lde_height}x{base_width}")) .in_scope(|| { let lde = config.dft().dft_batch(coeffs_padded); @@ -234,5 +242,62 @@ where let tree = config.lmcs().build_aligned_tree(vec![quotient_matrix]); - Committed::new(tree, log_blowup) + // The quotient is committed on the same LDE coset as the trace commits. + Committed::new(tree) +} + +#[cfg(test)] +mod tests { + use alloc::vec::Vec; + + use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; + use p3_field::{Field, PrimeCharacteristicRing}; + use p3_matrix::dense::RowMajorMatrix; + + use super::upsample_evals; + use crate::testing::configs::goldilocks_poseidon2::{Felt, QuadFelt}; + + fn coeffs(height: usize) -> Vec { + (0..height).map(|i| QuadFelt::from_u64((i as u64) + 1)).collect() + } + + /// Checks that `upsample_evals` on `dft` produces the same result as a direct + /// coset DFT of zero-padded coefficients. + fn assert_upsample_matches_direct>(dft: &D, shift: Felt) { + let small_height = 8; + let added_bits = 2; + let large_height = small_height << added_bits; + + let small_coeffs = RowMajorMatrix::new(coeffs(small_height), 1); + let small_evals = NaiveDft.coset_dft_algebra_batch(small_coeffs, shift).values; + + let mut large_coeffs = coeffs(small_height); + large_coeffs.resize(large_height, QuadFelt::ZERO); + let direct_large = NaiveDft + .coset_dft_algebra_batch(RowMajorMatrix::new(large_coeffs, 1), shift) + .values; + + let upsampled = upsample_evals::(dft, small_evals, added_bits); + assert_eq!(upsampled, direct_large); + } + + #[test] + fn upsample_evals_matches_direct_coset_dft() { + assert_upsample_matches_direct(&NaiveDft, Felt::GENERATOR.exp_power_of_2(2)); + } + + /// Same check with the production DFT backend. + #[test] + fn upsample_evals_with_radix2_dit_parallel_matches_naive() { + use p3_dft::Radix2DitParallel; + let dft = Radix2DitParallel::::default(); + assert_upsample_matches_direct(&dft, Felt::GENERATOR.exp_power_of_2(2)); + } + + #[test] + fn upsample_evals_with_zero_added_bits_returns_input_unchanged() { + let evals = coeffs(8); + let out = upsample_evals::(&NaiveDft, evals.clone(), 0); + assert_eq!(out, evals); + } } diff --git a/stark/miden-lifted-stark/src/selectors.rs b/stark/miden-lifted-stark/src/selectors.rs index 441937069d..adc454d09e 100644 --- a/stark/miden-lifted-stark/src/selectors.rs +++ b/stark/miden-lifted-stark/src/selectors.rs @@ -1,9 +1,10 @@ //! Selector container for constraint folding. //! //! The [`Selectors`] struct is a plain container holding selector values. -//! Computation is done via [`LiftedCoset`](crate::coset::LiftedCoset) methods: -//! - [`LiftedCoset::selectors`](crate::coset::LiftedCoset::selectors) for coset evaluation (prover) -//! - [`LiftedCoset::selectors_at`](crate::coset::LiftedCoset::selectors_at) for lifted OOD point +//! Computation is done via [`LiftedDomain`](crate::domain::LiftedDomain) methods: +//! - [`LiftedDomain::selectors`](crate::domain::LiftedDomain::selectors) for coset evaluation +//! (prover) +//! - [`LiftedDomain::selectors_at`](crate::domain::LiftedDomain::selectors_at) for lifted OOD point //! evaluation (verifier) use alloc::vec::Vec; @@ -12,7 +13,7 @@ use p3_field::{PackedField, TwoAdicField}; /// Selector values for constraint evaluation. /// -/// Plain container for selector values. Use [`LiftedCoset`](crate::coset::LiftedCoset) methods +/// Plain container for selector values. Use [`LiftedDomain`](crate::domain::LiftedDomain) methods /// to compute selectors. /// /// Generic over `T` to support: @@ -52,22 +53,25 @@ mod tests { use super::*; use crate::{ - coset::LiftedCoset, - testing::configs::goldilocks_poseidon2::{Felt, QuadFelt}, + domain::{Coset, LiftedDomain}, + testing::{ + canonical_domain, + configs::goldilocks_poseidon2::{Felt, QuadFelt}, + }, }; #[test] fn test_selectors_at_point() { let log_n = 4; - let coset = LiftedCoset::unlifted(log_n, 0); + let domain: LiftedDomain = canonical_domain(log_n, 0); // Sample a point outside the domain let z = QuadFelt::from(Felt::from_u32(12345)); - let _sels = coset.selectors_at::(z); + let _sels = domain.selectors_at(z); - // Verify vanishing_at matches manual computation - let vanishing = coset.vanishing_at::(z); + // Verify trace-subgroup vanishing matches manual computation + let vanishing = domain.trace_subgroup().vanishing_at(z); let n = 1usize << log_n; let expected = z.exp_u64(n as u64) - QuadFelt::ONE; assert_eq!(vanishing, expected); @@ -77,9 +81,9 @@ mod tests { fn test_selectors_on_coset() { let log_trace = 3; let log_blowup = 2; // 4x blowup - let coset = LiftedCoset::unlifted(log_trace, log_blowup); + let domain = canonical_domain::(log_trace, log_blowup).evaluation_domain(log_blowup); - let sels: Selectors> = coset.selectors(); + let sels: Selectors> = domain.selectors(); // Check lengths let coset_size = 1 << (log_trace + log_blowup); diff --git a/stark/miden-lifted-stark/src/testing/airs/blake3.rs b/stark/miden-lifted-stark/src/testing/airs/blake3.rs index 58cc9246ad..6b0fe8d178 100644 --- a/stark/miden-lifted-stark/src/testing/airs/blake3.rs +++ b/stark/miden-lifted-stark/src/testing/airs/blake3.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use miden_lifted_air::{Air, BaseAir, LiftedAir, LiftedAirBuilder}; pub use p3_blake3_air::{Blake3Air, NUM_BLAKE3_COLS}; use p3_field::{Field, PrimeField64}; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; /// [`Blake3Air`] adapted for the lifted STARK prover. /// @@ -38,8 +38,15 @@ impl LiftedAir for LiftedBlake3Air { 0 } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[F], + _aux_inputs: &[F], + _challenges: &[EF], + ) -> (RowMajorMatrix, Vec) { + // Main-trace-only AIR: a single all-zero aux column. + (RowMajorMatrix::new(EF::zero_vec(main.height()), 1), Vec::new()) } fn eval>(&self, builder: &mut AB) { diff --git a/stark/miden-lifted-stark/src/testing/airs/keccak.rs b/stark/miden-lifted-stark/src/testing/airs/keccak.rs index f10a047168..227e22e229 100644 --- a/stark/miden-lifted-stark/src/testing/airs/keccak.rs +++ b/stark/miden-lifted-stark/src/testing/airs/keccak.rs @@ -5,7 +5,7 @@ use alloc::vec::Vec; use miden_lifted_air::{Air, BaseAir, LiftedAir, LiftedAirBuilder}; use p3_field::{Field, PrimeField64}; use p3_keccak_air::{KeccakAir, NUM_KECCAK_COLS, NUM_ROUNDS}; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; /// [`KeccakAir`] adapted for the lifted STARK prover. /// @@ -37,8 +37,15 @@ impl LiftedAir for LiftedKeccakAir { 0 } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[F], + _aux_inputs: &[F], + _challenges: &[EF], + ) -> (RowMajorMatrix, Vec) { + // Main-trace-only AIR: a single all-zero aux column. + (RowMajorMatrix::new(EF::zero_vec(main.height()), 1), Vec::new()) } fn eval>(&self, builder: &mut AB) { diff --git a/stark/miden-lifted-stark/src/testing/airs/miden.rs b/stark/miden-lifted-stark/src/testing/airs/miden.rs index 895bb49679..fcd23189d1 100644 --- a/stark/miden-lifted-stark/src/testing/airs/miden.rs +++ b/stark/miden-lifted-stark/src/testing/airs/miden.rs @@ -4,9 +4,11 @@ //! single degree-9 base constraint producing 8 quotient chunks, and 8 extension-field //! auxiliary columns (= 16 base-field columns with Goldilocks `ext_degree=2`). +use alloc::vec::Vec; + use miden_lifted_air::{AirBuilder, BaseAir, LiftedAir, LiftedAirBuilder, WindowAccess}; use p3_field::{Field, PrimeCharacteristicRing}; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; // --------------------------------------------------------------------------- // Constants @@ -74,8 +76,17 @@ impl LiftedAir for DummyMidenAir { self.num_aux_cols } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[F], + _aux_inputs: &[F], + _challenges: &[EF], + ) -> (RowMajorMatrix, Vec) { + // Constraints touch only the main trace: emit an all-zero aux trace. + let aux = + RowMajorMatrix::new(EF::zero_vec(main.height() * self.num_aux_cols), self.num_aux_cols); + (aux, EF::zero_vec(self.num_aux_cols)) } fn eval>(&self, builder: &mut AB) { diff --git a/stark/miden-lifted-stark/src/testing/airs/mod.rs b/stark/miden-lifted-stark/src/testing/airs/mod.rs index 5be7bea35d..7af52fda2b 100644 --- a/stark/miden-lifted-stark/src/testing/airs/mod.rs +++ b/stark/miden-lifted-stark/src/testing/airs/mod.rs @@ -3,12 +3,6 @@ //! Each module adapts an upstream Plonky3 AIR into a `LiftedAir` so it can be proven //! and verified with the lifted STARK protocol. -use alloc::{vec, vec::Vec}; - -use miden_lifted_air::AuxBuilder; -use p3_field::{ExtensionField, Field}; -use p3_matrix::{Matrix, dense::RowMajorMatrix}; - #[cfg(feature = "testing")] pub mod blake3; #[cfg(feature = "testing")] @@ -16,38 +10,3 @@ pub mod keccak; pub mod miden; #[cfg(feature = "testing")] pub mod poseidon2; - -/// Aux builder that produces an all-zero auxiliary trace. -/// -/// Every `LiftedAir` must have at least one aux column, so this builder -/// satisfies the requirement with minimal cost. -/// -/// Use [`ZeroAuxBuilder::dummy()`] for AIRs with `num_aux_values() == 0` -/// (1-column all-zero trace, no aux values). -pub struct ZeroAuxBuilder { - pub num_aux_cols: usize, - pub num_aux_values: usize, -} - -impl ZeroAuxBuilder { - /// 1-column all-zero auxiliary trace with no aux values. - /// - /// Suitable for AIRs where `num_aux_values() == 0`. - pub fn dummy() -> Self { - Self { num_aux_cols: 1, num_aux_values: 0 } - } -} - -impl> AuxBuilder for ZeroAuxBuilder { - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - _challenges: &[EF], - ) -> (RowMajorMatrix, Vec) { - let height = main.height(); - let values = EF::zero_vec(height * self.num_aux_cols); - let aux_trace = RowMajorMatrix::new(values, self.num_aux_cols); - let aux_values = vec![EF::ZERO; self.num_aux_values]; - (aux_trace, aux_values) - } -} diff --git a/stark/miden-lifted-stark/src/testing/airs/poseidon2.rs b/stark/miden-lifted-stark/src/testing/airs/poseidon2.rs index b9b1dee24a..fdb8183eba 100644 --- a/stark/miden-lifted-stark/src/testing/airs/poseidon2.rs +++ b/stark/miden-lifted-stark/src/testing/airs/poseidon2.rs @@ -8,7 +8,7 @@ use alloc::vec::Vec; use miden_lifted_air::{Air, BaseAir, LiftedAir, LiftedAirBuilder}; use p3_field::Field; use p3_goldilocks::{GenericPoseidon2LinearLayersGoldilocks, Goldilocks}; -use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; use p3_poseidon2_air::{Poseidon2Air, RoundConstants, num_cols}; /// Goldilocks Poseidon2 configuration constants. @@ -67,8 +67,15 @@ impl LiftedAir for LiftedPoseidon2Air { 0 } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Goldilocks], + _aux_inputs: &[Goldilocks], + _challenges: &[EF], + ) -> (RowMajorMatrix, Vec) { + // Main-trace-only AIR: a single all-zero aux column. + (RowMajorMatrix::new(EF::zero_vec(main.height()), 1), Vec::new()) } fn eval>(&self, builder: &mut AB) { diff --git a/stark/miden-lifted-stark/src/testing/configs/goldilocks_poseidon2.rs b/stark/miden-lifted-stark/src/testing/configs/goldilocks_poseidon2.rs index fff799e0dd..7c87d251ca 100644 --- a/stark/miden-lifted-stark/src/testing/configs/goldilocks_poseidon2.rs +++ b/stark/miden-lifted-stark/src/testing/configs/goldilocks_poseidon2.rs @@ -6,14 +6,18 @@ use alloc::vec::Vec; use p3_challenger::DuplexChallenger; -use p3_field::PrimeCharacteristicRing; +use p3_dft::Radix2DitParallel; +use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; use p3_goldilocks::Poseidon2Goldilocks; use p3_matrix::dense::RowMajorMatrix; use p3_symmetric::{Hash, TruncatedPermutation}; use rand::{SeedableRng, rngs::SmallRng}; pub use super::{Felt, PackedFelt, QuadFelt}; -use crate::{AirWitness, testing::TEST_SEED}; +use crate::{ + air::{MultiAir, ProverStatement, Statement}, + testing::TEST_SEED, +}; // ============================================================================= // Base field/hash configuration @@ -93,7 +97,7 @@ pub fn random_lde_matrix( shift: Felt, ) -> RowMajorMatrix where - V: p3_field::BasedVectorSpace + Clone + Send + Sync + Default, + V: BasedVectorSpace + Clone + Send + Sync + Default, rand::distr::StandardUniform: rand::distr::Distribution, { use p3_dft::{Radix2DFTSmallBatch, TwoAdicSubgroupDft}; @@ -111,7 +115,7 @@ where // STARK layer // ============================================================================= -pub type Dft = p3_dft::Radix2DitParallel; +pub type Dft = Radix2DitParallel; pub type TestConfig = crate::config::GenericStarkConfig; @@ -135,57 +139,65 @@ pub fn generate_pow4_trace(start: Felt, height: usize) -> RowMajorMatrix { RowMajorMatrix::new(values, 1) } -/// Prove and verify from pre-built prover instances. -/// -/// Runs the full prove → verify → transcript-reparse cycle. -pub fn prove_and_verify_instances(instances: &[(&A, AirWitness<'_, Felt>, &B)]) +/// Run the full prove → verify → transcript-reparse cycle. +pub fn prove_and_verify_statement(prover_statement: &ProverStatement) where - A: crate::air::LiftedAir, - B: crate::air::AuxBuilder, + MA: MultiAir, { let config = test_config(); - let output = crate::prover::prove_multi(&config, instances, test_challenger()) - .expect("proving should succeed"); - - let verifier_instances: Vec<_> = - instances.iter().map(|(a, w, _)| (*a, w.to_instance())).collect(); + let prover_instance = crate::ProverInstance::new(&config, prover_statement, None) + .expect("no preprocessed columns"); + let output = prover_instance.prove(test_challenger()).expect("proving should succeed"); - let verifier_digest = crate::verifier::verify_multi( - &config, - &verifier_instances, - &output.proof, - test_challenger(), - ) - .expect("verification should succeed"); + let verifier_instance = + crate::VerifierInstance::new(&config, prover_statement.statement(), None) + .expect("no preprocessed columns"); + let verifier_digest = verifier_instance + .verify(&output.proof, test_challenger()) + .expect("verification should succeed"); assert_eq!(output.digest, verifier_digest); // Re-parse transcript from a fresh challenger and verify digest agreement. - let (_, reparse_digest) = crate::proof::StarkTranscript::from_proof( - &config, - &verifier_instances, - &output.proof, - test_challenger(), - ) - .expect("transcript re-parse should succeed"); + let (_, reparse_digest) = + crate::proof::StarkProof::from_data(&verifier_instance, &output.proof, test_challenger()) + .expect("transcript re-parse should succeed"); assert_eq!(output.digest, reparse_digest); } -/// Prove and verify multiple traces, each with its own public values. -/// -/// `instances` is a slice of `(trace, public_values)` pairs. -pub fn prove_and_verify( - air: &A, - aux_builder: &B, - instances: &[(RowMajorMatrix, Vec)], -) where +/// Minimal [`MultiAir`] wrapper for tests: a list of AIRs, each building its own +/// aux trace via [`LiftedAir::build_aux_trace`](crate::air::LiftedAir::build_aux_trace). +pub struct TestMultiAir { + pub airs: Vec, +} + +impl TestMultiAir { + pub fn new(airs: Vec) -> Self { + Self { airs } + } +} + +impl MultiAir for TestMultiAir +where A: crate::air::LiftedAir, - B: crate::air::AuxBuilder, { - let prover_instances: Vec<_> = instances - .iter() - .map(|(t, pv)| (air, AirWitness::new(t, pv, &[]), aux_builder)) - .collect(); + type Air = A; - prove_and_verify_instances(&prover_instances); + fn airs(&self) -> &[Self::Air] { + &self.airs + } +} + +/// Prove and verify multiple traces sharing one AIR. +pub fn prove_and_verify(air: &A, air_inputs: &[Felt], traces: &[RowMajorMatrix]) +where + A: crate::air::LiftedAir + Clone, +{ + let airs: Vec = core::iter::repeat_n(air.clone(), traces.len()).collect(); + let traces_owned: Vec> = traces.to_vec(); + let statement = Statement::new(TestMultiAir::new(airs), air_inputs.to_vec(), Vec::new()) + .expect("statement inputs valid"); + let prover_statement = + ProverStatement::new(statement, traces_owned).expect("trace shape valid"); + prove_and_verify_statement(&prover_statement); } diff --git a/stark/miden-lifted-stark/src/testing/mod.rs b/stark/miden-lifted-stark/src/testing/mod.rs index dc2e019e92..b35887ff8f 100644 --- a/stark/miden-lifted-stark/src/testing/mod.rs +++ b/stark/miden-lifted-stark/src/testing/mod.rs @@ -14,18 +14,24 @@ pub mod configs; pub mod params; #[cfg(test)] -mod test_aux_shape; -#[cfg(test)] -mod test_bus; +mod test_external_assertions; #[cfg(test)] mod test_multi_aux_alignment; #[cfg(test)] +mod test_per_air_degree; +#[cfg(test)] +mod test_preprocessed; +#[cfg(test)] mod test_tiny_air; // Re-export commonly used params at the module level for convenience. use alloc::vec::Vec; -use p3_field::Field; +// Re-exports for integration benches (which consume `miden_lifted_stark` as an +// external crate via the `testing` feature). Gated so production callers don't +// see this surface. +pub use miden_lifted_air::{MultiAir, ProverStatement, Statement, log2_strict_u8}; +use p3_field::{Field, TwoAdicField}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; pub use params::{ BENCH_PCS_PARAMS, FRI_FOLD_ARITY_2, FRI_FOLD_ARITY_4, FRI_FOLD_ARITY_8, LOG_HEIGHTS, @@ -37,6 +43,32 @@ use rand::{ rngs::SmallRng, }; +pub use crate::{ + domain::{Coset, LiftedDomain}, + lmcs::{Lmcs, LmcsTree}, + pcs::{ + deep::interpolate::PointQuotients, fri::fold::FriFold, params::PcsParams, + prover::open_with_channel, + }, + prover::quotient::commit_quotient, +}; + +// ============================================================================= +// Domain fixtures +// ============================================================================= + +/// Build the canonical [`LiftedDomain`] for `(log_trace_height, log_blowup)`, +/// panicking on out-of-range parameters. +/// +/// Fixtures pick their own sizes, so an out-of-range pair is a programmer error +/// rather than a recoverable condition; this wraps the validated +/// [`LiftedDomain::try_canonical`] so tests and benches don't repeat the +/// `.expect(...)`. +pub fn canonical_domain(log_trace_height: u8, log_blowup: u8) -> LiftedDomain { + LiftedDomain::try_canonical(log_trace_height, log_blowup) + .expect("canonical domain parameters out of range") +} + // ============================================================================= // Matrix generation // ============================================================================= @@ -137,12 +169,3 @@ macro_rules! define_lmcs_test_helpers { } pub(crate) use define_lmcs_test_helpers; - -// ============================================================================= -// Internal re-exports for benchmarks -// ============================================================================= -pub use crate::pcs::{ - deep::interpolate::PointQuotients, fri::fold::FriFold, prover::open_with_channel, - utils::bit_reversed_coset_points, -}; -pub use crate::prover::quotient::commit_quotient; diff --git a/stark/miden-lifted-stark/src/testing/params.rs b/stark/miden-lifted-stark/src/testing/params.rs index 5abf036de9..a3fcda3e0c 100644 --- a/stark/miden-lifted-stark/src/testing/params.rs +++ b/stark/miden-lifted-stark/src/testing/params.rs @@ -56,10 +56,13 @@ pub const PARALLEL_STR: &str = if cfg!(feature = "parallel") { // ============================================================================= /// PCS parameters for unit tests (fast, minimal security). +/// +/// `log_blowup = 3` supports AIRs with symbolic degree up to 9 +/// (`log_quotient_degree = 3`). pub const TEST_PCS_PARAMS: PcsParams = PcsParams { + log_blowup: 3, deep: DeepParams { deep_pow_bits: 0 }, fri: FriParams { - log_blowup: 2, fold: FRI_FOLD_ARITY_4, log_final_degree: 2, folding_pow_bits: 0, @@ -70,9 +73,9 @@ pub const TEST_PCS_PARAMS: PcsParams = PcsParams { /// PCS parameters for benchmarks (realistic security, zero PoW). pub const BENCH_PCS_PARAMS: PcsParams = PcsParams { + log_blowup: 2, deep: DeepParams { deep_pow_bits: 0 }, fri: FriParams { - log_blowup: 2, fold: FRI_FOLD_ARITY_4, log_final_degree: 8, folding_pow_bits: 0, @@ -83,9 +86,9 @@ pub const BENCH_PCS_PARAMS: PcsParams = PcsParams { /// PCS parameters for quotient commit benchmarks (lower blowup, single query). pub const QC_PCS_PARAMS: PcsParams = PcsParams { + log_blowup: 1, deep: DeepParams { deep_pow_bits: 0 }, fri: FriParams { - log_blowup: 1, fold: FRI_FOLD_ARITY_4, log_final_degree: 0, folding_pow_bits: 0, diff --git a/stark/miden-lifted-stark/src/testing/test_aux_shape.rs b/stark/miden-lifted-stark/src/testing/test_aux_shape.rs deleted file mode 100644 index 390cfecb5c..0000000000 --- a/stark/miden-lifted-stark/src/testing/test_aux_shape.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! Tests that the prover rejects aux trace width mismatches. - -use alloc::{vec, vec::Vec}; - -use p3_field::PrimeCharacteristicRing; -use p3_matrix::{Matrix, dense::RowMajorMatrix}; - -use crate::{ - air::{AuxBuilder, BaseAir, LiftedAir, LiftedAirBuilder}, - prove_single, - testing::configs::goldilocks_poseidon2::{Felt, QuadFelt, test_challenger, test_config}, -}; - -#[derive(Clone, Copy, Debug)] -struct BadAuxWidthAir; - -impl BaseAir for BadAuxWidthAir { - fn width(&self) -> usize { - 1 - } -} - -impl LiftedAir for BadAuxWidthAir { - fn num_randomness(&self) -> usize { - 1 - } - - fn aux_width(&self) -> usize { - 1 - } - - fn num_aux_values(&self) -> usize { - 0 - } - - fn num_var_len_public_inputs(&self) -> usize { - 0 - } - - fn eval>(&self, _builder: &mut AB) {} -} - -/// AuxBuilder that returns 2 EF columns when BadAuxWidthAir declares 1. -struct BadAuxBuilder; - -impl AuxBuilder for BadAuxBuilder { - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - _challenges: &[QuadFelt], - ) -> (RowMajorMatrix, Vec) { - let height = main.height(); - // Return 2 QuadFelt columns when aux_width() declares 1 - let aux = RowMajorMatrix::new(vec![QuadFelt::ZERO; height * 2], 2); - (aux, vec![QuadFelt::ZERO, QuadFelt::ZERO]) - } -} - -#[test] -#[should_panic(expected = "aux trace width mismatch")] -fn aux_width_mismatch_panics() { - let config = test_config(); - let air = BadAuxWidthAir; - - let trace = RowMajorMatrix::new(vec![Felt::ZERO, Felt::ONE, Felt::ONE, Felt::ZERO], 1); - let public_values = vec![]; - - let _result = - prove_single(&config, &air, &trace, &public_values, &[], &BadAuxBuilder, test_challenger()); -} diff --git a/stark/miden-lifted-stark/src/testing/test_bus.rs b/stark/miden-lifted-stark/src/testing/test_bus.rs deleted file mode 100644 index 213f240b74..0000000000 --- a/stark/miden-lifted-stark/src/testing/test_bus.rs +++ /dev/null @@ -1,312 +0,0 @@ -//! Tests reduced auxiliary values (multiset and logup bus identities). - -use alloc::{vec, vec::Vec}; - -use miden_lifted_air::ReductionError; -use p3_field::{Field, PrimeCharacteristicRing}; -use p3_matrix::{Matrix, dense::RowMajorMatrix}; - -use crate::{ - AirInstance, AirWitness, - air::{ - AirBuilder, AuxBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, - ReducedAuxValues, VarLenPublicInputs, WindowAccess, - }, - prove_multi, - testing::configs::goldilocks_poseidon2::{ - Felt, QuadFelt, generate_pow4_trace, test_challenger, test_config, - }, - verify_multi, -}; - -// --------------------------------------------------------------------------- -// BusTestAir: exercises reduced_aux_values with multiset + logup buses. -// -// Main trace: 1 column, power-of-4 chain (same as TinyAir). -// Aux trace: 2 constant columns (all rows identical): -// col 0: 1/(pi_0 + challenge[0]) — inverse for multiset bus -// col 1: pi_1 + challenge[1] — accumulator for logup bus -// -// Aux values (committed to transcript, constrained to match aux trace last row): -// aux_values[0] = col 0 value = 1/(pi_0 + c0) -// aux_values[1] = col 1 value = pi_1 + c1 -// -// reduced_aux_values (verifier-side bus identity check): -// Bus 0 (multiset): prod = aux_values[0] * (c0 + pi_0) == 1 -// Bus 1 (logup): sum = (aux_values[1] - c1) - pi_1 == 0 -// -// pi_0, pi_1 appear in two places: -// - public_values[1..]: used by eval() for aux trace constraints -// - var_len_public_inputs: used by reduced_aux_values() for bus check -// Both must agree for the proof to verify. -// --------------------------------------------------------------------------- - -#[derive(Clone, Debug)] -struct BusTestAir; - -impl BaseAir for BusTestAir { - fn width(&self) -> usize { - 1 - } - - fn num_public_values(&self) -> usize { - 3 // [start, pi_0, pi_1] - } -} - -impl LiftedAir for BusTestAir { - fn num_randomness(&self) -> usize { - 2 - } - - fn aux_width(&self) -> usize { - 2 - } - - fn num_aux_values(&self) -> usize { - 2 - } - - fn num_var_len_public_inputs(&self) -> usize { - 2 - } - - fn reduced_aux_values( - &self, - aux_values: &[QuadFelt], - challenges: &[QuadFelt], - _public_values: &[Felt], - var_len_public_inputs: VarLenPublicInputs<'_, Felt>, - ) -> Result, ReductionError> { - // Bus 0 (multiset): prod = aux_values[0] * (challenges[0] + pi_0) - // aux_values[0] = 1/(pi_0 + c0), so prod == 1 when pi_0 matches. - let pi_0 = QuadFelt::from(var_len_public_inputs[0][0]); - let prod = aux_values[0] * (challenges[0] + pi_0); - - // Bus 1 (logup): sum = (aux_values[1] - challenges[1]) - pi_1 - // aux_values[1] = pi_1 + c1, so sum == 0 when pi_1 matches. - let pi_1 = QuadFelt::from(var_len_public_inputs[1][0]); - let sum = (aux_values[1] - challenges[1]) - pi_1; - - Ok(ReducedAuxValues { prod, sum }) - } - - fn eval>(&self, builder: &mut AB) { - // Copy public values upfront (PublicVar: Copy) to release borrow. - let pv0 = builder.public_values()[0]; - let pv1 = builder.public_values()[1]; - let pv2 = builder.public_values()[2]; - - let main = builder.main(); - let (local, next) = (main.current_slice(), main.next_slice()); - - // Main trace: power-of-4 chain - builder.when_first_row().assert_eq(local[0], pv0); - let main_pow4: AB::Expr = local[0].into().exp_power_of_2(2); - builder.when_transition().assert_eq(next[0], main_pow4); - - // Copy challenges and aux values (RandomVar/VarEF: Copy) to release borrow. - let c0: AB::RandomVar = builder.permutation_randomness()[0]; - let c1: AB::RandomVar = builder.permutation_randomness()[1]; - let av0: AB::PermutationVar = builder.permutation_values()[0].clone(); - let av1: AB::PermutationVar = builder.permutation_values()[1].clone(); - - let aux = builder.permutation(); - let aux_local = aux.current_slice(); - let aux_next = aux.next_slice(); - - // pi_0 = public_values[1], pi_1 = public_values[2] - let pi_0: AB::ExprEF = Into::::into(pv1).into(); - let pi_1: AB::ExprEF = Into::::into(pv2).into(); - let c0: AB::ExprEF = c0.into(); - let c1: AB::ExprEF = c1.into(); - - // First row: aux[0] * (pi_0 + c0) == 1 - let a0: AB::ExprEF = aux_local[0].into(); - builder.when_first_row().assert_eq_ext(a0 * (pi_0 + c0), AB::ExprEF::ONE); - - // First row: aux[1] == pi_1 + c1 - let a1: AB::ExprEF = aux_local[1].into(); - builder.when_first_row().assert_eq_ext(a1, pi_1 + c1); - - // Transition: constant columns - builder - .when_transition() - .assert_eq_ext::(aux_next[0].into(), aux_local[0].into()); - builder - .when_transition() - .assert_eq_ext::(aux_next[1].into(), aux_local[1].into()); - - // Last row: aux columns match aux_values - builder - .when_last_row() - .assert_eq_ext::(aux_local[0].into(), av0.into()); - builder - .when_last_row() - .assert_eq_ext::(aux_local[1].into(), av1.into()); - } -} - -// --------------------------------------------------------------------------- -// AuxBuilder: constant aux columns. -// --------------------------------------------------------------------------- - -struct BusTestAuxBuilder { - pi_0: Felt, - pi_1: Felt, -} - -impl AuxBuilder for BusTestAuxBuilder { - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - challenges: &[QuadFelt], - ) -> (RowMajorMatrix, Vec) { - let height = main.height(); - let c0 = challenges[0]; - let c1 = challenges[1]; - - // col 0: 1/(pi_0 + c0), col 1: pi_1 + c1 - let col0_val = (QuadFelt::from(self.pi_0) + c0).inverse(); - let col1_val = QuadFelt::from(self.pi_1) + c1; - - let mut values = Vec::with_capacity(height * 2); - for _ in 0..height { - values.push(col0_val); - values.push(col1_val); - } - - let aux_trace = RowMajorMatrix::new(values, 2); - let aux_values = vec![col0_val, col1_val]; - (aux_trace, aux_values) - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[test] -fn bus_identity_check() { - let config = test_config(); - - let pi_0 = Felt::from_u64(42); - let pi_1 = Felt::from_u64(67); - let start = Felt::from_u64(2); - let height = 8; - - let air = BusTestAir; - let aux_builder = BusTestAuxBuilder { pi_0, pi_1 }; - let trace = generate_pow4_trace(start, height); - let public_values = vec![start, pi_0, pi_1]; - - // Build var_len_public_inputs (one reducible input per bus) - let input_0 = [pi_0]; - let input_1 = [pi_1]; - let var_len_pi: [&[Felt]; 2] = [&input_0, &input_1]; - - // Prove - let prover_instances = - [(&air, AirWitness::new(&trace, &public_values, &var_len_pi), &aux_builder)]; - let output = - prove_multi(&config, &prover_instances, test_challenger()).expect("proving should succeed"); - - let instance = AirInstance { - public_values: &public_values, - var_len_public_inputs: &var_len_pi, - }; - - // Verify - let verifier_digest = - verify_multi(&config, &[(&air, instance)], &output.proof, test_challenger()) - .expect("verification should succeed"); - assert_eq!(output.digest, verifier_digest); -} - -#[test] -fn bus_wrong_var_len_pi_fails() { - let config = test_config(); - - let pi_0 = Felt::from_u64(42); - let pi_1 = Felt::from_u64(67); - let start = Felt::from_u64(2); - let height = 8; - - let air = BusTestAir; - let aux_builder = BusTestAuxBuilder { pi_0, pi_1 }; - let trace = generate_pow4_trace(start, height); - let public_values = vec![start, pi_0, pi_1]; - - // Prove with correct values - let input_0 = [pi_0]; - let input_1 = [pi_1]; - let var_len_pi: [&[Felt]; 2] = [&input_0, &input_1]; - - let prover_instances = - [(&air, AirWitness::new(&trace, &public_values, &var_len_pi), &aux_builder)]; - let output = - prove_multi(&config, &prover_instances, test_challenger()).expect("proving should succeed"); - - // Verify with WRONG var_len_public_inputs (99 instead of 42) - let wrong_pi_0 = Felt::from_u64(99); - let wrong_input_0 = [wrong_pi_0]; - let wrong_var_len_pi: [&[Felt]; 2] = [&wrong_input_0, &input_1]; - - let instance = AirInstance { - public_values: &public_values, - var_len_public_inputs: &wrong_var_len_pi, - }; - - let err = verify_multi(&config, &[(&air, instance)], &output.proof, test_challenger()) - .expect_err("wrong var_len_pi should fail verification"); - - assert!( - matches!(err, crate::VerifierError::InvalidReducedAux), - "expected InvalidReducedAux, got {err:?}" - ); -} - -#[test] -fn bus_wrong_input_count_fails() { - let config = test_config(); - - let pi_0 = Felt::from_u64(42); - let pi_1 = Felt::from_u64(67); - let start = Felt::from_u64(2); - let height = 8; - - let air = BusTestAir; - let aux_builder = BusTestAuxBuilder { pi_0, pi_1 }; - let trace = generate_pow4_trace(start, height); - let public_values = vec![start, pi_0, pi_1]; - - // Prove with correct values - let input_0 = [pi_0]; - let input_1 = [pi_1]; - let var_len_pi: [&[Felt]; 2] = [&input_0, &input_1]; - - let prover_instances = - [(&air, AirWitness::new(&trace, &public_values, &var_len_pi), &aux_builder)]; - let output = - prove_multi(&config, &prover_instances, test_challenger()).expect("proving should succeed"); - - // Verify with WRONG input count (1 instead of 2) - let only_one: [&[Felt]; 1] = [&input_0]; - let instance = AirInstance { - public_values: &public_values, - var_len_public_inputs: &only_one, - }; - - let err = verify_multi(&config, &[(&air, instance)], &output.proof, test_challenger()) - .expect_err("wrong input count should fail verification"); - - assert!( - matches!( - err, - crate::VerifierError::Instance( - crate::InstanceValidationError::VarLenPublicInputsMismatch { .. } - ) - ), - "expected VarLenPublicInputsMismatch, got {err:?}" - ); -} diff --git a/stark/miden-lifted-stark/src/testing/test_external_assertions.rs b/stark/miden-lifted-stark/src/testing/test_external_assertions.rs new file mode 100644 index 0000000000..26332a8e71 --- /dev/null +++ b/stark/miden-lifted-stark/src/testing/test_external_assertions.rs @@ -0,0 +1,197 @@ +//! Tests cross-AIR assertions from `eval_external`: aux values committed to the +//! transcript are tied to public `aux_inputs` by extension-field assertion +//! expressions checked outside the per-row AIR constraints. + +use alloc::{vec, vec::Vec}; + +use miden_lifted_air::ReductionError; +use p3_field::PrimeCharacteristicRing; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; + +use crate::{ + ProverInstance, VerifierInstance, + air::{ + AirBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, MultiAir, + ProverStatement, Statement, WindowAccess, + }, + testing::configs::goldilocks_poseidon2::{ + Felt, QuadFelt, generate_pow4_trace, test_challenger, test_config, + }, +}; + +// --------------------------------------------------------------------------- +// ExternalAir: a power-of-4 main trace plus a single constant aux column equal +// to `challenge + input`. The committed aux value is that constant, bound to +// the aux column's last row by AIR constraints; `eval_external` then ties it to +// `aux_inputs[0]` shifted by the challenge. +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug)] +struct ExternalAir { + input: Felt, +} + +impl BaseAir for ExternalAir { + fn width(&self) -> usize { + 1 + } + + fn num_public_values(&self) -> usize { + 1 // [start] + } +} + +impl LiftedAir for ExternalAir { + fn num_randomness(&self) -> usize { + 1 + } + + fn aux_width(&self) -> usize { + 1 + } + + fn num_aux_values(&self) -> usize { + 1 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + // One constant column: challenge + input. + let value = QuadFelt::from(self.input) + challenges[0]; + let values = vec![value; main.height()]; + (RowMajorMatrix::new(values, 1), vec![value]) + } + + fn eval>(&self, builder: &mut AB) { + let start = builder.public_values()[0]; + + let main = builder.main(); + let (local, next) = (main.current_slice(), main.next_slice()); + + // Main trace: power-of-4 chain anchored at the public start value. + builder.when_first_row().assert_eq(local[0], start); + let main_pow4: AB::Expr = local[0].into().exp_power_of_2(2); + builder.when_transition().assert_eq(next[0], main_pow4); + + // Aux column is constant and exposes its value as the committed aux value + // on the last row — the value `eval_external` reasons about. + let aux = builder.permutation(); + let aux_local = aux.current_slice(); + let aux_next = aux.next_slice(); + let aux_value: AB::PermutationVar = builder.permutation_values()[0].clone(); + + builder + .when_transition() + .assert_eq_ext::(aux_next[0].into(), aux_local[0].into()); + builder + .when_last_row() + .assert_eq_ext::(aux_local[0].into(), aux_value.into()); + } +} + +// --------------------------------------------------------------------------- +// ExternalMultiAir: the cross-AIR `eval_external` reduction over `aux_inputs`. +// --------------------------------------------------------------------------- + +struct ExternalMultiAir { + airs: Vec, +} + +impl MultiAir for ExternalMultiAir { + type Air = ExternalAir; + + fn airs(&self) -> &[Self::Air] { + &self.airs + } + + fn max_aux_inputs(&self) -> usize { + 1 + } + + fn eval_external( + &self, + challenges: &[QuadFelt], + _air_inputs: &[Felt], + aux_inputs: &[Felt], + aux_values: &[&[QuadFelt]], + _log_trace_heights: &[u8], + ) -> Result, ReductionError> { + let aux = aux_values.first().ok_or("expected aux values for the instance")?; + let input = *aux_inputs.first().ok_or("missing external input")?; + + // The committed aux value must equal `challenge + input`. + Ok(vec![aux[0] - challenges[0] - QuadFelt::from(input)]) + } +} + +fn external_prover_statement( + input: Felt, + trace: RowMajorMatrix, + air_inputs: Vec, + aux_inputs: Vec, +) -> ProverStatement { + let statement = Statement::new( + ExternalMultiAir { airs: vec![ExternalAir { input }] }, + air_inputs, + aux_inputs, + ) + .expect("statement inputs valid"); + ProverStatement::new(statement, vec![trace]).expect("trace shape valid") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[test] +fn external_assertion_holds() { + let config = test_config(); + + let input = Felt::from_u64(42); + let start = Felt::from_u64(2); + + let trace = generate_pow4_trace(start, 8); + let prover_statement = external_prover_statement(input, trace, vec![start], vec![input]); + + let output = ProverInstance::new(&config, &prover_statement, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect("proving should succeed"); + + let verifier_digest = VerifierInstance::new(&config, prover_statement.statement(), None) + .expect("no preprocessed columns") + .verify(&output.proof, test_challenger()) + .expect("verification should succeed"); + assert_eq!(output.digest, verifier_digest); +} + +#[test] +fn missing_external_input_fails_proving() { + let config = test_config(); + + let input = Felt::from_u64(42); + let start = Felt::from_u64(2); + + let trace = generate_pow4_trace(start, 8); + + // Empty `aux_inputs`: `eval_external` runs out of inputs and surfaces a + // ReductionError. The prover mirrors the verifier's external assertion + // evaluation after aux values are available, so malformed statements fail + // early; this could become a debug assertion if proving needs to skip this + // verifier-side sanity check. + let broken = external_prover_statement(input, trace, vec![start], vec![]); + + let err = ProverInstance::new(&config, &broken, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect_err("missing external input should fail proving"); + assert!( + matches!(err, crate::ProverError::Reduction(_)), + "expected Reduction, got {err:?}" + ); +} diff --git a/stark/miden-lifted-stark/src/testing/test_multi_aux_alignment.rs b/stark/miden-lifted-stark/src/testing/test_multi_aux_alignment.rs index 2bbaf71d28..b58a83ed97 100644 --- a/stark/miden-lifted-stark/src/testing/test_multi_aux_alignment.rs +++ b/stark/miden-lifted-stark/src/testing/test_multi_aux_alignment.rs @@ -6,19 +6,18 @@ use p3_field::PrimeCharacteristicRing; use p3_matrix::{Matrix, dense::RowMajorMatrix}; use crate::{ - AirWitness, Lmcs, VerifierError, air::{ - AirBuilder, AuxBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, - WindowAccess, + AirBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, MultiAir, + ProverStatement, Statement, WindowAccess, }, - prove_multi, + lmcs::Lmcs, testing::configs::goldilocks_poseidon2::{ - Felt, QuadFelt, prove_and_verify_instances, test_challenger, test_config, + Felt, QuadFelt, prove_and_verify_statement, test_config, }, - transcript::TranscriptData, - verify_multi, }; +const START: u64 = 2; + #[derive(Clone, Debug)] struct PaddingAir { width: usize, @@ -54,8 +53,21 @@ impl LiftedAir for PaddingAir { 0 } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + // Column 0 holds the challenge; the rest pad with zeros up to `aux_width`. + let challenge = challenges[0]; + let mut values = Vec::with_capacity(main.height() * self.aux_width); + for _ in 0..main.height() { + values.push(challenge); + values.extend(core::iter::repeat_n(QuadFelt::ZERO, self.aux_width - 1)); + } + (RowMajorMatrix::new(values, self.aux_width), vec![]) } fn eval>(&self, builder: &mut AB) { @@ -75,25 +87,15 @@ impl LiftedAir for PaddingAir { } } -struct PaddingAuxBuilder { - aux_width: usize, +struct PaddingMultiAir { + airs: Vec, } -impl AuxBuilder for PaddingAuxBuilder { - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - challenges: &[QuadFelt], - ) -> (RowMajorMatrix, Vec) { - let height = main.height(); - let mut values = Vec::with_capacity(height * self.aux_width); - let challenge = challenges[0]; - for _ in 0..height { - values.push(challenge); - values.extend(core::iter::repeat_n(QuadFelt::ZERO, self.aux_width - 1)); - } - let aux_trace = RowMajorMatrix::new(values, self.aux_width); - (aux_trace, vec![]) +impl MultiAir for PaddingMultiAir { + type Air = PaddingAir; + + fn airs(&self) -> &[Self::Air] { + &self.airs } } @@ -106,9 +108,18 @@ fn generate_trace(start: Felt, height: usize, width: usize) -> RowMajorMatrix (RowMajorMatrix, Vec) { - let start = Felt::from_u64((idx + 2) as u64); - (generate_trace(start, height, width), vec![start]) +fn padding_prover_statement( + width: usize, + aux_width: usize, + start: Felt, +) -> ProverStatement { + let air = PaddingAir::new(width, aux_width); + let t0 = generate_trace(start, 8, width); + let t1 = generate_trace(start, 16, width); + let statement = + Statement::new(PaddingMultiAir { airs: vec![air.clone(), air] }, vec![start], Vec::new()) + .unwrap(); + ProverStatement::new(statement, vec![t0, t1]).unwrap() } #[test] @@ -117,50 +128,9 @@ fn multi_trace_with_aux_padding() { let alignment = config.lmcs.alignment(); let width = alignment + 1; let aux_width = alignment + 1; + let start = Felt::from_u64(START); - let air = PaddingAir::new(width, aux_width); - let aux_builder = PaddingAuxBuilder { aux_width }; - let instances = [instance(0, 8, width), instance(1, 16, width)]; - - let prover_instances: Vec<_> = instances - .iter() - .map(|(t, pv)| (&air, AirWitness::new(t, pv, &[]), &aux_builder)) - .collect(); + let prover_statement = padding_prover_statement(width, aux_width, start); - prove_and_verify_instances(&prover_instances); -} - -#[test] -fn multi_trace_rejects_trailing_transcript_data() { - let config = test_config(); - let alignment = config.lmcs.alignment(); - let width = alignment + 1; - let aux_width = alignment + 1; - - let air = PaddingAir::new(width, aux_width); - let aux_builder = PaddingAuxBuilder { aux_width }; - let instances = [instance(0, 8, width), instance(1, 16, width)]; - - let prover_instances: Vec<_> = instances - .iter() - .map(|(t, pv)| (&air, AirWitness::new(t, pv, &[]), &aux_builder)) - .collect(); - - let output = - prove_multi(&config, &prover_instances, test_challenger()).expect("proving should succeed"); - - let mut bad_proof = output.proof; - let (mut fields, commitments) = bad_proof.transcript.into_parts(); - fields.push(Felt::ONE); - bad_proof.transcript = TranscriptData::new(fields, commitments); - - let verifier_instances: Vec<_> = - prover_instances.iter().map(|(a, w, _)| (*a, w.to_instance())).collect(); - - let err = verify_multi(&config, &verifier_instances, &bad_proof, test_challenger()) - .expect_err("extra transcript data should fail verification"); - assert!(matches!( - err, - VerifierError::Transcript(crate::transcript::TranscriptError::TrailingData) - )); + prove_and_verify_statement(&prover_statement); } diff --git a/stark/miden-lifted-stark/src/testing/test_per_air_degree.rs b/stark/miden-lifted-stark/src/testing/test_per_air_degree.rs new file mode 100644 index 0000000000..138d4c6c93 --- /dev/null +++ b/stark/miden-lifted-stark/src/testing/test_per_air_degree.rs @@ -0,0 +1,255 @@ +//! End-to-end prove+verify tests for the per-AIR quotient-degree optimization. +//! +//! Exercises the branch in the prover loop where an AIR's native quotient degree +//! `D_j` is strictly less than the global `D_max`, so the prover divides on the +//! native domain and then `upsample_evals` lifts the resulting quotient evaluations. + +extern crate alloc; + +use alloc::{vec, vec::Vec}; + +use p3_field::PrimeCharacteristicRing; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; + +use crate::{ + air::{ + AirBuilder, BaseAir, LiftedAir, LiftedAirBuilder, MultiAir, ProverStatement, Statement, + WindowAccess, + }, + domain::log_quotient_degree, + testing::configs::goldilocks_poseidon2::{Felt, QuadFelt, prove_and_verify_statement}, +}; + +// --------------------------------------------------------------------------- +// PowerAir: single-column AIR with `next[0] = local[0]^power`. +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy)] +struct PowerAir { + power: u64, +} + +impl BaseAir for PowerAir { + fn width(&self) -> usize { + 1 + } +} + +impl LiftedAir for PowerAir { + fn num_randomness(&self) -> usize { + 1 + } + + fn aux_width(&self) -> usize { + 1 + } + + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + // Trivial aux: a single constant-challenge column. + (RowMajorMatrix::new(vec![challenges[0]; main.height()], 1), vec![]) + } + + fn eval>(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.current_slice().to_vec(), main.next_slice().to_vec()); + + let x: AB::Expr = local[0].into(); + let x_power: AB::Expr = match self.power { + 2 => x.clone() * x, + 5 => x.clone().exp_power_of_2(2) * x, + 9 => x.clone().exp_power_of_2(3) * x, + _ => unreachable!("tests only use power in {{2, 5, 9}}"), + }; + builder.when_transition().assert_eq(next[0].into(), x_power); + + // Trivial aux: aux_local == challenge (extension-field identity, degree 1). + let aux = builder.permutation(); + let aux_local = aux.current_slice().to_vec(); + let challenge = builder.permutation_randomness()[0]; + let aux_expr: AB::ExprEF = aux_local[0].into(); + let challenge_expr: AB::ExprEF = challenge.into(); + builder.assert_eq_ext(aux_expr, challenge_expr); + } +} + +// --------------------------------------------------------------------------- +// PeriodicPowerAir +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy)] +struct PeriodicPowerAir { + power: u64, +} + +impl BaseAir for PeriodicPowerAir { + fn width(&self) -> usize { + 1 + } +} + +impl LiftedAir for PeriodicPowerAir { + fn periodic_columns(&self) -> Vec> { + // Period 2: entries [1, 0] repeat across the trace. + vec![vec![Felt::ONE, Felt::ZERO]] + } + + fn num_randomness(&self) -> usize { + 1 + } + + fn aux_width(&self) -> usize { + 1 + } + + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + // Trivial aux: a single constant-challenge column. + (RowMajorMatrix::new(vec![challenges[0]; main.height()], 1), vec![]) + } + + fn eval>(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.current_slice().to_vec(), main.next_slice().to_vec()); + let periodic = builder.periodic_values().to_vec(); + + let x: AB::Expr = local[0].into(); + let x_power: AB::Expr = match self.power { + 3 => x.clone().exp_power_of_2(1) * x, + 5 => x.clone().exp_power_of_2(2) * x, + _ => unreachable!("periodic test uses power in {{3, 5}}"), + }; + builder.when_transition().assert_eq(next[0].into(), x_power); + + // Periodic column starts at 1 on the first trace row. + let p: AB::Expr = periodic[0].into(); + builder.when_first_row().assert_one(p); + + // Trivial aux: aux_local == challenge. + let aux = builder.permutation(); + let aux_local = aux.current_slice().to_vec(); + let challenge = builder.permutation_randomness()[0]; + let aux_expr: AB::ExprEF = aux_local[0].into(); + let challenge_expr: AB::ExprEF = challenge.into(); + builder.assert_eq_ext(aux_expr, challenge_expr); + } +} + +// --------------------------------------------------------------------------- +// MultiAir: trivial constant-challenge aux column for each trace. +// --------------------------------------------------------------------------- + +struct TwoTraceMultiAir { + airs: Vec, +} + +impl TwoTraceMultiAir { + fn new(airs: Vec) -> Self { + Self { airs } + } +} + +impl MultiAir for TwoTraceMultiAir +where + A: LiftedAir, +{ + type Air = A; + + fn airs(&self) -> &[Self::Air] { + &self.airs + } +} + +// --------------------------------------------------------------------------- +// Trace generator for `next = local^power`. +// --------------------------------------------------------------------------- + +fn generate_pow_trace(power: u64, start: Felt, height: usize) -> RowMajorMatrix { + let mut data = Vec::with_capacity(height); + let mut cur = start; + for _ in 0..height { + data.push(cur); + cur = cur.exp_u64(power); + } + RowMajorMatrix::new(data, 1) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[test] +fn quadratic_air_uses_one_quotient_chunk() { + assert_eq!(log_quotient_degree::(&PowerAir { power: 2 }), 0); +} + +#[test] +fn one_chunk_quadratic_quotient_proves() { + let air = PowerAir { power: 2 }; + let trace = generate_pow_trace(2, Felt::from_u64(7), 16); + + let statement = + Statement::new(TwoTraceMultiAir::new(vec![air]), Vec::new(), Vec::new()).unwrap(); + let prover_statement = ProverStatement::new(statement, vec![trace]).unwrap(); + prove_and_verify_statement(&prover_statement); +} + +fn run_upsample_case(low_power: u64, low_height: usize, high_power: u64, high_height: usize) { + let low = PowerAir { power: low_power }; + let high = PowerAir { power: high_power }; + + let t_low = generate_pow_trace(low_power, Felt::from_u64(7), low_height); + let t_high = generate_pow_trace(high_power, Felt::from_u64(11), high_height); + + let statement = + Statement::new(TwoTraceMultiAir::new(vec![low, high]), Vec::new(), Vec::new()).unwrap(); + let prover_statement = ProverStatement::new(statement, vec![t_low, t_high]).unwrap(); + prove_and_verify_statement(&prover_statement); +} + +#[test] +fn upsample_fires_on_d5_under_d9() { + run_upsample_case(5, 16, 9, 16); +} + +#[test] +fn upsample_fires_low_degree_on_taller_trace() { + run_upsample_case(2, 64, 9, 16); +} + +#[test] +fn upsample_fires_high_degree_on_taller_trace() { + run_upsample_case(2, 16, 9, 64); +} + +#[test] +fn upsample_fires_with_periodic_columns() { + let low = PeriodicPowerAir { power: 3 }; + let high = PeriodicPowerAir { power: 5 }; + + let t_low = generate_pow_trace(3, Felt::from_u64(7), 16); + let t_high = generate_pow_trace(5, Felt::from_u64(11), 16); + + let statement = + Statement::new(TwoTraceMultiAir::new(vec![low, high]), Vec::new(), Vec::new()).unwrap(); + let prover_statement = ProverStatement::new(statement, vec![t_low, t_high]).unwrap(); + prove_and_verify_statement(&prover_statement); +} diff --git a/stark/miden-lifted-stark/src/testing/test_preprocessed.rs b/stark/miden-lifted-stark/src/testing/test_preprocessed.rs new file mode 100644 index 0000000000..5f6df4869c --- /dev/null +++ b/stark/miden-lifted-stark/src/testing/test_preprocessed.rs @@ -0,0 +1,546 @@ +//! End-to-end tests for preprocessed traces on the stark-instance API. +//! +//! These exercise the real preprocessed path: the commitment is observed +//! first, the tree is opened via the PCS, and the per-AIR window is fed to the +//! constraint folders. Preprocessed content is served through +//! [`BaseAir::preprocessed_trace`]; the prover bundles it via +//! [`Preprocessed::build`] + [`ProverInstance::new`], and the verifier receives +//! only the commitment via [`VerifierInstance::new`]. + +use alloc::{vec, vec::Vec}; + +use p3_field::PrimeCharacteristicRing; +use p3_matrix::{Matrix, dense::RowMajorMatrix}; + +use crate::{ + Preprocessed, PreprocessedValidationError, ProverInstance, VerifierInstance, + air::{ + AirBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, MultiAir, + ProverStatement, Statement, WindowAccess, + }, + pcs::params::PcsParams, + proof::{StarkOutput, StarkProof}, + testing::configs::goldilocks_poseidon2::{ + Felt, Lmcs, QuadFelt, TestConfig, test_challenger, test_config, + }, +}; + +// --------------------------------------------------------------------------- +// AIR fixtures +// --------------------------------------------------------------------------- + +/// AIR with a preprocessed column carrying the row index `0, 1, 2, …`, served +/// by value through [`BaseAir::preprocessed_trace`]. +/// +/// Constraints (gated so symbolic degree ≥ 2): first row `main[0] == +/// preprocessed[0]`; transition `Δmain == Δpreprocessed` (uses the +/// preprocessed window non-trivially); first row `aux[0] == challenge`. +#[derive(Clone, Debug)] +struct RowCounterAir { + preprocessed: RowMajorMatrix, +} + +impl BaseAir for RowCounterAir { + fn width(&self) -> usize { + 1 + } + fn num_public_values(&self) -> usize { + 0 + } + fn preprocessed_trace(&self) -> Option> { + Some(self.preprocessed.clone()) + } +} + +impl LiftedAir for RowCounterAir { + fn preprocessed_width(&self) -> usize { + 1 + } + fn aux_width(&self) -> usize { + 1 + } + fn num_randomness(&self) -> usize { + 1 + } + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + build_aux(main.height(), challenges) + } + + fn eval>(&self, builder: &mut AB) { + let main = builder.main(); + let local_main: AB::Expr = main.current_slice()[0].into(); + let next_main: AB::Expr = main.next_slice()[0].into(); + + let preproc = builder.preprocessed(); + let local_preproc: AB::Expr = preproc.current_slice()[0].into(); + let next_preproc: AB::Expr = preproc.next_slice()[0].into(); + + let aux = builder.permutation(); + let aux_local: AB::ExprEF = aux.current_slice()[0].into(); + let challenge: AB::ExprEF = builder.permutation_randomness()[0].into(); + + builder.when_first_row().assert_eq(local_main.clone(), local_preproc.clone()); + builder + .when_transition() + .assert_eq(next_main - local_main, next_preproc - local_preproc); + builder.when_first_row().assert_eq_ext(aux_local, challenge); + } +} + +/// AIR with no preprocessed columns. Transition `next == local²`. +#[derive(Clone, Copy, Debug)] +struct ConstantAir; + +impl BaseAir for ConstantAir { + fn width(&self) -> usize { + 1 + } + fn num_public_values(&self) -> usize { + 0 + } +} + +impl LiftedAir for ConstantAir { + fn aux_width(&self) -> usize { + 1 + } + fn num_randomness(&self) -> usize { + 1 + } + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + build_aux(main.height(), challenges) + } + + fn eval>(&self, builder: &mut AB) { + let main = builder.main(); + let local: AB::Expr = main.current_slice()[0].into(); + let next: AB::Expr = main.next_slice()[0].into(); + + let aux = builder.permutation(); + let aux_local: AB::ExprEF = aux.current_slice()[0].into(); + let challenge: AB::ExprEF = builder.permutation_randomness()[0].into(); + + builder.when_transition().assert_eq(next, local.clone() * local); + builder.when_first_row().assert_eq_ext(aux_local, challenge); + } +} + +/// Heterogeneous AIR for mixed-instance tests. +#[derive(Clone, Debug)] +enum MixedAir { + Constant(ConstantAir), + RowCounter(RowCounterAir), +} + +impl BaseAir for MixedAir { + fn width(&self) -> usize { + 1 + } + fn num_public_values(&self) -> usize { + 0 + } + fn preprocessed_trace(&self) -> Option> { + match self { + Self::Constant(_) => None, + Self::RowCounter(a) => a.preprocessed_trace(), + } + } +} + +impl LiftedAir for MixedAir { + fn preprocessed_width(&self) -> usize { + match self { + Self::Constant(_) => 0, + Self::RowCounter(a) => a.preprocessed_width(), + } + } + fn aux_width(&self) -> usize { + 1 + } + fn num_randomness(&self) -> usize { + 1 + } + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + air_inputs: &[Felt], + aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + match self { + Self::Constant(a) => a.build_aux_trace(main, air_inputs, aux_inputs, challenges), + Self::RowCounter(a) => a.build_aux_trace(main, air_inputs, aux_inputs, challenges), + } + } + + fn eval>(&self, builder: &mut AB) { + match self { + Self::Constant(a) => a.eval(builder), + Self::RowCounter(a) => a.eval(builder), + } + } +} + +/// AIR that declares a wider preprocessed trace (`preprocessed_width() == 2`) +/// than the matrix it serves (width 1), to drive the width-mismatch check. +#[derive(Clone, Debug)] +struct WrongWidthAir { + preprocessed: RowMajorMatrix, +} + +impl BaseAir for WrongWidthAir { + fn width(&self) -> usize { + 1 + } + fn num_public_values(&self) -> usize { + 0 + } + fn preprocessed_trace(&self) -> Option> { + Some(self.preprocessed.clone()) + } +} + +impl LiftedAir for WrongWidthAir { + fn preprocessed_width(&self) -> usize { + 2 + } + fn aux_width(&self) -> usize { + 1 + } + fn num_randomness(&self) -> usize { + 1 + } + fn num_aux_values(&self) -> usize { + 0 + } + + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + build_aux(main.height(), challenges) + } + + fn eval>(&self, builder: &mut AB) { + // Well-formed degree-2 constraint; never reached because validation + // rejects the bundle before proving. + let main = builder.main(); + let local: AB::Expr = main.current_slice()[0].into(); + let next: AB::Expr = main.next_slice()[0].into(); + let aux = builder.permutation(); + let aux_local: AB::ExprEF = aux.current_slice()[0].into(); + let challenge: AB::ExprEF = builder.permutation_randomness()[0].into(); + builder.when_transition().assert_eq(next, local.clone() * local); + builder.when_first_row().assert_eq_ext(aux_local, challenge); + } +} + +// --------------------------------------------------------------------------- +// MultiAir + helpers +// --------------------------------------------------------------------------- + +/// Minimal [`MultiAir`] over a homogeneous AIR list. +#[derive(Clone, Debug)] +struct PreprocMultiAir { + airs: Vec, +} + +impl> MultiAir for PreprocMultiAir { + type Air = A; + + fn airs(&self) -> &[A] { + &self.airs + } +} + +fn row_index_trace(height: usize) -> RowMajorMatrix { + RowMajorMatrix::new((0..height).map(|r| Felt::from_u64(r as u64)).collect(), 1) +} + +fn shifted_row_index_trace(height: usize) -> RowMajorMatrix { + RowMajorMatrix::new((0..height).map(|r| Felt::from_u64((r + 1) as u64)).collect(), 1) +} + +/// Trace satisfying [`ConstantAir`]'s `next == local²` transition. +fn squaring_trace(height: usize) -> RowMajorMatrix { + let mut values = Vec::with_capacity(height); + let mut current = Felt::from_u64(2); + for _ in 0..height { + values.push(current); + current = current * current; + } + RowMajorMatrix::new(values, 1) +} + +/// Constant aux trace `[challenge; height]`, matching every fixture AIR. +fn build_aux(height: usize, challenges: &[QuadFelt]) -> (RowMajorMatrix, Vec) { + (RowMajorMatrix::new(vec![challenges[0]; height], 1), Vec::new()) +} + +/// Build a no-public-input prover statement for `airs` + `traces`. +fn prover_statement>( + airs: Vec, + traces: Vec>, +) -> ProverStatement> { + let statement = Statement::new(PreprocMultiAir { airs }, vec![], vec![]).expect("statement"); + ProverStatement::new(statement, traces).expect("prover statement") +} + +type TestOutput = StarkOutput; +type TestPreprocessed = Preprocessed; + +fn prove_with_preprocessed( + config: &TestConfig, + ps: &ProverStatement, +) -> (TestOutput, TestPreprocessed) +where + MA: MultiAir, +{ + let preprocessed = Preprocessed::build(ps.statement(), config).expect("has preprocessed"); + let output = ProverInstance::new(config, ps, Some(&preprocessed)) + .expect("valid preprocessed setup") + .prove(test_challenger()) + .expect("prove succeeds"); + (output, preprocessed) +} + +fn verify_and_reparse( + config: &TestConfig, + ps: &ProverStatement, + output: &TestOutput, + preprocessed: &TestPreprocessed, +) where + MA: MultiAir, +{ + let verifier_instance = + VerifierInstance::new(config, ps.statement(), Some(preprocessed.commitment())) + .expect("valid preprocessed setup"); + let digest = verifier_instance + .verify(&output.proof, test_challenger()) + .expect("verify succeeds"); + assert_eq!(output.digest, digest); + + let (_, reparse_digest) = + StarkProof::from_data(&verifier_instance, &output.proof, test_challenger()) + .expect("preprocessed transcript re-parse should succeed"); + assert_eq!(output.digest, reparse_digest); +} + +fn prove_verify_reparse(ps: &ProverStatement) +where + MA: MultiAir, +{ + let config = test_config(); + let (output, preprocessed) = prove_with_preprocessed(&config, ps); + verify_and_reparse(&config, ps, &output, &preprocessed); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[test] +fn single_air_with_preprocessed() { + let height = 8; + let ps = prover_statement( + vec![RowCounterAir { preprocessed: row_index_trace(height) }], + vec![row_index_trace(height)], + ); + let config = test_config(); + let (output, preprocessed) = prove_with_preprocessed(&config, &ps); + verify_and_reparse(&config, &ps, &output, &preprocessed); + + let missing_commitment = VerifierInstance::new(&config, ps.statement(), None); + assert!( + matches!( + missing_commitment, + Err(PreprocessedValidationError::PresenceMismatch { expected: true, actual: false }) + ), + "preprocessed statements require the setup commitment", + ); +} + +#[test] +fn mixed_airs_preprocessed_at_index_1() { + let height = 8; + let ps = prover_statement( + vec![ + MixedAir::Constant(ConstantAir), + MixedAir::RowCounter(RowCounterAir { preprocessed: row_index_trace(height) }), + ], + vec![squaring_trace(height), row_index_trace(height)], + ); + + prove_verify_reparse(&ps); +} + +#[test] +fn rejects_width_mismatch() { + let height = 8; + let ps = prover_statement( + vec![WrongWidthAir { preprocessed: row_index_trace(height) }], + vec![row_index_trace(height)], + ); + let config = test_config(); + + let preprocessed = Preprocessed::build(ps.statement(), &config).expect("has preprocessed"); + let result = ProverInstance::new(&config, &ps, Some(&preprocessed)); + assert!( + matches!( + result, + Err(PreprocessedValidationError::WidthMismatch { expected: 2, actual: 1, .. }) + ), + "expected WidthMismatch {{ expected: 2, actual: 1 }}", + ); +} + +#[test] +fn rejects_height_mismatch() { + // Preprocessed matrix height (4) differs from the main trace height (8). + let ps = prover_statement( + vec![RowCounterAir { preprocessed: row_index_trace(4) }], + vec![row_index_trace(8)], + ); + let config = test_config(); + + let preprocessed = Preprocessed::build(ps.statement(), &config).expect("has preprocessed"); + let result = ProverInstance::new(&config, &ps, Some(&preprocessed)); + assert!( + matches!( + result, + Err(PreprocessedValidationError::HeightMismatch { main: 8, preprocessed: 4, .. }) + ), + "expected HeightMismatch {{ main: 8, preprocessed: 4 }}", + ); +} + +#[test] +fn rejects_log_blowup_mismatch() { + let height = 8; + let ps = prover_statement( + vec![RowCounterAir { preprocessed: row_index_trace(height) }], + vec![row_index_trace(height)], + ); + let build_config = test_config(); + let preprocessed = + Preprocessed::build(ps.statement(), &build_config).expect("has preprocessed"); + + let mut proving_config = test_config(); + proving_config.pcs = PcsParams::new(2, 2, 2, 0, 0, 2, 0).expect("valid PCS params"); + let result = ProverInstance::new(&proving_config, &ps, Some(&preprocessed)); + assert!( + matches!( + result, + Err(PreprocessedValidationError::LdeHeightMismatch { + log_blowup: 2, + expected: 32, + actual: 64, + .. + }) + ), + "expected LdeHeightMismatch for setup built with a different log_blowup", + ); +} + +#[test] +fn rejects_wrong_trusted_preprocessed_commitment() { + let height = 8; + let ps = prover_statement( + vec![RowCounterAir { preprocessed: row_index_trace(height) }], + vec![row_index_trace(height)], + ); + let config = test_config(); + let (output, _preprocessed) = prove_with_preprocessed(&config, &ps); + + let wrong_ps = prover_statement( + vec![RowCounterAir { + preprocessed: shifted_row_index_trace(height), + }], + vec![row_index_trace(height)], + ); + let wrong_preprocessed = + Preprocessed::build(wrong_ps.statement(), &config).expect("has preprocessed"); + let verifier_instance = + VerifierInstance::new(&config, ps.statement(), Some(wrong_preprocessed.commitment())) + .expect("presence is valid"); + + assert!( + verifier_instance.verify(&output.proof, test_challenger()).is_err(), + "verification must reject a proof checked against the wrong trusted setup commitment", + ); +} + +#[test] +fn preprocessed_shorter_than_max_trace() { + // The tallest AIR (ConstantAir, height 8) has no preprocessed columns, so the + // tallest preprocessed trace (RowCounter, height 4) sits below the max trace + // height. The PCS virtually lifts the shorter preprocessed tree. + let ps = prover_statement( + vec![ + MixedAir::Constant(ConstantAir), + MixedAir::RowCounter(RowCounterAir { preprocessed: row_index_trace(4) }), + ], + vec![squaring_trace(8), row_index_trace(4)], + ); + + prove_verify_reparse(&ps); +} + +#[test] +fn preprocessed_much_shorter_than_max_trace() { + // Larger lift ratio: the preprocessed tree (height 2) is folded by 4 at query + // time against a max trace height of 8. + let ps = prover_statement( + vec![ + MixedAir::Constant(ConstantAir), + MixedAir::RowCounter(RowCounterAir { preprocessed: row_index_trace(2) }), + ], + vec![squaring_trace(8), row_index_trace(2)], + ); + + prove_verify_reparse(&ps); +} + +#[test] +fn preprocessed_multiple_heights_below_max() { + // Two preprocessed AIRs at heights 2 and 4 (so the preprocessed tree lifts + // internally to 4), under a taller non-preprocessed AIR at height 8. Exercises + // both within-tree lifting and the global query fold. + let ps = prover_statement( + vec![ + MixedAir::Constant(ConstantAir), + MixedAir::RowCounter(RowCounterAir { preprocessed: row_index_trace(4) }), + MixedAir::RowCounter(RowCounterAir { preprocessed: row_index_trace(2) }), + ], + vec![squaring_trace(8), row_index_trace(4), row_index_trace(2)], + ); + + prove_verify_reparse(&ps); +} diff --git a/stark/miden-lifted-stark/src/testing/test_tiny_air.rs b/stark/miden-lifted-stark/src/testing/test_tiny_air.rs index 2f09430d9a..03240ca3b9 100644 --- a/stark/miden-lifted-stark/src/testing/test_tiny_air.rs +++ b/stark/miden-lifted-stark/src/testing/test_tiny_air.rs @@ -7,24 +7,29 @@ use p3_field::PrimeCharacteristicRing; use p3_matrix::{Matrix, dense::RowMajorMatrix}; use crate::{ - AirWitness, InstanceValidationError, ProverError, VerifierError, + ProverInstance, VerifierError, VerifierInstance, air::{ - AirBuilder, AuxBuilder, BaseAir, ExtensionBuilder, LiftedAir, LiftedAirBuilder, - WindowAccess, + AirBuilder, BaseAir, ExtensionBuilder, InstanceError, LiftedAir, LiftedAirBuilder, + MultiAir, ProverStatement, Statement, WindowAccess, }, - prove_multi, prove_single, + domain::DomainError, + order::{ShapeError, TraceOrder}, + proof::{TranscriptData, TranscriptError}, testing::configs::goldilocks_poseidon2::{ Felt, QuadFelt, generate_pow4_trace, prove_and_verify, test_challenger, test_config, }, - transcript::TranscriptData, - verify_single, }; // --------------------------------------------------------------------------- // TinyAir: main[0] starts at public_values[0], each row is previous^4. // Optional periodic columns with pattern [1, 0, ..., 0, 1] per period. +// +// All AIRs in a proof share the same public_values, so multi-trace tests +// give every instance an identical starting value. // --------------------------------------------------------------------------- +const START: u64 = 2; + #[derive(Clone, Debug)] struct TinyAir { /// Pre-computed periodic column data. @@ -73,8 +78,14 @@ impl LiftedAir for TinyAir { 1 } - fn num_var_len_public_inputs(&self) -> usize { - 0 + fn build_aux_trace( + &self, + main: &RowMajorMatrix, + _air_inputs: &[Felt], + _aux_inputs: &[Felt], + challenges: &[QuadFelt], + ) -> (RowMajorMatrix, Vec) { + tiny_aux(main, challenges) } fn eval>(&self, builder: &mut AB) { @@ -111,68 +122,97 @@ impl LiftedAir for TinyAir { } } -/// AuxBuilder for TinyAir: aux column = challenge^{4^row}. -struct TinyAuxBuilder; +/// Aux build for TinyAir: aux column = challenge^{4^row}. +fn tiny_aux( + main: &RowMajorMatrix, + challenges: &[QuadFelt], +) -> (RowMajorMatrix, Vec) { + let height = main.height(); + let challenge = challenges[0]; + + let mut col_values = Vec::with_capacity(height); + let mut current = challenge; + for _ in 0..height { + col_values.push(current); + current = current.exp_power_of_2(2); + } -impl AuxBuilder for TinyAuxBuilder { - fn build_aux_trace( - &self, - main: &RowMajorMatrix, - challenges: &[QuadFelt], - ) -> (RowMajorMatrix, Vec) { - let height = main.height(); - let challenge = challenges[0]; - - let mut col_values = Vec::with_capacity(height); - let mut current = challenge; - for _ in 0..height { - col_values.push(current); - current = current.exp_power_of_2(2); - } + let aux_values = vec![col_values[height - 1]]; + let aux_trace = RowMajorMatrix::new(col_values, 1); + (aux_trace, aux_values) +} + +struct TinyMultiAir { + airs: Vec, +} + +impl MultiAir for TinyMultiAir { + type Air = TinyAir; - let aux_trace = RowMajorMatrix::new(col_values.clone(), 1); - let aux_values = vec![col_values[height - 1]]; - (aux_trace, aux_values) + fn airs(&self) -> &[Self::Air] { + &self.airs } } -/// Build a (trace, public_values) pair for instance `idx`. -fn instance(idx: usize, height: usize) -> (RowMajorMatrix, Vec) { - let start = Felt::from_u64((idx + 2) as u64); - (generate_pow4_trace(start, height), vec![start]) +/// Build a [`ProverStatement`] for tests. +/// +/// Returns `Err` if validation fails (so callers can exercise error paths). +fn tiny_prover_statement( + airs: Vec, + traces: Vec>, + air_inputs: Vec, +) -> Result, InstanceError> { + let statement = Statement::new(TinyMultiAir { airs }, air_inputs, Vec::new())?; + ProverStatement::new(statement, traces) +} + +fn trace_of_height(height: usize) -> RowMajorMatrix { + generate_pow4_trace(Felt::from_u64(START), height) } // --------------------------------------------------------------------------- // Single-trace tests // --------------------------------------------------------------------------- +#[test] +fn prover_statement_rejects_one_row_trace() { + assert!(matches!( + tiny_prover_statement( + vec![TinyAir::new(vec![])], + vec![trace_of_height(1)], + vec![Felt::from_u64(START)], + ), + Err(InstanceError::TraceHeightTooSmall { air: 0, height: 1 }) + )); +} + #[test] fn single_trace() { - prove_and_verify(&TinyAir::new(vec![]), &TinyAuxBuilder, &[instance(0, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![]), &pv, &[trace_of_height(8)]); } #[test] fn malformed_transcript_is_rejected() { let config = test_config(); - let air = TinyAir::new(vec![]); - - let (trace, public_values) = instance(0, 4); - - let output = prove_single( - &config, - &air, - &trace, - &public_values, - &[], - &TinyAuxBuilder, - test_challenger(), + let prover_statement = tiny_prover_statement( + vec![TinyAir::new(vec![])], + vec![trace_of_height(4)], + vec![Felt::from_u64(START)], ) - .expect("proving should succeed"); + .expect("valid"); + + let output = ProverInstance::new(&config, &prover_statement, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect("proving should succeed"); // Baseline should verify - let _digest = - verify_single(&config, &air, &public_values, &[], &output.proof, test_challenger()) - .expect("baseline proof should verify"); + let baseline_statement = VerifierInstance::new(&config, prover_statement.statement(), None) + .expect("no preprocessed columns"); + let _digest = baseline_statement + .verify(&output.proof, test_challenger()) + .expect("baseline proof should verify"); // Extra field element should cause rejection let mut bad_proof = output.proof; @@ -180,117 +220,54 @@ fn malformed_transcript_is_rejected() { fields.push(Felt::ONE); bad_proof.transcript = TranscriptData::new(fields, commitments); - let err = verify_single(&config, &air, &public_values, &[], &bad_proof, test_challenger()) + let err = baseline_statement + .verify(&bad_proof, test_challenger()) .expect_err("extra transcript data should fail verification"); - assert!(matches!( - err, - VerifierError::Transcript(crate::transcript::TranscriptError::TrailingData) - )); + assert!(matches!(err, VerifierError::Transcript(TranscriptError::TrailingData))); } #[test] fn malformed_log_trace_heights_is_rejected() { let config = test_config(); - let air = TinyAir::new(vec![]); - - let (trace, public_values) = instance(0, 4); - - let output = prove_single( - &config, - &air, - &trace, - &public_values, - &[], - &TinyAuxBuilder, - test_challenger(), + let prover_statement = tiny_prover_statement( + vec![TinyAir::new(vec![])], + vec![trace_of_height(4)], + vec![Felt::from_u64(START)], ) - .expect("proving should succeed"); - - // Push straight to the `pub(crate)` field to bypass - // `InstanceShapes::from_trace_heights` and exercise the verifier-path - // bound check in `validate_inputs`. - let mut bad_proof = output.proof.clone(); - bad_proof.instance_shapes.log_trace_heights.push(2); - bad_proof.instance_shapes.air_order.push(1); - let err = verify_single(&config, &air, &public_values, &[], &bad_proof, test_challenger()) - .expect_err("extra log trace height should fail verification"); - assert!(matches!( - err, - VerifierError::Instance(InstanceValidationError::AirOrderLengthMismatch { - instances: 1, - air_order: 2, - }) - )); - - // Empty heights → air_order / instance count mismatch. - let mut bad_proof = output.proof.clone(); - bad_proof.instance_shapes.log_trace_heights.clear(); - bad_proof.instance_shapes.air_order.clear(); - let err = verify_single(&config, &air, &public_values, &[], &bad_proof, test_challenger()) - .expect_err("empty log trace heights should fail verification"); - assert!(matches!( - err, - VerifierError::Instance(InstanceValidationError::AirOrderLengthMismatch { - instances: 1, - air_order: 0, - }) - )); - - // Out-of-range log height must surface as an error, not panic on - // `1usize << log_h` or `two_adic_generator(log_h + log_blowup)`. + .expect("valid"); + let statement = prover_statement.statement(); + let stark_statement = + VerifierInstance::new(&config, statement, None).expect("no preprocessed columns"); + + let output = ProverInstance::new(&config, &prover_statement, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect("proving should succeed"); + + // Poke the `pub(crate)` `log_trace_heights` directly to feed the verifier + // malformed shapes that bypass `ProverStatement` construction — the cases + // that would otherwise panic or overflow rather than return a clean error. + + // log_h = 200 trips the `usize` overflow guard in `TraceOrder::from_log_heights`, + // before `1usize << log_h` could overflow. let mut bad_proof = output.proof.clone(); - bad_proof.instance_shapes.log_trace_heights = vec![200]; - let err = verify_single(&config, &air, &public_values, &[], &bad_proof, test_challenger()) + bad_proof.log_trace_heights = vec![200]; + let err = stark_statement + .verify(&bad_proof, test_challenger()) .expect_err("oversized log trace height should fail verification"); assert!(matches!( err, - VerifierError::Instance(InstanceValidationError::LdeDomainExceedsTwoAdicity { - log_h: 200, - .. - }) + VerifierError::Shape(ShapeError::LogTraceHeightTooLarge { log_h: 200, .. }) )); - // Boundary case: `log_h` fits the raw bound (`log_h ≤ TWO_ADICITY`) but - // the LDE domain `log_h + log_blowup` does not. With `log_blowup = 2` - // from `TEST_PCS_PARAMS` and `Felt::TWO_ADICITY = 32`, `31 + 2 = 33 > 32` - // must be rejected before any `two_adic_generator` call on the LDE domain. + // log_h = 30 with log_blowup = 3 and Felt::TWO_ADICITY = 32 overflows the LDE + // domain (33 > 32), rejected by `try_canonical` before any generator lookup. let mut bad_proof = output.proof; - bad_proof.instance_shapes.log_trace_heights = vec![31]; - let err = verify_single(&config, &air, &public_values, &[], &bad_proof, test_challenger()) + bad_proof.log_trace_heights = vec![30]; + let err = stark_statement + .verify(&bad_proof, test_challenger()) .expect_err("log_h + log_blowup exceeding two-adicity should fail verification"); - assert!(matches!( - err, - VerifierError::Instance(InstanceValidationError::LdeDomainExceedsTwoAdicity { - log_h: 31, - log_blowup: 2, - .. - }) - )); -} - -#[test] -fn prover_rejects_non_power_of_two_trace_height() { - // Build the witness directly (via `pub` fields) to skip the - // power-of-two assertion in `AirWitness::new`. `InstanceShapes::from_trace_heights` - // must reject it rather than panicking inside `log2_strict_u8`. - let config = test_config(); - let air = TinyAir::new(vec![]); - - let trace = - RowMajorMatrix::new(vec![Felt::from_u64(2), Felt::from_u64(16), Felt::from_u64(65536)], 1); - let public_values = vec![Felt::from_u64(2)]; - let bad_witness = AirWitness { - trace: &trace, - public_values: &public_values, - var_len_public_inputs: &[], - }; - - let result = prove_multi(&config, &[(&air, bad_witness, &TinyAuxBuilder)], test_challenger()); - match result { - Err(ProverError::Instance(InstanceValidationError::InvalidTraceHeight { height: 3 })) => {}, - Err(other) => panic!("expected InvalidTraceHeight {{ height: 3 }}, got {other:?}"), - Ok(_) => panic!("non-power-of-two trace height should fail proving"), - } + assert!(matches!(err, VerifierError::Domain(DomainError::LdeOrderTooLarge { .. }))); } // --------------------------------------------------------------------------- @@ -299,20 +276,23 @@ fn prover_rejects_non_power_of_two_trace_height() { #[test] fn two_traces_same_height() { - prove_and_verify(&TinyAir::new(vec![]), &TinyAuxBuilder, &[instance(0, 8), instance(1, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![]), &pv, &[trace_of_height(8), trace_of_height(8)]); } #[test] fn two_traces_different_heights() { - prove_and_verify(&TinyAir::new(vec![]), &TinyAuxBuilder, &[instance(0, 4), instance(1, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![]), &pv, &[trace_of_height(4), trace_of_height(8)]); } #[test] fn three_traces_ascending_heights() { + let pv = vec![Felt::from_u64(START)]; prove_and_verify( &TinyAir::new(vec![]), - &TinyAuxBuilder, - &[instance(0, 4), instance(1, 8), instance(2, 16)], + &pv, + &[trace_of_height(4), trace_of_height(8), trace_of_height(16)], ); } @@ -322,64 +302,74 @@ fn three_traces_ascending_heights() { #[test] fn two_traces_reversed_order() { - prove_and_verify(&TinyAir::new(vec![]), &TinyAuxBuilder, &[instance(1, 8), instance(0, 4)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![]), &pv, &[trace_of_height(8), trace_of_height(4)]); } #[test] fn three_traces_descending_heights() { + let pv = vec![Felt::from_u64(START)]; prove_and_verify( &TinyAir::new(vec![]), - &TinyAuxBuilder, - &[instance(2, 16), instance(1, 8), instance(0, 4)], + &pv, + &[trace_of_height(16), trace_of_height(8), trace_of_height(4)], ); } #[test] fn three_traces_shuffled_order() { + let pv = vec![Felt::from_u64(START)]; prove_and_verify( &TinyAir::new(vec![]), - &TinyAuxBuilder, - &[instance(1, 8), instance(2, 16), instance(0, 4)], + &pv, + &[trace_of_height(8), trace_of_height(16), trace_of_height(4)], ); } #[test] fn periodic_columns_reversed_order() { - prove_and_verify(&TinyAir::new(vec![2, 4]), &TinyAuxBuilder, &[instance(1, 8), instance(0, 4)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![2, 4]), &pv, &[trace_of_height(8), trace_of_height(4)]); } #[test] fn air_order_reflects_caller_order() { let config = test_config(); - let air = TinyAir::new(vec![]); + let prover_statement = tiny_prover_statement( + vec![TinyAir::new(vec![]), TinyAir::new(vec![])], + // Pass traces in reverse height order: [height=8, height=4]. + vec![trace_of_height(8), trace_of_height(4)], + vec![Felt::from_u64(START)], + ) + .expect("valid"); - // Pass instances in reverse height order: [height=8, height=4]. - let (t0, pv0) = instance(0, 8); - let (t1, pv1) = instance(1, 4); + let output = ProverInstance::new(&config, &prover_statement, None) + .expect("no preprocessed columns") + .prove(test_challenger()) + .expect("proving should succeed"); - let w0 = AirWitness::new(&t0, &pv0, &[]); - let w1 = AirWitness::new(&t1, &pv1, &[]); + // The proof carries heights in instance order: [height=8, height=4] + // → [log_h=3, log_h=2]. The proof's AIR ordering itself is implicit + // (recomputed from the heights via TraceOrder). + assert_eq!( + output.proof.log_trace_heights.as_slice(), + &[3, 2], + "log heights should be in instance order (8=2^3, 4=2^2)" + ); - let output = prove_multi( - &config, - &[(&air, w0, &TinyAuxBuilder), (&air, w1, &TinyAuxBuilder)], - test_challenger(), + // The derived proof ordering is ascending height: instance index 1 + // (log_h=2) ends up at proof position 0, instance index 0 (log_h=3) at + // position 1. + let trace_order = TraceOrder::from_log_heights::( + prover_statement.statement().airs(), + output.proof.log_trace_heights, ) - .expect("proving should succeed"); - - // Proof ordering is ascending height: [height=4, height=8]. - // Caller index 1 (height=4) is at position 0 in the proof's ordering. - // Caller index 0 (height=8) is at position 1 in the proof's ordering. - let air_order = output.proof.instance_shapes.air_order(); + .expect("valid heights"); assert_eq!( - air_order, + trace_order.instance_indices(), &[1, 0], - "air_order should map ascending-height position → caller index" + "trace order should map ascending-height position → instance index" ); - - // Log trace heights should be in ascending order. - let log_heights = output.proof.instance_shapes.log_trace_heights(); - assert_eq!(log_heights, &[2, 3], "log heights should be ascending (4=2^2, 8=2^3)"); } // --------------------------------------------------------------------------- @@ -388,34 +378,40 @@ fn air_order_reflects_caller_order() { #[test] fn single_periodic_column() { - prove_and_verify(&TinyAir::new(vec![2]), &TinyAuxBuilder, &[instance(0, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![2]), &pv, &[trace_of_height(8)]); } #[test] fn periodic_column_period_4() { - prove_and_verify(&TinyAir::new(vec![4]), &TinyAuxBuilder, &[instance(0, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![4]), &pv, &[trace_of_height(8)]); } #[test] fn multiple_periodic_columns() { - prove_and_verify(&TinyAir::new(vec![2, 4]), &TinyAuxBuilder, &[instance(0, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![2, 4]), &pv, &[trace_of_height(8)]); } #[test] fn periodic_columns_multi_trace_same_height() { - prove_and_verify(&TinyAir::new(vec![2]), &TinyAuxBuilder, &[instance(0, 8), instance(1, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![2]), &pv, &[trace_of_height(8), trace_of_height(8)]); } #[test] fn periodic_columns_multi_trace_different_heights() { - prove_and_verify(&TinyAir::new(vec![2, 4]), &TinyAuxBuilder, &[instance(0, 4), instance(1, 8)]); + let pv = vec![Felt::from_u64(START)]; + prove_and_verify(&TinyAir::new(vec![2, 4]), &pv, &[trace_of_height(4), trace_of_height(8)]); } #[test] fn periodic_columns_three_traces() { + let pv = vec![Felt::from_u64(START)]; prove_and_verify( &TinyAir::new(vec![2, 4]), - &TinyAuxBuilder, - &[instance(0, 4), instance(1, 8), instance(2, 16)], + &pv, + &[trace_of_height(4), trace_of_height(8), trace_of_height(16)], ); } diff --git a/stark/miden-lifted-stark/src/util/align.rs b/stark/miden-lifted-stark/src/util/align.rs new file mode 100644 index 0000000000..f4129f1757 --- /dev/null +++ b/stark/miden-lifted-stark/src/util/align.rs @@ -0,0 +1,21 @@ +//! Width / length alignment helpers. + +use alloc::vec::Vec; + +/// Compute the aligned length for `len` given an alignment. +#[inline] +pub(crate) const fn aligned_len(len: usize, alignment: usize) -> usize { + if alignment <= 1 { + len + } else { + len.next_multiple_of(alignment) + } +} + +/// Align each width in place, returning the same `Vec`. +pub(crate) fn aligned_widths(mut widths: Vec, alignment: usize) -> Vec { + for w in &mut widths { + *w = aligned_len(*w, alignment); + } + widths +} diff --git a/stark/miden-lifted-stark/src/lmcs/bitrev.rs b/stark/miden-lifted-stark/src/util/bitrev.rs similarity index 88% rename from stark/miden-lifted-stark/src/lmcs/bitrev.rs rename to stark/miden-lifted-stark/src/util/bitrev.rs index d2f9216392..4f5189cd6d 100644 --- a/stark/miden-lifted-stark/src/lmcs/bitrev.rs +++ b/stark/miden-lifted-stark/src/util/bitrev.rs @@ -1,4 +1,4 @@ -//! Local copy of `BitReversibleMatrix` with additional impls for [`FlatMatrixView`]. +//! Bit-reversal helpers and a stopgap [`BitReversibleMatrix`] trait. //! //! # Temporary stopgap //! @@ -6,8 +6,8 @@ //! [`DenseMatrix`], not for [`FlatMatrixView`]. This module provides an identical //! trait with impls for all matrix types used by the LMCS and FRI. //! -//! Once an upstream impl is available, this module can be removed and all uses -//! replaced with `p3_matrix::bitrev::BitReversibleMatrix`. +//! Once an upstream impl is available, this trait (and `materialize_bitrev`) can +//! be removed and all uses replaced with `p3_matrix::bitrev::BitReversibleMatrix`. use p3_field::{ExtensionField, Field}; use p3_matrix::{ @@ -17,21 +17,6 @@ use p3_matrix::{ extension::FlatMatrixView, }; -/// Materialize a matrix into domain-ordered `BitReversedMatrixView>`. -/// -/// Temporary adapter for types that implement the upstream -/// [`p3_matrix::bitrev::BitReversibleMatrix`] but not this crate's local copy. -/// The returned type implements both traits and can be passed directly to -/// [`Lmcs::build_tree`](crate::lmcs::Lmcs::build_tree) / -/// [`Lmcs::build_aligned_tree`](crate::lmcs::Lmcs::build_aligned_tree). -/// -/// Remove alongside this module when upstream impls cover all DFT output types. -pub fn materialize_bitrev( - evals: impl p3_matrix::bitrev::BitReversibleMatrix, -) -> BitReversedMatrixView> { - BitReversalPerm::new_view(evals.bit_reverse_rows().to_row_major_matrix()) -} - /// A matrix that supports bit-reversed row reordering. /// /// Local copy of `p3_matrix::bitrev::BitReversibleMatrix` extended with impls for @@ -101,3 +86,18 @@ where self.inner } } + +/// Materialize a matrix into domain-ordered `BitReversedMatrixView>`. +/// +/// Temporary adapter for types that implement the upstream +/// [`p3_matrix::bitrev::BitReversibleMatrix`] but not this crate's local copy. +/// The returned type implements both traits and can be passed directly to +/// [`Lmcs::build_tree`](crate::lmcs::Lmcs::build_tree) / +/// [`Lmcs::build_aligned_tree`](crate::lmcs::Lmcs::build_aligned_tree). +/// +/// Remove alongside [`BitReversibleMatrix`] when upstream impls cover all DFT output types. +pub(crate) fn materialize_bitrev( + evals: impl p3_matrix::bitrev::BitReversibleMatrix, +) -> BitReversedMatrixView> { + BitReversalPerm::new_view(evals.bit_reverse_rows().to_row_major_matrix()) +} diff --git a/stark/miden-lifted-stark/src/util/horner.rs b/stark/miden-lifted-stark/src/util/horner.rs new file mode 100644 index 0000000000..c51a63f6bf --- /dev/null +++ b/stark/miden-lifted-stark/src/util/horner.rs @@ -0,0 +1,34 @@ +//! Horner-style polynomial evaluation helpers. + +use core::ops::{Add, Mul}; + +/// Horner fold with an explicit accumulator. +/// +/// Computes `acc·xⁿ + v₀·xⁿ⁻¹ + v₁·xⁿ⁻² + ... + vₙ₋₁·x⁰` where n = len(vals). +/// Equivalently: `((acc·x + v₀)·x + v₁)·x + ... + vₙ₋₁`. +/// The first element gets the highest power of `x`. +/// +/// For polynomial evaluation `p(x) = Σᵢ cᵢ·xⁱ`, pass coefficients in +/// descending degree order `[cₙ, ..., c₁, c₀]`. +#[inline] +pub(crate) fn horner_acc(acc: Acc, x: X, vals: I) -> Acc +where + I: IntoIterator, + Acc: Mul + Add, + X: Clone, +{ + vals.into_iter().fold(acc, |acc, val| acc * x.clone() + val) +} + +/// Horner fold starting from zero. +/// +/// See [`horner_acc`] for the evaluation convention. +#[inline] +pub(crate) fn horner(x: X, vals: I) -> Acc +where + I: IntoIterator, + Acc: Default + Mul + Add, + X: Clone, +{ + horner_acc(Acc::default(), x, vals) +} diff --git a/stark/miden-lifted-stark/src/util/mod.rs b/stark/miden-lifted-stark/src/util/mod.rs new file mode 100644 index 0000000000..b45ab05619 --- /dev/null +++ b/stark/miden-lifted-stark/src/util/mod.rs @@ -0,0 +1,6 @@ +//! Crate-wide utility helpers shared across LMCS, PCS, prover, and verifier. + +pub(crate) mod align; +pub(crate) mod bitrev; +pub(crate) mod horner; +pub(crate) mod packing; diff --git a/stark/miden-lifted-stark/src/util/packing.rs b/stark/miden-lifted-stark/src/util/packing.rs new file mode 100644 index 0000000000..5b928c20b8 --- /dev/null +++ b/stark/miden-lifted-stark/src/util/packing.rs @@ -0,0 +1,72 @@ +//! Packing / column-transpose helpers for SIMD operations on packed values +//! and packed extension-field elements. + +use alloc::vec::Vec; +use core::array; + +use p3_field::{ExtensionField, Field, PackedFieldExtension, PackedValue}; + +/// Extension trait for [`PackedFieldExtension`] adding `pack_ext_columns` and +/// `to_ext_slice` methods for column-wise SIMD operations on extension field elements. +pub(crate) trait PackedFieldExtensionExt< + BaseField: Field, + ExtField: ExtensionField, +>: PackedFieldExtension +{ + /// Pack N columns from WIDTH rows into N packed extension field elements. + /// + /// Input: `rows[lane][col]` - WIDTH rows, each with N extension field elements. + /// Output: `result[col]` - N packed values, where each packs WIDTH lanes. + fn pack_ext_columns(rows: &[[ExtField; N]]) -> [Self; N] { + let width = BaseField::Packing::WIDTH; + debug_assert_eq!(rows.len(), width); + array::from_fn(|col| { + let col_elems: Vec = (0..width).map(|lane| rows[lane][col]).collect(); + Self::from_ext_slice(&col_elems) + }) + } + + /// Extract all lanes to an output slice. + fn to_ext_slice(&self, out: &mut [ExtField]) { + let width = BaseField::Packing::WIDTH; + for (lane, slot) in out.iter_mut().enumerate().take(width) { + *slot = self.extract(lane); + } + } +} + +impl< + BaseField: Field, + ExtField: ExtensionField, + P: PackedFieldExtension, +> PackedFieldExtensionExt for P +{ +} + +/// Reconstitute EF elements from opened base field polynomial evaluations. +/// +/// When an EF polynomial is committed, it becomes DIM base field polynomials. +/// Opening at EF point z gives DIM EF values (F-polys evaluated at EF point). +/// Reconstruct each EF element: `vᵢ = Σⱼ basisⱼ·row[i·DIM + j]`. +/// +/// Returns `None` if `row.len()` is not a multiple of `EF::DIMENSION`. +pub(crate) fn row_to_packed_ext(row: &[EF]) -> Option> +where + F: Field, + EF: ExtensionField, +{ + if !row.len().is_multiple_of(EF::DIMENSION) { + return None; + } + Some( + row.chunks_exact(EF::DIMENSION) + .map(|chunk| { + chunk + .iter() + .enumerate() + .map(|(j, &c)| EF::ith_basis_element(j).unwrap() * c) + .sum() + }) + .collect(), + ) +} diff --git a/stark/miden-lifted-stark/src/verifier/README.md b/stark/miden-lifted-stark/src/verifier/README.md index 7fc5c69a4a..dcd4e0bada 100644 --- a/stark/miden-lifted-stark/src/verifier/README.md +++ b/stark/miden-lifted-stark/src/verifier/README.md @@ -2,7 +2,7 @@ End-to-end verification for the lifted STARK protocol using LMCS commitments and the lifted FRI PCS. Supports multiple traces of different -power-of-two heights via virtual lifting. +power-of-two heights of at least 2 rows via virtual lifting. Protocol-level overview lives in `miden-lifted-stark/README.md`. @@ -10,16 +10,21 @@ Protocol-level overview lives in `miden-lifted-stark/README.md`. | Item | Purpose | |------|---------| -| `verify_single` | Verify a single-AIR proof | -| `verify_multi` | Verify a multi-trace proof | -| `AirInstance` | Public values + variable-length inputs for one AIR | +| `verify` | Verify a `Statement` | +| `Statement` | A `MultiAir` plus the per-proof inputs (`air_inputs`, optional `aux_inputs`) | +| `MultiAir` | Trusted statement definition: AIR instances, cross-AIR assertions, and a Fiat-Shamir `observe` hook | | `StarkProof` | Log trace heights + raw transcript data | ```text -verify_single(config, air, public_values, var_len_public_inputs, proof, challenger) -verify_multi(config, &[(air, instance), ...], proof, challenger) +verify(config, &statement, proof, challenger) ``` +The `statement` carries its `MultiAir` (the AIRs and the cross-AIR assertions +from `eval_external`) plus the statement-owned `air_inputs` and (if any) +`aux_inputs`. The framework absorbs both `air_inputs` and `aux_inputs` into +Fiat-Shamir automatically via `Statement::observe` — callers must pass a +`Statement` carrying the same data on prover and verifier sides. + The proof is read from the provided transcript channel. This crate does not prescribe the *initial* challenger state used for Fiat-Shamir. @@ -30,13 +35,13 @@ prover module-level docs for the full binding contract. ## Transcript boundaries -`verify_multi` rejects trailing transcript data (`TranscriptNotConsumed`). If you +`verify` rejects trailing transcript data (`TranscriptError::TrailingData`). If you bundle extra data in the same transcript, you must manage boundaries yourself. ## Protocol flow -0. Validate `air_order` from the proof and reorder caller instances to match. -1. Observe log trace heights into the challenger (from proof, not transcript). +0. Reconstruct `TraceOrder` from the proof's log trace heights and reorder caller AIRs into the proof's ascending-height ordering. +1. Absorb statement-owned inputs via `Statement::observe`, then absorb the instance count and per-instance log trace heights into the challenger. 2. Receive main trace commitment. 3. Sample aux randomness. 4. Receive aux trace commitment. @@ -47,8 +52,9 @@ bundle extra data in the same transcript, you must manage boundaries yourself. 9. Reconstruct `Q(z)` from the opened quotient chunks. 10. For each trace instance j, set `y_j = z^{r_j}` and evaluate folded constraints at `y_j`. 11. Accumulate across traces with `beta`. -12. Check quotient identity: `accumulated == Q(z) * (z^N - 1)`. -13. Ensure transcript is fully consumed. +12. Call `statement.eval_external(...)` once with the global view (challenges, all aux values in instance order, log heights in instance order) and check each returned EF value is zero. +13. Check quotient identity: `accumulated == Q(z) * (z^N - 1)`. +14. Ensure transcript is fully consumed. ## Mathematical background @@ -160,7 +166,7 @@ $$ Z_{H^{r_j}}(y_j) = y_j^{n_j} - 1, $$ -and the unnormalized selector formulas (matching `LiftedCoset::selectors_at`): +and the unnormalized selector formulas (matching `LiftedDomain::selectors_at`): $$ \mathrm{is\_first}(y) = \frac{Z_{H^{r_j}}(y)}{y-1}, diff --git a/stark/miden-lifted-stark/src/verifier/constraints.rs b/stark/miden-lifted-stark/src/verifier/constraints.rs index c331e8e45c..75b37beab1 100644 --- a/stark/miden-lifted-stark/src/verifier/constraints.rs +++ b/stark/miden-lifted-stark/src/verifier/constraints.rs @@ -1,20 +1,16 @@ -//! Constraint evaluation and quotient reconstruction for the verifier. +//! Constraint evaluation for the verifier. //! -//! This module provides: -//! - [`ConstraintFolder`]: Minimal EF-only folder for verifier constraint evaluation -//! - [`reconstruct_quotient`]: Reconstructs Q(z) from quotient chunk evaluations -//! - [`row_to_packed_ext`]: Reconstitutes EF elements from opened base field evaluations +//! Provides [`ConstraintFolder`], a minimal EF-only folder evaluating the AIR's +//! constraints at a single OOD extension-field point. -use alloc::vec::Vec; use core::marker::PhantomData; use miden_lifted_air::{ - AirBuilder, EmptyWindow, ExtensionBuilder, PeriodicAirBuilder, PermutationAirBuilder, RowWindow, + AirBuilder, ExtensionBuilder, PeriodicAirBuilder, PermutationAirBuilder, RowWindow, }; -use p3_field::{ExtensionField, Field, TwoAdicField}; -use p3_util::log2_strict_usize; +use p3_field::{ExtensionField, Field}; -use crate::{coset::LiftedCoset, selectors::Selectors, verifier::VerifierError}; +use crate::selectors::Selectors; // ============================================================================ // ConstraintFolder @@ -34,12 +30,13 @@ use crate::{coset::LiftedCoset, selectors::Selectors, verifier::VerifierError}; /// `Σₖ α^{K−1−k}·Cₖ(z)`, but is cheaper for a single-point evaluation. /// The prover computes an equivalent fold over the whole quotient domain, optimized /// with base-field SIMD where possible. -pub struct ConstraintFolder<'a, F, EF> +pub(super) struct ConstraintFolder<'a, F, EF> where F: Field, EF: ExtensionField, { pub main: RowWindow<'a, EF>, + pub preprocessed: RowWindow<'a, EF>, pub aux: RowWindow<'a, EF>, pub randomness: &'a [EF], pub public_values: &'a [F], @@ -59,7 +56,7 @@ where type F = F; type Expr = EF; type Var = EF; - type PreprocessedWindow = EmptyWindow; + type PreprocessedWindow = RowWindow<'a, EF>; type MainWindow = RowWindow<'a, EF>; type PublicVar = F; @@ -68,7 +65,7 @@ where } fn preprocessed(&self) -> &Self::PreprocessedWindow { - EmptyWindow::empty_ref() + &self.preprocessed } fn is_first_row(&self) -> Self::Expr { @@ -80,11 +77,8 @@ where } fn is_transition_window(&self, size: usize) -> Self::Expr { - if size == 2 { - self.selectors.is_transition - } else { - panic!("only window size 2 supported in this prototype") - } + assert_eq!(size, 2, "AIR uses window size {size}; only 2 supported"); + self.selectors.is_transition } fn assert_zero>(&mut self, x: I) { @@ -146,87 +140,3 @@ where self.periodic_values } } - -// ============================================================================ -// Quotient Reconstruction -// ============================================================================ - -/// Reconstruct `Q(z)` from `D` quotient chunk evaluations. -/// -/// The quotient `Q` is committed as `D` chunk polynomials qₜ of degree `< N`, one for -/// each `H`-coset inside `J`: -/// -/// qₜ agrees with `Q` on the coset `g·ω_Jᵗ·H`. -/// -/// During verification we open all qₜ(z) at the same OOD point `z` and need to -/// recombine them into `Q(z)`. -/// -/// The key observation is that the map `x → xᴺ` collapses each coset -/// `g·ω_Jᵗ·H` to a single `D`-th root of unity. Let -/// - ωₛ = ω_Jᴺ (a `D`-th root of unity), -/// - u = (z/s)ᴺ where s = coset.lde_shift(). -/// -/// Then `Q(z)` is the barycentric interpolation of the values qₜ(z) at the points -/// ωₛᵗ: -/// -/// ```text -/// wₜ = ωₛᵗ / (u − ωₛᵗ) -/// Q(z) = (Σₜ wₜ·qₜ(z)) / (Σₜ wₜ) -/// ``` -pub fn reconstruct_quotient(z: EF, coset: &LiftedCoset, chunks: &[EF]) -> EF -where - F: TwoAdicField, - EF: ExtensionField, -{ - let log_d = log2_strict_usize(chunks.len()); - let shift: F = coset.lde_shift(); - let omega_s = F::two_adic_generator(log_d); - - // u = (z/s)ᴺ where s = lde_shift - let u = (z * shift.inverse()).exp_power_of_2(coset.log_trace_height as usize); - - // Compute weighted sum: Σₜ wₜ·qₜ(z) and Σₜ wₜ - let mut numerator = EF::ZERO; - let mut denominator = EF::ZERO; - let mut omega_s_t = F::ONE; // ωₛᵗ - - for &q_t in chunks.iter() { - let a_t = u - omega_s_t; // aₜ = u − ωₛᵗ - let w_t = a_t.inverse() * omega_s_t; // wₜ = ωₛᵗ / aₜ - - numerator += w_t * q_t; - denominator += w_t; - - omega_s_t *= omega_s; - } - - numerator * denominator.inverse() -} - -/// Reconstitute EF elements from opened base field polynomial evaluations. -/// -/// When an EF polynomial is committed, it becomes DIM base field polynomials. -/// Opening at EF point z gives DIM EF values (F-polys evaluated at EF point). -/// Reconstruct each EF element: `vᵢ = Σⱼ basisⱼ·row[i·DIM + j]`. -/// -/// An EF element `v = Σⱼ cⱼ·basisⱼ` is committed as DIM base field polynomials pⱼ -/// (one per basis coordinate cⱼ). Opening at `z` returns the DIM values pⱼ(z), and we -/// recover the original EF value as `v(z) = Σⱼ basisⱼ·pⱼ(z)`. -pub fn row_to_packed_ext(row: &[EF]) -> Result, VerifierError> -where - F: TwoAdicField, - EF: ExtensionField, -{ - if !row.len().is_multiple_of(EF::DIMENSION) { - return Err(VerifierError::InvalidAuxShape); - } - let num_elements = row.len() / EF::DIMENSION; - Ok((0..num_elements) - .map(|i| { - let start = i * EF::DIMENSION; - (0..EF::DIMENSION) - .map(|j| EF::ith_basis_element(j).unwrap() * row[start + j]) - .sum() - }) - .collect()) -} diff --git a/stark/miden-lifted-stark/src/verifier/mod.rs b/stark/miden-lifted-stark/src/verifier/mod.rs index 0c33fff097..c2efb59c63 100644 --- a/stark/miden-lifted-stark/src/verifier/mod.rs +++ b/stark/miden-lifted-stark/src/verifier/mod.rs @@ -1,11 +1,10 @@ //! Lifted STARK verifier. //! //! This module provides: -//! - [`verify_single`]: Verify a single AIR instance. -//! - [`verify_multi`]: Verify multiple AIR instances with traces of different heights. +//! - [`VerifierInstance::verify`](crate::VerifierInstance::verify): Verify a [`Statement`]. //! -//! These functions take a challenger (consumed by value) and proof data, construct -//! the verifier transcript internally, and return a [`StarkDigest`] on success. +//! Takes a challenger (consumed by value) and proof data, constructs the +//! verifier transcript internally, and returns a [`StarkDigest`] on success. //! The caller must check that the digest matches the prover's digest. //! //! # Fiat-Shamir / transcript binding @@ -14,41 +13,45 @@ //! prover module-level docs for the full binding contract and recommended //! pattern. //! -//! Log trace heights are carried on the [`StarkProof`] and observed into the -//! challenger by [`verify_multi`]. Callers must not pre-observe them. +//! The proof's per-AIR log trace heights are carried on [`StarkProofData`] (in +//! instance order); [`VerifierInstance::verify`](crate::VerifierInstance::verify) +//! observes the derived instance count and those heights at the protocol layer. +//! Callers must not pre-observe them. //! //! # Statement-bound trace heights //! //! The verifier accepts whatever trace heights the proof carries; it never //! compares them against a caller-supplied expectation. If your statement //! fixes the trace size (e.g. a proof for a 2^16-row execution), parse it -//! with -//! [`StarkTranscript::from_proof`](crate::proof::StarkTranscript::from_proof) -//! and check `transcript.instance_shapes.log_trace_heights()` yourself. +//! with [`StarkProof::from_data`](crate::proof::StarkProof::from_data) using the same +//! [`VerifierInstance`], and check `proof.log_trace_heights()` yourself. //! //! # Transcript boundaries (strict consumption) //! -//! [`verify_multi`] finalizes the transcript internally: it rejects proofs with -//! trailing data (via [`TranscriptError::TrailingData`]) and returns a binding -//! digest that must match the prover's digest. +//! [`VerifierInstance::verify`](crate::VerifierInstance::verify) finalizes the +//! transcript internally: it rejects proofs with trailing data (via +//! [`TranscriptError::TrailingData`]) and returns a binding digest that must +//! match the prover's digest. //! //! If you want to bundle extra data alongside the proof, you must manage //! boundaries yourself (e.g. parse and validate that data first, then pass the -//! remaining transcript to [`verify_multi`]). +//! remaining transcript to +//! [`VerifierInstance::verify`](crate::VerifierInstance::verify)). extern crate alloc; -pub mod constraints; -pub mod periodic; +pub(crate) mod constraints; +pub(crate) mod periodic; use alloc::{vec, vec::Vec}; use core::marker::PhantomData; -use constraints::{ConstraintFolder, reconstruct_quotient, row_to_packed_ext}; +use constraints::ConstraintFolder; use miden_lifted_air::{ - LiftedAir, ReducedAuxValues, ReductionError, RowWindow, VarLenPublicInputs, + BaseAir, InstanceError, LiftedAir, MultiAir, ReductionError, RowWindow, Statement, }; use miden_stark_transcript::{Channel, TranscriptError, VerifierChannel, VerifierTranscript}; +use p3_challenger::CanObserve; use p3_field::{ExtensionField, TwoAdicField}; use p3_matrix::Matrix; use periodic::PeriodicPolys; @@ -56,73 +59,141 @@ use thiserror::Error; use crate::{ StarkConfig, - coset::LiftedCoset, - instance::{AirInstance, InstanceValidationError, validate_air_order, validate_inputs}, - pcs::verifier::{PcsError, verify_aligned}, - proof::{StarkDigest, StarkProof}, + domain::{Coset, DomainError, LiftedDomain, log_quotient_degree}, + lmcs::Lmcs, + order::{ShapeError, TraceOrder}, + pcs::verifier::{CommitmentGroup, PcsError, verify_aligned}, + preprocessed::PreprocessedValidationError, + proof::{StarkDigest, StarkProofData}, + util::packing::row_to_packed_ext, }; -/// Errors that can occur during verification. -#[derive(Debug, Error)] -pub enum VerifierError { - #[error("instance validation failed: {0}")] - Instance(#[from] InstanceValidationError), - #[error("PCS verification failed: {0}")] - Pcs(#[from] PcsError), - #[error("transcript error: {0}")] - Transcript(#[from] TranscriptError), - #[error("invalid aux shape")] - InvalidAuxShape, - #[error("constraint mismatch: quotient * vanishing != folded constraints")] - ConstraintMismatch, - #[error( - "constraint degree exceeds blowup: \ - log_quotient_degree {log_quotient_degree} > log_blowup {log_blowup}" - )] - ConstraintDegreeTooHigh { log_quotient_degree: u8, log_blowup: u8 }, - #[error("global reduced aux identity check failed")] - InvalidReducedAux, - #[error("aux value reduction failed: {0}")] - Reduction(ReductionError), -} +// ============================================================================ +// VerifierInstance +// ============================================================================ -/// Verify a single AIR. Convenience wrapper around [`verify_multi`]. +/// Verifier-side bundle: a [`StarkConfig`], a borrowed [`Statement`], and the +/// optional preprocessed commitment (a trusted setup input, not read from the +/// proof). /// -/// The caller's challenger must already be bound to the full statement -/// — see the prover module-level docs. -pub fn verify_single( - config: &SC, - air: &A, - public_values: &[F], - var_len_public_inputs: VarLenPublicInputs<'_, F>, - proof: &StarkProof, - challenger: SC::Challenger, -) -> Result, VerifierError> +/// Construction validates preprocessed presence parity, so holding a +/// `VerifierInstance` guarantees the commitment is present exactly when the +/// AIRs declare preprocessed columns. +pub struct VerifierInstance<'a, F, EF, MA, SC> +where + F: TwoAdicField, + EF: ExtensionField, + MA: MultiAir, + SC: StarkConfig, +{ + config: &'a SC, + statement: &'a Statement, + preprocessed_commitment: Option<::Commitment>, +} + +impl<'a, F, EF, MA, SC> VerifierInstance<'a, F, EF, MA, SC> where F: TwoAdicField, EF: ExtensionField, + MA: MultiAir, SC: StarkConfig, - A: LiftedAir, { - let instance = AirInstance { public_values, var_len_public_inputs }; - verify_multi(config, &[(air, instance)], proof, challenger) + /// Bundle a config + statement with an optional preprocessed commitment. + /// + /// The commitment must be `Some` exactly when some AIR declares preprocessed + /// columns; otherwise this errors with + /// [`PreprocessedValidationError::PresenceMismatch`]. + pub fn new( + config: &'a SC, + statement: &'a Statement, + preprocessed_commitment: Option<::Commitment>, + ) -> Result { + let expected = statement.airs().iter().any(|a| a.preprocessed_width() > 0); + let actual = preprocessed_commitment.is_some(); + if expected != actual { + return Err(PreprocessedValidationError::PresenceMismatch { expected, actual }); + } + Ok(Self { + config, + statement, + preprocessed_commitment, + }) + } + + /// Verify a proof against this instance. + pub fn verify( + &self, + proof: &StarkProofData, + challenger: SC::Challenger, + ) -> Result, VerifierError> { + verify(self, proof, challenger) + } + + /// Borrow the STARK configuration. + pub fn config(&self) -> &SC { + self.config + } + + /// Borrow the wrapped air-crate statement. + pub fn statement(&self) -> &Statement { + self.statement + } + + /// Borrow the preprocessed commitment, if any. + pub fn preprocessed_commitment(&self) -> Option<&::Commitment> { + self.preprocessed_commitment.as_ref() + } +} + +/// Errors that can occur during verification. +/// +/// Returned exclusively for runtime instance / proof-shape failures or +/// cryptographic verification failures. AIR structural correctness is +/// trusted — call [`crate::debug::assert_prover_setup`] (or +/// [`miden_lifted_air::debug::assert_multi_air_valid`]) from tests. +#[derive(Debug, Error)] +pub enum VerifierError { + #[error(transparent)] + Instance(#[from] InstanceError), + #[error(transparent)] + Shape(#[from] ShapeError), + #[error(transparent)] + Domain(#[from] DomainError), + #[error(transparent)] + Pcs(#[from] PcsError), + #[error(transparent)] + Transcript(#[from] TranscriptError), + #[error("external assertion evaluation failed: {0}")] + Reduction(ReductionError), + #[error("constraint mismatch: quotient * vanishing != folded constraints")] + ConstraintMismatch, + #[error("external assertion {assertion} is non-zero")] + ExternalAssertionFailed { + /// Index into the assertions vector returned by + /// [`Statement::eval_external`]. + assertion: usize, + }, } -/// Verify multiple AIRs with traces of different heights. +/// Verify a [`Statement`]. /// -/// The verifier uses [`InstanceShapes::air_order`](crate::InstanceShapes::air_order) from the proof -/// to match the caller's instances to the proof's ordering. The caller's challenger -/// must already be bound to the full statement (protocol parameters, AIR -/// configurations, AIR ordering, and public inputs — both fixed and -/// variable-length) — see the prover module-level docs. +/// The verifier reads per-AIR log trace heights from the proof (in caller +/// order, matching [`Statement::airs`]) and reconstructs the proof's AIR +/// ordering deterministically from those heights. The caller's challenger +/// must already be bound to protocol parameters and AIR configurations — +/// see the prover module-level docs. The statement's inputs are absorbed via +/// [`Statement::observe`], then the instance count and proof's log trace heights +/// are observed in instance order. /// /// The verifier mirrors the prover's protocol: /// -/// 1. Validate instance shapes and observe log trace heights into the challenger +/// 1. Validate runtime statement/proof shape data, absorb statement-owned inputs, and observe the +/// instance count plus log trace heights in instance order /// 2. Receive commitments and sample challenges in the same order as the prover -/// 3. For each AIR, evaluate constraints at the lifted OOD point yⱼ = z^{rⱼ} +/// 3. For each AIR (in proof order), evaluate constraints at the lifted OOD point yⱼ = z^{rⱼ} /// 4. Accumulate folded constraints with β: acc = acc·β + foldedⱼ /// 5. Check quotient identity: `acc == Q(z) * Z_{H_max}(z)` +/// 6. Evaluate [`Statement::eval_external`] with aux values reordered back to instance order /// /// Lifting: for a trace of height nⱼ lifted by factor rⱼ, the committed /// codeword encodes `p_lift(X) = p(X^{rⱼ})`; opening at `[z, z · h_max]` @@ -131,68 +202,117 @@ where /// **Statement-bound heights:** this function does not compare the proof's /// declared heights against any caller expectation. If your statement fixes /// trace dimensions, parse via -/// [`StarkTranscript::from_proof`](crate::proof::StarkTranscript::from_proof) -/// and check `instance_shapes.log_trace_heights()` before calling this. See -/// the module-level docs for the full contract. -pub fn verify_multi( - config: &SC, - instances: &[(&A, AirInstance<'_, F>)], - proof: &StarkProof, +/// [`StarkProof::from_data`](crate::proof::StarkProof::from_data) using this instance and check +/// `proof.log_trace_heights()` before calling this. See the module-level docs for the full +/// contract. +/// +/// # Trust contract +/// +/// `verify` validates the runtime statement plus everything carried on the +/// proof; the AIR list is **trusted** (run +/// [`miden_lifted_air::debug::assert_multi_air_valid`] from tests). +/// +/// ## Validated +/// - Same statement checks as [`prove`](crate::ProverInstance::prove) minus trace shape (no traces +/// here) +/// - Same `log_quotient_degree <= log_blowup` compat check +/// - Proof shape via the internal trace-order reconstruction from log heights +/// - Proof byte parsing (transcript channel) +/// - PCS / FRI / DEEP / LMCS / transcript / constraint identity +/// - External assertions from [`Statement::eval_external`] +/// - Preprocessed openings against the trusted commitment (via the PCS) +/// +/// ## Trusted (NOT validated) +/// - AIR structural shape (same list as in [`prove`](crate::ProverInstance::prove)) +pub(crate) fn verify( + instance: &VerifierInstance<'_, F, EF, MA, SC>, + proof: &StarkProofData, mut challenger: SC::Challenger, ) -> Result, VerifierError> where F: TwoAdicField, EF: ExtensionField, SC: StarkConfig, - A: LiftedAir, + MA: MultiAir, { - let instance_shapes = &proof.instance_shapes; - let air_order = instance_shapes.air_order(); + // --- Trust boundary (see doc-block above). ------------------------------- + let config = instance.config(); + let statement: &Statement = instance.statement(); + let preprocessed_commitment = instance.preprocessed_commitment(); + // + // `TraceOrder::from_log_heights` validates the (untrusted) proof heights + // against the AIRs: it bounds `log_h` within the host's `usize` width + // (later code dereferences `1usize << log_h` and would otherwise overflow on + // a malicious proof) and checks per-AIR periodic-height feasibility. + // Statement::new already enforced the input-side contracts before construction. + let trace_order = TraceOrder::from_log_heights::( + statement.airs(), + proof.log_trace_heights.clone(), + )?; + + // Preprocessed trace↔AIR mappings (instance order), reconstructed from the + // heights — the same ones the prover used. + let preprocessed_trace_to_air = + trace_order.preprocessed_air_for_trace_index::(statement.airs()); + let air_to_preprocessed_trace = + trace_order.preprocessed_trace_index_for_air::(statement.airs()); - // Validate air_order and reorder caller instances to the proof's AIR ordering. - validate_air_order(air_order, instances.len())?; - let instances = instance_shapes.reorder(instances.to_vec())?; + let air_refs: Vec<&MA::Air> = statement.airs().iter().collect(); + let proof_ordered_airs = trace_order.to_proof_order(&air_refs); + let air_inputs = statement.air_inputs(); let log_blowup = config.pcs().log_blowup(); + let log_max_trace_height = trace_order.max_log_height(); + let max_lde_domain = LiftedDomain::::try_canonical(log_max_trace_height, log_blowup)?; + let instance_domains: Vec<_> = trace_order + .log_heights_proof() + .iter() + .map(|&log_h| max_lde_domain.try_sub_domain(log_h)) + .collect::>()?; - let log_max_trace_height = validate_inputs(&instances, instance_shapes, log_blowup)?; - let log_trace_heights = instance_shapes.log_trace_heights(); + // Observe the preprocessed commitment first (when present); mirrors the + // prover. It is a trusted statement input, not read from the proof. + if let Some(commitment) = preprocessed_commitment { + challenger.observe(commitment.clone()); + } - instance_shapes.observe_heights::(&mut challenger); + // `Statement::observe` absorbs statement-owned inputs. The protocol then + // binds the instance count and each log trace height in instance order. + statement.observe(&mut challenger, trace_order.log_heights()); + trace_order.observe_shape::(&mut challenger); let mut channel = VerifierTranscript::from_data(challenger, &proof.transcript); - // Clear the challenger's absorb buffer after observing instance shapes by - // squeezing a throwaway extension element. Must mirror the prover exactly. - let _instance_challenge: EF = channel.sample_algebra_element::(); - // Infer constraint degree from symbolic AIR analysis (max across all AIRs). - // NOTE: `log_quotient_degree()` runs symbolic eval and may panic if the AIR is - // invalid. Callers must ensure `validate_inputs` (above) passes first. - let log_constraint_degree = - instances.iter().map(|(air, _)| air.log_quotient_degree()).max().unwrap_or(1) as u8; - - if log_constraint_degree > log_blowup { - return Err(VerifierError::ConstraintDegreeTooHigh { - log_quotient_degree: log_constraint_degree, + // NOTE: `log_quotient_degree` runs symbolic eval and may panic if the AIR is + // invalid. The AIR is trusted (see `miden_lifted_air::debug::assert_multi_air_valid` + // for the debug-only structural check). + let max_log_quotient_degree = proof_ordered_airs + .iter() + .map(|&air| log_quotient_degree::(air)) + .max() + .expect("TraceOrder construction rejects empty AIR sets"); + if max_log_quotient_degree > log_blowup { + return Err(DomainError::ConstraintDegreeTooHigh { + log_quotient: max_log_quotient_degree, log_blowup, - }); + } + .into()); } - let constraint_degree = 1 << log_constraint_degree as usize; + // Pair the max LDE domain with the constraint degree for the constraint layer. + let max_eval_domain = max_lde_domain.evaluation_domain(max_log_quotient_degree); - let max_trace_height = 1 << log_max_trace_height as usize; - let log_lde_height = log_max_trace_height + log_blowup; + let quotient_degree = 1 << max_log_quotient_degree as usize; - // Max LDE coset (for the largest trace, no lifting) - let max_lde_coset = LiftedCoset::unlifted(log_max_trace_height, log_blowup); + let max_trace_height = max_lde_domain.trace_height(); // 1. Receive main trace commitment let main_commit = channel.receive_commitment()?.clone(); // 2. Sample randomness for aux traces let max_num_randomness = - instances.iter().map(|(air, _)| air.num_randomness()).max().unwrap_or(0); + proof_ordered_airs.iter().map(|air| air.num_randomness()).max().unwrap_or(0); let randomness: Vec = (0..max_num_randomness) .map(|_| channel.sample_algebra_element::()) @@ -203,9 +323,9 @@ where // Receive aux values from the transcript (one EF element per aux value, per instance). // When no AIR has aux columns, each entry is empty so nothing is received. - let all_aux_values: Vec> = instances + let all_aux_values: Vec> = proof_ordered_airs .iter() - .map(|(air, _)| { + .map(|air| { let count = air.num_aux_values(); (0..count) .map(|_| channel.receive_algebra_element::()) @@ -221,63 +341,104 @@ where let quotient_commit = channel.receive_commitment()?.clone(); // 6. Sample OOD point (outside max trace domain H and max LDE coset gK) - let z: EF = max_lde_coset.sample_ood_point(&mut channel); - let h = F::two_adic_generator(log_max_trace_height.into()); + let z: EF = max_lde_domain.sample_ood_point(&mut channel); + let h = max_lde_domain.trace_subgroup().generator(); let z_next = z * h; // 7. Widths per commitment group (unpadded data widths). - let main_widths: Vec = instances.iter().map(|(air, _)| air.width()).collect(); + let main_widths: Vec = proof_ordered_airs.iter().map(|air| air.width()).collect(); let aux_widths: Vec = - instances.iter().map(|(air, _)| air.aux_width() * EF::DIMENSION).collect(); - let quotient_widths: Vec = vec![constraint_degree * EF::DIMENSION]; + proof_ordered_airs.iter().map(|air| air.aux_width() * EF::DIMENSION).collect(); + let quotient_widths: Vec = vec![quotient_degree * EF::DIMENSION]; // Build commitments with original (unpadded) widths. // The PCS aligned wrapper handles alignment and truncation internally. - let commitments = vec![ - (main_commit, main_widths), - (aux_commit, aux_widths), - (quotient_commit, quotient_widths), - ]; + // The preprocessed group (when present) is first, mirroring the prover; its + // widths are in committed preprocessed trace order, while main/aux are in proof order. + // + // Group indices, in batch order `[preprocessed?, main, aux, quotient]`: the + // preprocessed group occupies index 0 only when present, shifting the rest + // up by one. The prover builds its `trees` vector in this same order. + let s = preprocessed_commitment.is_some() as usize; + let (preproc_g, main_g, aux_g, quot_g) = + (preprocessed_commitment.is_some().then_some(0), s, s + 1, s + 2); + let full_log_height = log_max_trace_height + log_blowup; + let mut commitments = Vec::with_capacity(4); + if let Some(commitment) = preprocessed_commitment { + let preprocessed_widths: Vec = preprocessed_trace_to_air + .iter() + .map(|&air_idx| statement.airs()[air_idx as usize].preprocessed_width()) + .collect(); + // The preprocessed tree is committed at its own setup-fixed depth, determined by the + // tallest preprocessed trace. The PCS virtually lifts it to the max when shorter. + let preprocessed_log_height = preprocessed_trace_to_air + .iter() + .map(|&air_idx| trace_order.log_heights()[air_idx as usize]) + .max() + .expect("preprocessed group is non-empty when a commitment is present") + + log_blowup; + commitments.push(CommitmentGroup { + root: commitment.clone(), + widths: preprocessed_widths, + log_height: preprocessed_log_height, + }); + } + commitments.push(CommitmentGroup { + root: main_commit, + widths: main_widths, + log_height: full_log_height, + }); + commitments.push(CommitmentGroup { + root: aux_commit, + widths: aux_widths, + log_height: full_log_height, + }); + commitments.push(CommitmentGroup { + root: quotient_commit, + widths: quotient_widths, + log_height: full_log_height, + }); // 8. Verify PCS openings (returns per-matrix RowMajorMatrix, truncated to original widths) let opened = verify_aligned::( config.pcs(), config.lmcs(), &commitments, - log_lde_height, + &max_lde_domain, [z, z_next], &mut channel, )?; - // 9. Group indices for accessing opened matrices: [main, aux, quotient]. - let (main_g, aux_g, quot_g) = (0, 1, 2); - - // 10. Per-AIR constraint evaluation and beta accumulation. + // 9. Per-AIR constraint evaluation and beta accumulation. // // opened[g] has one matrix per AIR (for main/aux) or one matrix total (quotient). // Each matrix has N=2 rows: row 0 = local (z), row 1 = next (z·h). // - // Instances are in the proof's AIR ordering (ascending height), so j - // indexes both AIR and trace position directly. - debug_assert_eq!(opened[main_g].len(), instances.len()); - debug_assert_eq!(opened[aux_g].len(), instances.len()); + // AIRs are in the proof's ordering (ascending height), so j indexes both + // AIR and trace position directly. + debug_assert_eq!(opened[main_g].len(), proof_ordered_airs.len()); + debug_assert_eq!(opened[aux_g].len(), proof_ordered_airs.len()); let mut accumulated = EF::ZERO; - let mut reduced_aux = ReducedAuxValues::::identity(); - for (j, (air, inst)) in instances.iter().enumerate() { - let coset_j = LiftedCoset::new(log_trace_heights[j], log_blowup, log_max_trace_height); + for (j, air) in proof_ordered_airs.iter().enumerate() { + let domain_j = instance_domains[j]; // opened[main_g][j] is a 2-row RowMajorMatrix (local, next) already truncated. let main_window = RowWindow::from_view(&opened[main_g][j].as_view()); // Extract aux trace opened values (reconstitute EF from base field components). + // Row widths were validated against `air.aux_width() * EF::DIMENSION` + // by `verify_aligned` upstream; reaching here with a mismatch would + // indicate a framework bug. let aux_mat = &opened[aux_g][j]; - let aux_local = row_to_packed_ext::(&aux_mat.row_slice(0).expect("aux row 0"))?; - let aux_next = row_to_packed_ext::(&aux_mat.row_slice(1).expect("aux row 1"))?; + let aux_local = row_to_packed_ext::(&aux_mat.row_slice(0).expect("aux row 0")) + .expect("aux row width should match: PCS verify_aligned validates this upstream"); + let aux_next = row_to_packed_ext::(&aux_mat.row_slice(1).expect("aux row 1")) + .expect("aux row width should match: PCS verify_aligned validates this upstream"); let aux_window = RowWindow::from_two_rows(&aux_local, &aux_next); - // Selectors at the lifted OOD point yⱼ = z^{rⱼ} (encapsulated in LiftedCoset). - let selectors = coset_j.selectors_at::(z); + // Selectors at the lifted OOD point yⱼ = z^{rⱼ} (encapsulated in LiftedDomain). + let selectors = domain_j.selectors_at(z); // Periodic values: for a column with period p, eval_at computes z^{n/p}. // Using (max_trace_height, z) gives z^{max_n / p}, which equals @@ -288,11 +449,27 @@ where let aux_values_j = &all_aux_values[j]; let num_rand = air.num_randomness(); + + // Extract the opened preprocessed window when this AIR declares + // preprocessed columns. The preprocessed trace index comes from the inverse + // `air_to_preprocessed_trace` mapping; the opened matrix is a 2-row + // `RowMajorMatrix` already truncated to the declared width by `verify_aligned`, so this + // is a zero-copy view (mirrors the main window above). + let instance_idx = trace_order.instance_indices()[j] as usize; + let preprocessed_window = match air_to_preprocessed_trace[instance_idx] { + Some(preprocessed_trace_idx) => RowWindow::from_view( + &opened[preproc_g.expect("preproc group present")][preprocessed_trace_idx] + .as_view(), + ), + None => RowWindow::from_two_rows(&[], &[]), + }; + let mut folder = ConstraintFolder { main: main_window, + preprocessed: preprocessed_window, aux: aux_window, randomness: &randomness[..num_rand], - public_values: inst.public_values, + public_values: air_inputs, periodic_values: &periodic_values, permutation_values: aux_values_j, selectors, @@ -301,40 +478,41 @@ where _phantom: PhantomData, }; - air.is_valid_builder(&folder).map_err(InstanceValidationError::from)?; + #[cfg(debug_assertions)] + miden_lifted_air::debug::check_builder_shape(*air, &folder); air.eval(&mut folder); // Accumulate: acc = acc * beta + folded_j accumulated = accumulated * beta + folder.accumulator; + } - // Compute reduced aux contribution and accumulate. - let contribution = air - .reduced_aux_values( - aux_values_j, - &randomness[..num_rand], - inst.public_values, - inst.var_len_public_inputs, - ) - .map_err(VerifierError::Reduction)?; - reduced_aux.combine_in_place(&contribution); + // 11. Evaluate the proof's external assertions. Aux values came off the + // wire in proof order; reorder them back to instance order before handing + // them to `eval_external`, which is defined in instance-order terms. + let aux_instance = trace_order.to_instance_order(&all_aux_values); + let aux_views: Vec<&[EF]> = aux_instance.iter().map(Vec::as_slice).collect(); + let assertions = statement + .eval_external(&randomness, &aux_views, trace_order.log_heights()) + .map_err(VerifierError::Reduction)?; + for (k, assertion) in assertions.iter().enumerate() { + if *assertion != EF::ZERO { + return Err(VerifierError::ExternalAssertionFailed { assertion: k }); + } } - // 11. Reconstruct Q(z) and check quotient identity Q(z) * Z_{H_max}(z) + // 12. Reconstruct Q(z) and check quotient identity Q(z) * Z_{H_max}(z) // Quotient group has a single matrix; row 0 is the evaluation at z. let quot_row = opened[quot_g][0].row_slice(0).expect("quotient row 0"); - let quotient_chunks = row_to_packed_ext::("_row)?; - let quotient_z = reconstruct_quotient::(z, &max_lde_coset, "ient_chunks); + let quotient_chunks = row_to_packed_ext::("_row) + .expect("quotient row width should match: PCS verify_aligned validates this upstream"); + let quotient_z = max_eval_domain.reconstruct_quotient::(z, "ient_chunks); - let vanishing = max_lde_coset.vanishing_at::(z); + // `max_lde_domain` is the tallest (lift_ratio = 0), so lifted == unlifted here. + let vanishing = max_lde_domain.trace_subgroup().vanishing_at(z); if accumulated != quotient_z * vanishing { return Err(VerifierError::ConstraintMismatch); } - // 12. Check global reduced aux identity (all bus contributions combine to identity) - if !reduced_aux.is_identity() { - return Err(VerifierError::InvalidReducedAux); - } - // 13. Finalize transcript: check emptiness and return digest Ok(channel.finalize()?) } diff --git a/stark/miden-lifted-stark/src/verifier/periodic.rs b/stark/miden-lifted-stark/src/verifier/periodic.rs index d5b773b33f..36e46430a2 100644 --- a/stark/miden-lifted-stark/src/verifier/periodic.rs +++ b/stark/miden-lifted-stark/src/verifier/periodic.rs @@ -10,12 +10,14 @@ use alloc::vec::Vec; use p3_dft::{NaiveDft, TwoAdicSubgroupDft}; use p3_field::{ExtensionField, TwoAdicField}; +use crate::util::horner::horner_acc; + /// Verifier-side periodic polynomials for OOD evaluation. /// /// Stores polynomial coefficients computed from the AIR's periodic columns. /// Used to evaluate periodic values at the OOD point during verification. #[derive(Clone, Debug)] -pub struct PeriodicPolys { +pub(super) struct PeriodicPolys { /// Polynomial coefficients for each column. polys: Vec>, } @@ -27,8 +29,9 @@ impl PeriodicPolys { /// /// # Panics /// Panics if any column length is zero or not a power of two. - /// This is a trusted path — the AIR should pass - /// [`LiftedAir::validate`](miden_lifted_air::LiftedAir::validate). + /// This is a trusted path — the AIR is assumed structurally valid (see + /// [`assert_multi_air_valid`](miden_lifted_air::debug::assert_multi_air_valid) for the + /// debug-only check). pub fn new(column_evals: &[Vec]) -> Self { let dft = NaiveDft; let mut polys = Vec::with_capacity(column_evals.len()); @@ -71,22 +74,11 @@ impl PeriodicPolys { for coeffs in &self.polys { let period = coeffs.len(); let y = z.exp_u64((trace_height / period) as u64); - result.push(horner_eval(coeffs, y)); + // Coefficients are stored in ascending degree (idft output): [c₀, c₁, ..., cₙ₋₁]. + // Horner needs descending order (highest degree first), hence `.rev()`. + result.push(horner_acc(EF::ZERO, y, coeffs.iter().rev().copied())); } result } } - -/// Evaluate a polynomial at a point using Horner's method. -fn horner_eval(coeffs: &[F], x: EF) -> EF -where - F: TwoAdicField, - EF: ExtensionField, -{ - let mut acc = EF::ZERO; - for coeff in coeffs.iter().rev() { - acc = acc * x + *coeff; - } - acc -} diff --git a/stark/miden-stark-transcript/src/channel.rs b/stark/miden-stark-transcript/src/channel.rs index cf6146e366..20d6bfbbcd 100644 --- a/stark/miden-stark-transcript/src/channel.rs +++ b/stark/miden-stark-transcript/src/channel.rs @@ -42,12 +42,15 @@ pub trait Channel { type Challenger: TranscriptChallenger; /// Sample a random field element from the challenger. + #[must_use = "sampled transcript challenges must be consumed or explicitly bound"] fn sample(&mut self) -> Self::F; /// Sample a random `bits`-bit integer from the challenger. + #[must_use = "sampled transcript challenges must be consumed or explicitly bound"] fn sample_bits(&mut self, bits: usize) -> usize; /// Sample a random algebra element (e.g. extension field) from the challenger. + #[must_use = "sampled transcript challenges must be consumed or explicitly bound"] fn sample_algebra_element>(&mut self) -> A { A::from_basis_coefficients_fn(|_| self.sample()) }