Skip to content

Commit 0701e2e

Browse files
richoxzhangli20
andauthored
fix agg failure: index out of bounds (#899)
Co-authored-by: zhangli20 <zhangli20@kuaishou.com>
1 parent 3b578da commit 0701e2e

9 files changed

Lines changed: 34 additions & 78 deletions

File tree

native-engine/datafusion-ext-plans/src/agg/acc.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ pub trait AccColumn: Send {
5353
fn unfreeze_from_rows(&mut self, array: &[&[u8]], offsets: &mut [usize]) -> Result<()>;
5454
fn spill(&self, idx: IdxSelection<'_>, w: &mut SpillCompressedWriter) -> Result<()>;
5555
fn unspill(&mut self, num_rows: usize, r: &mut SpillCompressedReader) -> Result<()>;
56+
57+
fn ensure_size(&mut self, idx: IdxSelection<'_>) {
58+
let idx_max_value = match idx {
59+
IdxSelection::Single(v) => v,
60+
IdxSelection::Indices(v) => v.iter().copied().max().unwrap_or(0),
61+
IdxSelection::IndicesU32(v) => v.iter().copied().max().unwrap_or(0) as usize,
62+
IdxSelection::Range(_begin, end) => end,
63+
};
64+
if idx_max_value >= self.num_records() {
65+
self.resize(idx_max_value + 1);
66+
}
67+
}
5668
}
5769

5870
pub type AccColumnRef = Box<dyn AccColumn>;

native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,10 @@ impl Agg for AggBloomFilter {
115115
partial_arg_idx: IdxSelection<'_>,
116116
) -> Result<()> {
117117
let accs = downcast_any!(accs, mut AccBloomFilterColumn).unwrap();
118+
accs.ensure_size(acc_idx);
119+
118120
let bloom_filter = match acc_idx {
119121
IdxSelection::Single(idx) => {
120-
if idx >= accs.num_records() {
121-
accs.resize(idx + 1);
122-
}
123-
124122
let bf = &mut accs.bloom_filters[idx];
125123
if bf.is_none() {
126124
*bf = Some(SparkBloomFilter::new_with_expected_num_items(

native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,11 @@ impl Agg for AggCollect {
8787
partial_args: &[ArrayRef],
8888
partial_arg_idx: IdxSelection<'_>,
8989
) -> Result<()> {
90+
accs.ensure_size(acc_idx);
9091
let list = partial_args[0].as_list::<i32>();
9192

9293
idx_for_zipped! {
9394
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
94-
if acc_idx >= accs.num_records() {
95-
accs.resize(acc_idx + 1);
96-
}
97-
9895
if list.is_valid(partial_arg_idx) {
9996
let values = list.value(partial_arg_idx);
10097
let values_len = values.len();

native-engine/datafusion-ext-plans/src/agg/brickhouse/combine_unique.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,11 @@ impl Agg for AggCombineUnique {
8787
partial_args: &[ArrayRef],
8888
partial_arg_idx: IdxSelection<'_>,
8989
) -> Result<()> {
90+
accs.ensure_size(acc_idx);
9091
let list = partial_args[0].as_list::<i32>();
9192

9293
idx_for_zipped! {
9394
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
94-
if acc_idx >= accs.num_records() {
95-
accs.resize(acc_idx + 1);
96-
}
97-
9895
if list.is_valid(partial_arg_idx) {
9996
let values = list.value(partial_arg_idx);
10097
let values_len = values.len();

native-engine/datafusion-ext-plans/src/agg/collect.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,10 @@ impl<C: AccCollectionColumn> Agg for AggGenericCollect<C> {
116116
partial_arg_idx: IdxSelection<'_>,
117117
) -> Result<()> {
118118
let accs = downcast_any!(accs, mut C).unwrap();
119+
accs.ensure_size(acc_idx);
120+
119121
idx_for_zipped! {
120122
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
121-
if acc_idx >= accs.num_records() {
122-
accs.resize(acc_idx + 1);
123-
}
124123
let scalar = ScalarValue::try_from_array(&partial_args[0], partial_arg_idx)?;
125124
if !scalar.is_null() {
126125
accs.append_item(acc_idx, &scalar);
@@ -138,12 +137,11 @@ impl<C: AccCollectionColumn> Agg for AggGenericCollect<C> {
138137
merging_acc_idx: IdxSelection<'_>,
139138
) -> Result<()> {
140139
let accs = downcast_any!(accs, mut C).unwrap();
140+
accs.ensure_size(acc_idx);
141+
141142
let merging_accs = downcast_any!(merging_accs, mut C).unwrap();
142143
idx_for_zipped! {
143144
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
144-
if acc_idx >= accs.num_records() {
145-
accs.resize(acc_idx + 1);
146-
}
147145
accs.merge_items(acc_idx, merging_accs, merging_acc_idx);
148146
}
149147
}

native-engine/datafusion-ext-plans/src/agg/first.rs

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ impl Agg for AggFirst {
9393
) -> Result<()> {
9494
let partial_arg = &partial_args[0];
9595
let accs = downcast_any!(accs, mut AccFirstColumn).unwrap();
96+
accs.ensure_size(acc_idx);
97+
9698
let old_heap_mem_used = accs.values.items_heap_mem_used(acc_idx);
9799

98100
macro_rules! handle_bytes {
@@ -101,9 +103,6 @@ impl Agg for AggFirst {
101103
let partial_arg = downcast_any!(partial_arg, TArray).unwrap();
102104
idx_for_zipped! {
103105
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
104-
if acc_idx >= accs.num_records() {
105-
accs.resize(acc_idx + 1);
106-
}
107106
if !accs.flags.prim_valid(acc_idx) {
108107
accs.flags.set_prim_valid(acc_idx, true);
109108
if partial_arg.is_valid(partial_arg_idx) {
@@ -121,9 +120,6 @@ impl Agg for AggFirst {
121120
partial_arg => {
122121
idx_for_zipped! {
123122
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
124-
if acc_idx >= accs.num_records() {
125-
accs.resize(acc_idx + 1);
126-
}
127123
if !accs.flags.prim_valid(acc_idx) {
128124
accs.flags.set_prim_valid(acc_idx, true);
129125
accs.values.set_prim_valid(acc_idx, partial_arg.is_valid(partial_arg_idx));
@@ -137,9 +133,6 @@ impl Agg for AggFirst {
137133
_other => {
138134
idx_for_zipped! {
139135
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
140-
if acc_idx >= accs.num_records() {
141-
accs.resize(acc_idx + 1);
142-
}
143136
if accs.flags.prim_valid(acc_idx) {
144137
accs.flags.set_prim_valid(acc_idx, true);
145138
accs.values.scalar_values_mut()[acc_idx] = ScalarValue::try_from_array(partial_arg, partial_arg_idx)?;
@@ -164,6 +157,8 @@ impl Agg for AggFirst {
164157
) -> Result<()> {
165158
let accs = downcast_any!(accs, mut AccFirstColumn).unwrap();
166159
let merging_accs = downcast_any!(merging_accs, mut AccFirstColumn).unwrap();
160+
accs.ensure_size(acc_idx);
161+
167162
let old_heap_mem_used = accs.values.items_heap_mem_used(acc_idx);
168163

169164
// safety: bypass borrow checker
@@ -184,10 +179,6 @@ impl Agg for AggFirst {
184179
) => {
185180
idx_for_zipped! {
186181
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
187-
if acc_idx >= accs.num_records() {
188-
accs.resize(acc_idx + 1);
189-
}
190-
191182
if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
192183
let acc_offset = *prim_size * acc_idx;
193184
let merging_acc_offset = *prim_size * merging_acc_idx;
@@ -207,10 +198,6 @@ impl Agg for AggFirst {
207198
) => {
208199
idx_for_zipped! {
209200
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
210-
if acc_idx >= accs.num_records() {
211-
accs.resize(acc_idx + 1);
212-
}
213-
214201
if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
215202
let item = &mut items[acc_idx];
216203
let mut other_item = &mut other_items[merging_acc_idx];
@@ -228,10 +215,6 @@ impl Agg for AggFirst {
228215
) => {
229216
idx_for_zipped! {
230217
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
231-
if acc_idx >= accs.num_records() {
232-
accs.resize(acc_idx + 1);
233-
}
234-
235218
if !accs.flags.prim_valid(acc_idx) && merging_accs.flags.prim_valid(merging_acc_idx) {
236219
let item = & mut items[acc_idx];
237220
let mut other_item = & mut other_items[merging_acc_idx];

native-engine/datafusion-ext-plans/src/agg/first_ignores_null.rs

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ impl Agg for AggFirstIgnoresNull {
8989
) -> Result<()> {
9090
let partial_arg = &partial_args[0];
9191
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
92+
accs.ensure_size(acc_idx);
93+
9294
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);
9395

9496
macro_rules! handle_bytes {
@@ -97,9 +99,6 @@ impl Agg for AggFirstIgnoresNull {
9799
let partial_arg = downcast_any!(partial_arg, TArray).unwrap();
98100
idx_for_zipped! {
99101
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
100-
if acc_idx >= accs.num_records() {
101-
accs.resize(acc_idx + 1);
102-
}
103102
if accs.bytes_value(acc_idx).is_none() && partial_arg.is_valid(partial_arg_idx) {
104103
accs.set_bytes_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref())));
105104
}
@@ -112,9 +111,6 @@ impl Agg for AggFirstIgnoresNull {
112111
partial_arg => {
113112
idx_for_zipped! {
114113
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
115-
if acc_idx >= accs.num_records() {
116-
accs.resize(acc_idx + 1);
117-
}
118114
if !accs.prim_valid(acc_idx) && partial_arg.is_valid(partial_arg_idx) {
119115
accs.set_prim_valid(acc_idx, true);
120116
accs.set_prim_value(acc_idx, partial_arg.value(partial_arg_idx));
@@ -127,9 +123,6 @@ impl Agg for AggFirstIgnoresNull {
127123
_other => {
128124
idx_for_zipped! {
129125
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
130-
if acc_idx >= accs.num_records() {
131-
accs.resize(acc_idx + 1);
132-
}
133126
if accs.scalar_values()[acc_idx].is_null() && partial_arg.is_valid(partial_arg_idx) {
134127
accs.scalar_values_mut()[acc_idx] = ScalarValue::try_from_array(partial_arg, partial_arg_idx)?;
135128
}
@@ -151,6 +144,8 @@ impl Agg for AggFirstIgnoresNull {
151144
merging_acc_idx: IdxSelection<'_>,
152145
) -> Result<()> {
153146
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
147+
accs.ensure_size(acc_idx);
148+
154149
let mut merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
155150
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);
156151

@@ -173,10 +168,6 @@ impl Agg for AggFirstIgnoresNull {
173168
) => {
174169
idx_for_zipped! {
175170
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
176-
if acc_idx >= accs.num_records() {
177-
accs.resize(acc_idx + 1);
178-
}
179-
180171
if !valids[acc_idx] && other_valids[merging_acc_idx] {
181172
valids.set(acc_idx, true);
182173
let acc_offset = *prim_size * acc_idx;
@@ -195,10 +186,6 @@ impl Agg for AggFirstIgnoresNull {
195186
) => {
196187
idx_for_zipped! {
197188
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
198-
if acc_idx >= accs.num_records() {
199-
accs.resize(acc_idx + 1);
200-
}
201-
202189
let item = &mut items[acc_idx];
203190
let mut other_item = &mut other_items[merging_acc_idx];
204191
if item.is_none() && other_item.is_some() {
@@ -215,10 +202,6 @@ impl Agg for AggFirstIgnoresNull {
215202
) => {
216203
idx_for_zipped! {
217204
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
218-
if acc_idx >= accs.num_records() {
219-
accs.resize(acc_idx + 1);
220-
}
221-
222205
let item = &mut items[acc_idx];
223206
let mut other_item = &mut other_items[merging_acc_idx];
224207
if item.is_null() && !other_item.is_null() {

native-engine/datafusion-ext-plans/src/agg/maxmin.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
9595
partial_arg_idx: IdxSelection<'_>,
9696
) -> Result<()> {
9797
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
98+
accs.ensure_size(acc_idx);
99+
98100
let old_heap_mem_used = accs.items_heap_mem_used(acc_idx);
99101

100102
macro_rules! handle_prim {
@@ -103,9 +105,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
103105
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
104106
idx_for_zipped! {
105107
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
106-
if acc_idx >= accs.num_records() {
107-
accs.resize(acc_idx + 1);
108-
}
109108
if !partial_arg.is_valid(partial_arg_idx) {
110109
continue;
111110
}
@@ -130,9 +129,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
130129
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
131130
idx_for_zipped! {
132131
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
133-
if acc_idx >= accs.num_records() {
134-
accs.resize(acc_idx + 1);
135-
}
136132
if !partial_arg.is_valid(partial_arg_idx) {
137133
continue;
138134
}
@@ -177,9 +173,6 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
177173
_ => {
178174
idx_for_zipped! {
179175
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
180-
if acc_idx >= accs.num_records() {
181-
accs.resize(acc_idx + 1);
182-
}
183176
let partial_arg_scalar = ScalarValue::try_from_array(&partial_args[0], partial_arg_idx)?;
184177
if !partial_arg_scalar.is_null() {
185178
let acc_scalar = &mut accs.scalar_values_mut()[acc_idx];
@@ -210,16 +203,15 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
210203
merging_acc_idx: IdxSelection<'_>,
211204
) -> Result<()> {
212205
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
206+
accs.ensure_size(acc_idx);
207+
213208
let merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
214209
let old_mem_used = accs.items_heap_mem_used(acc_idx);
215210

216211
macro_rules! handle_prim {
217212
($ty:ty) => {{
218213
idx_for_zipped! {
219214
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
220-
if acc_idx >= accs.num_records() {
221-
accs.resize(acc_idx + 1);
222-
}
223215
if !merging_accs.prim_valid(merging_acc_idx) {
224216
continue;
225217
}

native-engine/datafusion-ext-plans/src/agg/sum.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ impl Agg for AggSum {
9393
partial_arg_idx: IdxSelection<'_>,
9494
) -> Result<()> {
9595
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
96+
accs.ensure_size(acc_idx);
9697

9798
macro_rules! handle {
9899
($ty:ident) => {{
@@ -102,9 +103,6 @@ impl Agg for AggSum {
102103
let partial_arg = downcast_any!(&partial_args[0], TArray).unwrap();
103104
idx_for_zipped! {
104105
((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
105-
if acc_idx >= accs.num_records() {
106-
accs.resize(acc_idx + 1);
107-
}
108106
if partial_arg.is_valid(partial_arg_idx) {
109107
let partial_value = partial_arg.value(partial_arg_idx);
110108
if !accs.prim_valid(acc_idx) {
@@ -145,14 +143,12 @@ impl Agg for AggSum {
145143
) -> Result<()> {
146144
let accs = downcast_any!(accs, mut AccGenericColumn).unwrap();
147145
let merging_accs = downcast_any!(merging_accs, mut AccGenericColumn).unwrap();
146+
accs.ensure_size(acc_idx);
148147

149148
macro_rules! handle {
150149
($ty:ty) => {{
151150
idx_for_zipped! {
152151
((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
153-
if acc_idx >= accs.num_records() {
154-
accs.resize(acc_idx + 1);
155-
}
156152
if merging_accs.prim_valid(merging_acc_idx) {
157153
let merging_value = merging_accs.prim_value::<$ty>(merging_acc_idx);
158154
if !accs.prim_valid(acc_idx) {

0 commit comments

Comments
 (0)