Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 124 additions & 119 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"testing"
"testing/synctest"
"time"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -41,127 +42,131 @@ func TestDefaultConfig(t *testing.T) {
}

func TestAddFlags(t *testing.T) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use testing/synctest instead.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good suggestion, then we dont need to modify config/defaults.go updated 👍

// Create a command with flags
cmd := &cobra.Command{Use: "test"}
AddGlobalFlags(cmd, "test") // Add basic flags first
AddFlags(cmd)
// AddFlags and the assertions below each call DefaultConfig(), whose DA.Namespace
// is randString-seeded by time.Now().Unix(); run in a synctest bubble so both calls observe the same fake clock.
synctest.Test(t, func(t *testing.T) {
// Create a command with flags
cmd := &cobra.Command{Use: "test"}
AddGlobalFlags(cmd, "test") // Add basic flags first
AddFlags(cmd)

// Get both persistent and regular flags
flags := cmd.Flags()
persistentFlags := cmd.PersistentFlags()

// Test specific flags
assertFlagValue(t, flags, FlagDBPath, DefaultConfig().DBPath)
assertFlagValue(t, flags, FlagClearCache, DefaultConfig().ClearCache)

// Node flags
assertFlagValue(t, flags, FlagAggregator, DefaultConfig().Node.Aggregator)
assertFlagValue(t, flags, FlagBasedSequencer, DefaultConfig().Node.BasedSequencer)
assertFlagValue(t, flags, FlagLight, DefaultConfig().Node.Light)
assertFlagValue(t, flags, FlagBlockTime, DefaultConfig().Node.BlockTime.Duration)
assertFlagValue(t, flags, FlagLazyAggregator, DefaultConfig().Node.LazyMode)
assertFlagValue(t, flags, FlagMaxPendingHeadersAndData, DefaultConfig().Node.MaxPendingHeadersAndData)
assertFlagValue(t, flags, FlagLazyBlockTime, DefaultConfig().Node.LazyBlockInterval.Duration)
assertFlagValue(t, flags, FlagReadinessWindowSeconds, DefaultConfig().Node.ReadinessWindowSeconds)
assertFlagValue(t, flags, FlagReadinessMaxBlocksBehind, DefaultConfig().Node.ReadinessMaxBlocksBehind)
assertFlagValue(t, flags, FlagScrapeInterval, DefaultConfig().Node.ScrapeInterval)

// DA flags
assertFlagValue(t, flags, FlagDAAddress, DefaultConfig().DA.Address)
assertFlagValue(t, flags, FlagDAAuthToken, DefaultConfig().DA.AuthToken)
assertFlagValue(t, flags, FlagDABlockTime, DefaultConfig().DA.BlockTime.Duration)
assertFlagValue(t, flags, FlagDANamespace, DefaultConfig().DA.Namespace)
assertFlagValue(t, flags, FlagDADataNamespace, DefaultConfig().DA.DataNamespace)
assertFlagValue(t, flags, FlagDAForcedInclusionNamespace, DefaultConfig().DA.ForcedInclusionNamespace)
assertFlagValue(t, flags, FlagDASubmitOptions, DefaultConfig().DA.SubmitOptions)
assertFlagValue(t, flags, FlagDASigningAddresses, DefaultConfig().DA.SigningAddresses)
assertFlagValue(t, flags, FlagDAMempoolTTL, DefaultConfig().DA.MempoolTTL)
assertFlagValue(t, flags, FlagDAMaxSubmitAttempts, DefaultConfig().DA.MaxSubmitAttempts)
assertFlagValue(t, flags, FlagDARequestTimeout, DefaultConfig().DA.RequestTimeout.Duration)

// P2P flags
assertFlagValue(t, flags, FlagP2PListenAddress, DefaultConfig().P2P.ListenAddress)
assertFlagValue(t, flags, FlagP2PPeers, DefaultConfig().P2P.Peers)
assertFlagValue(t, flags, FlagP2PBlockedPeers, DefaultConfig().P2P.BlockedPeers)
assertFlagValue(t, flags, FlagP2PAllowedPeers, DefaultConfig().P2P.AllowedPeers)
assertFlagValue(t, flags, FlagP2PDisableConnectionGater, DefaultConfig().P2P.DisableConnectionGater)

// Instrumentation flags
instrDef := DefaultInstrumentationConfig()
assertFlagValue(t, flags, FlagPrometheus, instrDef.Prometheus)
assertFlagValue(t, flags, FlagPrometheusListenAddr, instrDef.PrometheusListenAddr)
assertFlagValue(t, flags, FlagMaxOpenConnections, instrDef.MaxOpenConnections)
assertFlagValue(t, flags, FlagPprof, instrDef.Pprof)
assertFlagValue(t, flags, FlagPprofListenAddr, instrDef.PprofListenAddr)
assertFlagValue(t, flags, FlagTracing, instrDef.Tracing)
assertFlagValue(t, flags, FlagTracingEndpoint, instrDef.TracingEndpoint)
assertFlagValue(t, flags, FlagTracingSampleRate, instrDef.TracingSampleRate)
assertFlagValue(t, flags, FlagTracingServiceName, instrDef.TracingServiceName)

// Logging flags (in persistent flags)
assertFlagValue(t, persistentFlags, FlagLogLevel, DefaultConfig().Log.Level)
assertFlagValue(t, persistentFlags, FlagLogFormat, "text")
assertFlagValue(t, persistentFlags, FlagLogTrace, false)
assertFlagValue(t, persistentFlags, FlagRootDir, DefaultRootDirWithName("test"))

// Signer flags
assertFlagValue(t, flags, FlagSignerPassphraseFile, "")
assertFlagValue(t, flags, FlagSignerType, "file")
assertFlagValue(t, flags, FlagSignerPath, DefaultConfig().Signer.SignerPath)
assertFlagValue(t, flags, FlagSignerKmsProvider, DefaultConfig().Signer.KMS.Provider)
assertFlagValue(t, flags, FlagSignerKmsAwsKeyID, DefaultConfig().Signer.KMS.AWS.KeyID)
assertFlagValue(t, flags, FlagSignerKmsAwsRegion, DefaultConfig().Signer.KMS.AWS.Region)
assertFlagValue(t, flags, FlagSignerKmsAwsProfile, DefaultConfig().Signer.KMS.AWS.Profile)
assertFlagValue(t, flags, FlagSignerKmsAwsTimeout, DefaultConfig().Signer.KMS.AWS.Timeout.Duration)
assertFlagValue(t, flags, FlagSignerKmsAwsMaxRetries, DefaultConfig().Signer.KMS.AWS.MaxRetries)
assertFlagValue(t, flags, FlagSignerKmsGcpKeyName, DefaultConfig().Signer.KMS.GCP.KeyName)
assertFlagValue(t, flags, FlagSignerKmsGcpCredentialsFile, DefaultConfig().Signer.KMS.GCP.CredentialsFile)
assertFlagValue(t, flags, FlagSignerKmsGcpTimeout, DefaultConfig().Signer.KMS.GCP.Timeout.Duration)
assertFlagValue(t, flags, FlagSignerKmsGcpMaxRetries, DefaultConfig().Signer.KMS.GCP.MaxRetries)

// RPC flags
assertFlagValue(t, flags, FlagRPCAddress, DefaultConfig().RPC.Address)
assertFlagValue(t, flags, FlagRPCEnableDAVisualization, DefaultConfig().RPC.EnableDAVisualization)

// Raft flags
assertFlagValue(t, flags, FlagRaftEnable, DefaultConfig().Raft.Enable)
assertFlagValue(t, flags, FlagRaftNodeID, DefaultConfig().Raft.NodeID)
assertFlagValue(t, flags, FlagRaftAddr, DefaultConfig().Raft.RaftAddr)
assertFlagValue(t, flags, FlagRaftDir, DefaultConfig().Raft.RaftDir)
assertFlagValue(t, flags, FlagRaftBootstrap, DefaultConfig().Raft.Bootstrap)
assertFlagValue(t, flags, FlagRaftPeers, DefaultConfig().Raft.Peers)
assertFlagValue(t, flags, FlagRaftSnapCount, DefaultConfig().Raft.SnapCount)
assertFlagValue(t, flags, FlagRaftSendTimeout, DefaultConfig().Raft.SendTimeout)
assertFlagValue(t, flags, FlagRaftHeartbeatTimeout, DefaultConfig().Raft.HeartbeatTimeout)
assertFlagValue(t, flags, FlagRaftLeaderLeaseTimeout, DefaultConfig().Raft.LeaderLeaseTimeout)
assertFlagValue(t, flags, FlagRaftElectionTimeout, DefaultConfig().Raft.ElectionTimeout)
assertFlagValue(t, flags, FlagRaftSnapshotThreshold, DefaultConfig().Raft.SnapshotThreshold)
assertFlagValue(t, flags, FlagRaftTrailingLogs, DefaultConfig().Raft.TrailingLogs)

// Pruning flags
assertFlagValue(t, flags, FlagPruningMode, DefaultConfig().Pruning.Mode)
assertFlagValue(t, flags, FlagPruningKeepRecent, DefaultConfig().Pruning.KeepRecent)
assertFlagValue(t, flags, FlagPruningInterval, DefaultConfig().Pruning.Interval.Duration)

// Count the number of flags we're explicitly checking
expectedFlagCount := 82 // Update this number if you add more flag checks above

// Get the actual number of flags (both regular and persistent)
actualFlagCount := 0
flags.VisitAll(func(flag *pflag.Flag) {
actualFlagCount++
})
persistentFlags.VisitAll(func(flag *pflag.Flag) {
actualFlagCount++
})
// Get both persistent and regular flags
flags := cmd.Flags()
persistentFlags := cmd.PersistentFlags()

// Test specific flags
assertFlagValue(t, flags, FlagDBPath, DefaultConfig().DBPath)
assertFlagValue(t, flags, FlagClearCache, DefaultConfig().ClearCache)

// Node flags
assertFlagValue(t, flags, FlagAggregator, DefaultConfig().Node.Aggregator)
assertFlagValue(t, flags, FlagBasedSequencer, DefaultConfig().Node.BasedSequencer)
assertFlagValue(t, flags, FlagLight, DefaultConfig().Node.Light)
assertFlagValue(t, flags, FlagBlockTime, DefaultConfig().Node.BlockTime.Duration)
assertFlagValue(t, flags, FlagLazyAggregator, DefaultConfig().Node.LazyMode)
assertFlagValue(t, flags, FlagMaxPendingHeadersAndData, DefaultConfig().Node.MaxPendingHeadersAndData)
assertFlagValue(t, flags, FlagLazyBlockTime, DefaultConfig().Node.LazyBlockInterval.Duration)
assertFlagValue(t, flags, FlagReadinessWindowSeconds, DefaultConfig().Node.ReadinessWindowSeconds)
assertFlagValue(t, flags, FlagReadinessMaxBlocksBehind, DefaultConfig().Node.ReadinessMaxBlocksBehind)
assertFlagValue(t, flags, FlagScrapeInterval, DefaultConfig().Node.ScrapeInterval)

// DA flags
assertFlagValue(t, flags, FlagDAAddress, DefaultConfig().DA.Address)
assertFlagValue(t, flags, FlagDAAuthToken, DefaultConfig().DA.AuthToken)
assertFlagValue(t, flags, FlagDABlockTime, DefaultConfig().DA.BlockTime.Duration)
assertFlagValue(t, flags, FlagDANamespace, DefaultConfig().DA.Namespace)
assertFlagValue(t, flags, FlagDADataNamespace, DefaultConfig().DA.DataNamespace)
assertFlagValue(t, flags, FlagDAForcedInclusionNamespace, DefaultConfig().DA.ForcedInclusionNamespace)
assertFlagValue(t, flags, FlagDASubmitOptions, DefaultConfig().DA.SubmitOptions)
assertFlagValue(t, flags, FlagDASigningAddresses, DefaultConfig().DA.SigningAddresses)
assertFlagValue(t, flags, FlagDAMempoolTTL, DefaultConfig().DA.MempoolTTL)
assertFlagValue(t, flags, FlagDAMaxSubmitAttempts, DefaultConfig().DA.MaxSubmitAttempts)
assertFlagValue(t, flags, FlagDARequestTimeout, DefaultConfig().DA.RequestTimeout.Duration)

// P2P flags
assertFlagValue(t, flags, FlagP2PListenAddress, DefaultConfig().P2P.ListenAddress)
assertFlagValue(t, flags, FlagP2PPeers, DefaultConfig().P2P.Peers)
assertFlagValue(t, flags, FlagP2PBlockedPeers, DefaultConfig().P2P.BlockedPeers)
assertFlagValue(t, flags, FlagP2PAllowedPeers, DefaultConfig().P2P.AllowedPeers)
assertFlagValue(t, flags, FlagP2PDisableConnectionGater, DefaultConfig().P2P.DisableConnectionGater)

// Instrumentation flags
instrDef := DefaultInstrumentationConfig()
assertFlagValue(t, flags, FlagPrometheus, instrDef.Prometheus)
assertFlagValue(t, flags, FlagPrometheusListenAddr, instrDef.PrometheusListenAddr)
assertFlagValue(t, flags, FlagMaxOpenConnections, instrDef.MaxOpenConnections)
assertFlagValue(t, flags, FlagPprof, instrDef.Pprof)
assertFlagValue(t, flags, FlagPprofListenAddr, instrDef.PprofListenAddr)
assertFlagValue(t, flags, FlagTracing, instrDef.Tracing)
assertFlagValue(t, flags, FlagTracingEndpoint, instrDef.TracingEndpoint)
assertFlagValue(t, flags, FlagTracingSampleRate, instrDef.TracingSampleRate)
assertFlagValue(t, flags, FlagTracingServiceName, instrDef.TracingServiceName)

// Logging flags (in persistent flags)
assertFlagValue(t, persistentFlags, FlagLogLevel, DefaultConfig().Log.Level)
assertFlagValue(t, persistentFlags, FlagLogFormat, "text")
assertFlagValue(t, persistentFlags, FlagLogTrace, false)
assertFlagValue(t, persistentFlags, FlagRootDir, DefaultRootDirWithName("test"))

// Signer flags
assertFlagValue(t, flags, FlagSignerPassphraseFile, "")
assertFlagValue(t, flags, FlagSignerType, "file")
assertFlagValue(t, flags, FlagSignerPath, DefaultConfig().Signer.SignerPath)
assertFlagValue(t, flags, FlagSignerKmsProvider, DefaultConfig().Signer.KMS.Provider)
assertFlagValue(t, flags, FlagSignerKmsAwsKeyID, DefaultConfig().Signer.KMS.AWS.KeyID)
assertFlagValue(t, flags, FlagSignerKmsAwsRegion, DefaultConfig().Signer.KMS.AWS.Region)
assertFlagValue(t, flags, FlagSignerKmsAwsProfile, DefaultConfig().Signer.KMS.AWS.Profile)
assertFlagValue(t, flags, FlagSignerKmsAwsTimeout, DefaultConfig().Signer.KMS.AWS.Timeout.Duration)
assertFlagValue(t, flags, FlagSignerKmsAwsMaxRetries, DefaultConfig().Signer.KMS.AWS.MaxRetries)
assertFlagValue(t, flags, FlagSignerKmsGcpKeyName, DefaultConfig().Signer.KMS.GCP.KeyName)
assertFlagValue(t, flags, FlagSignerKmsGcpCredentialsFile, DefaultConfig().Signer.KMS.GCP.CredentialsFile)
assertFlagValue(t, flags, FlagSignerKmsGcpTimeout, DefaultConfig().Signer.KMS.GCP.Timeout.Duration)
assertFlagValue(t, flags, FlagSignerKmsGcpMaxRetries, DefaultConfig().Signer.KMS.GCP.MaxRetries)

// RPC flags
assertFlagValue(t, flags, FlagRPCAddress, DefaultConfig().RPC.Address)
assertFlagValue(t, flags, FlagRPCEnableDAVisualization, DefaultConfig().RPC.EnableDAVisualization)

// Raft flags
assertFlagValue(t, flags, FlagRaftEnable, DefaultConfig().Raft.Enable)
assertFlagValue(t, flags, FlagRaftNodeID, DefaultConfig().Raft.NodeID)
assertFlagValue(t, flags, FlagRaftAddr, DefaultConfig().Raft.RaftAddr)
assertFlagValue(t, flags, FlagRaftDir, DefaultConfig().Raft.RaftDir)
assertFlagValue(t, flags, FlagRaftBootstrap, DefaultConfig().Raft.Bootstrap)
assertFlagValue(t, flags, FlagRaftPeers, DefaultConfig().Raft.Peers)
assertFlagValue(t, flags, FlagRaftSnapCount, DefaultConfig().Raft.SnapCount)
assertFlagValue(t, flags, FlagRaftSendTimeout, DefaultConfig().Raft.SendTimeout)
assertFlagValue(t, flags, FlagRaftHeartbeatTimeout, DefaultConfig().Raft.HeartbeatTimeout)
assertFlagValue(t, flags, FlagRaftLeaderLeaseTimeout, DefaultConfig().Raft.LeaderLeaseTimeout)
assertFlagValue(t, flags, FlagRaftElectionTimeout, DefaultConfig().Raft.ElectionTimeout)
assertFlagValue(t, flags, FlagRaftSnapshotThreshold, DefaultConfig().Raft.SnapshotThreshold)
assertFlagValue(t, flags, FlagRaftTrailingLogs, DefaultConfig().Raft.TrailingLogs)

// Pruning flags
assertFlagValue(t, flags, FlagPruningMode, DefaultConfig().Pruning.Mode)
assertFlagValue(t, flags, FlagPruningKeepRecent, DefaultConfig().Pruning.KeepRecent)
assertFlagValue(t, flags, FlagPruningInterval, DefaultConfig().Pruning.Interval.Duration)

// Count the number of flags we're explicitly checking
expectedFlagCount := 82 // Update this number if you add more flag checks above

// Get the actual number of flags (both regular and persistent)
actualFlagCount := 0
flags.VisitAll(func(flag *pflag.Flag) {
actualFlagCount++
})
persistentFlags.VisitAll(func(flag *pflag.Flag) {
actualFlagCount++
})

// Verify that the counts match
assert.Equal(
t,
expectedFlagCount,
actualFlagCount,
"Number of flags doesn't match. If you added a new flag, please update the test.",
)
// Verify that the counts match
assert.Equal(
t,
expectedFlagCount,
actualFlagCount,
"Number of flags doesn't match. If you added a new flag, please update the test.",
)
})
}

func TestLoad(t *testing.T) {
Expand Down
Loading