Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ public enum BlazeConf {
/// enable extra metrics of input batch statistics
INPUT_BATCH_STATISTICS_ENABLE("spark.blaze.enableInputBatchStatistics", true),

/// supports UDAF and other aggregate functions not implemented
UDAF_CONVERT_ENABLE("spark.blaze.enable.udaf", false),

/// ignore corrupted input files
IGNORE_CORRUPTED_FILES("spark.files.ignoreCorruptFiles", false),

Expand Down Expand Up @@ -88,7 +91,10 @@ public enum BlazeConf {
// suggested memory size for k-way merging
// use smaller batch memory size for kway merging since there will be multiple
// batches in memory at the same time
SUGGESTED_BATCH_MEM_SIZE_KWAY_MERGE("spark.blaze.suggested.batch.memSize.multiwayMerging", 1048576);
SUGGESTED_BATCH_MEM_SIZE_KWAY_MERGE("spark.blaze.suggested.batch.memSize.multiwayMerging", 1048576),

// TypedImperativeAggregate one row mem use size
SUGGESTED_UDAF_ROW_MEM_USAGE("spark.blaze.suggested.udaf.memUsedSize", 64);

public final String key;
private final Object defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1161,48 +1161,54 @@ object NativeConverters extends Logging {
case Some(converted) => return converted
case _ =>
}
// other udaf aggFunction
aggBuilder.setAggFunction(pb.AggFunction.UDAF)
val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]()
if (BlazeConf.UDAF_CONVERT_ENABLE.booleanConf()) {
// other udaf aggFunction
aggBuilder.setAggFunction(pb.AggFunction.UDAF)
val convertedChildren = mutable.LinkedHashMap[pb.PhysicalExprNode, BoundReference]()

val bound = udaf match {
case declarativeAggregate: DeclarativeAggregate =>
declarativeAggregate.mapChildren { p =>
val convertedChild = convertExpr(p)
val nextBindIndex =
convertedChildren.size + declarativeAggregate.inputAggBufferAttributes.length
convertedChildren.getOrElseUpdate(
convertedChild,
BoundReference(nextBindIndex, p.dataType, p.nullable))
}
case imperativeAggregate: ImperativeAggregate =>
imperativeAggregate.mapChildren { p =>
val convertedChild = convertExpr(p)
val nextBindIndex = convertedChildren.size
convertedChildren.getOrElseUpdate(
convertedChild,
BoundReference(nextBindIndex, p.dataType, p.nullable))
}
}

val bound = udaf match {
case declarativeAggregate: DeclarativeAggregate =>
declarativeAggregate.mapChildren { p =>
val convertedChild = convertExpr(p)
val nextBindIndex =
convertedChildren.size + declarativeAggregate.inputAggBufferAttributes.length
convertedChildren.getOrElseUpdate(
convertedChild,
BoundReference(nextBindIndex, p.dataType, p.nullable))
}
case imperativeAggregate: ImperativeAggregate =>
imperativeAggregate.mapChildren { p =>
val convertedChild = convertExpr(p)
val nextBindIndex = convertedChildren.size
convertedChildren.getOrElseUpdate(
convertedChild,
BoundReference(nextBindIndex, p.dataType, p.nullable))
}
}
val paramsSchema = StructType(
convertedChildren.values
.map(ref => StructField("", ref.dataType, ref.nullable))
.toSeq)

val paramsSchema = StructType(
convertedChildren.values
.map(ref => StructField("", ref.dataType, ref.nullable))
.toSeq)
val serialized =
serializeExpression(
bound.asInstanceOf[AggregateFunction with Serializable],
paramsSchema)

val serialized =
serializeExpression(
bound.asInstanceOf[AggregateFunction with Serializable],
paramsSchema)
aggBuilder.setUdaf(
pb.AggUdaf
.newBuilder()
.setSerialized(ByteString.copyFrom(serialized))
.setInputSchema(NativeConverters.convertSchema(paramsSchema))
.setReturnType(convertDataType(bound.dataType))
.setReturnNullable(bound.nullable))
aggBuilder.addAllChildren(convertedChildren.keys.asJava)
} else {
throw new NotImplementedError(s"unsupported aggregate expression: (${e.getClass}) $e," +
s" set spark.blaze.enable.udaf true to enable")
}

aggBuilder.setUdaf(
pb.AggUdaf
.newBuilder()
.setSerialized(ByteString.copyFrom(serialized))
.setInputSchema(NativeConverters.convertSchema(paramsSchema))
.setReturnType(convertDataType(bound.dataType))
.setReturnNullable(bound.nullable))
aggBuilder.addAllChildren(convertedChildren.keys.asJava)
}
pb.PhysicalExprNode
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ package org.apache.spark.sql.blaze
import java.io.ByteArrayOutputStream
import java.io.DataOutputStream
import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer

import org.apache.arrow.c.ArrowArray
import org.apache.arrow.c.Data
import org.apache.arrow.vector.VectorSchemaRoot
Expand Down Expand Up @@ -242,7 +240,7 @@ class DeclarativeEvaluator(agg: DeclarativeAggregate, inputAttributes: Seq[Attri
}

override def merge(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow = {
merger(joiner(row1, row2))
merger(joiner(row1, row2)).copy()
}

override def eval(row: UnsafeRow): UnsafeRow = {
Expand Down Expand Up @@ -281,10 +279,11 @@ class TypedImperativeEvaluator[B](agg: TypedImperativeAggregate[B])
extends AggregateEvaluator[B] {
private val evalRow = InternalRow(0)

private val rowMemUsage = BlazeConf.SUGGESTED_UDAF_ROW_MEM_USAGE.intConf()
override def createEmptyColumn(): BufferRowsColumn[B] = {
new BufferRowsColumn[B]() {
override def getRowMemUsage(row: B): Int = {
64 // estimated size of object
rowMemUsage // estimated size of object
}

override def update(i: Int, updater: B => B): Unit = {
Expand Down
Loading