diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index e12adb5a3..ee992b906 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto @@ -53,11 +53,6 @@ message PhysicalPlanNode { } } -enum JoinConstraint { - ON = 0; - USING = 1; -} - // physical expressions message PhysicalExprNode { oneof ExprType { @@ -387,6 +382,12 @@ message ScanLimit { uint32 limit = 1; } +message ColumnStats { + ScalarValue min_value = 1; + ScalarValue max_value = 2; + uint32 null_count = 3; + uint32 distinct_count = 4; +} message Statistics { int64 num_rows = 1; @@ -636,7 +637,7 @@ message PhysicalRoundRobinRepartition { message PhysicalRangeRepartition { SortExecNode sort_expr = 1; uint64 partition_count = 2; - repeated ScalarValue list_value= 3; + repeated ScalarValue list_value = 3; } @@ -717,52 +718,6 @@ message PartitionId { uint32 partition_id = 4; } -message PartitionStats { - int64 num_rows = 1; - int64 num_batches = 2; - int64 num_bytes = 3; - repeated ColumnStats column_stats = 4; -} - -message ColumnStats { - ScalarValue min_value = 1; - ScalarValue max_value = 2; - uint32 null_count = 3; - uint32 distinct_count = 4; -} - -message RunningTask { - string executor_id = 1; -} - -message FailedTask { - string error = 1; -} - -message CompletedTask { - string executor_id = 1; - // TODO tasks are currently always shuffle writes but this will not always be the case - // so we might want to think about some refactoring of the task definitions - repeated ShuffleWritePartition partitions = 2; -} - -message ShuffleWritePartition { - uint64 partition_id = 1; - string path = 2; - uint64 num_batches = 3; - uint64 num_rows = 4; - uint64 num_bytes = 5; -} - -message TaskStatus { - PartitionId partition_id = 1; - oneof status { - RunningTask running = 2; - FailedTask failed = 3; - CompletedTask completed = 4; - } -} - message TaskDefinition { PartitionId task_id = 1; PhysicalPlanNode plan = 2; @@ -770,7 +725,6 @@ message TaskDefinition { PhysicalRepartition output_partitioning = 3; } - /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -853,40 +807,8 @@ message Union { UnionMode union_mode = 2; } -message ScalarListValue { - ScalarType datatype = 1; - repeated ScalarValue values = 2; -} - -message ScalarDecimalValue { - Decimal decimal = 1; - int64 long_value = 2; // datafusion has i128 decimal value, only use i64 for blaze -} - message ScalarValue { - oneof value { - bool bool_value = 1; - string utf8_value = 2; - string large_utf8_value = 3; - int32 int8_value = 4; - int32 int16_value = 5; - int32 int32_value = 6; - int64 int64_value = 7; - uint32 uint8_value = 8; - uint32 uint16_value = 9; - uint32 uint32_value = 10; - uint64 uint64_value = 11; - float float32_value = 12; - double float64_value = 13; - int32 date32_value = 14; - int64 timestamp_second_value = 15; - int64 timestamp_millisecond_value = 16; - int64 timestamp_microsecond_value = 17; - int64 timestamp_nanosecond_value = 18; - ScalarListValue list_value = 19; - ScalarDecimalValue decimal_value = 20; - ScalarType null_value = 1000; - } + bytes ipc_bytes = 1; } // Contains all valid datafusion scalar type except for @@ -917,17 +839,6 @@ enum PrimitiveScalarType { INTERVAL_DAYTIME = 22; } -message ScalarListType { - ScalarType element_type = 1; -} - -message ScalarType { - oneof datatype { - PrimitiveScalarType scalar = 1; - ScalarListType list = 2; - } -} - // Broke out into multiple message types so that type // metadata did not need to be in separate message //All types that are of the empty message types contain no additional metadata diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index afe477ed0..a50866ab3 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs @@ -827,7 +827,7 @@ fn try_parse_physical_expr( let pexpr: Arc = match expr_type { ExprType::Column(c) => Arc::new(Column::new(&c.name, input_schema.index_of(&c.name)?)), - ExprType::Literal(scalar) => Arc::new(Literal::new(convert_required!(scalar.value)?)), + ExprType::Literal(scalar) => Arc::new(Literal::new(scalar.try_into()?)), ExprType::BoundReference(bound_reference) => { let pcol: Column = bound_reference.into(); Arc::new(pcol) @@ -1134,7 +1134,11 @@ pub fn parse_protobuf_partitioning( let sort = range_part.sort_expr.clone().unwrap(); let exprs = try_parse_physical_sort_expr(&input, &sort).unwrap(); - let value_list = &range_part.list_value; + let value_list: Vec = range_part + .list_value + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?; let sort_row_converter = Arc::new(SyncMutex::new(RowConverter::new( exprs @@ -1151,30 +1155,13 @@ pub fn parse_protobuf_partitioning( let bound_cols: Vec = value_list .iter() .map(|x| { - let xx = x.clone().value.unwrap(); - let values_ref = match xx { - protobuf::scalar_value::Value::ListValue(scalar_list) => { - let protobuf::ScalarListValue { - values, - datatype: _opt_scalar_type, - } = scalar_list; - let value_vec: Vec = values - .iter() - .map(|val| val.try_into()) - .collect::, _>>() - .map_err(|_| { - proto_error("partition::from_proto() error") - })?; - ScalarValue::iter_to_array(value_vec) - .map_err(|_| proto_error("partition::from_proto() error")) - } - _ => Err(proto_error( - "partition::from_proto() bound_list type error", - )), - }; - values_ref + if let ScalarValue::List(single) = x { + return single.value(0); + } else { + unreachable!("expect list scalar value"); + } }) - .collect::, _>>()?; + .collect::>(); let bound_rows = sort_row_converter.lock().convert_columns(&bound_cols)?; Ok(Some(Partitioning::RangePartitioning( diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index ed193ebb1..20005c5e8 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs @@ -446,256 +446,11 @@ impl TryInto for &protobuf::Schema { impl TryInto for &protobuf::ScalarValue { type Error = PlanSerDeError; fn try_into(self) -> Result { - let value = self.value.as_ref().ok_or_else(|| { - proto_error("Protobuf deserialization error: missing required field 'value'") - })?; - Ok(match value { - protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)), - protobuf::scalar_value::Value::Utf8Value(v) => ScalarValue::Utf8(Some(v.to_owned())), - protobuf::scalar_value::Value::LargeUtf8Value(v) => { - ScalarValue::LargeUtf8(Some(v.to_owned())) - } - protobuf::scalar_value::Value::Int8Value(v) => ScalarValue::Int8(Some(*v as i8)), - protobuf::scalar_value::Value::Int16Value(v) => ScalarValue::Int16(Some(*v as i16)), - protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)), - protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)), - protobuf::scalar_value::Value::Uint8Value(v) => ScalarValue::UInt8(Some(*v as u8)), - protobuf::scalar_value::Value::Uint16Value(v) => ScalarValue::UInt16(Some(*v as u16)), - protobuf::scalar_value::Value::Uint32Value(v) => ScalarValue::UInt32(Some(*v)), - protobuf::scalar_value::Value::Uint64Value(v) => ScalarValue::UInt64(Some(*v)), - protobuf::scalar_value::Value::Float32Value(v) => ScalarValue::Float32(Some(*v)), - protobuf::scalar_value::Value::Float64Value(v) => ScalarValue::Float64(Some(*v)), - protobuf::scalar_value::Value::Date32Value(v) => ScalarValue::Date32(Some(*v)), - protobuf::scalar_value::Value::TimestampSecondValue(v) => { - ScalarValue::TimestampSecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampMillisecondValue(v) => { - ScalarValue::TimestampMillisecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampMicrosecondValue(v) => { - ScalarValue::TimestampMicrosecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampNanosecondValue(v) => { - ScalarValue::TimestampNanosecond(Some(*v), None) - } - protobuf::scalar_value::Value::DecimalValue(v) => { - let decimal = v.decimal.as_ref().unwrap(); - ScalarValue::Decimal128( - Some(v.long_value as i128), - decimal.whole as u8, - decimal.fractional as i8, - ) - } - protobuf::scalar_value::Value::ListValue(scalar_list) => { - let protobuf::ScalarListValue { - values, - datatype: opt_scalar_type, - } = &scalar_list; - let pb_scalar_type = opt_scalar_type - .as_ref() - .ok_or_else(|| proto_error("Protobuf deserialization err: ScalaListValue missing required field 'datatype'"))?; - let typechecked_values: Vec = values - .iter() - .map(|val| val.try_into()) - .collect::, _>>()?; - let scalar_type: DataType = pb_scalar_type.try_into()?; - ScalarValue::List(ScalarValue::new_list( - &typechecked_values, - &scalar_type, - true, - )) - } - protobuf::scalar_value::Value::NullValue(v) => { - match v.datatype.as_ref().expect("missing scalar data type") { - protobuf::scalar_type::Datatype::Scalar(scalar) => { - let null_type_enum = protobuf::PrimitiveScalarType::try_from(*scalar) - .expect("invalid PrimitiveScalarType"); - null_type_enum.try_into()? - } - protobuf::scalar_type::Datatype::List(list) => { - let pb_scalar_type = list - .element_type - .as_ref() - .expect("missing list element type"); - let scalar_type: DataType = pb_scalar_type.as_ref().try_into()?; - ScalarValue::try_from(DataType::new_list(scalar_type, true))? - } - } - } - }) - } -} - -impl TryInto for &protobuf::ScalarType { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - let pb_scalartype = self.datatype.as_ref().expect("missing data type"); - pb_scalartype.try_into() - } -} - -impl TryInto for &protobuf::scalar_type::Datatype { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - use protobuf::scalar_type::Datatype; - Ok(match self { - Datatype::Scalar(scalar_type) => { - let pb_scalar_enum = protobuf::PrimitiveScalarType::try_from(*scalar_type) - .expect("invalid PrimitiveScalarType"); - pb_scalar_enum.into() - } - Datatype::List(list_type) => { - let element_scalar_type: DataType = list_type - .element_type - .as_ref() - .expect("missing element type") - .as_ref() - .try_into()?; - DataType::new_list(element_scalar_type, true) - } - }) - } -} - -impl TryInto for &protobuf::scalar_value::Value { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - use protobuf::PrimitiveScalarType; - let scalar = match self { - protobuf::scalar_value::Value::BoolValue(v) => ScalarValue::Boolean(Some(*v)), - protobuf::scalar_value::Value::Utf8Value(v) => ScalarValue::Utf8(Some(v.to_owned())), - protobuf::scalar_value::Value::LargeUtf8Value(v) => { - ScalarValue::LargeUtf8(Some(v.to_owned())) - } - protobuf::scalar_value::Value::Int8Value(v) => ScalarValue::Int8(Some(*v as i8)), - protobuf::scalar_value::Value::Int16Value(v) => ScalarValue::Int16(Some(*v as i16)), - protobuf::scalar_value::Value::Int32Value(v) => ScalarValue::Int32(Some(*v)), - protobuf::scalar_value::Value::Int64Value(v) => ScalarValue::Int64(Some(*v)), - protobuf::scalar_value::Value::Uint8Value(v) => ScalarValue::UInt8(Some(*v as u8)), - protobuf::scalar_value::Value::Uint16Value(v) => ScalarValue::UInt16(Some(*v as u16)), - protobuf::scalar_value::Value::Uint32Value(v) => ScalarValue::UInt32(Some(*v)), - protobuf::scalar_value::Value::Uint64Value(v) => ScalarValue::UInt64(Some(*v)), - protobuf::scalar_value::Value::Float32Value(v) => ScalarValue::Float32(Some(*v)), - protobuf::scalar_value::Value::Float64Value(v) => ScalarValue::Float64(Some(*v)), - protobuf::scalar_value::Value::Date32Value(v) => ScalarValue::Date32(Some(*v)), - protobuf::scalar_value::Value::TimestampSecondValue(v) => { - ScalarValue::TimestampSecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampMillisecondValue(v) => { - ScalarValue::TimestampMillisecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampMicrosecondValue(v) => { - ScalarValue::TimestampMicrosecond(Some(*v), None) - } - protobuf::scalar_value::Value::TimestampNanosecondValue(v) => { - ScalarValue::TimestampNanosecond(Some(*v), None) - } - protobuf::scalar_value::Value::ListValue(v) => v.try_into()?, - protobuf::scalar_value::Value::NullValue(v) => { - match v.datatype.as_ref().expect("missing null value type") { - protobuf::scalar_type::Datatype::Scalar(scalar) => { - PrimitiveScalarType::try_from(*scalar) - .expect("invalid PrimitiveScalarType") - .try_into()? - } - protobuf::scalar_type::Datatype::List(list) => { - let element_scalar_type: DataType = list - .element_type - .as_ref() - .expect("missing list element type") - .as_ref() - .try_into()?; - ScalarValue::try_from(DataType::new_list(element_scalar_type, true))? - } - } - } - protobuf::scalar_value::Value::DecimalValue(v) => { - let decimal = v.decimal.as_ref().unwrap(); - ScalarValue::Decimal128( - Some(v.long_value as i128), - decimal.whole as u8, - decimal.fractional as i8, - ) - } - }; + let mut ipc_reader = + arrow::ipc::reader::StreamReader::try_new(&self.ipc_bytes[..], None) + .map_err(|e| proto_error(format!("error deserializing arrow stream: {e}")))?; + let batch = ipc_reader.next().expect("missing record batch")?; + let scalar = ScalarValue::try_from_array(batch.column(0), 0)?; Ok(scalar) } } - -impl TryInto for &protobuf::ScalarListValue { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - let element_scalar_type: DataType = self - .datatype - .as_ref() - .expect("missing list data type") - .try_into()?; - let values: Vec = self - .values - .iter() - .map(|value| Ok(value.try_into()?)) - .collect::>()?; - Ok(ScalarValue::List(ScalarValue::new_list( - &values, - &element_scalar_type, - true, - ))) - } -} - -impl TryInto for &protobuf::ScalarListType { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - let element_scalar_type: DataType = self - .element_type - .as_ref() - .expect("missing list element type") - .as_ref() - .try_into()?; - Ok(DataType::new_list(element_scalar_type, true)) - } -} - -impl TryInto for protobuf::PrimitiveScalarType { - type Error = PlanSerDeError; - fn try_into(self) -> Result { - Ok(match self { - // protobuf::PrimitiveScalarType::Null => { - // return Err(proto_error("Untyped null is an invalid scalar value")) - // } - protobuf::PrimitiveScalarType::Null => ScalarValue::Null, - protobuf::PrimitiveScalarType::Bool => ScalarValue::Boolean(None), - protobuf::PrimitiveScalarType::Uint8 => ScalarValue::UInt8(None), - protobuf::PrimitiveScalarType::Int8 => ScalarValue::Int8(None), - protobuf::PrimitiveScalarType::Uint16 => ScalarValue::UInt16(None), - protobuf::PrimitiveScalarType::Int16 => ScalarValue::Int16(None), - protobuf::PrimitiveScalarType::Uint32 => ScalarValue::UInt32(None), - protobuf::PrimitiveScalarType::Int32 => ScalarValue::Int32(None), - protobuf::PrimitiveScalarType::Uint64 => ScalarValue::UInt64(None), - protobuf::PrimitiveScalarType::Int64 => ScalarValue::Int64(None), - protobuf::PrimitiveScalarType::Float32 => ScalarValue::Float32(None), - protobuf::PrimitiveScalarType::Float64 => ScalarValue::Float64(None), - protobuf::PrimitiveScalarType::Utf8 => ScalarValue::Utf8(None), - protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None), - protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None), - protobuf::PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 1, 0), - protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None), - protobuf::PrimitiveScalarType::TimestampSecond => { - ScalarValue::TimestampSecond(None, None) - } - protobuf::PrimitiveScalarType::TimestampMillisecond => { - ScalarValue::TimestampMillisecond(None, None) - } - protobuf::PrimitiveScalarType::TimestampMicrosecond => { - ScalarValue::TimestampMicrosecond(None, None) - } - protobuf::PrimitiveScalarType::TimestampNanosecond => { - ScalarValue::TimestampNanosecond(None, None) - } - protobuf::PrimitiveScalarType::IntervalYearmonth => { - ScalarValue::IntervalYearMonth(None) - } - protobuf::PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None), - }) - } -} diff --git a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs index 0a94ec509..c7106e9c1 100644 --- a/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs +++ b/native-engine/datafusion-ext-exprs/src/get_indexed_field.rs @@ -79,7 +79,10 @@ impl PhysicalExpr for GetIndexedFieldExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; + let array_value = self.arg.evaluate(batch)?; + let array_is_scalar = matches!(array_value, ColumnarValue::Scalar(_)); + let array = array_value.into_array(1)?; + match (array.data_type(), &self.key) { (DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => { let scalar_null: ScalarValue = array.data_type().try_into()?; @@ -110,13 +113,22 @@ impl PhysicalExpr for GetIndexedFieldExpr { &take_indices_builder.finish(), None, )?; + if array_is_scalar { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &taken, 0, + )?)); + } Ok(ColumnarValue::Array(taken)) } (DataType::Struct(_), ScalarValue::Int32(Some(k))) => { let as_struct_array = as_struct_array(&array)?; - Ok(ColumnarValue::Array( - as_struct_array.column(*k as usize).clone(), - )) + let taken = as_struct_array.column(*k as usize).clone(); + if array_is_scalar { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &taken, 0, + )?)); + } + Ok(ColumnarValue::Array(taken)) } (DataType::List(_), key) => df_execution_err!( "get indexed field is only possible on lists with int64 indexes. \ diff --git a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index c5bf9a131..bd507aed1 100644 --- a/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark3/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala @@ -521,10 +521,11 @@ class ShimsImpl extends Shims with Logging { case agg => convertBloomFilterAgg(agg) match { case Some(aggExpr) => - return Some(pb.PhysicalExprNode - .newBuilder() - .setAggExpr(aggExpr) - .build()) + return Some( + pb.PhysicalExprNode + .newBuilder() + .setAggExpr(aggExpr) + .build()) case None => } None diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index e421c7d09..7a81f2c50 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala @@ -26,6 +26,10 @@ import scala.language.postfixOps import scala.math.max import scala.math.min +import org.apache.arrow.c.CDataDictionaryProvider +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.commons.lang3.reflect.FieldUtils import org.apache.commons.lang3.reflect.MethodUtils @@ -33,6 +37,7 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkEnv import org.blaze.{protobuf => pb} import org.apache.spark.internal.Logging +import org.apache.spark.sql.blaze.util.Using import org.apache.spark.sql.catalyst.expressions.{Abs, Acos, Add, Alias, And, Asin, Atan, Attribute, AttributeReference, BitwiseAnd, BitwiseOr, BoundReference, CaseWhen, Cast, Ceil, CheckOverflow, Coalesce, Concat, ConcatWs, Contains, Cos, CreateArray, CreateNamedStruct, DayOfMonth, Divide, EndsWith, EqualTo, Exp, Expression, Floor, GetArrayItem, GetJsonObject, GetMapValue, GetStructField, GreaterThan, GreaterThanOrEqual, If, In, InSet, IsNotNull, IsNull, LeafExpression, Length, LessThan, LessThanOrEqual, Like, Literal, Log, Log10, Log2, Lower, MakeDecimal, Md5, Month, Multiply, Murmur3Hash, Not, NullIf, OctetLength, Or, Remainder, Sha2, ShiftLeft, ShiftRight, Signum, Sin, Sqrt, StartsWith, StringRepeat, StringSpace, StringTrim, StringTrimLeft, StringTrimRight, Substring, Subtract, Tan, TruncDate, Unevaluable, UnscaledValue, Upper, XxHash64, Year} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, ImperativeAggregate, Max, Min, Sum} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext @@ -48,10 +53,14 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.util.MapData import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.ExecSubqueryExpression import org.apache.spark.sql.execution.InSubqueryExec import org.apache.spark.sql.execution.ScalarSubquery +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowUtils.ROOT_ALLOCATOR +import org.apache.spark.sql.execution.blaze.arrowio.util.ArrowWriter import org.apache.spark.sql.hive.blaze.HiveUDFUtil import org.apache.spark.sql.hive.blaze.HiveUDFUtil.getFunctionClassName import org.apache.spark.sql.internal.SQLConf @@ -83,35 +92,6 @@ object NativeConverters extends Logging { val udfJsonEnabled: Boolean = SparkEnv.get.conf.getBoolean("spark.blaze.udf.UDFJson.enabled", defaultValue = true) - def convertScalarType(dataType: DataType): pb.ScalarType = { - val scalarTypeBuilder = dataType match { - case NullType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.NULL) - case BooleanType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.BOOL) - case ByteType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.INT8) - case ShortType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.INT16) - case IntegerType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.INT32) - case LongType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.INT64) - case FloatType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.FLOAT32) - case DoubleType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.FLOAT64) - case StringType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.UTF8) - case DateType => pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.DATE32) - case TimestampType => - pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.TIMESTAMP_MICROSECOND) - case _: DecimalType => - pb.ScalarType.newBuilder().setScalar(pb.PrimitiveScalarType.DECIMAL128) - case at: ArrayType => - pb.ScalarType - .newBuilder() - .setList( - pb.ScalarListType - .newBuilder() - .setElementType(convertScalarType(at.elementType))) - - case _ => throw new NotImplementedError(s"Value conversion not implemented ${dataType}") - } - scalarTypeBuilder.build() - } - def scalarTypeSupported(dataType: DataType): Boolean = { dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | @@ -212,50 +192,6 @@ object NativeConverters extends Logging { arrowTypeBuilder.build() } - def convertValue(sparkValue: Any, dataType: DataType): pb.ScalarValue = { - val scalarValueBuilder = pb.ScalarValue.newBuilder() - dataType match { - case _ if sparkValue == null => scalarValueBuilder.setNullValue(convertScalarType(dataType)) - case BooleanType => scalarValueBuilder.setBoolValue(sparkValue.asInstanceOf[Boolean]) - case ByteType => scalarValueBuilder.setInt8Value(sparkValue.asInstanceOf[Byte]) - case ShortType => scalarValueBuilder.setInt16Value(sparkValue.asInstanceOf[Short]) - case IntegerType => scalarValueBuilder.setInt32Value(sparkValue.asInstanceOf[Int]) - case LongType => scalarValueBuilder.setInt64Value(sparkValue.asInstanceOf[Long]) - case FloatType => scalarValueBuilder.setFloat32Value(sparkValue.asInstanceOf[Float]) - case DoubleType => scalarValueBuilder.setFloat64Value(sparkValue.asInstanceOf[Double]) - case StringType => - scalarValueBuilder.setUtf8Value(if (sparkValue != null) { - sparkValue.toString - } else { - null - }) - case DateType => scalarValueBuilder.setDate32Value(sparkValue.asInstanceOf[Int]) - case TimestampType => - scalarValueBuilder.setTimestampMicrosecondValue(sparkValue.asInstanceOf[Long]) - case t: DecimalType => - val decimalValue = sparkValue.asInstanceOf[Decimal] - val decimalType = convertDataType(t).getDECIMAL - scalarValueBuilder.setDecimalValue( - pb.ScalarDecimalValue - .newBuilder() - .setDecimal(decimalType) - .setLongValue(decimalValue.toUnscaledLong)) - - case at: ArrayType => - val values = - pb.ScalarListValue.newBuilder().setDatatype(convertScalarType(at.elementType)) - sparkValue - .asInstanceOf[ArrayData] - .foreach( - at.elementType, - (_, value) => { - values.addValues(convertValue(value, at.elementType)) - }) - scalarValueBuilder.setListValue(values) - } - scalarValueBuilder.build() - } - def convertField(sparkField: StructField): pb.Field = { pb.Field .newBuilder() @@ -432,26 +368,27 @@ object NativeConverters extends Logging { sparkExpr match { case e: NativeExprWrapperBase => e.wrapped - case Literal(value, dataType) => - buildExprNode { b => - if (value == null) { - dataType match { - case at: ArrayType => - b.setTryCast( - pb.PhysicalTryCastNode - .newBuilder() - .setArrowType(convertDataType(at)) - .setExpr(buildExprNode { - _.setLiteral( - pb.ScalarValue.newBuilder().setNullValue(convertScalarType(NullType))) - })) - case _ => - b.setLiteral( - pb.ScalarValue.newBuilder().setNullValue(convertScalarType(dataType))) - } - } else { - b.setLiteral(convertValue(value, dataType)) + case e: Literal => + val schema = StructType(Seq(StructField("", e.dataType, e.nullable))) + val row = InternalRow(e.eval(null)) + Using.resource( + VectorSchemaRoot.create(ArrowUtils.toArrowSchema(schema), ROOT_ALLOCATOR)) { root => + val arrowWriter = ArrowWriter.create(root) + arrowWriter.write(row) + arrowWriter.finish() + + val dictionaryProvider = new CDataDictionaryProvider() + val bo = new ByteArrayOutputStream() + Using(new ArrowStreamWriter(root, dictionaryProvider, bo)) { ipcWriter => + ipcWriter.start() + ipcWriter.writeBatch() + ipcWriter.end() } + val ipcBytes = bo.toByteArray + pb.PhysicalExprNode + .newBuilder() + .setLiteral(pb.ScalarValue.newBuilder().setIpcBytes(ByteString.copyFrom(ipcBytes))) + .build() } case bound: BoundReference => @@ -532,7 +469,7 @@ object NativeConverters extends Logging { Literal(utf8string, StringType), isPruningExpr, fallback) - case v => convertExprWithFallback(Literal.apply(v), isPruningExpr, fallback) + case v => convertExpr(Literal.apply(v)) }.asJava)) } @@ -755,7 +692,7 @@ object NativeConverters extends Logging { val resultType = e.dataType rhs match { case rhs: Literal if rhs == Literal.default(rhs.dataType) => - buildExprNode(_.setLiteral(convertValue(null, e.dataType))) + convertExpr(Literal(null, e.dataType)) case rhs: Literal if rhs != Literal.default(rhs.dataType) => buildBinaryExprNode(lhs, rhs, "Modulo") case rhs => @@ -911,8 +848,7 @@ object NativeConverters extends Logging { .setName("starts_with") .setFun(pb.ScalarFunction.StartsWith) .addArgs(convertExprWithFallback(expr, isPruningExpr, fallback)) - .addArgs( - convertExprWithFallback(Literal(prefix, StringType), isPruningExpr, fallback)) + .addArgs(convertExpr(Literal(prefix, StringType))) .setReturnType(convertDataType(BooleanType)))) case StartsWith(expr, Literal(prefix, StringType)) => buildExprNode( @@ -967,7 +903,7 @@ object NativeConverters extends Logging { val children = e.children.map(Cast(_, e.dataType)) buildScalarFunction(pb.ScalarFunction.Coalesce, children, e.dataType) - case e@If(predicate, trueValue, falseValue) => + case e @ If(predicate, trueValue, falseValue) => val castedTrueValue = trueValue match { case t if t.dataType != e.dataType => Cast(t, e.dataType) case t => t @@ -979,14 +915,15 @@ object NativeConverters extends Logging { val caseWhen = CaseWhen(Seq((predicate, castedTrueValue)), castedFalseValue) convertExprWithFallback(caseWhen, isPruningExpr, fallback) - case e@CaseWhen(branches, elseValue) => + case e @ CaseWhen(branches, elseValue) => val caseExpr = pb.PhysicalCaseNode.newBuilder() val whenThens = branches.map { case (w, t) => val casted = t match { case t if t.dataType != e.dataType => Cast(t, e.dataType) case t => t } - pb.PhysicalWhenThen.newBuilder() + pb.PhysicalWhenThen + .newBuilder() .setWhenExpr(convertExprWithFallback(w, isPruningExpr, fallback)) .setThenExpr(convertExprWithFallback(casted, isPruningExpr, fallback)) .build() @@ -1043,15 +980,14 @@ object NativeConverters extends Logging { .asInstanceOf[Literal] .value .isInstanceOf[Number] => + // NOTE: data-fusion index starts from 1 val ordinalValue = e.ordinal.asInstanceOf[Literal].value.asInstanceOf[Number] buildExprNode { _.setGetIndexedFieldExpr( pb.PhysicalGetIndexedFieldExprNode .newBuilder() .setExpr(convertExprWithFallback(e.child, isPruningExpr, fallback)) - .setKey(convertValue( - ordinalValue.longValue() + 1, // NOTE: data-fusion index starts from 1 - LongType))) + .setKey(convertExpr(Literal(ordinalValue.longValue() + 1, LongType)).getLiteral)) } case e: GetMapValue if e.key.isInstanceOf[Literal] => @@ -1062,7 +998,7 @@ object NativeConverters extends Logging { pb.PhysicalGetMapValueExprNode .newBuilder() .setExpr(convertExprWithFallback(e.child, isPruningExpr, fallback)) - .setKey(convertValue(value, dataType))) + .setKey(convertExpr(Literal(value, dataType)).getLiteral)) } case e: GetStructField => @@ -1071,7 +1007,7 @@ object NativeConverters extends Logging { pb.PhysicalGetIndexedFieldExprNode .newBuilder() .setExpr(convertExprWithFallback(e.child, isPruningExpr, fallback)) - .setKey(convertValue(e.ordinal, IntegerType))) + .setKey(convertExpr(Literal(e.ordinal, IntegerType)).getLiteral)) } case StubExpr("RowNum", _, _) => @@ -1342,7 +1278,7 @@ object NativeConverters extends Logging { def prepareExecSubquery(subquery: ExecSubqueryExpression): Unit = { val isCanonicalized = MethodUtils.invokeMethod(subquery.plan, true, "isCanonicalizedPlan").asInstanceOf[Boolean] - + if (!isCanonicalized) { subquery match { case e if e.getClass.getName == "org.apache.spark.sql.execution.ScalarSubquery" => diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala index b2f4b770c..7963cb1ac 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowWriter.scala @@ -65,11 +65,10 @@ object ArrowWriter { val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) case (MapType(_, _, _), vector: MapVector) => - val entryWriter = createFieldWriter(vector.getDataVector).asInstanceOf[StructWriter] - val keyWriter = createFieldWriter(entryWriter.valueVector.getChild(MapVector.KEY_NAME)) - val valueWriter = createFieldWriter( - entryWriter.valueVector.getChild(MapVector.VALUE_NAME)) - new MapWriter(vector, keyWriter, valueWriter) + val structVector = vector.getDataVector.asInstanceOf[StructVector] + val keyWriter = createFieldWriter(structVector.getChild(MapVector.KEY_NAME)) + val valueWriter = createFieldWriter(structVector.getChild(MapVector.VALUE_NAME)) + new MapWriter(vector, structVector, keyWriter, valueWriter) case (StructType(_), vector: StructVector) => val children = (0 until vector.size()).map { ordinal => createFieldWriter(vector.getChildByOrdinal(ordinal)) @@ -357,6 +356,7 @@ private[sql] class StructWriter(val valueVector: StructVector, children: Array[A private[sql] class MapWriter( val valueVector: MapVector, + val structVector: StructVector, val keyWriter: ArrowFieldWriter, val valueWriter: ArrowFieldWriter) extends ArrowFieldWriter { @@ -370,6 +370,7 @@ private[sql] class MapWriter( val values = map.valueArray() var i = 0 while (i < map.numElements()) { + structVector.setIndexDefined(keyWriter.count) keyWriter.write(keys, i) valueWriter.write(values, i) i += 1 diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFileSourceScanBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFileSourceScanBase.scala index 7c4e0adf0..0a5d43992 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFileSourceScanBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeFileSourceScanBase.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeSupports import org.apache.spark.sql.blaze.Shims import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.LeafExecNode @@ -91,9 +92,9 @@ abstract class NativeFileSourceScanBase(basedFileScan: FileSourceScanExec) // list input file statuses val nativePartitionedFile = (file: PartitionedFile) => { val nativePartitionValues = partitionSchema.zipWithIndex.map { case (field, index) => - NativeConverters.convertValue( - file.partitionValues.get(index, field.dataType), - field.dataType) + NativeConverters + .convertExpr(Literal(file.partitionValues.get(index, field.dataType), field.dataType)) + .getLiteral } pb.PartitionedFile .newBuilder() diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeBase.scala index ddb7390ef..ff460b5a3 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffleExchangeBase.scala @@ -16,7 +16,9 @@ package org.apache.spark.sql.execution.blaze.plan import java.util.UUID + import scala.collection.JavaConverters._ + import org.apache.spark.{OneToOneDependency, Partitioner, RangePartitioner, ShuffleDependency, SparkEnv, TaskContext} import org.blaze.protobuf.{IpcReaderExecNode, PhysicalExprNode, PhysicalHashRepartition, PhysicalPlanNode, PhysicalRangeRepartition, PhysicalRepartition, PhysicalRoundRobinRepartition, PhysicalSingleRepartition, PhysicalSortExprNode, Schema, SortExecNode} import org.apache.spark.rdd.{PartitionPruningRDD, RDD} @@ -34,7 +36,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, BoundReference, NullsFirst, UnsafeProjection} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} -import org.apache.spark.sql.execution.{SQLExecution, SparkPlan, UnsafeRowSerializer} +import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnsafeRowSerializer} import org.apache.spark.sql.execution.blaze.shuffle.BlazeBlockStoreShuffleReaderBase import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency import org.apache.spark.util.{CompletionIterator, MutablePair} @@ -42,12 +44,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ArrayType - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.hashing.byteswap32 +import org.apache.spark.sql.catalyst.expressions.Literal + abstract class NativeShuffleExchangeBase( override val outputPartitioning: Partitioning, override val child: SparkPlan) @@ -231,7 +234,7 @@ abstract class NativeShuffleExchangeBase( internal_row.get(index, field.dataType) } val arrayData = ArrayData.toArrayData(valueList) - NativeConverters.convertValue(arrayData, ArrayType(field.dataType)) + NativeConverters.convertExpr(Literal(arrayData, ArrayType(field.dataType))).getLiteral }.toList case _ => null diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativeHiveTableScanBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativeHiveTableScanBase.scala index af0498e06..44eb9e9ad 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativeHiveTableScanBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/hive/execution/blaze/plan/NativeHiveTableScanBase.scala @@ -39,10 +39,12 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration import org.blaze.{protobuf => pb} - import java.net.URI import java.security.PrivilegedExceptionAction +import org.apache.spark.sql.catalyst.expressions.Literal +import org.blaze.sparkver + abstract class NativeHiveTableScanBase(basedHiveScan: HiveTableScanExec) extends LeafExecNode with NativeSupports { @@ -81,9 +83,9 @@ abstract class NativeHiveTableScanBase(basedHiveScan: HiveTableScanExec) // list input file statuses val nativePartitionedFile = (file: PartitionedFile) => { val nativePartitionValues = partitionSchema.zipWithIndex.map { case (field, index) => - NativeConverters.convertValue( - file.partitionValues.get(index, field.dataType), - field.dataType) + NativeConverters + .convertExpr(Literal(file.partitionValues.get(index, field.dataType), field.dataType)) + .getLiteral } pb.PartitionedFile .newBuilder() @@ -140,6 +142,7 @@ abstract class NativeHiveTableScanBase(basedHiveScan: HiveTableScanExec) override protected def doCanonicalize(): SparkPlan = basedHiveScan.canonicalized + @sparkver("3.0 / 3.1 / 3.2 / 3.3 / 3.4 / 3.5") override def simpleString(maxFields: Int): String = s"$nodeName (${basedHiveScan.simpleString(maxFields)})" }