Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 9 additions & 15 deletions native-engine/datafusion-ext-commons/src/io/scalar_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ use std::{
sync::Arc,
};

use arrow::{
array::{AsArray, StructArray},
datatypes::*,
};
use arrow::{array::AsArray, datatypes::*};
use datafusion::{common::Result, parquet::data_type::AsBytes, scalar::ScalarValue};

use crate::{
Expand Down Expand Up @@ -89,9 +86,7 @@ pub fn write_scalar<W: Write>(value: &ScalarValue, nullable: bool, output: &mut
write_array(v.as_ref(), output, &mut TransposeOpt::Disabled)?;
}
ScalarValue::Struct(v) => {
for col in v.columns() {
write_array(col, output, &mut TransposeOpt::Disabled)?;
}
write_array(v.as_ref(), output, &mut TransposeOpt::Disabled)?;
}
ScalarValue::Map(v) => {
write_array(v.as_ref(), output, &mut TransposeOpt::Disabled)?;
Expand Down Expand Up @@ -184,15 +179,14 @@ pub fn read_scalar<R: Read>(
.clone();
ScalarValue::List(Arc::new(list))
}
DataType::Struct(fields) => {
let columns = fields
.iter()
.map(|field| read_array(input, field.data_type(), 1, &mut TransposeOpt::Disabled))
.collect::<Result<Vec<_>>>()?;
ScalarValue::Struct(Arc::new(StructArray::new(fields.clone(), columns, None)))
DataType::Struct(_) => {
let struct_ = read_array(input, data_type, 1, &mut TransposeOpt::Disabled)?
.as_struct()
.clone();
ScalarValue::Struct(Arc::new(struct_))
}
DataType::Map(field, _bool) => {
let map = read_array(input, field.data_type(), 1, &mut TransposeOpt::Disabled)?
DataType::Map(_, _bool) => {
let map = read_array(input, data_type, 1, &mut TransposeOpt::Disabled)?
.as_map()
.clone();
ScalarValue::Map(Arc::new(map))
Expand Down
51 changes: 16 additions & 35 deletions native-engine/datafusion-ext-exprs/src/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ impl PhysicalExpr for GetIndexedFieldExpr {

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
let data_type = self.arg.data_type(input_schema)?;
let field = get_indexed_field(input_schema, &self.arg, &data_type, &self.key)?;
let field = get_indexed_field(&data_type, &self.key)?;
Ok(field.data_type().clone())
}

fn nullable(&self, input_schema: &Schema) -> Result<bool> {
let data_type = self.arg.data_type(input_schema)?;
let nullable = self.arg.nullable(input_schema)?;
let field = get_indexed_field(input_schema, &self.arg, &data_type, &self.key)?;
let field = get_indexed_field(&data_type, &self.key)?;
Ok(nullable || field.is_nullable())
}

Expand Down Expand Up @@ -133,18 +133,10 @@ impl PhysicalExpr for GetIndexedFieldExpr {
}
Ok(ColumnarValue::Array(taken))
}
(DataType::List(_), key) => df_execution_err!(
"get indexed field is only possible on lists with int64 indexes. \
Tried with {key:?} index"
),
(DataType::Struct(_), key) => df_execution_err!(
"get indexed field is only possible on struct with int32 indexes. \
Tried with {key:?} index"
),
(dt, key) => df_execution_err!(
"get indexed field is only possible on lists with int64 indexes or struct \
with utf8 indexes. Tried {dt:?} with {key:?} index"
),
(dt, key) => {
let key_dt = key.data_type();
df_execution_err!("unsupported data types for GetIndexedField: ({dt}, {key_dt})")
}
}
}

Expand Down Expand Up @@ -177,35 +169,24 @@ impl PartialEq<dyn Any> for GetIndexedFieldExpr {
}
}

fn get_indexed_field(
input_schema: &Schema,
arg: &Arc<dyn PhysicalExpr>,
data_type: &DataType,
key: &ScalarValue,
) -> Result<Field> {
fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Arc<Field>> {
match (data_type, key) {
(DataType::List(lt), ScalarValue::Int64(Some(i))) => {
Ok(Field::new(i.to_string(), lt.data_type().clone(), true))
}
(DataType::List(lt), ScalarValue::Int64(Some(i))) => Ok(Arc::new(Field::new(
i.to_string(),
lt.data_type().clone(),
true,
))),
(DataType::Struct(fields), ScalarValue::Int32(Some(k))) => {
let field: Option<&Arc<Field>> = fields.get(*k as usize);
match field {
None => df_execution_err!("Field {k} not found in struct"),
Some(f) => Ok(f
.as_ref()
.clone()
.with_nullable(arg.nullable(input_schema)?)),
Some(f) => Ok(f.clone()),
}
}
(DataType::Struct(_), _) => {
df_execution_err!("Only ints are valid as an indexed field in a struct",)
}
(DataType::List(_), _) => {
df_execution_err!("Only ints are valid as an indexed field in a list",)
(dt, key) => {
let key_dt = key.data_type();
df_execution_err!("unsupported data types for GetIndexedField: ({dt}, {key_dt})")
}
_ => df_execution_err!(
"The expression to get an indexed field is only valid for List or Struct types",
),
}
}

Expand Down
15 changes: 9 additions & 6 deletions native-engine/datafusion-ext-plans/src/agg/acc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,22 +574,25 @@ impl AccColumn for AccBytesColumn {
pub struct AccScalarValueColumn {
items: Vec<ScalarValue>,
dt: DataType,
null_value: ScalarValue,
heap_mem_used: usize,
}

impl AccScalarValueColumn {
pub fn new(dt: &DataType, num_rows: usize) -> Self {
let null_value = ScalarValue::try_from(dt).expect("unsupported data type");
Self {
items: vec![ScalarValue::Null; num_rows],
items: (0..num_rows).map(|_| null_value.clone()).collect(),
dt: dt.clone(),
null_value,
heap_mem_used: 0,
}
}

pub fn to_array(&mut self, _dt: &DataType, idx: IdxSelection<'_>) -> Result<ArrayRef> {
idx_with_iter!((idx @ idx) => {
ScalarValue::iter_to_array(idx.map(|i| {
std::mem::replace(&mut self.items[i], ScalarValue::Null)
std::mem::replace(&mut self.items[i], self.null_value.clone())
}))
})
}
Expand All @@ -600,7 +603,7 @@ impl AccScalarValueColumn {

pub fn take_value(&mut self, idx: usize) -> ScalarValue {
self.heap_mem_used -= scalar_value_heap_mem_size(&self.items[idx]);
std::mem::replace(&mut self.items[idx], ScalarValue::Null)
std::mem::replace(&mut self.items[idx], self.null_value.clone())
}

pub fn set_value(&mut self, idx: usize, value: ScalarValue) {
Expand All @@ -621,7 +624,7 @@ impl AccColumn for AccScalarValueColumn {

fn resize(&mut self, len: usize) {
if len > self.items.len() {
self.items.resize(len, ScalarValue::Null);
self.items.resize_with(len, || self.null_value.clone());
} else {
for idx in len..self.items.len() {
self.heap_mem_used -= scalar_value_heap_mem_size(&self.items[idx]);
Expand Down Expand Up @@ -652,7 +655,7 @@ impl AccColumn for AccScalarValueColumn {
}

fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
self.items.resize(0, ScalarValue::Null);
self.items.truncate(0);
self.heap_mem_used = 0;

for cursor in cursors {
Expand All @@ -673,7 +676,7 @@ impl AccColumn for AccScalarValueColumn {
}

fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
self.items.resize(0, ScalarValue::Null);
self.items.truncate(0);
self.heap_mem_used = 0;

for _ in 0..num_rows {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,10 @@ object NativeConverters extends Logging {
})
aggBuilder.addChildren(convertExpr(child))

case CollectList(child, _, _) if child.dataType.isInstanceOf[AtomicType] =>
case CollectList(child, _, _) =>
aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST)
aggBuilder.addChildren(convertExpr(child))
case CollectSet(child, _, _) if child.dataType.isInstanceOf[AtomicType] =>
case CollectSet(child, _, _) =>
aggBuilder.setAggFunction(pb.AggFunction.COLLECT_SET)
aggBuilder.addChildren(convertExpr(child))

Expand Down