Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions native-engine/datafusion-ext-plans/src/generate/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ impl Generator for ExplodeArray {
let mut sub_lists = vec![];

while row_idx < state.input_array.len() && row_ids.len() < batch_size {
let sub_list = state.input_array.value(row_idx);
row_ids.resize(row_ids.len() + sub_list.len(), row_idx as i32);
pos_ids.extend(0..sub_list.len() as i32);
sub_lists.push(sub_list);
if state.input_array.is_valid(row_idx) {
let sub_list = state.input_array.value(row_idx);
row_ids.resize(row_ids.len() + sub_list.len(), row_idx as i32);
pos_ids.extend(0..sub_list.len() as i32);
sub_lists.push(sub_list);
}
row_idx += 1;
}
state.cur_row_id = row_idx;
Expand Down Expand Up @@ -147,13 +149,15 @@ impl Generator for ExplodeMap {
let mut sub_val_lists = vec![];

while row_idx < state.input_array.len() && row_ids.len() < batch_size {
let sub_struct = state.input_array.value(row_idx);
let sub_key_list = sub_struct.column(0);
let sub_val_list = sub_struct.column(1);
row_ids.resize(row_ids.len() + sub_key_list.len(), row_idx as i32);
pos_ids.extend(0..sub_key_list.len() as i32);
sub_key_lists.push(sub_key_list.clone());
sub_val_lists.push(sub_val_list.clone());
if state.input_array.is_valid(row_idx) {
let sub_struct = state.input_array.value(row_idx);
let sub_key_list = sub_struct.column(0);
let sub_val_list = sub_struct.column(1);
row_ids.resize(row_ids.len() + sub_key_list.len(), row_idx as i32);
pos_ids.extend(0..sub_key_list.len() as i32);
sub_key_lists.push(sub_key_list.clone());
sub_val_lists.push(sub_val_list.clone());
}
row_idx += 1;
}
state.cur_row_id = row_idx;
Expand Down
158 changes: 84 additions & 74 deletions native-engine/datafusion-ext-plans/src/generate_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
use std::{
any::Any,
fmt::{Debug, Formatter},
sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Arc,
},
sync::Arc,
};

use arrow::{
array::{new_empty_array, Array, ArrayRef, Int32Array, Int32Builder},
array::{new_null_array, Array, ArrayBuilder, ArrayRef, Int32Builder},
datatypes::{Field, Schema, SchemaRef},
record_batch::{RecordBatch, RecordBatchOptions},
};
Expand All @@ -37,7 +34,7 @@ use datafusion::{
PlanProperties, SendableRecordBatchStream,
},
};
use datafusion_ext_commons::arrow::cast::cast;
use datafusion_ext_commons::arrow::{cast::cast, selection::take_cols};
use futures::StreamExt;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
Expand Down Expand Up @@ -164,11 +161,18 @@ impl ExecutionPlan for GenerateExec {
) -> Result<SendableRecordBatchStream> {
let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics);
let generator = self.generator.clone();
let generator_output_schema = self.generator_output_schema.clone();
let outer = self.outer;
let child_output_cols = self.required_child_output_cols.clone();
let input = exec_ctx.execute_with_input_stats(&self.input)?;
let output =
execute_generate(input, exec_ctx.clone(), generator, outer, child_output_cols)?;
let output = execute_generate(
input,
exec_ctx.clone(),
generator,
generator_output_schema,
outer,
child_output_cols,
)?;
Ok(exec_ctx.coalesce_with_default_batch_size(output))
}

Expand All @@ -185,6 +189,7 @@ fn execute_generate(
mut input_stream: SendableRecordBatchStream,
exec_ctx: Arc<ExecutionContext>,
generator: Arc<dyn Generator>,
generator_output_schema: SchemaRef,
outer: bool,
child_output_cols: Vec<Column>,
) -> Result<SendableRecordBatchStream> {
Expand All @@ -200,7 +205,6 @@ fn execute_generate(
sender.exclude_time(exec_ctx.baseline_metrics().elapsed_compute());
let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer();
let last_child_outputs: Arc<Mutex<Option<Vec<ArrayRef>>>> = Arc::default();
let last_child_outtpus_len = AtomicUsize::new(0);

while let Some(batch) = exec_ctx
.baseline_metrics()
Expand All @@ -220,101 +224,106 @@ fn execute_generate(
.collect::<Result<Vec<_>>>()
.map_err(|err| err.context("generate: evaluating child output arrays error"))?;
last_child_outputs.lock().replace(child_outputs.clone());
last_child_outtpus_len.store(batch.num_rows(), SeqCst);

let mut generate_state = generator.eval_start(&batch)?;
let mut cur_row_id = 0;

while let Some(generated_outputs) = generator.eval_loop(&mut generate_state)? {
let capacity = generated_outputs.row_ids.len();
let mut child_output_row_ids = Int32Builder::with_capacity(capacity);
let mut generated_ids = Int32Builder::with_capacity(capacity);
while cur_row_id < batch.num_rows() {
let mut child_output_row_ids = Int32Builder::new();
let mut generated_ids = Int32Builder::new();
let end_row_id;

// build ids for joining
for i in 0..generated_outputs.row_ids.len() {
let row_id = generated_outputs.row_ids.value(i);
while cur_row_id < row_id {
macro_rules! build_ids_for_outer_generate {
($from_row_id:expr, $to_row_id:expr) => {{
if outer {
child_output_row_ids.append_value(cur_row_id);
generated_ids.append_null();
for i in $from_row_id..$to_row_id {
child_output_row_ids.append_value(i as i32);
generated_ids.append_null();
}
}
cur_row_id += 1;
}
child_output_row_ids.append_value(row_id);
generated_ids.append_value(i as i32);
cur_row_id = row_id + 1;
$from_row_id = $to_row_id;
}};
}
while cur_row_id < generate_state.cur_row_id() as i32 {
if outer {
child_output_row_ids.append_value(cur_row_id);
generated_ids.append_null();

// generate one output batch
let generated_outputs = generator.eval_loop(&mut generate_state)?;

// build ids for joining
if let Some(generated_outputs) = &generated_outputs {
for i in 0..generated_outputs.row_ids.len() {
let row_id = generated_outputs.row_ids.value(i) as usize;
build_ids_for_outer_generate!(cur_row_id, row_id);
child_output_row_ids.append_value(row_id as i32);
generated_ids.append_value(i as i32);
cur_row_id = row_id + 1;
}
cur_row_id += 1;
end_row_id = generate_state.cur_row_id();
} else {
end_row_id = batch.num_rows();
}
let child_output_row_ids = child_output_row_ids.finish();
let generated_ids = generated_ids.finish();

let child_outputs = child_outputs
.iter()
.map(|col| Ok(arrow::compute::take(col, &child_output_row_ids, None)?))
.collect::<Result<Vec<_>>>()?;
let generated_outputs = generated_outputs
.cols
.iter()
.map(|col| Ok(arrow::compute::take(col, &generated_ids, None)?))
.collect::<Result<Vec<_>>>()?;
build_ids_for_outer_generate!(cur_row_id, end_row_id);

// build output cols
let num_rows = generated_ids.len();
let outputs: Vec<ArrayRef> = child_outputs
.iter()
.chain(&generated_outputs)
.zip(exec_ctx.output_schema().fields())
.map(|(array, field)| {
if array.data_type() != field.data_type() {
return cast(&array, field.data_type());
if num_rows > 0 {
let child_output_row_ids = child_output_row_ids.finish();
let generated_ids = generated_ids.finish();
let child_outputs = take_cols(&child_outputs, child_output_row_ids)?;
let generated_outputs = match generated_outputs {
Some(generated_outputs) => {
take_cols(&generated_outputs.cols, generated_ids)?
}
Ok(array.clone())
})
.collect::<Result<_>>()?;
let output_batch = RecordBatch::try_new_with_options(
exec_ctx.output_schema(),
outputs,
&RecordBatchOptions::new().with_row_count(Some(num_rows)),
)?;

exec_ctx
.baseline_metrics()
.record_output(output_batch.num_rows());
sender.send(output_batch).await;
None => generator_output_schema
.fields()
.iter()
.map(|f| new_null_array(f.data_type(), generated_ids.len()))
.collect(),
};
let output_cols = [child_outputs, generated_outputs].concat();
let outputs: Vec<ArrayRef> = output_cols
.iter()
.zip(exec_ctx.output_schema().fields())
.map(|(array, field)| {
if array.data_type() != field.data_type() {
return cast(&array, field.data_type());
}
Ok(array.clone())
})
.collect::<Result<_>>()?;

// output
let output_batch = RecordBatch::try_new_with_options(
exec_ctx.output_schema(),
outputs,
&RecordBatchOptions::new().with_row_count(Some(num_rows)),
)?;
exec_ctx
.baseline_metrics()
.record_output(output_batch.num_rows());
sender.send(output_batch).await;
}
}
}

// execute generator.terminate()
while let Some(generated_outputs) = generator.terminate_loop()? {
let last_child_outputs = last_child_outputs.lock().take();
let last_child_outputs_len = last_child_outtpus_len.load(SeqCst);
let num_rows = generated_outputs.row_ids.len();

let child_output_row_ids = generated_outputs
.row_ids
.iter()
.flatten()
.map(|row_id| ((row_id as usize) < last_child_outputs_len).then_some(row_id))
.collect::<Int32Array>();
let child_output_row_ids = generated_outputs.row_ids;
let child_outputs = last_child_outputs
.unwrap_or(
child_output_dts
.iter()
.map(|dt| new_empty_array(dt))
.map(|dt| new_null_array(dt, 1))
.collect(),
)
.iter()
.map(|c| Ok(arrow::compute::take(c, &child_output_row_ids, None)?))
.collect::<Result<Vec<_>>>()?;

let num_rows = generated_outputs.row_ids.len();
let outputs: Vec<ArrayRef> = child_outputs
let output_cols = [child_outputs, generated_outputs.cols].concat();
let outputs: Vec<ArrayRef> = output_cols
.iter()
.chain(&generated_outputs.cols)
.zip(exec_ctx.output_schema().fields())
.map(|(array, field)| {
if array.data_type() != field.data_type() {
Expand All @@ -323,6 +332,7 @@ fn execute_generate(
Ok(array.clone())
})
.collect::<Result<_>>()?;

let output_batch = RecordBatch::try_new_with_options(
exec_ctx.output_schema(),
outputs,
Expand Down