Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,94 @@ class BlazeQuerySuite
}
}

test("repartition over MapType") {
withTable("t_map") {
sql("create table t_map using parquet as select map('a', '1', 'b', '2') as data_map")
val df = sql("SELECT /*+ repartition(10) */ data_map FROM t_map")
checkAnswer(df, Seq(Row(Map("a" -> "1", "b" -> "2"))))
}
}

test("repartition over MapType with ArrayType") {
withTable("t_map_struct") {
sql(
"create table t_map_struct using parquet as select named_struct('m', map('x', '1')) as data_struct")
val df = sql("SELECT /*+ repartition(10) */ data_struct FROM t_map_struct")
checkAnswer(df, Seq(Row(Row(Map("x" -> "1")))))
}
}

test("repartition over ArrayType with MapType") {
withTable("t_array_map") {
sql("""
|create table t_array_map using parquet as
|select array(map('k1', 1, 'k2', 2), map('k3', 3)) as array_of_map
|""".stripMargin)
val df = sql("SELECT /*+ repartition(10) */ array_of_map FROM t_array_map")
checkAnswer(df, Seq(Row(Seq(Map("k1" -> 1, "k2" -> 2), Map("k3" -> 3)))))
}
}

test("repartition over StructType with MapType") {
withTable("t_struct_map") {
sql("""
|create table t_struct_map using parquet as
|select named_struct('id', 101, 'metrics', map('ctr', 0.123, 'cvr', 0.045)) as user_metrics
|""".stripMargin)
val df = sql("SELECT /*+ repartition(10) */ user_metrics FROM t_struct_map")
checkAnswer(df, Seq(Row(Row(101, Map("ctr" -> 0.123, "cvr" -> 0.045)))))
}
}

test("repartition over MapType with StructType") {
withTable("t_map_struct_value") {
sql("""
|create table t_map_struct_value using parquet as
|select map(
| 'item1', named_struct('count', 3, 'score', 4.5),
| 'item2', named_struct('count', 7, 'score', 9.1)
|) as map_struct_value
|""".stripMargin)
val df = sql("SELECT /*+ repartition(10) */ map_struct_value FROM t_map_struct_value")
checkAnswer(df, Seq(Row(Map("item1" -> Row(3, 4.5), "item2" -> Row(7, 9.1)))))
}
}

test("repartition over nested MapType") {
withTable("t_nested_map") {
sql("""
|create table t_nested_map using parquet as
|select map(
| 'outer1', map('inner1', 10, 'inner2', 20),
| 'outer2', map('inner3', 30)
|) as nested_map
|""".stripMargin)
val df = sql("SELECT /*+ repartition(10) */ nested_map FROM t_nested_map")
checkAnswer(
df,
Seq(Row(
Map("outer1" -> Map("inner1" -> 10, "inner2" -> 20), "outer2" -> Map("inner3" -> 30)))))
}
}

test("repartition over ArrayType of StructType with MapType") {
withTable("t_array_struct_map") {
sql("""
|create table t_array_struct_map using parquet as
|select array(
| named_struct('name', 'user1', 'features', map('f1', 1.0, 'f2', 2.0)),
| named_struct('name', 'user2', 'features', map('f3', 3.5))
|) as user_feature_array
|""".stripMargin)
val df = sql("SELECT /*+ repartition(10) */ user_feature_array FROM t_array_struct_map")
checkAnswer(
df,
Seq(
Row(
Seq(Row("user1", Map("f1" -> 1.0f, "f2" -> 2.0f)), Row("user2", Map("f3" -> 3.5f))))))
}
}

test("log function with negative input") {
withTable("t1") {
sql("create table t1 using parquet as select -1 as c1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.blaze.BlazeConvertStrategy.convertStrategyTag
import org.apache.spark.sql.blaze.BlazeConvertStrategy.convertToNonNativeTag
import org.apache.spark.sql.blaze.BlazeConvertStrategy.isNeverConvert
import org.apache.spark.sql.blaze.BlazeConvertStrategy.joinSmallerSideTag
import org.apache.spark.sql.blaze.NativeConverters.{scalarTypeSupported, StubExpr}
import org.apache.spark.sql.blaze.NativeConverters.{roundRobinTypeSupported, scalarTypeSupported, StubExpr}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
Expand Down Expand Up @@ -283,20 +283,27 @@ object BlazeConverters extends Logging {
logDebug(s"Converting ShuffleExchangeExec: ${Shims.get.simpleStringWithNodeId(exec)}")

assert(
exec.outputPartitioning.numPartitions == 1 || exec.outputPartitioning
.isInstanceOf[HashPartitioning] || exec.outputPartitioning
.isInstanceOf[RoundRobinPartitioning] || exec.outputPartitioning
outputPartitioning.numPartitions == 1 || outputPartitioning
.isInstanceOf[HashPartitioning] || outputPartitioning
.isInstanceOf[RoundRobinPartitioning] || outputPartitioning
.isInstanceOf[RangePartitioning],
s"partitioning not supported: ${exec.outputPartitioning}")

if (exec.outputPartitioning.isInstanceOf[RangePartitioning]) {
val unsupportedOrderType = exec.outputPartitioning
.asInstanceOf[RangePartitioning]
.ordering
.find(e => !scalarTypeSupported(e.dataType))
assert(
unsupportedOrderType.isEmpty,
s"Unsupported order type in range partitioning: ${unsupportedOrderType.get}")
s"partitioning not supported: $outputPartitioning")

outputPartitioning match {
case partitioning: RangePartitioning =>
val unsupportedOrderType = partitioning.ordering
.find(e => !scalarTypeSupported(e.dataType))
assert(
unsupportedOrderType.isEmpty,
s"Unsupported order type in range partitioning: ${unsupportedOrderType.get}")
case _: RoundRobinPartitioning =>
val unsupportedTypeInRR =
exec.output.find(attr => !roundRobinTypeSupported(attr.dataType))
assert(
unsupportedTypeInRR.isEmpty,
s"Unsupported data type in $outputPartitioning: attribute=${unsupportedTypeInRR.get.name}" +
s", dataType=${unsupportedTypeInRR.get.dataType}")
case _ =>
}

val convertedChild = outputPartitioning match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ object NativeConverters extends Logging {
}
}

def roundRobinTypeSupported(dataType: DataType): Boolean = dataType match {
case MapType(_, _, _) => false
case ArrayType(elementType, _) => roundRobinTypeSupported(elementType)
case StructType(fields) => fields.forall(f => roundRobinTypeSupported(f.dataType))
case _ => true
}

def convertDataType(sparkDataType: DataType): pb.ArrowType = {
val arrowTypeBuilder = pb.ArrowType.newBuilder()
sparkDataType match {
Expand Down