Skip to content

Commit 37145b9

Browse files
lihao712lihao29
andauthored
optimize sort merge join and avoid oom (#970)
* get_array_mem_size() prefers capacity to len * optimize sort merge join and avoid oom --------- Co-authored-by: lihao29 <lihao29@kuaishou.com>
1 parent 43389c8 commit 37145b9

7 files changed

Lines changed: 239 additions & 308 deletions

File tree

native-engine/datafusion-ext-commons/src/arrow/array_size.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ fn get_array_data_mem_size(array_data: &ArrayData) -> usize {
5454
mem_size += size_of::<Option<Buffer>>();
5555
mem_size += array_data
5656
.nulls()
57-
.map(|nb| nb.buffer().len())
57+
.map(|nb| nb.buffer().len().max(nb.buffer().capacity()))
5858
.unwrap_or_default();
5959

6060
// summing child data size

native-engine/datafusion-ext-plans/src/joins/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@ use arrow::{
2020
datatypes::{DataType, SchemaRef},
2121
};
2222
use datafusion::{common::Result, physical_expr::PhysicalExprRef};
23+
use stream_cursor::StreamCursor;
2324

24-
use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor};
25+
use crate::joins::join_utils::JoinType;
2526

26-
pub mod join_hash_map;
2727
pub mod join_utils;
28-
pub mod stream_cursor;
2928

3029
// join implementations
3130
pub mod bhj;
31+
pub mod join_hash_map;
3232
pub mod smj;
33+
pub mod stream_cursor;
3334
mod test;
3435

3536
#[derive(Debug, Clone)]

native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs

Lines changed: 57 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ use std::{cmp::Ordering, pin::Pin, sync::Arc};
1717
use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions};
1818
use async_trait::async_trait;
1919
use datafusion::common::Result;
20-
use datafusion_ext_commons::{
21-
arrow::selection::create_batch_interleaver, suggested_batch_mem_size,
22-
};
20+
use datafusion_ext_commons::arrow::selection::create_batch_interleaver;
2321

2422
use crate::{
2523
common::execution_context::WrappedRecordBatchSender,
@@ -47,21 +45,8 @@ impl ExistenceJoiner {
4745
}
4846
}
4947

50-
fn should_flush(&self, curs: &StreamCursors) -> bool {
51-
if self.indices.len() >= self.join_params.batch_size {
52-
return true;
53-
}
54-
55-
if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 5
56-
|| curs.0.mem_size() + curs.1.mem_size() > suggested_batch_mem_size()
57-
{
58-
if let Some(first_idx) = self.indices.first() {
59-
if first_idx.0 < curs.0.cur_idx.0 {
60-
return true;
61-
}
62-
}
63-
}
64-
false
48+
fn should_flush(&self) -> bool {
49+
self.indices.len() >= self.join_params.batch_size
6550
}
6651

6752
async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
@@ -91,53 +76,75 @@ impl ExistenceJoiner {
9176
impl Joiner for ExistenceJoiner {
9277
async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> {
9378
while !curs.0.finished && !curs.1.finished {
94-
let mut lidx = curs.0.cur_idx;
95-
let mut ridx = curs.1.cur_idx;
79+
if self.should_flush()
80+
|| curs.0.num_buffered_batches() > 1
81+
|| curs.1.num_buffered_batches() > 1
82+
{
83+
self.as_mut().flush(curs).await?;
84+
curs.0.clean_out_dated_batches();
85+
curs.1.clean_out_dated_batches();
86+
}
9687

9788
match compare_cursor!(curs) {
9889
Ordering::Less => {
9990
self.indices.push(curs.0.cur_idx);
10091
self.exists.push(false);
10192
cur_forward!(curs.0);
102-
if self.should_flush(curs) {
103-
self.as_mut().flush(curs).await?;
104-
}
105-
curs.0
106-
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx));
10793
}
10894
Ordering::Greater => {
10995
cur_forward!(curs.1);
110-
curs.1
111-
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx));
11296
}
11397
Ordering::Equal => {
114-
loop {
115-
self.indices.push(lidx);
116-
self.exists.push(true);
117-
cur_forward!(curs.0);
118-
if self.should_flush(curs) {
119-
self.as_mut().flush(curs).await?;
120-
}
121-
curs.0
122-
.set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx));
98+
let l_key_idx = curs.0.cur_idx;
99+
let r_key_idx = curs.1.cur_idx;
100+
101+
self.indices.push(curs.0.cur_idx);
102+
self.exists.push(true);
103+
cur_forward!(curs.0);
104+
cur_forward!(curs.1);
123105

124-
if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) {
125-
lidx = curs.0.cur_idx;
126-
continue;
106+
// iterate both stream, find smaller one, use it for probing
107+
let mut l_equal = true;
108+
let mut r_equal = true;
109+
while l_equal && r_equal {
110+
if l_equal {
111+
l_equal = !curs.0.finished && curs.0.cur_key() == curs.0.key(l_key_idx);
112+
if l_equal {
113+
self.indices.push(curs.0.cur_idx);
114+
self.exists.push(true);
115+
cur_forward!(curs.0);
116+
}
117+
}
118+
if r_equal {
119+
r_equal = !curs.1.finished && curs.1.cur_key() == curs.1.key(r_key_idx);
120+
if r_equal {
121+
cur_forward!(curs.1);
122+
}
127123
}
128-
break;
129124
}
130125

131-
// skip all right equal rows
132-
loop {
133-
cur_forward!(curs.1);
134-
curs.1.set_min_reserved_idx(ridx);
126+
if l_equal {
127+
// stream left side
128+
while !curs.0.finished && curs.0.cur_key() == curs.1.key(r_key_idx) {
129+
self.indices.push(curs.0.cur_idx);
130+
self.exists.push(true);
131+
cur_forward!(curs.0);
132+
if self.should_flush() || curs.0.num_buffered_batches() > 1 {
133+
self.as_mut().flush(curs).await?;
134+
curs.0.clean_out_dated_batches();
135+
}
136+
}
137+
}
135138

136-
if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) {
137-
ridx = curs.1.cur_idx;
138-
continue;
139+
if r_equal {
140+
// stream right side
141+
while !curs.1.finished && curs.1.cur_key() == curs.0.key(l_key_idx) {
142+
cur_forward!(curs.1);
143+
if self.should_flush() || curs.1.num_buffered_batches() > 1 {
144+
self.as_mut().flush(curs).await?;
145+
curs.1.clean_out_dated_batches();
146+
}
139147
}
140-
break;
141148
}
142149
}
143150
}
@@ -147,11 +154,10 @@ impl Joiner for ExistenceJoiner {
147154
self.indices.push(curs.0.cur_idx);
148155
self.exists.push(false);
149156
cur_forward!(curs.0);
150-
if self.should_flush(curs) {
157+
if self.should_flush() {
151158
self.as_mut().flush(curs).await?;
159+
curs.0.clean_out_dated_batches();
152160
}
153-
curs.0
154-
.set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx));
155161
}
156162
if !self.indices.is_empty() {
157163
self.flush(curs).await?;

0 commit comments

Comments
 (0)