Skip to content

[mlir] MemrefToLLVM: Support llvm.address_space as memory space and do not generate noop addrspacecasts#173387

Closed
Hardcode84 wants to merge 1 commit into
llvm:mainfrom
Hardcode84:fix-memref-addrspace
Closed

[mlir] MemrefToLLVM: Support llvm.address_space as memory space and do not generate noop addrspacecasts#173387
Hardcode84 wants to merge 1 commit into
llvm:mainfrom
Hardcode84:fix-memref-addrspace

Conversation

@Hardcode84

Copy link
Copy Markdown
Contributor

Having llvm.address_space in memref is useful for progressive lowering and addrspacecast with same address spaces are forbidden by LLVM spec.

… do not generate noop `addrspacecast`s

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@llvmbot

llvmbot commented Dec 23, 2025

Copy link
Copy Markdown
Member

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

Having llvm.address_space in memref is useful for progressive lowering and addrspacecast with same address spaces are forbidden by LLVM spec.


Full diff: https://github.com/llvm/llvm-project/pull/173387.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+14-1)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+15-9)
  • (modified) mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir (+20)
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index cb9dea108cc48..07661550d436e 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -269,6 +269,12 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // Integer memory spaces map to themselves.
   addTypeAttributeConversion(
       [](BaseMemRefType memref, IntegerAttr addrspace) { return addrspace; });
+
+  // LLVM address spaces map to themselves.
+  addTypeAttributeConversion(
+      [](BaseMemRefType memref, LLVM::AddressSpaceAttr addrspace) {
+        return addrspace;
+      });
 }
 
 /// Returns the MLIR context.
@@ -575,17 +581,24 @@ FailureOr<unsigned>
 LLVMTypeConverter::getMemRefAddressSpace(BaseMemRefType type) const {
   if (!type.getMemorySpace()) // Default memory space -> 0.
     return 0;
+
   std::optional<Attribute> converted =
       convertTypeAttribute(type, type.getMemorySpace());
   if (!converted)
     return failure();
+
   if (!(*converted)) // Conversion to default is 0.
     return 0;
-  if (auto explicitSpace = dyn_cast_if_present<IntegerAttr>(*converted)) {
+
+  if (auto explicitSpace = dyn_cast<IntegerAttr>(*converted)) {
     if (explicitSpace.getType().isIndex() ||
         explicitSpace.getType().isSignlessInteger())
       return explicitSpace.getInt();
   }
+
+  if (auto explicitSpace = dyn_cast<LLVM::AddressSpaceAttr>(*converted))
+    return explicitSpace.getAddressSpace();
+
   return failure();
 }
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 91a0c4b55fa84..d37895d1fb1ad 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -114,6 +114,14 @@ static unsigned getMemRefEltSizeInBytes(const LLVMTypeConverter *typeConverter,
   return layout->getTypeSize(elementType);
 }
 
+static Value createAddrSpaceCast(ConversionPatternRewriter &rewriter,
+                                 Location loc, Type type, Value value) {
+  if (value.getType() == type)
+    return value;
+
+  return LLVM::AddrSpaceCastOp::create(rewriter, loc, type, value);
+}
+
 static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
                                  Location loc, Value allocatedPtr,
                                  MemRefType memRefType, Type elementPtrType,
@@ -124,7 +132,7 @@ static Value castAllocFuncResult(ConversionPatternRewriter &rewriter,
   assert(succeeded(maybeMemrefAddrSpace) && "unsupported address space");
   unsigned memrefAddrSpace = *maybeMemrefAddrSpace;
   if (allocatedPtrTy.getAddressSpace() != memrefAddrSpace)
-    allocatedPtr = LLVM::AddrSpaceCastOp::create(
+    allocatedPtr = createAddrSpaceCast(
         rewriter, loc,
         LLVM::LLVMPointerType::get(rewriter.getContext(), memrefAddrSpace),
         allocatedPtr);
@@ -1262,10 +1270,8 @@ struct MemorySpaceCastOpLowering
       SmallVector<Value> descVals;
       MemRefDescriptor::unpack(rewriter, loc, adaptor.getSource(), resultTypeR,
                                descVals);
-      descVals[0] =
-          LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[0]);
-      descVals[1] =
-          LLVM::AddrSpaceCastOp::create(rewriter, loc, newPtrType, descVals[1]);
+      descVals[0] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[0]);
+      descVals[1] = createAddrSpaceCast(rewriter, loc, newPtrType, descVals[1]);
       Value result = MemRefDescriptor::pack(rewriter, loc, *getTypeConverter(),
                                             resultTypeR, descVals);
       rewriter.replaceOp(op, result);
@@ -1314,10 +1320,10 @@ struct MemorySpaceCastOpLowering
       Value alignedPtr =
           sourceDesc.alignedPtr(rewriter, loc, *getTypeConverter(),
                                 sourceUnderlyingDesc, sourceElemPtrType);
-      allocatedPtr = LLVM::AddrSpaceCastOp::create(
-          rewriter, loc, resultElemPtrType, allocatedPtr);
-      alignedPtr = LLVM::AddrSpaceCastOp::create(rewriter, loc,
-                                                 resultElemPtrType, alignedPtr);
+      allocatedPtr =
+          createAddrSpaceCast(rewriter, loc, resultElemPtrType, allocatedPtr);
+      alignedPtr =
+          createAddrSpaceCast(rewriter, loc, resultElemPtrType, alignedPtr);
 
       result.setAllocatedPtr(rewriter, loc, resultUnderlyingDesc,
                              resultElemPtrType, allocatedPtr);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 0cbe064572911..ff20ccba123af 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -289,6 +289,26 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
 
 // -----
 
+// ALL-LABEL: func @llvm_address_space_cast
+//  ALL-SAME:   (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>)
+func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> {
+  // ALL: %[[UNREALIZED:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<f32, #llvm.address_space<3>> to !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[ALLOC:.*]] = llvm.extractvalue %[[UNREALIZED]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[ALIGNED:.*]] = llvm.extractvalue %[[UNREALIZED]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[OFFSET:.*]] = llvm.extractvalue %[[UNREALIZED]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[POISON:.*]] = llvm.mlir.poison : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS0:.*]] = llvm.insertvalue %[[ALLOC]], %[[POISON]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS1:.*]] = llvm.insertvalue %[[ALIGNED]], %[[INS0]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[INS2:.*]] = llvm.insertvalue %[[OFFSET]], %[[INS1]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64)>
+  // ALL: %[[RECAST:.*]] = builtin.unrealized_conversion_cast %[[INS2]] : !llvm.struct<(ptr<3>, ptr<3>, i64)> to memref<f32, 3 : i32>
+  // ALL: return %[[RECAST]] : memref<f32, 3 : i32>
+
+  %0 = memref.memory_space_cast %arg0 : memref<f32, #llvm.address_space<3>> to memref<f32, 3 : i32>
+  func.return %0 : memref<f32, 3 : i32>
+}
+
+// -----
+
 // CHECK-LABEL: func @transpose
 //       CHECK:   llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>
 //       CHECK:   llvm.insertvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>


// ALL-LABEL: func @llvm_address_space_cast
// ALL-SAME: (%[[ARG:.*]]: memref<f32, #llvm.address_space<3>>)
func.func @llvm_address_space_cast(%arg0 : memref<f32, #llvm.address_space<3>>) -> memref<f32, 3 : i32> {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by #llvm.address_space. Never seen it before. Why would we have memref<..., #llvm.address_space<3>> instead of memref<..., 3>?

@Hardcode84 Hardcode84 Dec 23, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of ptr dialect infra as I understand. But I actually no longer strictly need this PR.

@Hardcode84 Hardcode84 closed this Dec 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants