|
26 | 26 | #include <tvm/arith/analyzer.h> |
27 | 27 | #include <tvm/runtime/registry.h> |
28 | 28 | #include <tvm/tir/stmt_functor.h> |
| 29 | +#include <tvm/tir/index_map.h> |
| 30 | +#include <tvm/arith/iter_affine_map.h> |
29 | 31 |
|
| 32 | +#include <algorithm> |
30 | 33 | #include <cmath> |
31 | 34 | #include <string> |
32 | 35 | #include <utility> |
33 | 36 | #include <vector> |
34 | 37 |
|
35 | 38 | #include "literal/cuda_half_t.h" |
36 | 39 | #include "ptx.h" |
37 | | -#include "tvm/arith/iter_affine_map.h" |
38 | 40 |
|
39 | 41 | namespace tvm { |
40 | 42 | namespace codegen { |
@@ -839,40 +841,26 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { |
839 | 841 | std::string dst = this->PrintExpr(op->args[2]); |
840 | 842 | std::string src = this->PrintExpr(op->args[3]); |
841 | 843 | std::string src_offset = this->PrintExpr(op->args[4]); |
| 844 | + PrimExpr stride = op->args[5]; |
| 845 | + |
| 846 | + ICHECK(m == 16 && n == 16) << "Only m == 16 && n == 16 case supported for now"; |
| 847 | + |
| 848 | + const auto* index_map_func = |
| 849 | + runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); |
| 850 | + ICHECK(index_map_func); |
| 851 | + |
| 852 | + auto inverse_index_map = |
| 853 | + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, 16), Range(0, 16)}); |
| 854 | + auto indices_16x16 = inverse_index_map->final_indices; |
| 855 | + |
| 856 | + var_idmap_[inverse_index_map->initial_indices[0].get()] = "threadIdx.x"; |
| 857 | + var_idmap_[inverse_index_map->initial_indices[1].get()] = "local_id"; |
| 858 | + |
| 859 | + os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; |
| 860 | + os << dst << "[" + this->PrintExpr(indices_16x16[0] * stride + indices_16x16[1]) + "]" |
| 861 | + << " = " << src << "[" << src_offset << " + local_id];\n"; |
| 862 | + os << "}\n"; |
842 | 863 |
|
843 | | - if (m == 16 && n == 8) { |
844 | | - std::string stride = this->PrintExpr(op->args[5]); |
845 | | - os << "for (int i = 0; i < 4; ++i) {\n"; |
846 | | - os << dst << "[(i / 2 * 8 + threadIdx.x / 4) * " << stride |
847 | | - << " + (threadIdx.x % 4) * 2 + i % 2]" |
848 | | - << " = " << src << "[" << src_offset << " + i];\n"; |
849 | | - os << "}\n"; |
850 | | - } else if (m == 16 && n == 16) { |
851 | | - const auto* index_map = |
852 | | - runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); |
853 | | - ICHECK(index_map); |
854 | | - |
855 | | - Var var_i("i"); |
856 | | - Var var_j("j"); |
857 | | - Array<PrimExpr> forward_map = (*index_map)(var_i, var_j); |
858 | | - |
859 | | - arith::Analyzer ana; |
860 | | - auto iter_map = arith::DetectIterMap( |
861 | | - forward_map, {{var_i, Range(0, 16)}, {var_j, Range(0, 16)}}, true, true, &ana, true); |
862 | | - |
863 | | - Var thread_id("threadIdx.x"); |
864 | | - Var local_id("local_id"); |
865 | | - auto inverse_map = arith::InverseAffineIterMap(iter_map, {thread_id, local_id}); |
866 | | - PrimExpr stride = op->args[5]; |
867 | | - auto dst_idx = inverse_map[var_i] * stride + inverse_map[var_j]; |
868 | | - |
869 | | - var_idmap_[thread_id.get()] = "threadIdx.x"; |
870 | | - var_idmap_[local_id.get()] = "local_id"; |
871 | | - os << "for (int local_id = 0; local_id < 8; ++local_id) {\n"; |
872 | | - os << dst << "[" + this->PrintExpr(dst_idx) + "]" |
873 | | - << " = " << src << "[" << src_offset << " + local_id];\n"; |
874 | | - os << "}\n"; |
875 | | - } |
876 | 864 | } else if (op->op.same_as(builtin::mma_fill())) { |
877 | 865 | std::string num_elem = this->PrintExpr(op->args[0]); |
878 | 866 | std::string dst = this->PrintExpr(op->args[1]); |
|
0 commit comments