Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/acc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ pub trait AccColumn: Send {
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()>;
fn spill(&self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()>;
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()>;

fn ensure_size(&mut self, idx: IdxSelection<'_>) {
let idx_max_value = match idx {
IdxSelection::Single(v) => v,
IdxSelection::Indices(v) => v.iter().copied().max().unwrap_or(0),
IdxSelection::IndicesU32(v) => v.iter().copied().max().unwrap_or(0) as usize,
IdxSelection::Range(_begin, end) => end,
};
if idx_max_value >= self.num_records() {
self.resize(idx_max_value + 1);
}
}
}

pub type AccColumnRef = Box<dyn AccColumn>;
Expand Down
6 changes: 2 additions & 4 deletions native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,10 @@ impl Agg for AggBloomFilter {
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccBloomFilterColumn).unwrap();
accs.ensure_size(acc_idx);

let bloom_filter = match acc_idx {
IdxSelection::Single(idx) => {
if idx >= accs.num_records() {
accs.resize(idx + 1);
}

let bf = &mut accs.bloom_filters[idx];
if bf.is_none() {
*bf = Some(SparkBloomFilter::new_with_expected_num_items(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ impl Agg for AggCollect {
partial_args: &[ArrayRef],
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
accs.ensure_size(acc_idx);
let list = partial_args[0].as_list::<i32>();

idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if list.is_valid(partial_arg_idx) {
let values = list.value(partial_arg_idx);
let values_len = values.len();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,11 @@ impl Agg for AggCombineUnique {
partial_args: &[ArrayRef],
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
accs.ensure_size(acc_idx);
let list = partial_args[0].as_list::<i32>();

idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if list.is_valid(partial_arg_idx) {
let values = list.value(partial_arg_idx);
let values_len = values.len();
Expand Down
10 changes: 4 additions & 6 deletions native-engine/datafusion-ext-plans/src/agg/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ impl<C: AccCollectionColumn> Agg for AggGenericCollect<C> {
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut C).unwrap();
accs.ensure_size(acc_idx);

idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
let scalar = ScalarValue::try_from_array(&partial_args[0], partial_arg_idx)?;
if !scalar.is_null() {
accs.append_item(acc_idx, &scalar);
Expand All @@ -138,12 +137,11 @@ impl<C: AccCollectionColumn> Agg for AggGenericCollect<C> {
merging_acc_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut C).unwrap();
accs.ensure_size(acc_idx);

let merging_accs = downcast_any!(merging_accs, mut C).unwrap();
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
accs.merge_items(acc_idx, merging_accs, merging_acc_idx);
}
}
Expand Down
25 changes: 4 additions & 21 deletions native-engine/datafusion-ext-plans/src/agg/first.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ impl Agg for AggFirst {
) -> Result<()> {
let partial_arg = &partial_args[0];
let accs = downcast_any!(accs, mut AccFirstColumn).unwrap();
accs.ensure_size(acc_idx);

let old_heap_mem_used = accs.values.items_heap_mem_used(acc_idx);

macro_rules! handle_bytes {
Expand All @@ -101,9 +103,6 @@ impl Agg for AggFirst {
let partial_arg = downcast_any!(partial_arg, TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !accs.flags.prim_valid(acc_idx) {
accs.flags.set_prim_valid(acc_idx, true);
if partial_arg.is_valid(partial_arg_idx) {
Expand All @@ -121,9 +120,6 @@ impl Agg for AggFirst {
partial_arg => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !accs.flags.prim_valid(acc_idx) {
accs.flags.set_prim_valid(acc_idx, true);
accs.values.set_prim_valid(acc_idx, partial_arg.is_valid(partial_arg_idx));
Expand All @@ -137,9 +133,6 @@ impl Agg for AggFirst {
_other => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if accs.flags.prim_valid(acc_idx) {
accs.flags.set_prim_valid(acc_idx, true);
accs.values.scalar_values_mut()[acc_idx] = ScalarValue::try_from_array(partial_arg, partial_arg_idx)?;
Expand All @@ -164,6 +157,8 @@ impl Agg for AggFirst {
) -> Result<()> {
let accs = downcast_any!(accs, mut AccFirstColumn).unwrap();
let merging_accs = downcast_any!(merging_accs, mut AccFirstColumn).unwrap();
accs.ensure_size(acc_idx);

let old_heap_mem_used = accs.values.items_heap_mem_used(acc_idx);

// safety: bypass borrow checker
Expand All @@ -184,10 +179,6 @@ impl Agg for AggFirst {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
let acc_offset = *prim_size * acc_idx;
let merging_acc_offset = *prim_size * merging_acc_idx;
Expand All @@ -207,10 +198,6 @@ impl Agg for AggFirst {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
let item = &mut items[acc_idx];
let mut other_item = &mut other_items[merging_acc_idx];
Expand All @@ -228,10 +215,6 @@ impl Agg for AggFirst {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
let item = & mut items[acc_idx];
let mut other_item = & mut other_items[merging_acc_idx];
Expand Down
25 changes: 4 additions & 21 deletions native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ impl Agg for AggFirstIgnoresNull {
) -> Result<()> {
let partial_arg = &partial_args[0];
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);

macro_rules! handle_bytes {
Expand All @@ -97,9 +99,6 @@ impl Agg for AggFirstIgnoresNull {
let partial_arg = downcast_any!(partial_arg, TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if accs.bytes_value(acc_idx).is_none() && partial_arg.is_valid(partial_arg_idx) {
accs.set_bytes_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref())));
}
Expand All @@ -112,9 +111,6 @@ impl Agg for AggFirstIgnoresNull {
partial_arg => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !accs.prim_valid(acc_idx) && partial_arg.is_valid(partial_arg_idx) {
accs.set_prim_valid(acc_idx, true);
accs.set_prim_value(acc_idx, partial_arg.value(partial_arg_idx));
Expand All @@ -127,9 +123,6 @@ impl Agg for AggFirstIgnoresNull {
_other => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if accs.scalar_values()[acc_idx].is_null() && partial_arg.is_valid(partial_arg_idx) {
accs.scalar_values_mut()[acc_idx] = ScalarValue::try_from_array(partial_arg, partial_arg_idx)?;
}
Expand All @@ -151,6 +144,8 @@ impl Agg for AggFirstIgnoresNull {
merging_acc_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

let mut merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);

Expand All @@ -173,10 +168,6 @@ impl Agg for AggFirstIgnoresNull {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

if !valids[acc_idx] && other_valids[merging_acc_idx] {
valids.set(acc_idx, true);
let acc_offset = *prim_size * acc_idx;
Expand All @@ -195,10 +186,6 @@ impl Agg for AggFirstIgnoresNull {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

let item = &mut items[acc_idx];
let mut other_item = &mut other_items[merging_acc_idx];
if item.is_none() && other_item.is_some() {
Expand All @@ -215,10 +202,6 @@ impl Agg for AggFirstIgnoresNull {
) => {
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}

let item = &mut items[acc_idx];
let mut other_item = &mut other_items[merging_acc_idx];
if item.is_null() && !other_item.is_null() {
Expand Down
16 changes: 4 additions & 12 deletions native-engine/datafusion-ext-plans/src/agg/maxmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);

macro_rules! handle_prim {
Expand All @@ -103,9 +105,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !partial_arg.is_valid(partial_arg_idx) {
continue;
}
Expand All @@ -130,9 +129,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !partial_arg.is_valid(partial_arg_idx) {
continue;
}
Expand Down Expand Up @@ -177,9 +173,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
_ => {
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
let partial_arg_scalar = ScalarValue::try_from_array(&partial_args[0], partial_arg_idx)?;
if !partial_arg_scalar.is_null() {
let acc_scalar = &mut accs.scalar_values_mut()[acc_idx];
Expand Down Expand Up @@ -210,16 +203,15 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
merging_acc_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

let merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
let old_mem_used = accs.items_heap_mem_used(acc_idx);

macro_rules! handle_prim {
($ty:ty) => {{
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if !merging_accs.prim_valid(merging_acc_idx) {
continue;
}
Expand Down
8 changes: 2 additions & 6 deletions native-engine/datafusion-ext-plans/src/agg/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ impl Agg for AggSum {
partial_arg_idx: IdxSelection<'_>,
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

macro_rules! handle {
($ty:ident) => {{
Expand All @@ -102,9 +103,6 @@ impl Agg for AggSum {
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
idx_for_zipped! {
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if partial_arg.is_valid(partial_arg_idx) {
let partial_value = partial_arg.value(partial_arg_idx);
if !accs.prim_valid(acc_idx) {
Expand Down Expand Up @@ -145,14 +143,12 @@ impl Agg for AggSum {
) -> Result<()> {
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
let merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
accs.ensure_size(acc_idx);

macro_rules! handle {
($ty:ty) => {{
idx_for_zipped! {
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
if acc_idx >= accs.num_records() {
accs.resize(acc_idx + 1);
}
if merging_accs.prim_valid(merging_acc_idx) {
let merging_value = merging_accs.prim_value::<$ty>(merging_acc_idx);
if !accs.prim_valid(acc_idx) {
Expand Down
Loading