Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-zendnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}
Expand Down
27 changes: 23 additions & 4 deletions ggml/src/ggml-zendnn/ggml-zendnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
params.dtypes.dst = ggml_to_zendnn_type<TC>();
params.num_threads = ctx->n_threads;

zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;
zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
n, // M: rows of B and C
Expand All @@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
0.0f, // beta
C, ldc, // output C[n,m]
true, // is_weights_const
{}, // batch_params
batch_params, // batch_params
params // params
);

Expand Down Expand Up @@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm
GGML_UNUSED(max_tensor_size);
}

static bool ggml_zendnn_adaptive_fallback_enabled() {
static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr ||
std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0;
return enabled;
}

static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
switch (op->op) {
case GGML_OP_NONE:
Expand All @@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
const int64_t ne10 = inputs->ne[0];
const int64_t ne0 = op->ne[0];
const int64_t ne1 = op->ne[1];

const int64_t min_batch = 1;
if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {

if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) {
return false;
}

if (ggml_zendnn_adaptive_fallback_enabled()) {
const int64_t K = inputs->ne[0];
const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]);
const int64_t M = weights->ne[1];
if(K <= 256 || N <= 128 || M <= 96) {
return false;
}
}
else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
return false;
}

// MUL_MAT_ID performs best with a moderate number of experts due to its
// gather + batched matmul + scatter approach. Future versions will leverage
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
Expand Down
Loading