Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions beacon-chain/state/state-native/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ go_library(
"gloas.go",
"hasher.go",
"log.go",
"merkle_layers.go",
"multi_value_slices.go",
"proofs.go",
"readonly_validator.go",
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/state/state-native/beacon_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type BeaconState struct {
stateFieldLeaves map[types.FieldIndex]*fieldtrie.FieldTrie
rebuildTrie map[types.FieldIndex]bool
valMapHandler *stateutil.ValidatorMapHandler
merkleLayers [][][]byte
merkle *sharedMerkleLayers
sharedFieldReferences map[types.FieldIndex]*stateutil.Reference
}

Expand Down
66 changes: 66 additions & 0 deletions beacon-chain/state/state-native/merkle_layers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package state_native

import (
"sync"

"github.com/OffchainLabs/prysm/v7/beacon-chain/state/stateutil"
)

// sharedMerkleLayers wraps the beacon state's top-level Merkle tree layers with
// reference counting so that Copy can share them instead of deep-copying.
// All access is protected by the owning BeaconState's lock. This struct does
// not carry its own mutex.
type sharedMerkleLayers struct {
layers [][][]byte
ref *stateutil.Reference
mu sync.Mutex
}

// newSharedMerkleLayers wraps existing layers in a ref-counted container.
func newSharedMerkleLayers(layers [][][]byte) *sharedMerkleLayers {
return &sharedMerkleLayers{
layers: layers,
ref: stateutil.NewRef(1),
}
}

// copy increments the reference count and returns the same pointer, making
// BeaconState.Copy O(1) for this field. The caller must call ensureUnique
// before mutating the layers.
func (s *sharedMerkleLayers) copy() *sharedMerkleLayers {
s.ref.AddRef()
return s
}

// ensureUnique deep-copies the layers if this instance is shared (refs > 1)
// and returns the (possibly new) sharedMerkleLayers to use. The caller must
// replace its field with the returned value:
//
// b.merkle = b.merkle.ensureUnique()
func (s *sharedMerkleLayers) ensureUnique() *sharedMerkleLayers {
s.mu.Lock()
defer s.mu.Unlock()

if s.ref.Refs() == 1 {
return s
}

// Shared. Deep-copy and detach.
s.ref.MinusRef()

newLayers := make([][][]byte, len(s.layers))
for i, layer := range s.layers {
newLayers[i] = make([][]byte, len(layer))
for j, content := range layer {
newLayers[i][j] = make([]byte, len(content))
copy(newLayers[i][j], content)
}
}

return newSharedMerkleLayers(newLayers)
}

// release decrements the reference count. Called during finalizer cleanup.
func (s *sharedMerkleLayers) release() {
s.ref.MinusRef()
}
2 changes: 1 addition & 1 deletion beacon-chain/state/state-native/proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (b *BeaconState) proofByFieldIndex(ctx context.Context, f types.FieldIndex)
if err := b.recomputeDirtyFields(ctx); err != nil {
return nil, err
}
return trie.ProofFromMerkleLayers(b.merkleLayers, f.RealPosition()), nil
return trie.ProofFromMerkleLayers(b.merkle.layers, f.RealPosition()), nil
}

func (b *BeaconState) validateFieldIndex(f types.FieldIndex) error {
Expand Down
5 changes: 2 additions & 3 deletions beacon-chain/state/state-native/setters_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ func (b *BeaconState) AppendHistoricalSummaries(summary *ethpb.HistoricalSummary
// hold the lock before calling this method.
func (b *BeaconState) recomputeRoot(idx int) {
hashFunc := hash.CustomSHA256Hasher()
layers := b.merkleLayers
layers := b.merkle.layers
// The merkle tree structure looks as follows:
// [[r1, r2, r3, r4], [parent1, parent2], [root]]
// Using information about the index which changed, idx, we recompute
// only its branch up the tree.
currentIndex := idx
root := b.merkleLayers[0][idx]
root := layers[0][idx]
for i := 0; i < len(layers)-1; i++ {
isLeft := currentIndex%2 == 0
neighborIdx := currentIndex ^ 1
Expand All @@ -187,7 +187,6 @@ func (b *BeaconState) recomputeRoot(idx int) {
layers[i+1][parentIdx] = root
currentIndex = parentIdx
}
b.merkleLayers = layers
}

func (b *BeaconState) markFieldAsDirty(field types.FieldIndex) {
Expand Down
2 changes: 1 addition & 1 deletion beacon-chain/state/state-native/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
}
_, err = st.HashTreeRoot(t.Context())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][types.Balances])
newRt := bytesutil.ToBytes32(st.merkle.layers[0][types.Balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances())
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
Expand Down
27 changes: 14 additions & 13 deletions beacon-chain/state/state-native/state_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -1031,18 +1031,12 @@ func (b *BeaconState) Copy() state.BeaconState {
}
}

if b.merkleLayers != nil {
dst.merkleLayers = make([][][]byte, len(b.merkleLayers))
for i, layer := range b.merkleLayers {
dst.merkleLayers[i] = make([][]byte, len(layer))
for j, content := range layer {
dst.merkleLayers[i][j] = make([]byte, len(content))
copy(dst.merkleLayers[i][j], content)
}
}
if b.merkle != nil {
dst.merkle = b.merkle.copy()
}

state.Count.Inc()

// Finalizer runs when dst is being destroyed in garbage collection.
runtime.SetFinalizer(dst, finalizerCleanup)
return dst
Expand All @@ -1062,22 +1056,22 @@ func (b *BeaconState) HashTreeRoot(ctx context.Context) ([32]byte, error) {
if err := b.recomputeDirtyFields(ctx); err != nil {
return [32]byte{}, err
}
return bytesutil.ToBytes32(b.merkleLayers[len(b.merkleLayers)-1][0]), nil
return bytesutil.ToBytes32(b.merkle.layers[len(b.merkle.layers)-1][0]), nil
}

// Initializes the Merkle layers for the beacon state if they are empty.
//
// WARNING: Caller must acquire the mutex before using.
func (b *BeaconState) initializeMerkleLayers(ctx context.Context) error {
if len(b.merkleLayers) > 0 {
if b.merkle != nil && len(b.merkle.layers) > 0 {
return nil
}
fieldRoots, err := ComputeFieldRootsWithHasher(ctx, b)
if err != nil {
return err
}
layers := stateutil.Merkleize(fieldRoots)
b.merkleLayers = layers
b.merkle = newSharedMerkleLayers(layers)
switch b.version {
case version.Phase0:
b.dirtyFields = make(map[types.FieldIndex]bool, params.BeaconConfig().BeaconStateFieldCount)
Expand Down Expand Up @@ -1106,13 +1100,17 @@ func (b *BeaconState) initializeMerkleLayers(ctx context.Context) error {
//
// WARNING: Caller must acquire the mutex before using.
func (b *BeaconState) recomputeDirtyFields(ctx context.Context) error {
if len(b.dirtyFields) > 0 {
b.merkle = b.merkle.ensureUnique()
}

for field := range b.dirtyFields {
root, err := b.rootSelector(ctx, field)
if err != nil {
return err
}
idx := field.RealPosition()
b.merkleLayers[0][idx] = root[:]
b.merkle.layers[0][idx] = root[:]
b.recomputeRoot(idx)
delete(b.dirtyFields, field)
}
Expand Down Expand Up @@ -1473,6 +1471,9 @@ func finalizerCleanup(b *BeaconState) {
if b.validatorsMultiValue != nil {
b.validatorsMultiValue.Detach(b)
}
if b.merkle != nil {
b.merkle.release()
}

state.Count.Sub(1)
}
Expand Down
2 changes: 2 additions & 0 deletions changelog/manu-copy-on-write.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
### Changed
- Defer deep copy of merkle layers until mutation via lazy copy-on-write.
Loading