Updated mul_mat_f16_f32 metal kernel to allow llama-2-70B on metal#2459
Updated mul_mat_f16_f32 metal kernel to allow llama-2-70B on metal#2459ggerganov merged 4 commits intoggml-org:masterfrom
Conversation
|
It seems that in certain situations, assertion failures may be triggered. env: full output: |
|
Sorry, I used an old version and found that this issue has already been fixed on the mbosc/master branch. |
|
Thanks! The best solution would be to implement the broadcast logic in the Metal original kernel. See the reference C implementation in ggml.c |
I can surely give it a shot later. I think I'll need to pass the full shape of src0 and src1 in the shaders to do that though. If that's no problem, I'll change the signature for all mul_mat kernels accordingly. |
|
I think we already pass the full shapes |
If I understand the kernel code correctly, we only pass it ne00 and ne01 (plus ne10 and ne11). I'll work on it! Thanks! |
Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
|
Ok, I added a new commit where I simply regrouped the code for broadcasting within the original kernel_mul_mat_f16_f32 kernel. I removed my previous extra kernel. As I anticipated, I needed to add ne02 and ne12 as arguments in the kernel and - consequently - to pass their value prior to kernel dispatch in ggml-metal.m (around line 849). I am still kind of confused by the fact that all other mul_mat kernels have a different signature with way less arguments, but the dispatch code in ggml-metal.m seems to be designed for kernel_mul_mat_f16_f32. Still, I get coherent generations, so I guess I am missing something and this is not an issue... |
ggerganov
left a comment
There was a problem hiding this comment.
Yes the signatures were a bit problematic. It works since ne11 is always 1, but we should fix it at some point.
ref #2429 #2276
Hi all!
I worked out a quick fix to get llama-2-70b working with metal.
The fix is rather inelegant:
I've tried with q4_K_M and q5_K_M quantized models from theBloke and generations seem coherent, q8_0 fails because there is no GGML_OP_GET_ROWS kernel for GGML_TYPE_Q8_0.
A better solution would probably require all matmul metal kernels to also take gqa as input, but I didn't want to alter the codebase too much, so I opted for a quicker patch instead.
Note that I rebased against the last commit (a113689) to submit this PR, but I can no longer compile with LLAMA_METAL=1 after that. Instead, my fix works applied the penultimate commit as of now (11f3ca0).