Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn get_array_data_mem_size(array_data: &ArrayData) -> usize {
mem_size += size_of::<Option<Buffer>>();
mem_size += array_data
.nulls()
.map(|nb| nb.buffer().len())
.map(|nb| nb.buffer().len().max(nb.buffer().capacity()))
.unwrap_or_default();

// summing child data size
Expand Down
7 changes: 4 additions & 3 deletions native-engine/datafusion-ext-plans/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ use arrow::{
datatypes::{DataType, SchemaRef},
};
use datafusion::{common::Result, physical_expr::PhysicalExprRef};
use stream_cursor::StreamCursor;

use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor};
use crate::joins::join_utils::JoinType;

pub mod join_hash_map;
pub mod join_utils;
pub mod stream_cursor;

// join implementations
pub mod bhj;
pub mod join_hash_map;
pub mod smj;
pub mod stream_cursor;
mod test;

#[derive(Debug, Clone)]
Expand Down
108 changes: 57 additions & 51 deletions native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ use std::{cmp::Ordering, pin::Pin, sync::Arc};
use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions};
use async_trait::async_trait;
use datafusion::common::Result;
use datafusion_ext_commons::{
arrow::selection::create_batch_interleaver, suggested_batch_mem_size,
};
use datafusion_ext_commons::arrow::selection::create_batch_interleaver;

use crate::{
common::execution_context::WrappedRecordBatchSender,
Expand Down Expand Up @@ -47,21 +45,8 @@ impl ExistenceJoiner {
}
}

fn should_flush(&self, curs: &StreamCursors) -> bool {
if self.indices.len() >= self.join_params.batch_size {
return true;
}

if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 5
|| curs.0.mem_size() + curs.1.mem_size() > suggested_batch_mem_size()
{
if let Some(first_idx) = self.indices.first() {
if first_idx.0 < curs.0.cur_idx.0 {
return true;
}
}
}
false
fn should_flush(&self) -> bool {
self.indices.len() >= self.join_params.batch_size
}

async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
Expand Down Expand Up @@ -91,53 +76,75 @@ impl ExistenceJoiner {
impl Joiner for ExistenceJoiner {
async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
while !curs.0.finished && !curs.1.finished {
let mut lidx = curs.0.cur_idx;
let mut ridx = curs.1.cur_idx;
if self.should_flush()
|| curs.0.num_buffered_batches() > 1
|| curs.1.num_buffered_batches() > 1
{
self.as_mut().flush(curs).await?;
curs.0.clean_out_dated_batches();
curs.1.clean_out_dated_batches();
}

match compare_cursor!(curs) {
Ordering::Less => {
self.indices.push(curs.0.cur_idx);
self.exists.push(false);
cur_forward!(curs.0);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
curs.0
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx));
}
Ordering::Greater => {
cur_forward!(curs.1);
curs.1
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx));
}
Ordering::Equal => {
loop {
self.indices.push(lidx);
self.exists.push(true);
cur_forward!(curs.0);
if self.should_flush(curs) {
self.as_mut().flush(curs).await?;
}
curs.0
.set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx));
let l_key_idx = curs.0.cur_idx;
let r_key_idx = curs.1.cur_idx;

self.indices.push(curs.0.cur_idx);
self.exists.push(true);
cur_forward!(curs.0);
cur_forward!(curs.1);

if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) {
lidx = curs.0.cur_idx;
continue;
// iterate both stream, find smaller one, use it for probing
let mut l_equal = true;
let mut r_equal = true;
while l_equal && r_equal {
if l_equal {
l_equal = !curs.0.finished && curs.0.cur_key() == curs.0.key(l_key_idx);
if l_equal {
self.indices.push(curs.0.cur_idx);
self.exists.push(true);
cur_forward!(curs.0);
}
}
if r_equal {
r_equal = !curs.1.finished && curs.1.cur_key() == curs.1.key(r_key_idx);
if r_equal {
cur_forward!(curs.1);
}
}
break;
}

// skip all right equal rows
loop {
cur_forward!(curs.1);
curs.1.set_min_reserved_idx(ridx);
if l_equal {
// stream left side
while !curs.0.finished && curs.0.cur_key() == curs.1.key(r_key_idx) {
self.indices.push(curs.0.cur_idx);
self.exists.push(true);
cur_forward!(curs.0);
if self.should_flush() || curs.0.num_buffered_batches() > 1 {
self.as_mut().flush(curs).await?;
curs.0.clean_out_dated_batches();
}
}
}

if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) {
ridx = curs.1.cur_idx;
continue;
if r_equal {
// stream right side
while !curs.1.finished && curs.1.cur_key() == curs.0.key(l_key_idx) {
cur_forward!(curs.1);
if self.should_flush() || curs.1.num_buffered_batches() > 1 {
self.as_mut().flush(curs).await?;
curs.1.clean_out_dated_batches();
}
}
break;
}
}
}
Expand All @@ -147,11 +154,10 @@ impl Joiner for ExistenceJoiner {
self.indices.push(curs.0.cur_idx);
self.exists.push(false);
cur_forward!(curs.0);
if self.should_flush(curs) {
if self.should_flush() {
self.as_mut().flush(curs).await?;
curs.0.clean_out_dated_batches();
}
curs.0
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx));
}
if !self.indices.is_empty() {
self.flush(curs).await?;
Expand Down
Loading
Loading