Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,24 @@ case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) extends Logging {
def diskIOTime: Long = spillBuf.diskIOTime

def write(buf: ByteBuffer): Unit = {
var needSpill = false
synchronized {
spillBuf match {
case _: MemBasedSpillBuf =>
val acquiredMemory = hsm.acquireMemory(buf.capacity())
if (acquiredMemory < buf.capacity()) { // cannot allocate memory, will spill buffer
hsm.freeMemory(acquiredMemory)
spillInternal()
needSpill = true
}
case _ =>
}
}

if (needSpill) {
spillInternal()
}

synchronized {
spillBuf.write(buf)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.apache.spark.sql.blaze.memory

import java.nio.ByteBuffer

import scala.collection.concurrent
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -180,10 +181,10 @@ class OnHeapSpillManager(taskContext: TaskContext)
}

object OnHeapSpillManager extends Logging {
val all: mutable.Map[Long, OnHeapSpillManager] = mutable.Map()
val all: mutable.Map[Long, OnHeapSpillManager] = concurrent.TrieMap[Long, OnHeapSpillManager]()

def current: OnHeapSpillManager = {
val taskContext = TaskContext.get
all.getOrElseUpdate(taskContext.taskAttemptId(), new OnHeapSpillManager(taskContext))
val tc = TaskContext.get
all.getOrElseUpdate(tc.taskAttemptId(), new OnHeapSpillManager(tc))
}
}