Skip to content

Commit a72cc0f

Browse files
richoxzhangli20
andauthored
refactor aggregate unfreeze_from_rows() and fix UDAF fallbacking error (#940)
disable udaf falling-back by default because of possible bugs Co-authored-by: zhangli20 <zhangli20@kuaishou.com>
1 parent 42ce943 commit a72cc0f

11 files changed

Lines changed: 140 additions & 168 deletions

File tree

native-engine/datafusion-ext-plans/src/agg/acc.rs

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion::{
3030
};
3131
use datafusion_ext_commons::{
3232
assume,
33-
io::{read_bytes_into_vec, read_bytes_slice, read_len, read_scalar, write_len, write_scalar},
33+
io::{read_bytes_slice, read_len, read_scalar, write_len, write_scalar},
3434
unchecked,
3535
};
3636
use smallvec::SmallVec;
@@ -50,7 +50,7 @@ pub trait AccColumn: Send {
5050
fn num_records(&self) -> usize;
5151
fn mem_used(&self) -> usize;
5252
fn freeze_to_rows(&self, idx: IdxSelection<'_>, array: &mut [Vec<u8>]) -> Result<()>;
53-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()>;
53+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()>;
5454
fn spill(&self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()>;
5555
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()>;
5656

@@ -442,7 +442,6 @@ impl AccColumn for AccGenericColumn {
442442
raw.set_len(new_len);
443443
} else {
444444
raw.truncate(new_len);
445-
raw.set_len(new_len);
446445
}
447446
}
448447
valids.resize(len, false);
@@ -559,44 +558,38 @@ impl AccColumn for AccGenericColumn {
559558
Ok(())
560559
}
561560

562-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
563-
let mut idx = self.num_records();
564-
self.resize(idx + array.len());
561+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
562+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
563+
self.resize(cursors.len());
565564

566565
match self {
567566
&mut AccGenericColumn::Prim {
568567
ref mut raw,
569568
ref mut valids,
570569
prim_size,
571570
} => {
572-
for (data, offset) in array.iter().zip(offsets) {
573-
let mut r = Cursor::new(data);
574-
r.set_position(*offset as u64);
575-
576-
let valid = r.read_u8()?;
571+
for (idx, cursor) in cursors.iter_mut().enumerate() {
572+
let valid = cursor.read_u8()?;
577573
if valid == 1 {
578-
r.read_exact(&mut raw.as_raw_bytes_mut()[prim_size * idx..][..prim_size])?;
574+
cursor.read_exact(
575+
&mut raw.as_raw_bytes_mut()[prim_size * idx..][..prim_size],
576+
)?;
579577
valids.set(idx, true);
580578
} else {
581579
valids.set(idx, false);
582580
}
583-
*offset = r.position() as usize;
584-
idx += 1;
585581
}
586582
}
587583
AccGenericColumn::Bytes {
588584
items,
589585
heap_mem_used,
590586
} => {
591-
for (data, offset) in array.iter().zip(offsets) {
592-
let mut r = Cursor::new(data);
593-
r.set_position(*offset as u64);
594-
595-
let len = read_len(&mut r)?;
587+
for (idx, cursor) in cursors.iter_mut().enumerate() {
588+
let len = read_len(cursor)?;
596589
if len > 0 {
597590
let len = len - 1;
598591
let bytes = AccBytes::from_vec({
599-
let vec: Vec<u8> = read_bytes_slice(&mut r, len)?.into();
592+
let vec: Vec<u8> = read_bytes_slice(cursor, len)?.into();
600593
vec
601594
});
602595
if bytes.spilled() {
@@ -606,23 +599,16 @@ impl AccColumn for AccGenericColumn {
606599
} else {
607600
items[idx] = None;
608601
}
609-
*offset = r.position() as usize;
610-
idx += 1;
611602
}
612603
}
613604
AccGenericColumn::Scalar {
614605
items,
615606
dt,
616607
heap_mem_used,
617608
} => {
618-
for (data, offset) in array.iter().zip(offsets) {
619-
let mut r = Cursor::new(data);
620-
r.set_position(*offset as u64);
621-
622-
items[idx] = read_scalar(&mut r, dt, true)?;
609+
for (idx, cursor) in cursors.iter_mut().enumerate() {
610+
items[idx] = read_scalar(cursor, dt, true)?;
623611
*heap_mem_used += items[idx].size() - size_of::<ScalarValue>();
624-
*offset = r.position() as usize;
625-
idx += 1;
626612
}
627613
}
628614
}
@@ -678,23 +664,20 @@ impl AccColumn for AccGenericColumn {
678664
}
679665

680666
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
681-
let idx = self.num_records();
682-
self.resize(idx + num_rows);
667+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
668+
self.resize(num_rows);
683669

684670
match self {
685671
&mut AccGenericColumn::Prim {
686672
ref mut raw,
687673
ref mut valids,
688674
prim_size,
689675
} => {
690-
let mut valid_buf = vec![];
691-
let valid_len = (num_rows + 7) / 8;
692-
read_bytes_into_vec(r, &mut valid_buf, valid_len)?;
693-
let unfreezed_valids = BitVec::<u8>::from_vec(valid_buf);
694-
valids.truncate(idx);
695-
valids.extend_from_bitslice(unfreezed_valids.as_bitslice());
696-
697-
for i in idx..idx + num_rows {
676+
let mut bits: BitVec<u8> = BitVec::repeat(false, num_rows);
677+
r.read_exact(bits.as_raw_mut_slice())?;
678+
valids.clear();
679+
valids.extend_from_bitslice(bits.as_bitslice());
680+
for i in 0..num_rows {
698681
if valids[i] {
699682
r.read_exact(&mut raw.as_raw_bytes_mut()[prim_size * i..][..prim_size])?;
700683
}
@@ -704,7 +687,7 @@ impl AccColumn for AccGenericColumn {
704687
items,
705688
heap_mem_used,
706689
} => {
707-
for i in idx..idx + num_rows {
690+
for i in 0..num_rows {
708691
let len = read_len(r)?;
709692
if len > 0 {
710693
let len = len - 1;
@@ -721,7 +704,7 @@ impl AccColumn for AccGenericColumn {
721704
dt,
722705
heap_mem_used,
723706
} => {
724-
for i in idx..idx + num_rows {
707+
for i in 0..num_rows {
725708
items[i] = read_scalar(r, dt, true)?;
726709
*heap_mem_used += items[i].size() - size_of::<ScalarValue>();
727710
}

native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
use std::{
1616
fmt::{Debug, Formatter},
17+
io::Cursor,
1718
sync::Arc,
1819
};
1920

2021
use arrow::{
21-
array::{Array, ArrayRef, BinaryArray, RecordBatchOptions},
22+
array::{ArrayRef, BinaryArray, RecordBatchOptions},
2223
datatypes::{DataType, Field, Fields, Schema, SchemaRef},
2324
record_batch::RecordBatch,
2425
row::{RowConverter, Rows, SortField},
@@ -39,6 +40,7 @@ use crate::{
3940
agg::{
4041
acc::AccTable,
4142
agg::{Agg, IdxSelection},
43+
agg_hash_map::AggHashMapKey,
4244
spark_udaf_wrapper::{AccUDAFBufferRowsColumn, SparkUDAFMemTracker, SparkUDAFWrapper},
4345
AggExecMode, AggExpr, AggMode, GroupingExpr, AGG_BUF_COLUMN_NAME,
4446
},
@@ -244,7 +246,9 @@ impl AggContext {
244246
acc_table: &mut AccTable,
245247
acc_idx: IdxSelection,
246248
) -> Result<()> {
247-
let batch_selection = IdxSelection::Range(batch_start_idx, batch_end_idx);
249+
// NOTE:
250+
// arrow-ffi with sliced batch is buggy in older arrow-java, so we use unsliced
251+
// batch with explicit offsets
248252

249253
// partial update
250254
if self.need_partial_update {
@@ -263,6 +267,7 @@ impl AggContext {
263267
input_arrays.push(vec![]);
264268
}
265269
}
270+
let batch_selection = IdxSelection::Range(batch_start_idx, batch_end_idx);
266271
self.partial_update(acc_table, acc_idx, &input_arrays, batch_selection)?;
267272
}
268273

@@ -274,15 +279,21 @@ impl AggContext {
274279
let partial_merged_array = as_binary_array(batch.columns().last().unwrap())?;
275280
let array = partial_merged_array
276281
.iter()
282+
.skip(batch_start_idx)
283+
.take(batch_end_idx - batch_start_idx)
277284
.map(|bytes| bytes.unwrap())
278285
.collect::<Vec<_>>();
279-
let mut offsets = vec![0; partial_merged_array.len()];
286+
let mut cursors = array
287+
.iter()
288+
.map(|bytes| Cursor::new(bytes.as_bytes()))
289+
.collect::<Vec<_>>();
280290

281291
for (agg_idx, _agg) in &self.need_partial_merge_aggs {
282292
let acc_col = &mut merging_acc_table.cols_mut()[*agg_idx];
283-
acc_col.unfreeze_from_rows(&array, &mut offsets)?;
293+
acc_col.unfreeze_from_rows(&mut cursors)?;
284294
}
285295
}
296+
let batch_selection = IdxSelection::Range(0, batch_end_idx - batch_start_idx);
286297
self.partial_merge(acc_table, acc_idx, &mut merging_acc_table, batch_selection)?;
287298
}
288299
Ok(())
@@ -413,13 +424,6 @@ impl AggContext {
413424
Ok(vec)
414425
}
415426

416-
pub fn unfreeze_acc_table(&self, acc_table: &mut AccTable, data: &[&[u8]]) -> Result<()> {
417-
let mut offsets = vec![0; data.len()];
418-
for acc_col in acc_table.cols_mut() {
419-
acc_col.unfreeze_from_rows(data, &mut offsets)?;
420-
}
421-
Ok(())
422-
}
423427
pub async fn process_partial_skipped(
424428
&self,
425429
batch: RecordBatch,

native-engine/datafusion-ext-plans/src/agg/avg.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use std::{
1616
any::Any,
1717
fmt::{Debug, Formatter},
18+
io::Cursor,
1819
sync::Arc,
1920
};
2021

@@ -209,9 +210,9 @@ impl AccColumn for AccAvgColumn {
209210
Ok(())
210211
}
211212

212-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
213-
self.sum.unfreeze_from_rows(array, offsets)?;
214-
self.count.unfreeze_from_rows(array, offsets)?;
213+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
214+
self.sum.unfreeze_from_rows(cursors)?;
215+
self.count.unfreeze_from_rows(cursors)?;
215216
Ok(())
216217
}
217218

native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,21 +259,16 @@ impl AccColumn for AccBloomFilterColumn {
259259
Ok(())
260260
}
261261

262-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
263-
let mut idx = self.num_records();
264-
self.resize(idx + array.len());
265-
266-
for (data, offset) in array.iter().zip(offsets) {
267-
let mut cursor = Cursor::new(*data);
268-
cursor.set_position(*offset as u64);
269-
270-
if cursor.read_u8()? == 1 {
271-
self.bloom_filters[idx] = Some(SparkBloomFilter::read_from(&mut cursor)?);
272-
} else {
273-
self.bloom_filters[idx] = None;
274-
}
275-
*offset = cursor.position() as usize;
276-
idx += 1;
262+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
263+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
264+
for r in cursors {
265+
self.bloom_filters.push({
266+
if r.read_u8()? == 1 {
267+
Some(SparkBloomFilter::read_from(r)?)
268+
} else {
269+
None
270+
}
271+
});
277272
}
278273
Ok(())
279274
}
@@ -293,15 +288,15 @@ impl AccColumn for AccBloomFilterColumn {
293288
}
294289

295290
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
296-
let idx = self.num_records();
297-
self.resize(idx + num_rows);
298-
299-
for i in idx..idx + num_rows {
300-
if r.read_u8()? == 1 {
301-
self.bloom_filters[i] = Some(SparkBloomFilter::read_from(r)?);
302-
} else {
303-
self.bloom_filters[i] = None;
304-
}
291+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
292+
for _ in 0..num_rows {
293+
self.bloom_filters.push({
294+
if r.read_u8()? == 1 {
295+
Some(SparkBloomFilter::read_from(r)?)
296+
} else {
297+
None
298+
}
299+
});
305300
}
306301
Ok(())
307302
}

native-engine/datafusion-ext-plans/src/agg/collect.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,12 @@ pub trait AccCollectionColumn: AccColumn + Send + Sync + 'static {
191191
Ok(())
192192
}
193193

194-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
195-
let mut idx = self.num_records();
196-
self.resize(idx + array.len());
197-
198-
for (raw, offset) in array.iter().zip(offsets) {
199-
let mut cursor = Cursor::new(raw);
200-
cursor.set_position(*offset as u64);
201-
self.load_raw(idx, &mut cursor)?;
202-
*offset = cursor.position() as usize;
203-
idx += 1;
194+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
195+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
196+
self.resize(cursors.len());
197+
198+
for (idx, cursor) in cursors.iter_mut().enumerate() {
199+
self.load_raw(idx, cursor)?;
204200
}
205201
Ok(())
206202
}
@@ -298,8 +294,8 @@ impl AccColumn for AccSetColumn {
298294
AccCollectionColumn::freeze_to_rows(self, idx, array)
299295
}
300296

301-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
302-
AccCollectionColumn::unfreeze_from_rows(self, array, offsets)
297+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
298+
AccCollectionColumn::unfreeze_from_rows(self, cursors)
303299
}
304300

305301
fn spill(&self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()> {
@@ -312,12 +308,11 @@ impl AccColumn for AccSetColumn {
312308
}
313309

314310
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
315-
let mut idx = self.num_records();
316-
self.resize(idx + num_rows);
311+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
312+
self.resize(num_rows);
317313

318-
while idx < self.num_records() {
314+
for idx in 0..num_rows {
319315
self.load_raw(idx, r)?;
320-
idx += 1;
321316
}
322317
Ok(())
323318
}
@@ -411,8 +406,8 @@ impl AccColumn for AccListColumn {
411406
AccCollectionColumn::freeze_to_rows(self, idx, array)
412407
}
413408

414-
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()> {
415-
AccCollectionColumn::unfreeze_from_rows(self, array, offsets)
409+
fn unfreeze_from_rows(&mut self, cursors: &mut [Cursor<&[u8]>]) -> Result<()> {
410+
AccCollectionColumn::unfreeze_from_rows(self, cursors)
416411
}
417412

418413
fn spill(&self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()> {
@@ -425,12 +420,10 @@ impl AccColumn for AccListColumn {
425420
}
426421

427422
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()> {
428-
let mut idx = self.num_records();
429-
self.resize(idx + num_rows);
430-
431-
while idx < self.num_records() {
423+
assert_eq!(self.num_records(), 0, "expect empty AccColumn");
424+
self.resize(num_rows);
425+
for idx in 0..num_rows {
432426
self.load_raw(idx, r)?;
433-
idx += 1;
434427
}
435428
Ok(())
436429
}

0 commit comments

Comments
 (0)