Skip to content

Commit a094b16

Browse files
committed
egraphs: don't let rematerialization override LICM.
This reworks the way that remat and LICM interact during aegraph elaboration. In principle, both happen during the same single-pass "code placement" algorithm: we decide where to place pure instructions (those that are eligible for movement), and remat pushes them one way while LICM pushes them the other. The interaction is a little more subtle than simple heuristic priority, though -- it's really a decision ordering issue. A remat'd value wants to sink as deep into the loop nest as it can (to the use's block), but we don't know *where* the uses go until we process them (and make LICM-related choices), and we process uses after defs during elaboration. Or more precisely, we have some work at the use before recursively processing the def, and some work after the recursion returns; and the LICM decision happens after recursion returns, because LICM wants to know where the defs are to know how high we can hoist. (The recursion is itself unrolled into a state machine on an explicit stack so that's a little hard to see but that's what is happening in principle.) The solution here is to make remat a separate just-in-time thing, once we have arg values. Just before we plug the final arg values into the elaborated instruction, we ask: is this a remat'd value, and if so, do we have a copy of the computation in this block yet. If not, we make one. This has to happen in two places (the main elab loop and the toplevel driver from the skeleton). The one downside of this solution is that it doesn't handle *recursive* rematerialization by default. This means that if we, for example, decide to remat single-constant-arg adds (as we actually do in our current rules), we won't then also recursively remat the constant arg to those adds. This can be seen in the `licm.clif` test case. This doesn't seem to be a dealbreaker to me because most such cases will be able to fold the constants anyway (they happen mostly because of pointer pre-computations: a loop over structs in Wasm computes heap_base + p + offset, and naive LICM pulls a `heap_base + offset` out of the loop for every struct field accessed in the loop, with horrible register pressure resulting; that's why we have that remat rule. Most such offsets are pretty small.). Fixes #7283.
1 parent 04fcb6a commit a094b16

7 files changed

Lines changed: 161 additions & 91 deletions

File tree

cranelift/codegen/src/egraph.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ pub(crate) struct Stats {
670670
pub(crate) elaborate_visit_node: u64,
671671
pub(crate) elaborate_memoize_hit: u64,
672672
pub(crate) elaborate_memoize_miss: u64,
673-
pub(crate) elaborate_memoize_miss_remat: u64,
673+
pub(crate) elaborate_remat: u64,
674674
pub(crate) elaborate_licm_hoist: u64,
675675
pub(crate) elaborate_func: u64,
676676
pub(crate) elaborate_func_pre_insts: u64,

cranelift/codegen/src/egraph/elaborate.rs

Lines changed: 100 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use super::cost::{pure_op_cost, Cost};
55
use super::domtree::DomTreeWithChildren;
66
use super::Stats;
77
use crate::dominator_tree::DominatorTree;
8-
use crate::fx::FxHashSet;
8+
use crate::fx::{FxHashMap, FxHashSet};
9+
use crate::hash_map::Entry as HashEntry;
910
use crate::ir::{Block, Function, Inst, Value, ValueDef};
1011
use crate::loop_analysis::{Loop, LoopAnalysis};
1112
use crate::scoped_hash_map::ScopedHashMap;
@@ -56,6 +57,8 @@ pub(crate) struct Elaborator<'a> {
5657
elab_result_stack: Vec<ElaboratedValue>,
5758
/// Explicitly-unrolled block elaboration stack.
5859
block_stack: Vec<BlockStackEntry>,
60+
/// Copies of values that have been rematerialized.
61+
remat_copies: FxHashMap<(Block, Value), Value>,
5962
/// Stats for various events during egraph processing, to help
6063
/// with optimization of this infrastructure.
6164
stats: &'a mut Stats,
@@ -95,7 +98,6 @@ enum ElabStackEntry {
9598
inst: Inst,
9699
result_idx: usize,
97100
num_args: usize,
98-
remat: bool,
99101
before: Inst,
100102
},
101103
}
@@ -134,6 +136,7 @@ impl<'a> Elaborator<'a> {
134136
elab_stack: vec![],
135137
elab_result_stack: vec![],
136138
block_stack: vec![],
139+
remat_copies: FxHashMap::default(),
137140
stats,
138141
}
139142
}
@@ -256,6 +259,45 @@ impl<'a> Elaborator<'a> {
256259
self.elab_result_stack.pop().unwrap()
257260
}
258261

262+
/// Possibly rematerialize the instruction producing the value in
263+
/// `arg` and rewrite `arg` to refer to it, if needed. Returns
264+
/// `true` if a rewrite occurred.
265+
fn maybe_remat_arg(
266+
remat_values: &FxHashSet<Value>,
267+
func: &mut Function,
268+
remat_copies: &mut FxHashMap<(Block, Value), Value>,
269+
insert_block: Block,
270+
before: Inst,
271+
arg: &mut ElaboratedValue,
272+
stats: &mut Stats,
273+
) -> bool {
274+
// TODO: we may want to consider recursive rematerialization
275+
// as well. We could process the arguments of the
276+
// rematerialized instruction up to a certain depth. This
277+
// would affect, e.g., adds-with-one-constant-arg, which are
278+
// currently rematerialized. Right now we don't do this, to
279+
// avoid the need for another fixpoint loop here.
280+
if arg.in_block != insert_block && remat_values.contains(&arg.value) {
281+
let new_value = match remat_copies.entry((insert_block, arg.value)) {
282+
HashEntry::Occupied(o) => *o.get(),
283+
HashEntry::Vacant(v) => {
284+
let inst = func.dfg.value_def(arg.value).inst().unwrap();
285+
debug_assert_eq!(func.dfg.inst_results(inst).len(), 1);
286+
let new_inst = func.dfg.clone_inst(inst);
287+
func.layout.insert_inst(new_inst, before);
288+
let new_result = func.dfg.inst_results(new_inst)[0];
289+
*v.insert(new_result)
290+
}
291+
};
292+
trace!("rematerialized {} as {}", arg.value, new_value);
293+
arg.value = new_value;
294+
stats.elaborate_remat += 1;
295+
true
296+
} else {
297+
false
298+
}
299+
}
300+
259301
fn process_elab_stack(&mut self) {
260302
while let Some(entry) = self.elab_stack.pop() {
261303
match entry {
@@ -278,39 +320,17 @@ impl<'a> Elaborator<'a> {
278320
// eclass.
279321
trace!("looking up best value for {}", value);
280322
let (_, best_value) = self.value_to_best_value[value];
281-
debug_assert_ne!(best_value, Value::reserved_value());
282323
trace!("elaborate: value {} -> best {}", value, best_value);
324+
debug_assert_ne!(best_value, Value::reserved_value());
325+
326+
if let Some(elab_val) = self.value_to_elaborated_value.get(&canonical_value) {
327+
// Value is available; use it.
328+
trace!("elaborate: value {} -> {:?}", value, elab_val);
329+
self.stats.elaborate_memoize_hit += 1;
330+
self.elab_result_stack.push(*elab_val);
331+
continue;
332+
}
283333

284-
let remat = if let Some(elab_val) =
285-
self.value_to_elaborated_value.get(&canonical_value)
286-
{
287-
// Value is available. Look at the defined
288-
// block, and determine whether this node kind
289-
// allows rematerialization if the value comes
290-
// from another block. If so, ignore the hit
291-
// and recompute below.
292-
let remat = elab_val.in_block != self.cur_block
293-
&& self.remat_values.contains(&best_value);
294-
if !remat {
295-
trace!("elaborate: value {} -> {:?}", value, elab_val);
296-
self.stats.elaborate_memoize_hit += 1;
297-
self.elab_result_stack.push(*elab_val);
298-
continue;
299-
}
300-
trace!("elaborate: value {} -> remat", canonical_value);
301-
self.stats.elaborate_memoize_miss_remat += 1;
302-
// The op is pure at this point, so it is always valid to
303-
// remove from this map.
304-
self.value_to_elaborated_value.remove(&canonical_value);
305-
true
306-
} else {
307-
// Value not available; but still look up
308-
// whether it's been flagged for remat because
309-
// this affects placement.
310-
let remat = self.remat_values.contains(&best_value);
311-
trace!(" -> not present in map; remat = {}", remat);
312-
remat
313-
};
314334
self.stats.elaborate_memoize_miss += 1;
315335

316336
// Now resolve the value to its definition to see
@@ -358,7 +378,6 @@ impl<'a> Elaborator<'a> {
358378
inst,
359379
result_idx,
360380
num_args,
361-
remat,
362381
before,
363382
});
364383

@@ -375,23 +394,21 @@ impl<'a> Elaborator<'a> {
375394
inst,
376395
result_idx,
377396
num_args,
378-
remat,
379397
before,
380398
} => {
381399
trace!(
382-
"PendingInst: {} result {} args {} remat {} before {}",
400+
"PendingInst: {} result {} args {} before {}",
383401
inst,
384402
result_idx,
385403
num_args,
386-
remat,
387404
before
388405
);
389406

390407
// We should have all args resolved at this
391408
// point. Grab them and drain them out, removing
392409
// them.
393410
let arg_idx = self.elab_result_stack.len() - num_args;
394-
let arg_values = &self.elab_result_stack[arg_idx..];
411+
let arg_values = &mut self.elab_result_stack[arg_idx..];
395412

396413
// Compute max loop depth.
397414
//
@@ -437,16 +454,15 @@ impl<'a> Elaborator<'a> {
437454

438455
// We know that this is a pure inst, because
439456
// non-pure roots have already been placed in the
440-
// value-to-elab'd-value map and are never subject
441-
// to remat, so they will not reach this stage of
442-
// processing.
457+
// value-to-elab'd-value map, so they will not
458+
// reach this stage of processing.
443459
//
444460
// We now must determine the location at which we
445461
// place the instruction. This is the current
446462
// block *unless* we hoist above a loop when all
447463
// args are loop-invariant (and this op is pure).
448464
let (scope_depth, before, insert_block) =
449-
if loop_hoist_level == self.loop_stack.len() || remat {
465+
if loop_hoist_level == self.loop_stack.len() {
450466
// Depends on some value at the current
451467
// loop depth, or remat forces it here:
452468
// place it at the current location.
@@ -479,16 +495,39 @@ impl<'a> Elaborator<'a> {
479495
insert_block
480496
);
481497

482-
// Now we need to place `inst` at the computed
483-
// location (just before `before`). Note that
484-
// `inst` may already have been placed somewhere
485-
// else, because a pure node may be elaborated at
486-
// more than one place. In this case, we need to
487-
// duplicate the instruction (and return the
488-
// `Value`s for that duplicated instance
489-
// instead).
498+
// Now that we have the location for the
499+
// instruction, check if any of its args are remat
500+
// values. If so, and if we don't have a copy of
501+
// the rematerializing instruction for this block
502+
// yet, create one.
503+
let mut remat_arg = false;
504+
for arg_value in arg_values.iter_mut() {
505+
if Self::maybe_remat_arg(
506+
&self.remat_values,
507+
&mut self.func,
508+
&mut self.remat_copies,
509+
insert_block,
510+
before,
511+
arg_value,
512+
&mut self.stats,
513+
) {
514+
remat_arg = true;
515+
}
516+
}
517+
518+
// Now we need to place `inst` at the computed
519+
// location (just before `before`). Note that
520+
// `inst` may already have been placed somewhere
521+
// else, because a pure node may be elaborated at
522+
// more than one place. In this case, we need to
523+
// duplicate the instruction (and return the
524+
// `Value`s for that duplicated instance instead).
525+
//
526+
// Also clone if we rematerialized, because we
527+
// don't want to rewrite the args in the original
528+
// copy.
490529
trace!("need inst {} before {}", inst, before);
491-
let inst = if self.func.layout.inst_block(inst).is_some() {
530+
let inst = if self.func.layout.inst_block(inst).is_some() || remat_arg {
492531
// Clone the inst!
493532
let new_inst = self.func.dfg.clone_inst(inst);
494533
trace!(
@@ -605,7 +644,16 @@ impl<'a> Elaborator<'a> {
605644
// Elaborate the arg, placing any newly-inserted insts
606645
// before `before`. Get the updated value, which may
607646
// be different than the original.
608-
let new_arg = self.elaborate_eclass_use(*arg, before);
647+
let mut new_arg = self.elaborate_eclass_use(*arg, before);
648+
Self::maybe_remat_arg(
649+
&self.remat_values,
650+
&mut self.func,
651+
&mut self.remat_copies,
652+
block,
653+
inst,
654+
&mut new_arg,
655+
&mut self.stats,
656+
);
609657
trace!(" -> rewrote arg to {:?}", new_arg);
610658
*arg = new_arg.value;
611659
}

cranelift/codegen/src/fx.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ use super::{HashMap, HashSet};
1414
use core::default::Default;
1515
use core::hash::{BuildHasherDefault, Hash, Hasher};
1616
use core::ops::BitXor;
17-
1817
pub type FxHashMap<K, V> = HashMap<K, V, BuildHasherDefault<FxHasher>>;
1918
pub type FxHashSet<V> = HashSet<V, BuildHasherDefault<FxHasher>>;
2019

cranelift/codegen/src/scoped_hash_map.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -188,19 +188,6 @@ where
188188
.checked_sub(1)
189189
.expect("generation_by_depth cannot be empty")
190190
}
191-
192-
/// Remote an entry.
193-
pub fn remove(&mut self, key: &K) -> Option<V> {
194-
self.map.remove(key).and_then(|val| {
195-
let entry_generation = val.generation;
196-
let entry_depth = val.level as usize;
197-
if self.generation_by_depth.get(entry_depth).cloned() == Some(entry_generation) {
198-
Some(val.value)
199-
} else {
200-
None
201-
}
202-
})
203-
}
204191
}
205192

206193
#[cfg(test)]

cranelift/filetests/filetests/egraph/licm.clif

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ block2(v9: i32):
2020
}
2121

2222
; check: block0(v0: i32, v1: i32):
23-
; nextln: jump block1(v0)
23+
; check: jump block1(v0)
2424

2525
; check: block1(v2: i32):
26-
;; constants are not lifted; they are rematerialized in each block where used
27-
; check: v5 = iconst.i32 40
28-
; check: v6 = icmp eq v2, v5
29-
; check: v3 = iconst.i32 1
30-
; check: v8 = iadd v2, v3
31-
; check: brif v6, block2, block1(v8)
26+
;; constants are rematerialized in each block where used
27+
; check: v10 = iconst.i32 40
28+
; check: v11 = icmp eq v2, v10
29+
; check: v12 = iconst.i32 1
30+
; check: v13 = iadd v2, v12
31+
; check: brif v11, block2, block1(v13)
3232

3333

3434
; check: block2:
35-
; check: v10 = iconst.i32 1
36-
; check: v4 = iadd.i32 v1, v10
37-
; check: return v4
35+
; check: v14 = iconst.i32 1
36+
; check: v15 = iadd.i32 v1, v14
37+
; check: return v15
3838

3939
function %f(i64x2, i32) -> i64x2 {
4040
block0(v0: i64x2, v1: i32):
@@ -52,14 +52,14 @@ block2(v8: i64x2):
5252
}
5353

5454
; check: block0(v0: i64x2, v1: i32):
55-
; nextln: v4 = vconst.i64x2 const0
55+
; check: v4 = vconst.i64x2 const0
5656
; nextln: jump block1(v0, v1)
5757
; check: block1(v2: i64x2, v3: i32):
58-
; check: v6 = iconst.i32 1
59-
; check: v7 = isub v3, v6
58+
; check: v9 = iconst.i32 1
59+
; check: v10 = isub v3, v9
6060
; check: v5 = iadd v2, v4
6161
; check: v8 -> v5
62-
; check: brif v7, block1(v5, v7), block2
62+
; check: brif v10, block1(v5, v10), block2
6363
; check: block2:
6464
; check: return v5
6565

@@ -94,13 +94,50 @@ block4:
9494
; check: v8 = vconst.i64x2 const0
9595
; check: jump block2(v3, v1)
9696
; check: block2(v6: i64x2, v7: i32):
97-
; check: v10 = iconst.i32 1
98-
; check: v11 = isub v7, v10
97+
; check: v15 = iconst.i32 1
98+
; check: v16 = isub v7, v15
9999
; check: v9 = iadd v6, v8
100-
; check: brif v11, block2(v9, v11), block3
100+
; check: brif v16, block2(v9, v16), block3
101101
; check: block3:
102-
; check: v15 = iconst.i32 1
103-
; check: v14 = isub.i32 v5, v15
104-
; check: brif v14, block1(v9, v14), block4
102+
; check: v17 = iconst.i32 1
103+
; check: v18 = isub.i32 v5, v17
104+
; check: brif v18, block1(v9, v18), block4
105105
; check: block4:
106106
; check: return v9
107+
108+
;; Don't let a rematerialized iconst inhibit (or even reverse)
109+
;; LICM. See issue #7283.
110+
111+
function %f(i64, i64) {
112+
block0(v0: i64, v1: i64):
113+
;; Create a loop-invariant value `v10` which is some operation which
114+
;; includes a constant somewhere.
115+
v8 = load.f64 v0+100
116+
v9 = f64const 0x1.0000000000000p1
117+
v10 = fdiv v8, v9
118+
119+
;; jump to the loop
120+
v3 = iconst.i64 0
121+
jump block2(v3) ; v3 = 0
122+
123+
block2(v11: i64):
124+
;; store the loop-invariant `v10` to memory "somewhere"
125+
v15 = iadd v0, v11
126+
store.f64 v10, v15
127+
128+
;; loop breakout condition
129+
v17 = iadd_imm v11, 1
130+
v19 = icmp_imm ne v17, 100
131+
brif v19, block2(v17), block1
132+
133+
block1:
134+
return
135+
}
136+
137+
; check: load
138+
; check: f64const
139+
; check: fdiv
140+
; check: block2(v11: i64)
141+
; check: iadd
142+
; check: store
143+
; check: brif

0 commit comments

Comments
 (0)