diff --git a/native-engine/datafusion-ext-plans/src/generate/explode.rs b/native-engine/datafusion-ext-plans/src/generate/explode.rs index d138e9a5b..070251ed1 100644 --- a/native-engine/datafusion-ext-plans/src/generate/explode.rs +++ b/native-engine/datafusion-ext-plans/src/generate/explode.rs @@ -64,10 +64,12 @@ impl Generator for ExplodeArray { let mut sub_lists = vec![]; while row_idx < state.input_array.len() && row_ids.len() < batch_size { - let sub_list = state.input_array.value(row_idx); - row_ids.resize(row_ids.len() + sub_list.len(), row_idx as i32); - pos_ids.extend(0..sub_list.len() as i32); - sub_lists.push(sub_list); + if state.input_array.is_valid(row_idx) { + let sub_list = state.input_array.value(row_idx); + row_ids.resize(row_ids.len() + sub_list.len(), row_idx as i32); + pos_ids.extend(0..sub_list.len() as i32); + sub_lists.push(sub_list); + } row_idx += 1; } state.cur_row_id = row_idx; @@ -147,13 +149,15 @@ impl Generator for ExplodeMap { let mut sub_val_lists = vec![]; while row_idx < state.input_array.len() && row_ids.len() < batch_size { - let sub_struct = state.input_array.value(row_idx); - let sub_key_list = sub_struct.column(0); - let sub_val_list = sub_struct.column(1); - row_ids.resize(row_ids.len() + sub_key_list.len(), row_idx as i32); - pos_ids.extend(0..sub_key_list.len() as i32); - sub_key_lists.push(sub_key_list.clone()); - sub_val_lists.push(sub_val_list.clone()); + if state.input_array.is_valid(row_idx) { + let sub_struct = state.input_array.value(row_idx); + let sub_key_list = sub_struct.column(0); + let sub_val_list = sub_struct.column(1); + row_ids.resize(row_ids.len() + sub_key_list.len(), row_idx as i32); + pos_ids.extend(0..sub_key_list.len() as i32); + sub_key_lists.push(sub_key_list.clone()); + sub_val_lists.push(sub_val_list.clone()); + } row_idx += 1; } state.cur_row_id = row_idx; diff --git a/native-engine/datafusion-ext-plans/src/generate_exec.rs b/native-engine/datafusion-ext-plans/src/generate_exec.rs index bea75698a..852a2a8e5 100644 --- a/native-engine/datafusion-ext-plans/src/generate_exec.rs +++ b/native-engine/datafusion-ext-plans/src/generate_exec.rs @@ -15,14 +15,11 @@ use std::{ any::Any, fmt::{Debug, Formatter}, - sync::{ - atomic::{AtomicUsize, Ordering::SeqCst}, - Arc, - }, + sync::Arc, }; use arrow::{ - array::{new_empty_array, Array, ArrayRef, Int32Array, Int32Builder}, + array::{new_null_array, Array, ArrayBuilder, ArrayRef, Int32Builder}, datatypes::{Field, Schema, SchemaRef}, record_batch::{RecordBatch, RecordBatchOptions}, }; @@ -37,7 +34,7 @@ use datafusion::{ PlanProperties, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::arrow::cast::cast; +use datafusion_ext_commons::arrow::{cast::cast, selection::take_cols}; use futures::StreamExt; use once_cell::sync::OnceCell; use parking_lot::Mutex; @@ -164,11 +161,18 @@ impl ExecutionPlan for GenerateExec { ) -> Result { let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics); let generator = self.generator.clone(); + let generator_output_schema = self.generator_output_schema.clone(); let outer = self.outer; let child_output_cols = self.required_child_output_cols.clone(); let input = exec_ctx.execute_with_input_stats(&self.input)?; - let output = - execute_generate(input, exec_ctx.clone(), generator, outer, child_output_cols)?; + let output = execute_generate( + input, + exec_ctx.clone(), + generator, + generator_output_schema, + outer, + child_output_cols, + )?; Ok(exec_ctx.coalesce_with_default_batch_size(output)) } @@ -185,6 +189,7 @@ fn execute_generate( mut input_stream: SendableRecordBatchStream, exec_ctx: Arc, generator: Arc, + generator_output_schema: SchemaRef, outer: bool, child_output_cols: Vec, ) -> Result { @@ -200,7 +205,6 @@ fn execute_generate( sender.exclude_time(exec_ctx.baseline_metrics().elapsed_compute()); let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); let last_child_outputs: Arc>>> = Arc::default(); - let last_child_outtpus_len = AtomicUsize::new(0); while let Some(batch) = exec_ctx .baseline_metrics() @@ -220,101 +224,106 @@ fn execute_generate( .collect::>>() .map_err(|err| err.context("generate: evaluating child output arrays error"))?; last_child_outputs.lock().replace(child_outputs.clone()); - last_child_outtpus_len.store(batch.num_rows(), SeqCst); let mut generate_state = generator.eval_start(&batch)?; let mut cur_row_id = 0; - while let Some(generated_outputs) = generator.eval_loop(&mut generate_state)? { - let capacity = generated_outputs.row_ids.len(); - let mut child_output_row_ids = Int32Builder::with_capacity(capacity); - let mut generated_ids = Int32Builder::with_capacity(capacity); + while cur_row_id < batch.num_rows() { + let mut child_output_row_ids = Int32Builder::new(); + let mut generated_ids = Int32Builder::new(); + let end_row_id; - // build ids for joining - for i in 0..generated_outputs.row_ids.len() { - let row_id = generated_outputs.row_ids.value(i); - while cur_row_id < row_id { + macro_rules! build_ids_for_outer_generate { + ($from_row_id:expr, $to_row_id:expr) => {{ if outer { - child_output_row_ids.append_value(cur_row_id); - generated_ids.append_null(); + for i in $from_row_id..$to_row_id { + child_output_row_ids.append_value(i as i32); + generated_ids.append_null(); + } } - cur_row_id += 1; - } - child_output_row_ids.append_value(row_id); - generated_ids.append_value(i as i32); - cur_row_id = row_id + 1; + $from_row_id = $to_row_id; + }}; } - while cur_row_id < generate_state.cur_row_id() as i32 { - if outer { - child_output_row_ids.append_value(cur_row_id); - generated_ids.append_null(); + + // generate one output batch + let generated_outputs = generator.eval_loop(&mut generate_state)?; + + // build ids for joining + if let Some(generated_outputs) = &generated_outputs { + for i in 0..generated_outputs.row_ids.len() { + let row_id = generated_outputs.row_ids.value(i) as usize; + build_ids_for_outer_generate!(cur_row_id, row_id); + child_output_row_ids.append_value(row_id as i32); + generated_ids.append_value(i as i32); + cur_row_id = row_id + 1; } - cur_row_id += 1; + end_row_id = generate_state.cur_row_id(); + } else { + end_row_id = batch.num_rows(); } - let child_output_row_ids = child_output_row_ids.finish(); - let generated_ids = generated_ids.finish(); - - let child_outputs = child_outputs - .iter() - .map(|col| Ok(arrow::compute::take(col, &child_output_row_ids, None)?)) - .collect::>>()?; - let generated_outputs = generated_outputs - .cols - .iter() - .map(|col| Ok(arrow::compute::take(col, &generated_ids, None)?)) - .collect::>>()?; + build_ids_for_outer_generate!(cur_row_id, end_row_id); + // build output cols let num_rows = generated_ids.len(); - let outputs: Vec = child_outputs - .iter() - .chain(&generated_outputs) - .zip(exec_ctx.output_schema().fields()) - .map(|(array, field)| { - if array.data_type() != field.data_type() { - return cast(&array, field.data_type()); + if num_rows > 0 { + let child_output_row_ids = child_output_row_ids.finish(); + let generated_ids = generated_ids.finish(); + let child_outputs = take_cols(&child_outputs, child_output_row_ids)?; + let generated_outputs = match generated_outputs { + Some(generated_outputs) => { + take_cols(&generated_outputs.cols, generated_ids)? } - Ok(array.clone()) - }) - .collect::>()?; - let output_batch = RecordBatch::try_new_with_options( - exec_ctx.output_schema(), - outputs, - &RecordBatchOptions::new().with_row_count(Some(num_rows)), - )?; - - exec_ctx - .baseline_metrics() - .record_output(output_batch.num_rows()); - sender.send(output_batch).await; + None => generator_output_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), generated_ids.len())) + .collect(), + }; + let output_cols = [child_outputs, generated_outputs].concat(); + let outputs: Vec = output_cols + .iter() + .zip(exec_ctx.output_schema().fields()) + .map(|(array, field)| { + if array.data_type() != field.data_type() { + return cast(&array, field.data_type()); + } + Ok(array.clone()) + }) + .collect::>()?; + + // output + let output_batch = RecordBatch::try_new_with_options( + exec_ctx.output_schema(), + outputs, + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + exec_ctx + .baseline_metrics() + .record_output(output_batch.num_rows()); + sender.send(output_batch).await; + } } } // execute generator.terminate() while let Some(generated_outputs) = generator.terminate_loop()? { let last_child_outputs = last_child_outputs.lock().take(); - let last_child_outputs_len = last_child_outtpus_len.load(SeqCst); + let num_rows = generated_outputs.row_ids.len(); - let child_output_row_ids = generated_outputs - .row_ids - .iter() - .flatten() - .map(|row_id| ((row_id as usize) < last_child_outputs_len).then_some(row_id)) - .collect::(); + let child_output_row_ids = generated_outputs.row_ids; let child_outputs = last_child_outputs .unwrap_or( child_output_dts .iter() - .map(|dt| new_empty_array(dt)) + .map(|dt| new_null_array(dt, 1)) .collect(), ) .iter() .map(|c| Ok(arrow::compute::take(c, &child_output_row_ids, None)?)) .collect::>>()?; - - let num_rows = generated_outputs.row_ids.len(); - let outputs: Vec = child_outputs + let output_cols = [child_outputs, generated_outputs.cols].concat(); + let outputs: Vec = output_cols .iter() - .chain(&generated_outputs.cols) .zip(exec_ctx.output_schema().fields()) .map(|(array, field)| { if array.data_type() != field.data_type() { @@ -323,6 +332,7 @@ fn execute_generate( Ok(array.clone()) }) .collect::>()?; + let output_batch = RecordBatch::try_new_with_options( exec_ctx.output_schema(), outputs,