diff --git a/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp b/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp index 8b4981e33ed9..72d661b12b3e 100644 --- a/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp +++ b/llvm/lib/CodeGen/GlobalISel/PatternGen.cpp @@ -579,7 +579,7 @@ struct UnopNode : public PatternNode { std::string TypeStr = lltToString(Type); // ignore bitcast ops for now - if (Op == TargetOpcode::G_BITCAST) + if ((Op == TargetOpcode::G_BITCAST) || (Op == TargetOpcode::G_CONSTANT_FOLD_BARRIER)) return Operand->patternString(); return "(" + TypeStr + " (" + std::string(UnopStr.at(Op)) + " " + @@ -587,7 +587,7 @@ struct UnopNode : public PatternNode { } LLT getRegisterTy(int OperandId) const override { - if (OperandId == -1 && Op != TargetOpcode::G_BITCAST) + if (OperandId == -1 && Op != TargetOpcode::G_BITCAST && Op != TargetOpcode::G_CONSTANT_FOLD_BARRIER) return Type; return Operand->getRegisterTy(OperandId); } @@ -937,8 +937,28 @@ static PatternOrError traverseRegLoad(MachineRegisterInfo &MRI, ReadOffset = Offset->getOperand(1).getCImm()->getLimitedValue(); } if (AddrI->getOpcode() == TargetOpcode::G_SELECT) { - // TODO: implement this! - return pError(FORMAT_LOAD, AddrI); + assert(AddrI->getOperand(1).isReg() && "expected register"); + auto CondInstr = AddrI->getOperand(1); + auto CondReg = CondInstr.getReg(); + auto [ErrCond, CondNode] = traverse(MRI, *MRI.getVRegDef(CondReg)); + if (ErrCond) + return PError(ErrCond); + assert(AddrI->getOperand(2).isReg() && "expected register"); + auto TrueInstr = AddrI->getOperand(2); + auto TrueReg = TrueInstr.getReg(); + auto [ErrTrue, TrueNode] = traverseRegLoad(MRI, Cur, ReadSize, MRI.getVRegDef(TrueReg)); + if (ErrTrue) + return PError(ErrTrue); + assert(AddrI->getOperand(3).isReg() && "expected register"); + auto FalseInstr = AddrI->getOperand(3); + auto FalseReg = FalseInstr.getReg(); + auto [ErrFalse, FalseNode] = traverseRegLoad(MRI, Cur, ReadSize, MRI.getVRegDef(FalseReg)); + if (ErrFalse) + return PError(ErrFalse); + auto Node = std::make_unique( + MRI.getType(Cur.getOperand(0).getReg()), AddrI->getOpcode(), + std::move(CondNode), std::move(TrueNode), std::move(FalseNode)); + return PPattern(std::move(Node)); } if (AddrI->getOpcode() != TargetOpcode::COPY) return pError(FORMAT_LOAD, AddrI); @@ -1040,6 +1060,7 @@ static PatternOrError traverse(MachineRegisterInfo &MRI, MachineInstr &Cur) { return std::make_pair(SUCCESS, std::move(Node)); } + case TargetOpcode::G_CONSTANT_FOLD_BARRIER: case TargetOpcode::G_ANYEXT: case TargetOpcode::G_SEXT: case TargetOpcode::G_ZEXT: @@ -1331,6 +1352,10 @@ bool PatternGen::runOnMachineFunction(MachineFunction &MF) { MayLoad = 0; MayStore = 0; + if (PatternGenArgs::Args.DumpMIR) { + MF.dump(); + } + std::string InstName = MF.getName().str().substr(4); std::string InstNameO = InstName; ++PatternGenNumInstructionsProcessed; @@ -1362,10 +1387,6 @@ bool PatternGen::runOnMachineFunction(MachineFunction &MF) { return true; } - llvm::outs() << "Pattern for " << InstName << ": " << Node->patternString() - << '\n'; - ++PatternGenNumPatternsGenerated; - LLT OutType = LLT(); std::string OutsString; std::string InsString; @@ -1409,6 +1430,11 @@ bool PatternGen::runOnMachineFunction(MachineFunction &MF) { } } + llvm::outs() << "Pattern for " << InstName << ": " << Node->patternString() + << '\n'; + ++PatternGenNumPatternsGenerated; + + InsString = InsString.substr(0, InsString.size() - 2); OutsString = OutsString.substr(0, OutsString.size() - 2); diff --git a/llvm/tools/pattern-gen/Main.cpp b/llvm/tools/pattern-gen/Main.cpp index 62eea631137a..14dca1923a32 100644 --- a/llvm/tools/pattern-gen/Main.cpp +++ b/llvm/tools/pattern-gen/Main.cpp @@ -55,6 +55,8 @@ static cl::opt SkipVerify("skip-verify", cl::cat(ToolOptions)); static cl::opt PrintIR("print-ir", cl::desc("Print LLVM-IR module."), cl::cat(ToolOptions)); +static cl::opt PrintMIR("print-mir", cl::desc("Print LLVM-MIR functions."), + cl::cat(ToolOptions)); static cl::opt NoExtend( "no-extend", cl::desc("Do not apply CDSL typing rules (Use C-like type inference)."), @@ -128,15 +130,6 @@ int main(int argc, char **argv) { auto Mod = std::make_unique("mod", Ctx); auto Instrs = ParseCoreDSL2(Ts, (XLen == 64), Mod.get(), NoExtend); - if (irOut) { - std::string Str; - raw_string_ostream OS(Str); - OS << *Mod; - OS.flush(); - irOut << Str << "\n"; - irOut.close(); - } - if (!SkipVerify) if (verifyModule(*Mod, &errs())) return -1; @@ -165,9 +158,14 @@ int main(int argc, char **argv) { PGArgsStruct Args{.Mattr = "", .OptLevel = Opt, .Predicates = Predicates, - .Is64Bit = (XLen == 64)}; + .Is64Bit = (XLen == 64), + .DumpMIR = PrintMIR.getValue()}; optimizeBehavior(Mod.get(), Instrs, irOut, Args); + + if (irOut) + irOut.close(); + if (PrintIR) llvm::outs() << *Mod << "\n"; if (!SkipFmt) diff --git a/llvm/tools/pattern-gen/PatternGen.hpp b/llvm/tools/pattern-gen/PatternGen.hpp index 18e8d8e2f526..cb8b288600e7 100644 --- a/llvm/tools/pattern-gen/PatternGen.hpp +++ b/llvm/tools/pattern-gen/PatternGen.hpp @@ -10,6 +10,7 @@ struct PGArgsStruct llvm::CodeGenOptLevel OptLevel; std::string Predicates; bool Is64Bit; + bool DumpMIR; }; int optimizeBehavior(llvm::Module* M, std::vector const& Instrs, std::ostream& OstreamIR, PGArgsStruct Args); diff --git a/llvm/tools/pattern-gen/lib/Parser.cpp b/llvm/tools/pattern-gen/lib/Parser.cpp index b7e4a9eac0bc..63e29b550cef 100644 --- a/llvm/tools/pattern-gen/lib/Parser.cpp +++ b/llvm/tools/pattern-gen/lib/Parser.cpp @@ -268,10 +268,12 @@ Value gen_subscript(TokenStream &ts, llvm::Function *func, llvm::Value *mask = (len == llLen) ? llvm::ConstantInt::get(upper.ll->getType(), 0) - : build.CreateShl(llvm::ConstantInt::get(upper.ll->getType(), 1), + : build.CreateShl( + llvm::ConstantInt::get(llvm::Type::getIntNTy(ctx, len + 1), 1), len); - mask = - build.CreateSub(mask, llvm::ConstantInt::get(upper.ll->getType(), 1)); + mask = build.CreateSub(mask, llvm::ConstantInt::get(mask->getType(), 1)); + mask = (len < left.ll->getType()->getIntegerBitWidth()) ? build.CreateZExt(mask, left.ll->getType()) : ((build.CreateTrunc(mask, left.ll->getType())) ? : mask); + left.ll = build.CreateAnd(left.ll, mask); left.bitWidth = len; @@ -805,6 +807,17 @@ Value ParseExpressionTerminal(TokenStream &ts, llvm::Function *func, if (t.ident.str == "X" || t.ident.str == "XW") { bool sizeIs32 = t.ident.str == "XW"; pop_cur(ts, ABrOpen); + if (ts.Peek().type == IntLiteral) { // Handle X[0] + auto idx = pop_cur(ts, IntLiteral); + pop_cur(ts, ABrClose); + if (idx.literal.value == 0) // X[0] -> 0 + return Value( + llvm::ConstantInt::get(llvm::Type::getIntNTy(ctx, sizeIs32 ? 32 : xlen), + 0, true), + true); + else // X[1],... + not_implemented(ts); + } auto ident = pop_cur(ts, Identifier).ident; pop_cur(ts, ABrClose);