Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ import org.apache.arrow.c.ArrowSchema
import org.apache.arrow.c.Data
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils
import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.blaze.BlazeConf
import org.apache.spark.sql.blaze.{BlazeConf, NativeHelper}
import org.apache.spark.sql.blaze.util.Using
import org.apache.spark.TaskContext

import java.security.PrivilegedExceptionAction

class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) {
private val maxBatchNumRows = BlazeConf.BATCH_SIZE.intConf()
private val maxBatchMemorySize = 1 << 24 // 16MB
Expand All @@ -53,6 +56,9 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) {

def exportNextBatch(exportArrowArrayPtr: Long): Boolean = {
val tc = TaskContext.get()
val currentUserInfo = UserGroupInformation.getCurrentUser
val nativeCurrentUser = NativeHelper.currentUser
val isNativeCurrentUser = currentUserInfo.equals(nativeCurrentUser)

if (tc != null && (tc.isCompleted() || tc.isInterrupted())) return false
if (!rowIter.hasNext) return false
Expand All @@ -64,11 +70,26 @@ class ArrowFFIExporter(rowIter: Iterator[InternalRow], schema: StructType) {
val arrowWriter = ArrowWriter.create(root)
var rowCount = 0

while (rowIter.hasNext
&& rowCount < maxBatchNumRows
&& batchAllocator.getAllocatedMemory < maxBatchMemorySize) {
arrowWriter.write(rowIter.next())
rowCount += 1
def processRows(): Unit = {
while (rowIter.hasNext
&& rowCount < maxBatchNumRows
&& batchAllocator.getAllocatedMemory < maxBatchMemorySize) {
arrowWriter.write(rowIter.next())
rowCount += 1
}
}
// if current user is native user, process rows directly
if (isNativeCurrentUser) {
processRows()
} else {
// otherwise, process rows as native user
nativeCurrentUser.doAs(
new PrivilegedExceptionAction[Unit] {
override def run(): Unit = {
processRows()
}
}
)
}
arrowWriter.finish()

Expand Down