diff --git a/beacon-chain/state/state-native/BUILD.bazel b/beacon-chain/state/state-native/BUILD.bazel index 90ab3929a44b..f37e579618f2 100644 --- a/beacon-chain/state/state-native/BUILD.bazel +++ b/beacon-chain/state/state-native/BUILD.bazel @@ -27,6 +27,7 @@ go_library( "gloas.go", "hasher.go", "log.go", + "merkle_layers.go", "multi_value_slices.go", "proofs.go", "readonly_validator.go", diff --git a/beacon-chain/state/state-native/beacon_state.go b/beacon-chain/state/state-native/beacon_state.go index 2ebd9ebc6797..5d2c2ca39f74 100644 --- a/beacon-chain/state/state-native/beacon_state.go +++ b/beacon-chain/state/state-native/beacon_state.go @@ -87,7 +87,7 @@ type BeaconState struct { stateFieldLeaves map[types.FieldIndex]*fieldtrie.FieldTrie rebuildTrie map[types.FieldIndex]bool valMapHandler *stateutil.ValidatorMapHandler - merkleLayers [][][]byte + merkle *sharedMerkleLayers sharedFieldReferences map[types.FieldIndex]*stateutil.Reference } diff --git a/beacon-chain/state/state-native/merkle_layers.go b/beacon-chain/state/state-native/merkle_layers.go new file mode 100644 index 000000000000..ab9ebe225f7d --- /dev/null +++ b/beacon-chain/state/state-native/merkle_layers.go @@ -0,0 +1,66 @@ +package state_native + +import ( + "sync" + + "github.com/OffchainLabs/prysm/v7/beacon-chain/state/stateutil" +) + +// sharedMerkleLayers wraps the beacon state's top-level Merkle tree layers with +// reference counting so that Copy can share them instead of deep-copying. +// All access is protected by the owning BeaconState's lock. This struct does +// not carry its own mutex. +type sharedMerkleLayers struct { + layers [][][]byte + ref *stateutil.Reference + mu sync.Mutex +} + +// newSharedMerkleLayers wraps existing layers in a ref-counted container. +func newSharedMerkleLayers(layers [][][]byte) *sharedMerkleLayers { + return &sharedMerkleLayers{ + layers: layers, + ref: stateutil.NewRef(1), + } +} + +// copy increments the reference count and returns the same pointer, making +// BeaconState.Copy O(1) for this field. The caller must call ensureUnique +// before mutating the layers. +func (s *sharedMerkleLayers) copy() *sharedMerkleLayers { + s.ref.AddRef() + return s +} + +// ensureUnique deep-copies the layers if this instance is shared (refs > 1) +// and returns the (possibly new) sharedMerkleLayers to use. The caller must +// replace its field with the returned value: +// +// b.merkle = b.merkle.ensureUnique() +func (s *sharedMerkleLayers) ensureUnique() *sharedMerkleLayers { + s.mu.Lock() + defer s.mu.Unlock() + + if s.ref.Refs() == 1 { + return s + } + + // Shared. Deep-copy and detach. + s.ref.MinusRef() + + newLayers := make([][][]byte, len(s.layers)) + for i, layer := range s.layers { + newLayers[i] = make([][]byte, len(layer)) + for j, content := range layer { + newLayers[i][j] = make([]byte, len(content)) + copy(newLayers[i][j], content) + } + } + + return newSharedMerkleLayers(newLayers) +} + +// release decrements the reference count. Called during finalizer cleanup. +func (s *sharedMerkleLayers) release() { + s.ref.MinusRef() +} diff --git a/beacon-chain/state/state-native/proofs.go b/beacon-chain/state/state-native/proofs.go index 11243f7543b4..4bc9612e7e6e 100644 --- a/beacon-chain/state/state-native/proofs.go +++ b/beacon-chain/state/state-native/proofs.go @@ -93,7 +93,7 @@ func (b *BeaconState) proofByFieldIndex(ctx context.Context, f types.FieldIndex) if err := b.recomputeDirtyFields(ctx); err != nil { return nil, err } - return trie.ProofFromMerkleLayers(b.merkleLayers, f.RealPosition()), nil + return trie.ProofFromMerkleLayers(b.merkle.layers, f.RealPosition()), nil } func (b *BeaconState) validateFieldIndex(f types.FieldIndex) error { diff --git a/beacon-chain/state/state-native/setters_misc.go b/beacon-chain/state/state-native/setters_misc.go index 4a45081878d9..74489da61984 100644 --- a/beacon-chain/state/state-native/setters_misc.go +++ b/beacon-chain/state/state-native/setters_misc.go @@ -160,13 +160,13 @@ func (b *BeaconState) AppendHistoricalSummaries(summary *ethpb.HistoricalSummary // hold the lock before calling this method. func (b *BeaconState) recomputeRoot(idx int) { hashFunc := hash.CustomSHA256Hasher() - layers := b.merkleLayers + layers := b.merkle.layers // The merkle tree structure looks as follows: // [[r1, r2, r3, r4], [parent1, parent2], [root]] // Using information about the index which changed, idx, we recompute // only its branch up the tree. currentIndex := idx - root := b.merkleLayers[0][idx] + root := layers[0][idx] for i := 0; i < len(layers)-1; i++ { isLeft := currentIndex%2 == 0 neighborIdx := currentIndex ^ 1 @@ -187,7 +187,6 @@ func (b *BeaconState) recomputeRoot(idx int) { layers[i+1][parentIdx] = root currentIndex = parentIdx } - b.merkleLayers = layers } func (b *BeaconState) markFieldAsDirty(field types.FieldIndex) { diff --git a/beacon-chain/state/state-native/state_test.go b/beacon-chain/state/state-native/state_test.go index e6e26266740e..7147a07e5742 100644 --- a/beacon-chain/state/state-native/state_test.go +++ b/beacon-chain/state/state-native/state_test.go @@ -307,7 +307,7 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) { } _, err = st.HashTreeRoot(t.Context()) assert.NoError(t, err) - newRt := bytesutil.ToBytes32(st.merkleLayers[0][types.Balances]) + newRt := bytesutil.ToBytes32(st.merkle.layers[0][types.Balances]) wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances()) assert.NoError(t, err) assert.Equal(t, wantedRt, newRt, "state roots are unequal") diff --git a/beacon-chain/state/state-native/state_trie.go b/beacon-chain/state/state-native/state_trie.go index ba4091556cd2..e43a0b8fa7ee 100644 --- a/beacon-chain/state/state-native/state_trie.go +++ b/beacon-chain/state/state-native/state_trie.go @@ -1031,18 +1031,12 @@ func (b *BeaconState) Copy() state.BeaconState { } } - if b.merkleLayers != nil { - dst.merkleLayers = make([][][]byte, len(b.merkleLayers)) - for i, layer := range b.merkleLayers { - dst.merkleLayers[i] = make([][]byte, len(layer)) - for j, content := range layer { - dst.merkleLayers[i][j] = make([]byte, len(content)) - copy(dst.merkleLayers[i][j], content) - } - } + if b.merkle != nil { + dst.merkle = b.merkle.copy() } state.Count.Inc() + // Finalizer runs when dst is being destroyed in garbage collection. runtime.SetFinalizer(dst, finalizerCleanup) return dst @@ -1062,14 +1056,14 @@ func (b *BeaconState) HashTreeRoot(ctx context.Context) ([32]byte, error) { if err := b.recomputeDirtyFields(ctx); err != nil { return [32]byte{}, err } - return bytesutil.ToBytes32(b.merkleLayers[len(b.merkleLayers)-1][0]), nil + return bytesutil.ToBytes32(b.merkle.layers[len(b.merkle.layers)-1][0]), nil } // Initializes the Merkle layers for the beacon state if they are empty. // // WARNING: Caller must acquire the mutex before using. func (b *BeaconState) initializeMerkleLayers(ctx context.Context) error { - if len(b.merkleLayers) > 0 { + if b.merkle != nil && len(b.merkle.layers) > 0 { return nil } fieldRoots, err := ComputeFieldRootsWithHasher(ctx, b) @@ -1077,7 +1071,7 @@ func (b *BeaconState) initializeMerkleLayers(ctx context.Context) error { return err } layers := stateutil.Merkleize(fieldRoots) - b.merkleLayers = layers + b.merkle = newSharedMerkleLayers(layers) switch b.version { case version.Phase0: b.dirtyFields = make(map[types.FieldIndex]bool, params.BeaconConfig().BeaconStateFieldCount) @@ -1106,13 +1100,17 @@ func (b *BeaconState) initializeMerkleLayers(ctx context.Context) error { // // WARNING: Caller must acquire the mutex before using. func (b *BeaconState) recomputeDirtyFields(ctx context.Context) error { + if len(b.dirtyFields) > 0 { + b.merkle = b.merkle.ensureUnique() + } + for field := range b.dirtyFields { root, err := b.rootSelector(ctx, field) if err != nil { return err } idx := field.RealPosition() - b.merkleLayers[0][idx] = root[:] + b.merkle.layers[0][idx] = root[:] b.recomputeRoot(idx) delete(b.dirtyFields, field) } @@ -1473,6 +1471,9 @@ func finalizerCleanup(b *BeaconState) { if b.validatorsMultiValue != nil { b.validatorsMultiValue.Detach(b) } + if b.merkle != nil { + b.merkle.release() + } state.Count.Sub(1) } diff --git a/changelog/manu-copy-on-write.md b/changelog/manu-copy-on-write.md new file mode 100644 index 000000000000..b843727ca7d4 --- /dev/null +++ b/changelog/manu-copy-on-write.md @@ -0,0 +1,2 @@ +### Changed +- Defer deep copy of merkle layers until mutation via lazy copy-on-write.