Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ val scalaCollectionCompatVersion = "2.14.0"
val scalaMacrosVersion = "2.1.1"
val scalatestVersion = "3.2.19"
val shapelessVersion = "2.3.13"
val sparkeyVersion = "3.5.1"
val sparkeyVersion = "3.7.0"
val tensorFlowVersion = "1.1.0"
val tensorFlowMetadataVersion = "1.16.1"
val testContainersVersion = "0.44.1"
Expand Down Expand Up @@ -330,6 +330,10 @@ ThisBuild / mimaBinaryIssueFilters ++= Seq(
),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.beam.sdk.extensions.sorter.BufferedExternalSorter$Options"
),
// checkMemory replaced by HostMemoryTracker (private object, no external callers)
ProblemFilters.exclude[DirectMissingMethodProblem](
"com.spotify.scio.extra.sparkey.package#SparkeySideInput.checkMemory"
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Copyright 2026 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.sparkey

import com.sun.management.OperatingSystemMXBean
import java.lang.management.ManagementFactory
import java.util.concurrent.atomic.AtomicLong
import org.slf4j.LoggerFactory

/**
* Tracks memory budgets for sparkey readers on this JVM, covering both off-heap and on-heap memory.
* On Dataflow, the JVM heap is hardcoded to ~70% of worker memory, leaving only ~30% for off-heap
* use (page cache, OS, kernel, etc.).
*
* Sparkey readers are opened via mmap (off-heap) by default. When the off-heap budget is exhausted,
* readers can fall back to heap-backed mode. This tracker provides atomic budget claiming for both
* pools to coordinate across threads.
*
* Budget is claimed but never released — readers are cached for the JVM lifetime and never closed.
* If a reader fails to open partway through (e.g. a sharded sparkey with some shards already
* claimed), the budget for those shards is leaked. This is acceptable since a reader failure is
* fatal to the pipeline.
*/
private[sparkey] class HostMemoryTracker(offHeapBudget: Long, heapBudget: Long) {
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is technically not purely a sparkey thing, could also be used for other resources/sideinput types

private val logger = LoggerFactory.getLogger(this.getClass)

private val remainingOffHeap = new AtomicLong(offHeapBudget)
private val remainingHeap = new AtomicLong(heapBudget)

logger.info(
"Memory budgets — off-heap: {} bytes, heap: {} bytes",
Array[AnyRef](Long.box(offHeapBudget), Long.box(heapBudget)): _*
)

/**
* Atomically try to claim `bytes` from the off-heap budget. Returns true if the claim succeeded
* (enough budget remaining), false otherwise. On success, the budget is reduced by `bytes`. On
* failure, the budget is unchanged.
*/
def tryClaimOffHeap(bytes: Long): Boolean = tryClaim(remainingOffHeap, "off-heap", bytes)

/**
* Atomically try to claim `bytes` from the heap budget. Returns true if the claim succeeded
* (enough budget remaining), false otherwise.
*/
def tryClaimHeap(bytes: Long): Boolean = tryClaim(remainingHeap, "heap", bytes)

private def tryClaim(budget: AtomicLong, name: String, bytes: Long): Boolean = {
val prev =
budget.getAndAccumulate(bytes, (current, b) => if (current >= b) current - b else current)
if (prev >= bytes) {
logger.info(
"Claimed {} bytes of {} budget, {} bytes remaining",
Array[AnyRef](Long.box(bytes), name, Long.box(prev - bytes)): _*
)
true
} else {
logger.debug(
"Cannot claim {} bytes of {} budget, only {} bytes remaining",
Array[AnyRef](Long.box(bytes), name, Long.box(prev)): _*
)
false
}
}
}

private[sparkey] object HostMemoryTracker {
private val logger = LoggerFactory.getLogger(this.getClass)

// Reserve 2 GB of off-heap memory for OS, kernel structures, JVM class files, Beam shuffle, etc.
private val OffHeapReserveBytes: Long = 2L * 1024 * 1024 * 1024

// Reserve 10% of max heap (min 4 GB) for GC headroom and non-sparkey allocations.
private val HeapReserveBytes: Long = {
val maxHeap = Runtime.getRuntime.maxMemory()
Math.max(4L * 1024 * 1024 * 1024, (maxHeap * 0.1).toLong)
}

private val offHeapBudget: Long = {
val totalPhysical = ManagementFactory.getOperatingSystemMXBean
.asInstanceOf[OperatingSystemMXBean]
.getTotalPhysicalMemorySize
val maxHeap = Runtime.getRuntime.maxMemory()
Math.max(0, totalPhysical - maxHeap - OffHeapReserveBytes)
}

private val heapBudget: Long = {
val maxHeap = Runtime.getRuntime.maxMemory()
Math.max(0, maxHeap - HeapReserveBytes)
}

logger.info(
"Host memory — off-heap reserve: {}, heap reserve: {}",
Array[AnyRef](Long.box(OffHeapReserveBytes), Long.box(HeapReserveBytes)): _*
)

val instance: HostMemoryTracker = new HostMemoryTracker(offHeapBudget, heapBudget)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ import java.io.File
import java.net.URI
import com.spotify.scio.util.{RemoteFileUtil, ScioUtil}
import com.spotify.scio.extra.sparkey.instances.ShardedSparkeyReader
import com.spotify.sparkey.Sparkey
import com.spotify.sparkey.SparkeyReader
import com.spotify.sparkey.{Sparkey, SparkeyReader}
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.{EmptyMatchTreatment, MatchResult, ResourceId}
import org.apache.beam.sdk.options.PipelineOptions
import org.slf4j.LoggerFactory

import java.nio.file.Path
import java.util.UUID
Expand Down Expand Up @@ -80,7 +80,7 @@ case class SparkeyUri(path: String) {
if (!isSharded) {
val path =
if (isLocal) new File(basePath) else downloadRemoteUris(Seq(basePath), rfu).head.toFile
Sparkey.open(path)
ShardedSparkeyUri.openWithMemoryTracking(path)
} else {
val (basePaths, numShards) =
ShardedSparkeyUri.basePathsAndCount(EmptyMatchTreatment.DISALLOW, globExpression)
Expand Down Expand Up @@ -113,6 +113,43 @@ case class SparkeyUri(path: String) {
}

private[sparkey] object ShardedSparkeyUri {
private val logger = LoggerFactory.getLogger(this.getClass)

// Conservative fixed overhead per reader for JVM array/object headers (heap) and page cache
// metadata (off-heap). The exact overhead is hard to measure; we prefer to overestimate slightly
// rather than risk overcommitting memory.
private val PerReaderOverheadBytes: Long = 64 * 1024

/**
* Open a sparkey reader with memory-aware strategy: try off-heap (mmap) first, fall back to
* on-heap (byte[]) if page cache budget is exhausted, or use off-heap with a warning if neither
* budget has room.
*/
private[sparkey] def openWithMemoryTracking(
file: File,
tracker: HostMemoryTracker = HostMemoryTracker.instance
): SparkeyReader = {
val indexFile = Sparkey.getIndexFile(file)
val logFile = Sparkey.getLogFile(file)
val totalBytes = indexFile.length() + logFile.length() + PerReaderOverheadBytes

if (tracker.tryClaimOffHeap(totalBytes)) {
Sparkey.open(indexFile)
} else if (tracker.tryClaimHeap(totalBytes)) {
logger.info(
"Opening {} on heap ({} bytes)",
Array[AnyRef](indexFile.getName, Long.box(totalBytes)): _*
)
Sparkey.reader().file(indexFile).useHeap(true).open()
} else {
logger.warn(
"Neither off-heap nor heap budget available for {} ({} bytes), falling back to mmap",
Array[AnyRef](indexFile.getName, Long.box(totalBytes)): _*
)
Sparkey.open(indexFile)
}
}

private[sparkey] def shardsFromPath(path: String): (Short, Short) = {
"part-([0-9]+)-of-([0-9]+)".r
.findFirstMatchIn(path)
Expand All @@ -131,7 +168,7 @@ private[sparkey] object ShardedSparkeyUri {
): Map[Short, SparkeyReader] =
localBasePaths.iterator.map { path =>
val (shardIndex, _) = shardsFromPath(path)
val reader = Sparkey.open(new File(path + ".spi"))
val reader = openWithMemoryTracking(new File(path + ".spi"))
(shardIndex, reader)
}.toMap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import org.apache.beam.sdk.util.CoderUtils
import org.apache.beam.sdk.values.PCollectionView
import org.slf4j.LoggerFactory

import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.util.hashing.MurmurHash3

/**
Expand Down Expand Up @@ -545,12 +547,11 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
mapFn: SparkeyReader => T
) extends SideInput[T] {
override def updateCacheOnGlobalWindow: Boolean = false
override def get[I, O](context: DoFn[I, O]#ProcessContext): T =
mapFn(
SparkeySideInput.checkMemory(
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions))
)
)
override def get[I, O](context: DoFn[I, O]#ProcessContext): T = {
val uri = context.sideInput(view)
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
mapFn(SparkeySideInput.getOrCreateReader(uri, rfu))
}
}

/**
Expand All @@ -561,8 +562,10 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
extends SideInput[SparkeyMap[K, V]] {
override def updateCacheOnGlobalWindow: Boolean = false
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeyMap[K, V] = {
val uri = context.sideInput(view)
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
new SparkeyMap(
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions)),
SparkeySideInput.getOrCreateReader(uri, rfu),
CoderMaterializer.beam(context.getPipelineOptions, Coder[K]),
CoderMaterializer.beam(context.getPipelineOptions, Coder[V])
)
Expand All @@ -576,29 +579,56 @@ package object sparkey extends SparkeyReaderInstances with SparkeyCoders {
private class LargeSetSideInput[K: Coder](val view: PCollectionView[SparkeyUri])
extends SideInput[SparkeySet[K]] {
override def updateCacheOnGlobalWindow: Boolean = false
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeySet[K] =
override def get[I, O](context: DoFn[I, O]#ProcessContext): SparkeySet[K] = {
val uri = context.sideInput(view)
val rfu = RemoteFileUtil.create(context.getPipelineOptions)
new SparkeySet(
context.sideInput(view).getReader(RemoteFileUtil.create(context.getPipelineOptions)),
SparkeySideInput.getOrCreateReader(uri, rfu),
CoderMaterializer.beam(context.getPipelineOptions, Coder[K])
)
}
}

// Readers are cached for the lifetime of the JVM and never closed. This is intentional:
// Beam side inputs have no close/teardown lifecycle, and in batch pipelines the JVM exits
// when the pipeline finishes. The HostMemoryTracker budget is similarly never released.
// Note: the cache is keyed by URI path only. If the same path is rewritten with different
// data and a new pipeline is run in the same JVM (e.g. DirectRunner, REPL), stale readers
// will be returned. This is acceptable for Dataflow batch (one pipeline per JVM).
private object SparkeySideInput {
private val logger = LoggerFactory.getLogger(this.getClass)
def checkMemory(reader: SparkeyReader): SparkeyReader = {
val memoryBytes = java.lang.management.ManagementFactory.getOperatingSystemMXBean
.asInstanceOf[com.sun.management.OperatingSystemMXBean]
.getTotalPhysicalMemorySize
if (reader.getTotalBytes > memoryBytes) {
logger.warn(
"Sparkey size {} > total memory {}, look up performance will be severely degraded. " +
"Increase memory or use faster SSD drives.",
reader.getTotalBytes,
memoryBytes
)

// Small dedicated pool for loading readers concurrently. Multiple side inputs can be
// opened in parallel (useful when on-heap loading is CPU-bound), without risking
// starvation of the common ForkJoinPool.
private val threadCount = new AtomicInteger()
private val loaderPool = Executors.newFixedThreadPool(
4,
r => {
val t = new Thread(r)
t.setDaemon(true)
t.setName(s"sparkey-reader-loader-${threadCount.getAndIncrement()}")
t
}
reader
}
)

private val readerCache =
new ConcurrentHashMap[String, CompletableFuture[SparkeyReader]]()

def getOrCreateReader(uri: SparkeyUri, rfu: RemoteFileUtil): SparkeyReader =
readerCache
.computeIfAbsent(
uri.path,
_ =>
CompletableFuture.supplyAsync(
() => {
logger.info("Opening shared sparkey reader for {}", uri.path)
uri.getReader(rfu)
},
loaderPool
)
)
.join()
}

sealed trait SparkeyWritable[K, V] extends Serializable {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2026 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.sparkey

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class HostMemoryTrackerTest extends AnyFlatSpec with Matchers {

"HostMemoryTracker" should "claim off-heap budget when sufficient" in {
val tracker = new HostMemoryTracker(offHeapBudget = 100, heapBudget = 50)
tracker.tryClaimOffHeap(60) shouldBe true
tracker.tryClaimOffHeap(40) shouldBe true
}

it should "reject off-heap claim when insufficient" in {
val tracker = new HostMemoryTracker(offHeapBudget = 100, heapBudget = 50)
tracker.tryClaimOffHeap(60) shouldBe true
tracker.tryClaimOffHeap(50) shouldBe false
}

it should "claim heap budget when sufficient" in {
val tracker = new HostMemoryTracker(offHeapBudget = 0, heapBudget = 100)
tracker.tryClaimHeap(60) shouldBe true
tracker.tryClaimHeap(40) shouldBe true
}

it should "reject heap claim when insufficient" in {
val tracker = new HostMemoryTracker(offHeapBudget = 0, heapBudget = 100)
tracker.tryClaimHeap(60) shouldBe true
tracker.tryClaimHeap(50) shouldBe false
}

it should "handle zero budgets" in {
val tracker = new HostMemoryTracker(offHeapBudget = 0, heapBudget = 0)
tracker.tryClaimOffHeap(1) shouldBe false
tracker.tryClaimHeap(1) shouldBe false
}

it should "handle exact budget claims" in {
val tracker = new HostMemoryTracker(offHeapBudget = 100, heapBudget = 50)
tracker.tryClaimOffHeap(100) shouldBe true
tracker.tryClaimOffHeap(1) shouldBe false
}

it should "track off-heap and heap budgets independently" in {
val tracker = new HostMemoryTracker(offHeapBudget = 100, heapBudget = 100)
tracker.tryClaimOffHeap(100) shouldBe true
tracker.tryClaimOffHeap(1) shouldBe false
// heap should still be available
tracker.tryClaimHeap(100) shouldBe true
tracker.tryClaimHeap(1) shouldBe false
}

"HostMemoryTracker.instance" should "exist as a singleton" in {
HostMemoryTracker.instance should not be null
// same reference
HostMemoryTracker.instance shouldBe theSameInstanceAs(HostMemoryTracker.instance)
}
}
Loading
Loading