Skip to content

Commit 76b6e18

Browse files
lucas-ramirishabhmadan19
authored andcommitted
Re-apply "[AMDGPU][Scheduler] Scoring system for rematerializations (llvm#175050)" (llvm#177206)
This re-applies commit f21e359 along with the compile fix failure introduced in 8ab7937 before the initial patch was reverted. It also fixes for the previously observed assert failure. We were hitting the assert in the HIP Blender due to a combination of two issues that could happen when rematerializations are being rolled back. 1. Small changes in slots indices (while preserving instruction order) compared to the pre-re-scheduling state means that we have to re-compute live ranges for all register operands of rolled back rematerializations. This was not being done before. 2. Re-scheduling can move registers that were rematerialized at arbitrary positions in their respective regions while their opcode is set to DBG_VALUE, even before their read operands are defined. This makes re-scheduling reverts mandatory before rolling back rematerializations, as otherwise def-use chains may be broken. The original patch did not guarantee that, but previous refactoring of the rollback/revert logic for the rematerialization stage now ensures that reverts always precede rollbacks.
1 parent 0de0fae commit 76b6e18

9 files changed

Lines changed: 2719 additions & 1602 deletions

llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp

Lines changed: 522 additions & 294 deletions
Large diffs are not rendered by default.

llvm/lib/Target/AMDGPU/GCNSchedStrategy.h

Lines changed: 210 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,7 @@ class ScheduleMetrics {
202202
};
203203

204204
inline raw_ostream &operator<<(raw_ostream &OS, const ScheduleMetrics &Sm) {
205-
dbgs() << "\n Schedule Metric (scaled by "
206-
<< ScheduleMetrics::ScaleFactor
205+
dbgs() << "\n Schedule Metric (scaled by " << ScheduleMetrics::ScaleFactor
207206
<< " ) is: " << Sm.getMetric() << " [ " << Sm.getBubbles() << "/"
208207
<< Sm.getLength() << " ]\n";
209208
return OS;
@@ -305,21 +304,15 @@ class GCNScheduleDAGMILive final : public ScheduleDAGMILive {
305304
// Compute and cache live-ins and pressure for all regions in block.
306305
void computeBlockPressure(unsigned RegionIdx, const MachineBasicBlock *MBB);
307306

308-
/// If necessary, updates a region's boundaries following insertion ( \p NewMI
309-
/// != nullptr) or removal ( \p NewMI == nullptr) of a \p MI in the region.
310-
/// For an MI removal, this must be called before the MI is actually erased
311-
/// from its parent MBB.
312-
void updateRegionBoundaries(RegionBoundaries &RegionBounds,
313-
MachineBasicBlock::iterator MI,
314-
MachineInstr *NewMI);
315-
316307
/// Makes the scheduler try to achieve an occupancy of \p TargetOccupancy.
317308
void setTargetOccupancy(unsigned TargetOccupancy);
318309

319310
void runSchedStages();
320311

321312
std::unique_ptr<GCNSchedStage> createSchedStage(GCNSchedStageID SchedStageID);
322313

314+
void deleteMI(unsigned RegionIdx, MachineInstr *MI);
315+
323316
public:
324317
GCNScheduleDAGMILive(MachineSchedContext *C,
325318
std::unique_ptr<MachineSchedStrategy> S);
@@ -516,49 +509,184 @@ class ClusteredLowOccStage : public GCNSchedStage {
516509
};
517510

518511
/// Attempts to reduce function spilling or, if there is no spilling, to
519-
/// increase function occupancy by one with respect to ArchVGPR usage by sinking
520-
/// rematerializable instructions to their use. When the stage
521-
/// estimates reducing spilling or increasing occupancy is possible, as few
522-
/// instructions as possible are rematerialized to reduce potential negative
512+
/// increase function occupancy by one with respect to register usage by sinking
513+
/// rematerializable instructions to their use. When the stage estimates that
514+
/// reducing spilling or increasing occupancy is possible, it tries to
515+
/// rematerialize as few registers as possible to reduce potential negative
523516
/// effects on function latency.
517+
///
518+
/// The stage only supports rematerializing registers that meet all of the
519+
/// following constraints.
520+
/// 1. The register is virtual and has a single defining instruction.
521+
/// 2. The single defining instruction is either deemed rematerializable by the
522+
/// target-independent logic, or if not, has no non-constant and
523+
/// non-ignorable physical register use.
524+
/// 3 The register has no virtual register use whose live range would be
525+
/// extended by the rematerialization.
526+
/// 4. The register has a single non-debug user in a different region from its
527+
/// defining region.
528+
/// 5. The register is not used by or using another register that is going to be
529+
/// rematerialized.
524530
class PreRARematStage : public GCNSchedStage {
525531
private:
526-
/// Useful information about a rematerializable instruction.
527-
struct RematInstruction {
528-
/// Single use of the rematerializable instruction's defined register,
529-
/// located in a different block.
532+
/// A rematerializable register.
533+
struct RematReg {
534+
/// Single MI defining the rematerializable register.
535+
MachineInstr *DefMI;
536+
/// Single user of the rematerializable register.
530537
MachineInstr *UseMI;
531-
/// Rematerialized version of \p DefMI, set in
532-
/// PreRARematStage::rematerialize. Used for reverting rematerializations.
533-
MachineInstr *RematMI;
534-
/// Set of regions in which the rematerializable instruction's defined
535-
/// register is a live-in.
536-
SmallDenseSet<unsigned, 4> LiveInRegions;
538+
/// Regions in which the register is live-in/live-out/live anywhere.
539+
BitVector LiveIn, LiveOut, Live;
540+
/// The rematerializable register's lane bitmask.
541+
LaneBitmask Mask;
542+
/// Defining and using regions.
543+
unsigned DefRegion, UseRegion;
544+
545+
RematReg(MachineInstr *DefMI, MachineInstr *UseMI,
546+
GCNScheduleDAGMILive &DAG,
547+
const DenseMap<MachineInstr *, unsigned> &MIRegion);
548+
549+
/// Returns the rematerializable register. Do not call after deleting the
550+
/// original defining instruction.
551+
Register getReg() const { return DefMI->getOperand(0).getReg(); }
552+
553+
/// Determines whether this rematerialization may be beneficial in at least
554+
/// one target region.
555+
bool maybeBeneficial(const BitVector &TargetRegions,
556+
ArrayRef<GCNRPTarget> RPTargets) const;
557+
558+
/// Determines if the register is both unused and live-through in region \p
559+
/// I. This guarantees that rematerializing it will reduce RP in the region.
560+
bool isUnusedLiveThrough(unsigned I) const {
561+
assert(I < Live.size() && "region index out of range");
562+
return LiveIn[I] && LiveOut[I] && I != UseRegion;
563+
}
564+
565+
/// Updates internal structures following a MI rematerialization. Part of
566+
/// the stage instead of the DAG because it makes assumptions that are
567+
/// specific to the rematerialization process.
568+
void insertMI(unsigned RegionIdx, MachineInstr *RematMI,
569+
GCNScheduleDAGMILive &DAG) const;
570+
};
571+
572+
/// A scored rematerialization candidate. Higher scores indicate more
573+
/// beneficial rematerializations. A null score indicate the rematerialization
574+
/// is not helpful to reduce RP in target regions.
575+
struct ScoredRemat {
576+
/// The rematerializable register under consideration.
577+
RematReg *Remat;
578+
579+
/// Execution frequency information required by scoring heuristics.
580+
/// Frequencies are scaled down if they are high to avoid overflow/underflow
581+
/// when combining them.
582+
struct FreqInfo {
583+
/// Per-region execution frequencies. 0 when unknown.
584+
SmallVector<uint64_t> Regions;
585+
/// Minimum and maximum observed frequencies.
586+
uint64_t MinFreq, MaxFreq;
587+
588+
FreqInfo(MachineFunction &MF, const GCNScheduleDAGMILive &DAG);
589+
590+
private:
591+
static const uint64_t ScaleFactor = 1024;
592+
};
593+
594+
/// This only initializes state-independent characteristics of \p Remat, not
595+
/// the actual score.
596+
ScoredRemat(RematReg *Remat, const FreqInfo &Freq,
597+
const GCNScheduleDAGMILive &DAG);
598+
599+
/// Updates the rematerialization's score w.r.t. the current \p RPTargets.
600+
/// \p RegionFreq indicates the frequency of each region
601+
void update(const BitVector &TargetRegions, ArrayRef<GCNRPTarget> RPTargets,
602+
const FreqInfo &Freq, bool ReduceSpill);
603+
604+
/// Returns whether the current score is null, indicating the
605+
/// rematerialization is useless.
606+
bool hasNullScore() const { return !RegionImpact; }
607+
608+
/// Compare score components of non-null scores pair-wise. A null score is
609+
/// always strictly lesser than another non-null score.
610+
bool operator<(const ScoredRemat &O) const {
611+
if (hasNullScore())
612+
return !O.hasNullScore();
613+
if (O.hasNullScore())
614+
return false;
615+
if (MaxFreq != O.MaxFreq)
616+
return MaxFreq < O.MaxFreq;
617+
if (FreqDiff != O.FreqDiff)
618+
return FreqDiff < O.FreqDiff;
619+
if (RegionImpact != O.RegionImpact)
620+
return RegionImpact < O.RegionImpact;
621+
// Break ties using pointer to rematerializable register. Rematerializable
622+
// registers are collected in instruction order so, within the same
623+
// region, this will prefer registers defined earlier that have longer
624+
// live ranges in their defining region (since the registers we consider
625+
// are always live-out in their defining region).
626+
return Remat > O.Remat;
627+
}
628+
629+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
630+
Printable print() const;
631+
#endif
537632

538-
RematInstruction(MachineInstr *UseMI) : UseMI(UseMI) {}
633+
private:
634+
/// Number of 32-bit registers this rematerialization covers.
635+
unsigned NumRegs;
636+
637+
// The three members below are the scoring components, top to bottom from
638+
// most important to least important when comparing candidates.
639+
640+
/// Frequency of impacted target region with highest known frequency. This
641+
/// only matters when the stage is trying to reduce spilling, so it is
642+
/// always 0 when it is not.
643+
uint64_t MaxFreq;
644+
/// Frequency difference between defining and using regions. Negative values
645+
/// indicate we are rematerializing to higher frequency regions; positive
646+
/// values indicate the contrary.
647+
int64_t FreqDiff;
648+
/// Expected number of target regions impacted by the rematerialization,
649+
/// scaled by the size of the register being rematerialized.
650+
unsigned RegionImpact;
651+
652+
unsigned getNumRegs(const GCNScheduleDAGMILive &DAG) const;
653+
654+
int64_t getFreqDiff(const FreqInfo &Freq) const;
539655
};
540656

541-
/// Maps all MIs to their parent region. MI terminators are considered to be
542-
/// outside the region they delimitate, and as such are not stored in the map.
543-
DenseMap<MachineInstr *, unsigned> MIRegion;
544657
/// Parent MBB to each region, in region order.
545658
SmallVector<MachineBasicBlock *> RegionBB;
546-
/// Collects instructions to rematerialize.
547-
MapVector<MachineInstr *, RematInstruction> Rematerializations;
548-
/// Collects regions whose live-ins or register pressure will change due to
549-
/// rematerializations.
550-
DenseMap<unsigned, GCNRegPressure> ImpactedRegions;
551-
/// In case we need to rollback rematerializations, save lane masks for all
552-
/// rematerialized registers in all regions in which they are live-ins.
553-
DenseMap<std::pair<unsigned, Register>, LaneBitmask> RegMasks;
554-
/// After successful stage initialization, indicates which regions should be
555-
/// rescheduled.
556-
BitVector RescheduleRegions;
557-
/// The target occupancy the stage is trying to achieve. Empty when the
659+
/// Register pressure targets for all regions.
660+
SmallVector<GCNRPTarget> RPTargets;
661+
/// Regions which are above the stage's RP target.
662+
BitVector TargetRegions;
663+
/// The target occupancy the set is trying to achieve. Empty when the
558664
/// objective is spilling reduction.
559665
std::optional<unsigned> TargetOcc;
560666
/// Achieved occupancy *only* through rematerializations (pre-rescheduling).
561667
unsigned AchievedOcc;
668+
/// After successful stage initialization, indicates which regions should be
669+
/// rescheduled.
670+
BitVector RescheduleRegions;
671+
672+
/// List of rematerializable registers.
673+
SmallVector<RematReg> RematRegs;
674+
675+
/// Holds enough information to rollback a rematerialization decision post
676+
/// re-scheduling.
677+
struct RollbackInfo {
678+
/// The rematerializable register under consideration.
679+
const RematReg *Remat;
680+
/// The rematerialized MI replacing the original defining MI.
681+
MachineInstr *RematMI;
682+
/// Maps register machine operand indices to their original register.
683+
SmallDenseMap<unsigned, Register, 4> RegMap;
684+
685+
RollbackInfo(const RematReg *Remat) : Remat(Remat) {}
686+
};
687+
/// List of rematerializations to rollback if rematerialization does not end
688+
/// up being beneficial.
689+
SmallVector<RollbackInfo> Rollbacks;
562690

563691
/// State of a region pre-re-scheduling but post-rematerializations that we
564692
/// must keep to be able to revert re-scheduling effects.
@@ -582,20 +710,46 @@ class PreRARematStage : public GCNSchedStage {
582710
/// Returns the occupancy the stage is trying to achieve.
583711
unsigned getStageTargetOccupancy() const;
584712

585-
/// Returns whether remat can reduce spilling or increase function occupancy
586-
/// by 1 through rematerialization. If it can do one, collects instructions in
587-
/// PreRARematStage::Rematerializations and sets the target occupancy in
588-
/// PreRARematStage::TargetOccupancy.
589-
bool canIncreaseOccupancyOrReduceSpill();
713+
/// Determines the stage's objective (increasing occupancy or reducing
714+
/// spilling, set in \ref TargetOcc). Defines \ref RPTargets in all regions to
715+
/// achieve that objective and mark those that don't achieve it in \ref
716+
/// TargetRegions. Returns whether there is any target region.
717+
bool setObjective();
718+
719+
/// Unsets target regions in \p Regions whose RP target has been reached.
720+
void unsetSatisifedRPTargets(const BitVector &Regions);
721+
722+
/// Fully recomputes RP from the DAG in \p Regions. Among those regions, sets
723+
/// again all \ref TargetRegions that were optimistically marked as satisfied
724+
/// but are actually not, and returns whether there were any such regions.
725+
bool updateAndVerifyRPTargets(const BitVector &Regions);
726+
727+
/// Collects all rematerializable registers and appends them to \ref
728+
/// RematRegs. \p MIRegion maps MIs to their region. Returns whether any
729+
/// rematerializable register was found.
730+
bool collectRematRegs(const DenseMap<MachineInstr *, unsigned> &MIRegion);
731+
732+
/// Rematerializes \p Remat. This removes the rematerialized register from
733+
/// live-in/out lists in the DAG and updates RP targets in all affected
734+
/// regions, which are also marked in \ref RescheduleRegions. Regions in which
735+
/// RP savings are not guaranteed are set in \p RecomputeRP. When \p Rollback
736+
/// is non-null, fills it with required information to be able to rollback the
737+
/// rematerialization post-rescheduling.
738+
void rematerialize(const RematReg &Remat, BitVector &RecomputeRP,
739+
RollbackInfo *Rollback);
740+
741+
/// Rollbacks the rematerialization decision represented by \p Rollback. This
742+
/// update live-in/out lists in the DAG but does not update cached register
743+
/// pressures.
744+
void rollback(const RollbackInfo &Rollback) const;
745+
746+
/// Deletes all rematerialized MIs from the MIR when they were kept around for
747+
/// potential rollback.
748+
void commitRematerializations() const;
590749

591750
/// Whether the MI is rematerializable
592751
bool isReMaterializable(const MachineInstr &MI);
593752

594-
/// Rematerializes all instructions in PreRARematStage::Rematerializations
595-
/// and stores the achieved occupancy after remat in
596-
/// PreRARematStage::AchievedOcc.
597-
void rematerialize();
598-
599753
/// If remat alone did not increase occupancy to the target one, rollbacks all
600754
/// rematerializations and resets live-ins/RP in all regions impacted by the
601755
/// stage to their pre-stage values.
@@ -611,7 +765,12 @@ class PreRARematStage : public GCNSchedStage {
611765
bool shouldRevertScheduling(unsigned WavesAfter) override;
612766

613767
PreRARematStage(GCNSchedStageID StageID, GCNScheduleDAGMILive &DAG)
614-
: GCNSchedStage(StageID, DAG), RescheduleRegions(DAG.Regions.size()) {}
768+
: GCNSchedStage(StageID, DAG), TargetRegions(DAG.Regions.size()),
769+
RescheduleRegions(DAG.Regions.size()) {
770+
const unsigned NumRegions = DAG.Regions.size();
771+
RPTargets.reserve(NumRegions);
772+
RegionBB.reserve(NumRegions);
773+
}
615774
};
616775

617776
class ILPInitialScheduleStage : public GCNSchedStage {

0 commit comments

Comments
 (0)