[mlir] MemrefToLLVM: Support llvm.address_space as memory space and do not generate noop addrspacecasts#173387
Closed
Hardcode84 wants to merge 1 commit into
Closed
[mlir] MemrefToLLVM: Support llvm.address_space as memory space and do not generate noop addrspacecasts#173387Hardcode84 wants to merge 1 commit into
llvm.address_space as memory space and do not generate noop addrspacecasts#173387Hardcode84 wants to merge 1 commit into
Conversation
… do not generate noop `addrspacecast`s Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Ivan Butygin (Hardcode84) ChangesHaving Full diff: https://github.com/llvm/llvm-project/pull/173387.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cb9dea108cc48..07661550d436e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,6 +269,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
// Integer memory spaces map to themselves.
addTypeAttributeConversion(
[](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
+
+ // LLVM address spaces map to themselves.
+ addTypeAttributeConversion(
+ [](BaseMemRefType memref, LLVM::AddressSpaceAttr addrspace) {
+ return addrspace;
+ });
}
/// Returns the MLIR context.
@@ -575,17 +581,24 @@ FailureOr<unsigned>
LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
if (!type.getMemorySpace()) // Default memory space -> 0.
return 0;
+
std::optional<Attribute> converted =
convertTypeAttribute(type, type.getMemorySpace());
if (!converted)
return failure();
+
if (!(*converted)) // Conversion to default is 0.
return 0;
- if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
+
+ if (auto explicitSpace = dyn_cast<IntegerAttr>(*converted)) {
if (explicitSpace.getType().isIndex() ||
explicitSpace.getType().isSignlessInteger())
return explicitSpace.getInt();
}
+
+ if (auto explicitSpace = dyn_cast<LLVM::AddressSpaceAttr>(*converted))
+ return explicitSpace.getAddressSpace();
+
return failure();
}
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d37895d1fb1ad 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -114,6 +114,14 @@ static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
return layout->getTypeSize(elementType);
}
+static Value createAddrSpaceCast(ConversionPatternRewriter &rewriter,
+ Location loc, Type type, Value value) {
+ if (value.getType() == type)
+ return value;
+
+ return LLVM::AddrSpaceCastOp::create(rewriter, loc, type, value);
+}
+
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
Location loc, Value allocatedPtr,
MemRefType memRefType, Type elementPtrType,
@@ -124,7 +132,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
- allocatedPtr = LLVM::AddrSpaceCastOp::create(
+ allocatedPtr = createAddrSpaceCast(
rewriter, loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
allocatedPtr);
@@ -1262,10 +1270,8 @@ struct MemorySpaceCastOpLowering
SmallVector<Value> descVals;
MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
descVals);
- descVals[0] =
- LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
- descVals[1] =
- LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
+ descVals[0] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[0]);
+ descVals[1] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[1]);
Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
resultTypeR, descVals);
rewriter.replaceOp(op, result);
@@ -1314,10 +1320,10 @@ struct MemorySpaceCastOpLowering
Value alignedPtr =
sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
sourceUnderlyingDesc, sourceElemPtrType);
- allocatedPtr = LLVM::AddrSpaceCastOp::create(
- rewriter, loc, resultElemPtrType, allocatedPtr);
- alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
- resultElemPtrType, alignedPtr);
+ allocatedPtr =
+ createAddrSpaceCast(rewriter, loc, resultElemPtrType, allocatedPtr);
+ alignedPtr =
+ createAddrSpaceCast(rewriter, loc, resultElemPtrType, alignedPtr);
result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
resultElemPtrType, allocatedPtr);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..ff20ccba123af 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -289,6 +289,26 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// -----
+// ALL-LABEL: func @llvm_address_space_cast
+// ALL-SAME: (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>)
+func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> {
+ // ALL: %[[UNREALIZED:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32, #llvm.address_space<3>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[ALLOC:.*]] = llvm.extractvalue %[[UNREALIZED]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[ALIGNED:.*]] = llvm.extractvalue %[[UNREALIZED]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[OFFSET:.*]] = llvm.extractvalue %[[UNREALIZED]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[POISON:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS0:.*]] = llvm.insertvalue %[[ALLOC]], %[[POISON]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS1:.*]] = llvm.insertvalue %[[ALIGNED]], %[[INS0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[INS2:.*]] = llvm.insertvalue %[[OFFSET]], %[[INS1]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+ // ALL: %[[RECAST:.*]] = builtin.unrealized_conversion_cast %[[INS2]] : !llvm.struct<(ptr<3>, ptr<3>, i64)> to memref<f32, 3 : i32>
+ // ALL: return %[[RECAST]] : memref<f32, 3 : i32>
+
+ %0 = memref.memory_space_cast %arg0 : memref<f32, #llvm.address_space<3>> to memref<f32, 3 : i32>
+ func.return %0 : memref<f32, 3 : i32>
+}
+
+// -----
+
// CHECK-LABEL: func @transpose
// CHECK: llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
|
|
|
||
| // ALL-LABEL: func @llvm_address_space_cast | ||
| // ALL-SAME: (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>) | ||
| func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> { |
Member
There was a problem hiding this comment.
I'm a bit confused by #llvm.address_space. Never seen it before. Why would we have memref<..., #llvm.address_space<3>> instead of memref<..., 3>?
Contributor
Author
There was a problem hiding this comment.
This is part of ptr dialect infra as I understand. But I actually no longer strictly need this PR.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Having
llvm.address_spacein memref is useful for progressive lowering andaddrspacecastwith same address spaces are forbidden by LLVM spec.