Skip to content

Commit b88ac0e

Browse files
author
zhangli20
committed
improve OnHeapSpillManager concurrency:
1. use ConcurrentHashMap to avoid creating multiple memmgr in one task. 2. avoid deadlock with a spill spills itself.
1 parent 9001938 commit b88ac0e

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpill.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,24 @@ case class OnHeapSpill(hsm: OnHeapSpillManager, id: Int) extends Logging {
2828
def diskIOTime: Long = spillBuf.diskIOTime
2929

3030
def write(buf: ByteBuffer): Unit = {
31+
var needSpill = false
3132
synchronized {
3233
spillBuf match {
3334
case _: MemBasedSpillBuf =>
3435
val acquiredMemory = hsm.acquireMemory(buf.capacity())
3536
if (acquiredMemory < buf.capacity()) { // cannot allocate memory, will spill buffer
3637
hsm.freeMemory(acquiredMemory)
37-
spillInternal()
38+
needSpill = true
3839
}
39-
case _ =>
40+
case _ => false
4041
}
42+
}
43+
44+
if (needSpill) {
45+
spillInternal()
46+
}
47+
48+
synchronized {
4149
spillBuf.write(buf)
4250
}
4351
}

spark-extension/src/main/scala/org/apache/spark/sql/blaze/memory/OnHeapSpillManager.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package org.apache.spark.sql.blaze.memory
1717

1818
import java.nio.ByteBuffer
1919

20+
import scala.collection.concurrent
2021
import scala.collection.mutable
2122
import scala.collection.mutable.ArrayBuffer
2223

@@ -180,10 +181,10 @@ class OnHeapSpillManager(taskContext: TaskContext)
180181
}
181182

182183
object OnHeapSpillManager extends Logging {
183-
val all: mutable.Map[Long, OnHeapSpillManager] = mutable.Map()
184+
val all: mutable.Map[Long, OnHeapSpillManager] = concurrent.TrieMap[Long, OnHeapSpillManager]()
184185

185186
def current: OnHeapSpillManager = {
186-
val taskContext = TaskContext.get
187-
all.getOrElseUpdate(taskContext.taskAttemptId(), new OnHeapSpillManager(taskContext))
187+
val tc = TaskContext.get
188+
all.getOrElseUpdate(tc.taskAttemptId(), new OnHeapSpillManager(tc))
188189
}
189190
}

0 commit comments

Comments
 (0)