forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmetal_backend.py
More file actions
118 lines (97 loc) · 4.28 KB
/
metal_backend.py
File metadata and controls
118 lines (97 loc) · 4.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import subprocess
import typing
from typing import Any, Dict, final, List
from executorch.backends.aoti.aoti_backend import AotiBackend
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import BackendDetails
from executorch.exir.backend.compile_spec_schema import CompileSpec
@final
@experimental(
"This API and all of Metal backend related functionality are experimental."
)
class MetalBackend(AotiBackend, BackendDetails):
"""
MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate
optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices
using the Executorch runtime.
"""
@classmethod
def get_device_name(cls) -> str:
return "metal"
@classmethod
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"aoti_torch_mps_bmm_out": None,
"aoti_torch_mps_convolution": None,
"aoti_torch_mps_mm_out": None,
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
"at::_ops::_scaled_dot_product_attention_math_for_mps_v2::call": None,
"torchao::_linear_fp_act_4bit_weight": None,
"at::_ops::topk::call": None,
"metal::gather_qmv": None,
"metal::gated_delta_rule": None,
}
@classmethod
def get_decomposition_table(cls) -> Dict[Any, Any]:
return {}
@classmethod
def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]:
"""Return Metal-specific passes"""
from executorch.backends.apple.metal.passes.decompose_linear_pass import (
DecomposeLinearPass,
)
return [DecomposeLinearPass()]
@classmethod
def get_aoti_compile_options(
cls, compile_specs: List[CompileSpec]
) -> Dict[str, typing.Any]:
"""Get AOTI compile options for Metal backend."""
_ = compile_specs # Unused, but required by interface
inductor_configs = {
# Do not link against the full PyTorch/libtorch library
"aot_inductor.link_libtorch": False,
# Separate weight constants from the .so file
"aot_inductor.package": True,
"aot_inductor.package_constants_in_so": False,
# Store weight constants on disk in a binary blob
"aot_inductor.package_constants_on_disk_format": "binary_blob",
# Enable maximum automatic tuning for optimal performance
"max_autotune": True,
# "aot_inductor.debug_compile": True,
# "aot_inductor.force_mmap_weights": False,
"padding_stride_threshold": float("inf"), # avoid padding stride
}
from torchao.experimental.ops.mps.cshim import torchao_op_c_shim
custom_c_shims = {**torchao_op_c_shim}
try:
from executorch.backends.apple.metal.ops.gather_qmv import (
metal_gather_qmv_c_shim,
)
custom_c_shims.update(metal_gather_qmv_c_shim)
except ImportError:
pass
try:
from executorch.backends.apple.metal.ops.gated_delta_rule import (
metal_gated_delta_rule_c_shim,
)
custom_c_shims.update(metal_gated_delta_rule_c_shim)
except ImportError:
pass
inductor_configs["aot_inductor.custom_ops_to_c_shims"] = custom_c_shims
return inductor_configs
@classmethod
def codesign_so(cls, so_path: str, compile_specs: List[CompileSpec]) -> None:
"""Sign the compiled .so for macOS Hardened Runtime compatibility.
Only signs if a ``codesign_identity`` compile spec is provided.
Pass ``"-"`` for ad-hoc signing or a Developer ID for distribution.
"""
for spec in compile_specs:
if spec.key == "codesign_identity":
identity = spec.value.decode("utf-8")
subprocess.run(["codesign", "-f", "-s", identity, so_path], check=True)
return