-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathexplain_a_lds_index_swapping_equiv.py
More file actions
191 lines (163 loc) · 7.64 KB
/
explain_a_lds_index_swapping_equiv.py
File metadata and controls
191 lines (163 loc) · 7.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#!/usr/bin/env python3
"""
Exact-construct debug using pytensor's C++-equivalent APIs for:
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_unmerge_transform(make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
This script constructs the full chain:
1) a_lds_block_desc_0 (naive)
2) a_lds_block_desc_permuted (xor on dims [1,0], pass-through dim 2)
3) a_lds_block_desc_xk0_mnldslayer_mn_xk1 (unmerge + pass-throughs with index swap)
It prints the top-dimension lengths and a small sample of index-to-offset mappings
to validate index placement and order.
"""
from typing import List
from pytensor.tensor_descriptor import (
make_naive_tensor_descriptor,
make_unmerge_transform,
make_pass_through_transform,
make_xor_transform,
transform_tensor_descriptor,
)
def get_defaults():
try:
from tensor_transforms.examples import get_default_variables # type: ignore
vars_ = get_default_variables()["A LDS Block Desc Example"]
return (
int(vars_["kKPerBlock"]),
int(vars_["kKPack"]),
int(vars_["MLdsLayer"]),
int(vars_["kMPerBlock"]),
)
except Exception:
return 32, 4, 2, 64
def main() -> None:
kKPerBlock, kKPack, MLdsLayer, kMPerBlock = get_defaults()
kKPerBlock_over_kKPack = kKPerBlock // kKPack
kMPerBlock_over_MLdsLayer = kMPerBlock // MLdsLayer
# 1) a_lds_block_desc_0
# lengths: [kKPerBlock/kKPack*MLdsLayer, kMPerBlock/MLdsLayer, kKPack]
# strides: [kKPack, kKPerBlock*MLdsLayer, 1]
lengths0 = [kKPerBlock_over_kKPack * MLdsLayer, kMPerBlock_over_MLdsLayer, kKPack]
strides0 = [kKPack, kKPerBlock * MLdsLayer, 1]
a_lds_block_desc_0 = make_naive_tensor_descriptor(lengths0, strides0)
# 2) a_lds_block_desc_permuted
# XOR over dims [1,0], pass-through dim 2; upper mapping mirrors lower: [1,0] and [2]
a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
transforms=[
make_xor_transform([kMPerBlock_over_MLdsLayer, kKPerBlock_over_kKPack * MLdsLayer]),
make_pass_through_transform(kKPack),
],
lower_dimension_hidden_idss=[
[1, 0], # sequence<1,0>
[2], # sequence<2>
],
upper_dimension_hidden_idss=[
[1, 0], # sequence<1,0>
[2], # sequence<2>
],
)
# 3) a_lds_block_desc_xk0_mnldslayer_mn_xk1
a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
transforms=[
make_unmerge_transform([MLdsLayer, kKPerBlock_over_kKPack]),
make_pass_through_transform(kMPerBlock_over_MLdsLayer),
make_pass_through_transform(kKPack),
],
lower_dimension_hidden_idss=[
[0], # sequence<0>
[1], # sequence<1>
[2], # sequence<2>
],
upper_dimension_hidden_idss=[
[0, 2], # sequence<0,2>
[1], # sequence<1>
[3], # sequence<3>
],
)
# Show how the index swapping materializes in top-dimension ordering
print("\nTop-dimension ordering (origin labels):")
transforms_all = a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_transforms()
upper_all = a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_upper_dimension_hidden_idss()
top_hidden = a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_top_dimension_hidden_ids()
# Only the NEW transforms contribute to final top dims.
prev_tf_count = a_lds_block_desc_permuted.get_num_of_transform()
new_tf_start = prev_tf_count
new_tf_end = len(transforms_all)
# Map hidden-id -> human-readable label for the new transforms' outputs
hidden_to_label = {}
for i in range(new_tf_start, new_tf_end):
tf = transforms_all[i]
up_ids = upper_all[i]
for j, hid in enumerate(up_ids):
label = f"T{(i - new_tf_start)}:{type(tf).__name__}[out{j}]"
hidden_to_label[hid] = label
# Reorder by top positions
labels_by_top_pos = [hidden_to_label.get(hid, f"hid{hid}") for hid in top_hidden]
for pos, label in enumerate(labels_by_top_pos):
print(f" top[{pos}] <- {label}")
# Assertions on top-dimension ordering
assert len(labels_by_top_pos) == 4, f"Expected 4 top dims, got {len(labels_by_top_pos)}"
expected_label_prefixes = [
"T0:UnmergeTransform[out0]", # position 0
"T1:PassThroughTransform[out0]", # position 1
"T0:UnmergeTransform[out1]", # position 2
"T2:PassThroughTransform[out0]", # position 3
]
for i, exp in enumerate(expected_label_prefixes):
assert labels_by_top_pos[i].startswith(exp), (
f"Top position {i} expected {exp}, got {labels_by_top_pos[i]}"
)
# Report lengths
print("Descriptor lengths (top dims):")
print(" a_lds_block_desc_0: ", a_lds_block_desc_0.get_lengths())
print(" a_lds_block_desc_permuted: ", a_lds_block_desc_permuted.get_lengths())
print(" a_lds_block_desc_xk0_mnldslayer_mn_xk1: ", a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_lengths())
print(f"a_lds_block_desc_xk0_mnldslayer_mn_xk1=")
# Expected final lengths: [MLdsLayer, kMPerBlock/MLdsLayer, kKPerBlock/kKPack, kKPack]
expected = [MLdsLayer, kMPerBlock_over_MLdsLayer, kKPerBlock_over_kKPack, kKPack]
print("Expected final lengths: ", expected)
# Assertions on lengths
assert a_lds_block_desc_0.get_lengths() == lengths0, "Base descriptor lengths mismatch"
assert a_lds_block_desc_permuted.get_lengths() == lengths0, "Permuted descriptor should preserve lengths"
assert a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_lengths() == expected, (
"Final descriptor lengths do not match expected order/values"
)
# Small sample of index-to-offset mapping to validate reordering
# We also print each transform's lower/upper mappings so you can see the [0,2] impact
print("\nTransform chain (lower -> upper hidden ids):")
for i, tf in enumerate(transforms_all):
print(f" TF[{i}] {type(tf).__name__}: lower={a_lds_block_desc_xk0_mnldslayer_mn_xk1.get_lower_dimension_hidden_idss()[i]} upper={upper_all[i]}")
# We test offsets on the final descriptor directly
desc = a_lds_block_desc_xk0_mnldslayer_mn_xk1
lens = desc.get_lengths()
print("\nSample index -> linear offset (final descriptor):")
max0 = min(lens[0], 3)
max1 = min(lens[1], 3)
max2 = min(lens[2], 3)
max3 = min(lens[3], 3)
for i0 in range(max0):
for i1 in range(max1):
for i2 in range(max2):
for i3 in range(max3):
off = desc.calculate_offset([i0, i1, i2, i3])
print(f" [{i0},{i1},{i2},{i3}] -> {off}")
# Assertions on last-dimension stride (should be contiguous, stride 1)
if lens[3] >= 2:
for i0 in range(min(lens[0], 2)):
for i1 in range(min(lens[1], 2)):
for i2 in range(min(lens[2], 2)):
off0 = desc.calculate_offset([i0, i1, i2, 0])
off1 = desc.calculate_offset([i0, i1, i2, 1])
assert off1 - off0 == 1, (
f"Last-dim stride should be 1; got {off1 - off0} at (i0,i1,i2)={(i0,i1,i2)}"
)
if __name__ == "__main__":
main()