Skip to content
3 changes: 2 additions & 1 deletion cmd/juno/juno.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"syscall"
"time"

"github.com/NethermindEth/juno/cmd/juno/verify"
_ "github.com/NethermindEth/juno/encoder/registry"
_ "github.com/NethermindEth/juno/jemalloc"
"github.com/NethermindEth/juno/node"
Expand Down Expand Up @@ -453,7 +454,7 @@ func NewCmd(config *node.Config, run func(*cobra.Command, []string) error) *cobr
junoCmd.Flags().Bool(
transactionCombinedLayoutF, defaultTransactionCombinedLayout, transactionCombinedLayoutUsage,
)
junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath))
junoCmd.AddCommand(GenP2PKeyPair(), DBCmd(defaultDBPath), verify.VerifyCmd(defaultDBPath))

return junoCmd
}
107 changes: 107 additions & 0 deletions cmd/juno/verify/trie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package verify

import (
"fmt"
"slices"

"github.com/NethermindEth/juno/core/felt"
"github.com/NethermindEth/juno/utils"
verifytrie "github.com/NethermindEth/juno/verify/trie"
"github.com/spf13/cobra"
)

const (
verifyTrieType = "type"
verifyContractAddr = "address"
)

func verifyTrieCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "trie",
Short: "Verify trie integrity",
Long: `Verify trie integrity by rebuilding tries and comparing root hashes.`,
RunE: runTrieVerify,
SilenceUsage: true,
SilenceErrors: true,
}

cmd.Flags().StringSlice(
verifyTrieType,
nil,
"Trie types to verify (contract, class, contract-storage)."+
"If not specified, all trie types are verified.",
)

cmd.Flags().String(
verifyContractAddr,
"",
"Contract address to verify (only used with --type contract-storage). "+
"If not specified, all contract storage tries are verified.",
)

return cmd
}

func runTrieVerify(cmd *cobra.Command, args []string) error {
dbPath, err := cmd.Flags().GetString(verifyDBPathF)
if err != nil {
return err
}

database, err := openDB(dbPath)
if err != nil {
return err
}
defer database.Close()

trieTypes, err := cmd.Flags().GetStringSlice(verifyTrieType)
if err != nil {
return err
}

contractAddrStr, err := cmd.Flags().GetString(verifyContractAddr)
if err != nil {
return err
}

var tries []verifytrie.TrieType
if len(trieTypes) > 0 {
tries = make([]verifytrie.TrieType, 0, len(trieTypes))
for _, t := range trieTypes {
tt := verifytrie.TrieType(t)
if !tt.IsValid() {
return fmt.Errorf("invalid trie type %q (allowed: contract, class, contract-storage)", t)
}
tries = append(tries, tt)
}
}

var contractAddr *felt.Felt
if contractAddrStr != "" {
hasContractStorage := slices.Contains(tries, verifytrie.ContractStorageTrie)
if len(tries) == 0 {
hasContractStorage = true
}

if !hasContractStorage {
return fmt.Errorf("--address flag can only be used with --type contract-storage")
}

var addr felt.Felt
_, err = addr.SetString(contractAddrStr)
if err != nil {
return fmt.Errorf("invalid contract address %s: %w", contractAddrStr, err)
}
contractAddr = &addr
}

logLevel := utils.NewLogLevel(utils.INFO)
logger, err := utils.NewZapLogger(logLevel, true)
if err != nil {
return fmt.Errorf("failed to create logger: %w", err)
}

verifier := verifytrie.NewTrieVerifier(database, logger, tries, contractAddr)
ctx := cmd.Context()
return verifier.Run(ctx)
}
98 changes: 98 additions & 0 deletions cmd/juno/verify/trie_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package verify

import (
"context"
"os"
"path/filepath"
"testing"

"github.com/NethermindEth/juno/db/pebblev2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestRunTrieVerify_AddressFlagValidation(t *testing.T) {
tests := []struct {
name string
trieTypes []string
address string
expectError bool
expectedErrMsg string
}{
{
name: "address with contract-storage type should succeed",
trieTypes: []string{"contract-storage"},
address: "0x123",
expectError: false,
},
{
name: "address with contract and class types should fail",
trieTypes: []string{"contract", "class"},
address: "0x123",
expectError: true,
expectedErrMsg: "--address flag can only be used with --type contract-storage",
},
{
name: "address with no type specified should succeed (default includes contract-storage)",
trieTypes: []string{},
address: "0x123",
expectError: false,
},
{
name: "invalid type should fail",
trieTypes: []string{"invalid-type"},
address: "",
expectError: true,
expectedErrMsg: "invalid trie type",
},
{
name: "invalid address format should fail",
trieTypes: []string{"contract-storage"},
address: "not-a-hex",
expectError: true,
expectedErrMsg: "invalid contract address",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")

testDB, err := pebblev2.New(dbPath)
require.NoError(t, err)
testDB.Close()

parentCmd := VerifyCmd("")
args := []string{"--db-path", dbPath, "trie"}

for _, trieType := range tt.trieTypes {
args = append(args, "--type", trieType)
}

if tt.address != "" {
args = append(args, "--address", tt.address)
}

parentCmd.SetArgs(args)
parentCmd.SetOut(os.Stderr)
parentCmd.SetErr(os.Stderr)

err = parentCmd.ExecuteContext(context.Background())

if tt.expectError {
require.Error(t, err)
if tt.expectedErrMsg != "" {
assert.Contains(t, err.Error(), tt.expectedErrMsg)
}
} else if err != nil {
// For "success" cases, we're testing flag validation, not full execution.
// The command may fail downstream (empty DB, no data) - that's expected.
// We only verify that the specific flag validation error we're testing didn't occur.
addrFlagErr := "--address flag can only be used with --type contract-storage"
assert.NotContains(t, err.Error(), addrFlagErr,
"flag validation should pass; downstream errors are acceptable")
}
})
}
}
82 changes: 82 additions & 0 deletions cmd/juno/verify/verify.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package verify

import (
"context"
"errors"
"fmt"
"os"

"github.com/NethermindEth/juno/db"
"github.com/NethermindEth/juno/db/pebblev2"
"github.com/NethermindEth/juno/utils"
"github.com/NethermindEth/juno/verify/trie"
"github.com/spf13/cobra"
)

type Verifier interface {
Name() string
Run(ctx context.Context) error
}

const verifyDBPathF = "db-path"

func VerifyCmd(defaultDBPath string) *cobra.Command {
verifyCmd := &cobra.Command{
Use: "verify",
Short: "Verify database integrity",
Long: `Verify database integrity using various verification methods.`,
}

verifyCmd.PersistentFlags().String(verifyDBPathF, defaultDBPath, "Path to the database")
verifyCmd.AddCommand(verifyTrieCmd())
verifyCmd.RunE = verifyAll

return verifyCmd
}

func verifyAll(cmd *cobra.Command, args []string) error {
dbPath, err := cmd.Flags().GetString(verifyDBPathF)
if err != nil {
return err
}

database, err := openDB(dbPath)
if err != nil {
return err
}
defer database.Close()

logLevel := utils.NewLogLevel(utils.INFO)
logger, err := utils.NewZapLogger(logLevel, true)
if err != nil {
return fmt.Errorf("failed to create logger: %w", err)
}

ctx := cmd.Context()

verifiers := []Verifier{
trie.NewTrieVerifier(database, logger, nil, nil),
}

for _, v := range verifiers {
if err := v.Run(ctx); err != nil {
return fmt.Errorf("%s verification stopped: %w", v.Name(), err)
}
}

return nil
}

func openDB(path string) (db.KeyValueStore, error) {
_, err := os.Stat(path)
if os.IsNotExist(err) {
return nil, errors.New("database path does not exist")
}

database, err := pebblev2.New(path)
if err != nil {
return nil, fmt.Errorf("failed to open db: %w", err)
}

return database, nil
}
56 changes: 56 additions & 0 deletions verify/trie/traversal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package trie

import (
"context"
"sync"
)

func TraverseBinary[T any](
ctx context.Context,
depth uint8,
maxConcurrentDepth uint8,
leftFn func(ctx context.Context) (T, error),
rightFn func(ctx context.Context) (T, error),
) (left, right T, err error) {
if depth <= maxConcurrentDepth {
return traverseConcurrently(ctx, leftFn, rightFn)
}
return traverseSequentially(ctx, leftFn, rightFn)
}

func traverseConcurrently[T any](
ctx context.Context,
leftFn func(ctx context.Context) (T, error),
rightFn func(ctx context.Context) (T, error),
) (left, right T, err error) {
var leftErr, rightErr error
var wg sync.WaitGroup

wg.Go(func() {
left, leftErr = leftFn(ctx)
})

right, rightErr = rightFn(ctx)
wg.Wait()

if leftErr != nil {
return left, right, leftErr
}
if rightErr != nil {
return left, right, rightErr
}
return left, right, nil
}

func traverseSequentially[T any](
ctx context.Context,
leftFn func(ctx context.Context) (T, error),
rightFn func(ctx context.Context) (T, error),
) (left, right T, err error) {
left, err = leftFn(ctx)
if err != nil {
return left, right, err
}
right, err = rightFn(ctx)
return left, right, err
}
Loading
Loading