Skip to content

CUDA: faster Deepseek FA, add Turing support #13435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025

Conversation

JohannesGaessler
Copy link
Collaborator

This PR optimizes the CUDA FlashAttention kernel for better Deepseek performance. Since there seem to still be issues with my last FA PR on master I would suggest only merging this PR after they are sorted out.

  • The batch sizes for K, V, and combining the results can now be set as a function of compute capability and the number of Q columns per CUDA block. This enables running the kernel on Turing with 64 kiB SRAM per streaming multiprocessor (vs. 99+ kiB on Ampere and newer). I don't have a Turing GPU for testing but it should at least be faster than running the kernel on the CPU. For Deepseek I found parameters that are faster for batch size 1 than master on Ampere or newer. For Gemma prompt processing on Turing should now also be a bit faster.
  • Following CUDA: FA support for Deepseek (Ampere or newer) #13306 (comment) I re-wrote the kernel to load both the K and V data from the K tensor. For Deepseek the V tensor is in fact now not being used at all by the CUDA code. For batch size 1 100% of the V data load can be skipped, for other batch sizes 50% can be skipped. Going forward I would suggest re-writing the code in other backends as well to use only the K tensor. The KV cache size could then be reduced by ~47% by simply not allocating and filling the V cache. At least as long as FlashAttention is used this should be relatively simple. So for a first version it would I think be fine to only deduplicate K and V if FA is used. @jukofyork are you interested in working on this?
  • Following CUDA: FA support for Deepseek (Ampere or newer) #13306 (comment) I tried to make the suggested change to ggml_cuda_unroll. However, the code then only compiles if a compile flag for relaxed constexpr is used. Notably this compile flag is declared experimental with no guarantee of being forwards compatible. Since the compiler was able to do the correct loop optimization with the code as it is now I would suggest keeping it that way.
Performance
GPU Model Microbatch size Test t/s master t/s PR Speedup
RTX 3090 deepseek2 16B Q4_0 1 pp16384 146.30 163.07 1.11
RTX 3090 deepseek2 16B Q4_0 2 pp16384 143.80 145.86 1.01
RTX 3090 deepseek2 16B Q4_0 4 pp16384 228.95 231.63 1.01
RTX 3090 deepseek2 16B Q4_0 8 pp16384 342.11 345.60 1.01
RTX 3090 deepseek2 16B Q4_0 16 pp16384 473.21 478.07 1.01
RTX 3090 deepseek2 16B Q4_0 32 pp16384 751.96 762.99 1.01
RTX 3090 deepseek2 16B Q4_0 64 pp16384 1068.48 1092.93 1.02
RTX 3090 deepseek2 16B Q4_0 128 pp16384 1572.56 1625.06 1.03
RTX 3090 deepseek2 16B Q4_0 256 pp16384 2104.45 2188.24 1.04
RTX 3090 deepseek2 16B Q4_0 512 pp16384 2441.73 2543.48 1.04
RTX 3090 deepseek2 16B Q4_0 1024 pp16384 2608.48 2751.08 1.05
RTX 3090 deepseek2 16B Q4_0 2048 pp16384 2773.21 2923.00 1.05
RTX 4090 deepseek2 16B Q4_0 1 pp16384 196.35 223.38 1.14
RTX 4090 deepseek2 16B Q4_0 2 pp16384 186.94 191.89 1.03
RTX 4090 deepseek2 16B Q4_0 4 pp16384 318.09 325.71 1.02
RTX 4090 deepseek2 16B Q4_0 8 pp16384 521.87 529.75 1.02
RTX 4090 deepseek2 16B Q4_0 16 pp16384 784.25 794.17 1.01
RTX 4090 deepseek2 16B Q4_0 32 pp16384 1317.60 1337.75 1.02
RTX 4090 deepseek2 16B Q4_0 64 pp16384 1959.90 2000.23 1.02
RTX 4090 deepseek2 16B Q4_0 128 pp16384 3028.04 3114.23 1.03
RTX 4090 deepseek2 16B Q4_0 256 pp16384 4222.41 4407.39 1.04
RTX 4090 deepseek2 16B Q4_0 512 pp16384 5087.84 5337.91 1.05
RTX 4090 deepseek2 16B Q4_0 1024 pp16384 5564.05 5866.69 1.05
RTX 4090 deepseek2 16B Q4_0 2048 pp16384 5594.26 5835.99 1.04
GPU Model Test t/s no FA t/s FA
RTX 3090 deepseek2 16B Q4_0 pp512 4435.67 4400.56
RTX 3090 deepseek2 16B Q4_0 tg128 181.77 162.23
RTX 3090 deepseek2 16B Q4_0 pp512 @ d512 4112.98 4212.77
RTX 3090 deepseek2 16B Q4_0 tg128 @ d512 164.48 158.07
RTX 3090 deepseek2 16B Q4_0 pp512 @ d1024 3855.40 4014.16
RTX 3090 deepseek2 16B Q4_0 tg128 @ d1024 155.14 156.58
RTX 3090 deepseek2 16B Q4_0 pp512 @ d2048 3535.88 3627.22
RTX 3090 deepseek2 16B Q4_0 tg128 @ d2048 142.18 152.69
RTX 3090 deepseek2 16B Q4_0 pp512 @ d4096 2982.95 3162.64
RTX 3090 deepseek2 16B Q4_0 tg128 @ d4096 121.94 149.42
RTX 3090 deepseek2 16B Q4_0 pp512 @ d8192 2308.45 2433.54
RTX 3090 deepseek2 16B Q4_0 tg128 @ d8192 67.52 ± 143.38
RTX 4090 deepseek2 16B Q4_0 pp512 8304.31 8353.35
RTX 4090 deepseek2 16B Q4_0 tg128 223.64 204.81
RTX 4090 deepseek2 16B Q4_0 pp512 @ d512 7744.87 8250.84
RTX 4090 deepseek2 16B Q4_0 tg128 @ d512 209.76 198.06
RTX 4090 deepseek2 16B Q4_0 pp512 @ d1024 7373.56 7944.04
RTX 4090 deepseek2 16B Q4_0 tg128 @ d1024 201.47 195.62
RTX 4090 deepseek2 16B Q4_0 pp512 @ d2048 6643.92 7367.00
RTX 4090 deepseek2 16B Q4_0 tg128 @ d2048 189.58 192.90
RTX 4090 deepseek2 16B Q4_0 pp512 @ d4096 5530.89 6483.69
RTX 4090 deepseek2 16B Q4_0 pp512 @ d8192 4181.64 5147.22
RTX 4090 deepseek2 16B Q4_0 tg128 @ d4096 169.63 188.48
RTX 4090 deepseek2 16B Q4_0 tg128 @ d8192 138.76 179.80

On my RTX 3090 and 4090 -fa seems to now universally improve performance starting at a context size of ~2000 tokens, previously this was at ~6000 tokens.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels May 10, 2025
@jukofyork
Copy link
Collaborator

jukofyork commented May 10, 2025

Following CUDA: FA support for Deepseek (Ampere or newer) #13306 (comment) I re-wrote the kernel to load both the K and V data from the K tensor. For Deepseek the V tensor is in fact now not being used at all by the CUDA code. For batch size 1 100% of the V data load can be skipped, for other batch sizes 50% can be skipped. Going forward I would suggest re-writing the code in other backends as well to use only the K tensor. The KV cache size could then be reduced by ~47% by simply not allocating and filling the V cache. At least as long as FlashAttention is used this should be relatively simple. So for a first version it would I think be fine to only deduplicate K and V if FA is used. @jukofyork are you interested in working on this?

Yeah, I'll try and have a look at this next week.

@jukofyork
Copy link
Collaborator

@JohannesGaessler Is there any advantage to make Q continuous by doing this alternative permutation order:

https://github.com/ggml-org/llama.cpp/compare/master...jukofyork:llama.cpp:alt-perm-mla?expand=1

which then gets undone here:

ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);

This along with the 2D view, improved the non-FA performance, but not sure if it could also help with the FA version.

@JohannesGaessler
Copy link
Collaborator Author

For CUDA FA I think it doesn't matter because each CUDA block loads the Q data from VRAM exactly once. For the CPU code I don't know because there is no explicit allocation of SRAM.

@jukofyork
Copy link
Collaborator

jukofyork commented May 12, 2025

Following CUDA: FA support for Deepseek (Ampere or newer) #13306 (comment) I re-wrote the kernel to load both the K and V data from the K tensor. For Deepseek the V tensor is in fact now not being used at all by the CUDA code. For batch size 1 100% of the V data load can be skipped, for other batch sizes 50% can be skipped. Going forward I would suggest re-writing the code in other backends as well to use only the K tensor. The KV cache size could then be reduced by ~47% by simply not allocating and filling the V cache. At least as long as FlashAttention is used this should be relatively simple. So for a first version it would I think be fine to only deduplicate K and V if FA is used. @jukofyork are you interested in working on this?

So I'm just looking at this now and could do with some input from @ggerganov on how best to proceed here I think:

If we go back to disabling context shifting:

https://github.com/jukofyork/llama.cpp/blob/95e18884fc7ea4031f70f1a518d5d1df616e5717/src/llama-kv-cache.cpp#L36

cache.can_shift = !(model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_attn);

then it looks very easy to do:

https://github.com/jukofyork/llama.cpp/blob/95e18884fc7ea4031f70f1a518d5d1df616e5717/src/llama-kv-cache.cpp#L104

ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_att ? 0 : n_embd_v_gqa*kv_size);

and then similarly to not copy it here:

https://github.com/jukofyork/llama.cpp/blob/95e18884fc7ea4031f70f1a518d5d1df616e5717/src/llama-graph.cpp#L1436

        if (!v_trans) {
            v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
        } else if (!(model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_att)) {
            // note: the V cache is transposed when not using flash attention
            v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
                    (  n_ctx)*ggml_element_size(kv_self->v_l[il]),
                    (kv_head)*ggml_element_size(kv_self->v_l[il]));

            v_cur = ggml_transpose(ctx0, v_cur);
        }

(I've not run this code so may have typos, but the general idea should be this I think)


The beauty of the "treat MLA as MQA with large heads" method, was that all the context-shifting stuff "just worked", but I fear this could end up a huge mess with lots of special cases in llama-kv-cache.cpp if I'm not careful here:

image

due to having to alter every n_embd_v_gqa value, etc.

@JohannesGaessler
Copy link
Collaborator Author

The FA issues on master seem to have been sorted out and in terms of debugging I think this PR would now be fine to merge.

@Panchovix
Copy link

Testing this PR on top of latest commits and all working fine. I get about 10-11% better perf on DeepSeek 14GB (Q5_K_M) on a single 5090, and about 1% on DeepSeek V3 0324 with CUDA + CPU (Q2_K_XL/IQ3_XXS)

@jukofyork
Copy link
Collaborator

@JohannesGaessler Is it expected that speculative sampling performance is worse for the FA kernels now compared to non-FA?

I can't remember the exact test case I was running, but for coding tasks I was getting around 2x increase using speculative sampling before (eg: 5.5 tokens/s --> 10.5 tokens per second).

I have now got up to 6.0 tokens/s using this PR, but haven't managed to get more than about 7.5 tokens/s when using speculative sampling. Is this something expected (eg: due to performance of very small batches speculative sampling uses)?

@JohannesGaessler
Copy link
Collaborator Author

Looking at the numbers I posted in the OP again, while the performance did increase for all batch sizes the performance for batch size 1 (16 Q colummns/CUDA block) seems to now be better than with batch size 2 (13 Q column/CUDA block). So far the kernel selection logic always just picks the largest possible tile sizes, but I'll check whether this is suboptimal here.

@jukofyork
Copy link
Collaborator

#13529 now works with this PR to give the 47% reduction in KV-cache size (less than 11GB for the full 160k tokens!).

I don't feel confident making all the changes needed to get context-shifting working with the empty V-cache, so have just disabled it for now.

@jukofyork
Copy link
Collaborator

jukofyork commented May 14, 2025

slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 8192, n_tokens = 8192, progress = 0.156695
/home/juk/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:75: CUDA error
ggml_cuda_compute_forward: MUL_MAT_ID failed
CUDA error: invalid configuration argument
  current device: 0, in function ggml_cuda_compute_forward at /home/juk/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu:2359
  err

There seems to be a hard limit of 8191 for ubatch size?

@JohannesGaessler
Copy link
Collaborator Author

Probably the same problem as with #13384 , please tell me the exact model and command you're using.

@jukofyork
Copy link
Collaborator

Probably the same problem as with #13384 , please tell me the exact model and command you're using.

#!/bin/bash

host_address=192.168.1.1
port_number=8080

# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null

# Ask for permission to drop caches
read -p "Do you want to drop caches? (y/n) " -n 1 -r
echo    # Move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
    echo "Dropping caches..."
    echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null
fi

# Run the main command
~/llama.cpp/build/bin/llama-server \
        --host "$host_address" \
        --port "$port_number" \
        --alias "deepseek-r1" \
        --chat-template deepseek3 \
        --temp 0.6 \
        --min-p 0.05 \
        --model ~/models/gguf/deepseek-r1-Q4_K_XL.gguf \
        --n-gpu-layers 99 \
        --numa distribute \
        --threads 80 \
        --override-tensor exps=CPU \
        --flash-attn \
        --ctx_size 65536 \
        --batch-size 8191 \
        --ubatch-size 8191

This works but 8192 fails.


deepseek-r1-Q4_K_XL.gguf

llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type q4_K:  174 tensors
llama_model_loader: - type q6_K:  429 tensors
llama_model_loader: - type bf16:  122 tensors

Is Q6_K apart from the two 3D MLA tensors which are kept as BF16 and the non-shared experts which are Q4_K.

@jukofyork
Copy link
Collaborator

It's those Q4_K tensors that are getting pulled into VRAM now.

I also hack ggml/src/ggml-cuda/ggml-cuda.cu to:

const int min_batch_size = 2048

as this is the break-even point where pulling through the PCI-e 3.0 16x bus is slower than just computing on the CPU cores, but I doubt this makes any difference as it only triggers at the end of the PP stage for any small remainder chunks, etc.

@jukofyork
Copy link
Collaborator

Also the version I'm using is just the current master, but with this and the other PR merged:

git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp

git fetch origin pull/13435/head:pr-13435
git fetch origin pull/13529/head:pr-13529

git merge pr-13435 --no-edit
git merge pr-13529 --no-edit

so I should have the fix that helped whatever problem was solved for @danielhanchen in #13384

@JohannesGaessler
Copy link
Collaborator Author

so I should have the fix that helped whatever problem was solved for danielhanchen in #13384

By same problem I meant same problem on a conceptual level but for a different kernel.

@JohannesGaessler
Copy link
Collaborator Author

Please confirm whether #13537 fixes the issue.

@JohannesGaessler JohannesGaessler merged commit 6da34fa into ggml-org:master May 14, 2025
44 checks passed
@jukofyork
Copy link
Collaborator

jukofyork commented May 14, 2025

Please confirm whether #13537 fixes the issue.

With this fix and ubatch-size = 10240, I'm now getting nearly 100 tokens/s PP on a 20k token prompt:

slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 65536, n_keep = 0, n_prompt_tokens = 21171
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 10240, n_tokens = 10240, progress = 0.483681
slot update_slots: id  0 | task 0 | kv cache rm [10240, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 20480, n_tokens = 10240, progress = 0.967361
slot update_slots: id  0 | task 0 | kv cache rm [20480, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 21171, n_tokens = 691, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 21171, n_tokens = 691
slot      release: id  0 | task 0 | stop processing: n_past = 22709, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =  216221.70 ms / 21171 tokens (   10.21 ms per token,    97.91 tokens per second)
       eval time =  271587.49 ms /  1539 tokens (  176.47 ms per token,     5.67 tokens per second)
      total time =  487809.19 ms / 22710 tokens

This is with the experts offloaded (as Q4_K) to a dual Xeon Gold 6248 system. I'm just trying with using 6 experts instead of 8 for a fair comparison with KTransformers, but I think llama.cpp is getting pretty close now, and pretty sure this should be more accurate than KTransformers since I'm using Q6_K for everything else. This is also using an RTX 5000 Ada, which is pretty gimped compared to the 4090 they used!

@jukofyork
Copy link
Collaborator

Using:

#!/bin/bash

host_address=192.168.1.1
port_number=8080

# Turn off NUMA balancing
echo 0 | sudo tee /proc/sys/kernel/numa_balancing > /dev/null

# Ask for permission to drop caches
read -p "Do you want to drop caches? (y/n) " -n 1 -r
echo    # Move to a new line
if [[ $REPLY =~ ^[Yy]$ ]]
then
    echo "Dropping caches..."
    echo 3 | sudo tee /proc/sys/vm/drop_caches > /dev/null
fi

# Run the main command
~/llama.cpp/build/bin/llama-server \
        --host "$host_address" \
        --port "$port_number" \
        --alias "deepseek-r1" \
        --chat-template deepseek3 \
        --temp 0.6 \
        --min-p 0.05 \
        --model ~/models/gguf/deepseek-r1-Q4_K_XL.gguf \
        --n-gpu-layers 99 \
        --numa distribute \
        --threads 80 \
        --override-tensor exps=CPU \
        --flash-attn \
        --ctx_size 65536 \
        --batch-size 10240 \
        --ubatch-size 10240 \
        --override-kv "deepseek2.expert_used_count=int:6" \
        --override-kv "deepseek2.expert_weights_scale=float:2.35"
slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 65536, n_keep = 0, n_prompt_tokens = 21171
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 10240, n_tokens = 10240, progress = 0.483681
slot update_slots: id  0 | task 0 | kv cache rm [10240, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 20480, n_tokens = 10240, progress = 0.967361
slot update_slots: id  0 | task 0 | kv cache rm [20480, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 21171, n_tokens = 691, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 21171, n_tokens = 691
slot      release: id  0 | task 0 | stop processing: n_past = 23118, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =  201432.48 ms / 21171 tokens (    9.51 ms per token,   105.10 tokens per second)
       eval time =  305629.72 ms /  1948 tokens (  156.89 ms per token,     6.37 tokens per second)
      total time =  507062.20 ms / 23119 tokens

The reason I settled on using Q4_K for the non-shared experts is explained in: #11446 (comment) (NOTE: IQ4_XS seems to perform significantly worse, and unlike standard LLMs with a high degree of fan-out for the MLP layers; it seems that bumping the down_proj quant makes little difference).

The reason for scaling expert_weights_scale is explained in: #11446 (comment)

@Panchovix
Copy link

100 t/s PP on Q4_K_XL is really, really impressive running most of the model on RAM. How much RAM and bandwidth do you have on your system?

@jukofyork
Copy link
Collaborator

jukofyork commented May 14, 2025

100 t/s PP on Q4_K_XL is really, really impressive running most of the model on RAM. How much RAM and bandwidth do you have on your system?

The machine I'm running this on has 1.5TB of DDR4-2666 (it could run DDR4-2933 in theory, but I got this RAM for cheap) and looking at the older Skylake Xeon Gold 6148 which can only run DDR4-2666; it has around 131.13 GiB/s per node (but NUMA is barely helping here and I can get nearly the same TG speed using just 1 node).

The biggest limiting factor before this PR was the fact that it uses PCI-e 3.0 16x and therefore needs a huge ubatch size to make it worth copying the non-shared experts from RAM to VRAM (the break-even ubatch size seems to be around 2048 where both GPU and CPU get around 30 tokens/s PP).

@jukofyork
Copy link
Collaborator

jukofyork commented May 14, 2025

Here's another test (but using all 8 experts as I'm not convinced it's worth the drop in quality using 6...):

Using the "simpler: single-page edition without illustrations" version of Joe Dever's Flight from the Dark (279kb HTML file) and this prompt:

Based on your reading of the gamebook in the context, create a helpful guide for me about the choices I should make without spoiling the game.

slot update_slots: id  0 | task 0 | new prompt, n_ctx_slot = 65536, n_keep = 0, n_prompt_tokens = 52280
slot update_slots: id  0 | task 0 | kv cache rm [0, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 10240, n_tokens = 10240, progress = 0.195868
slot update_slots: id  0 | task 0 | kv cache rm [10240, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 20480, n_tokens = 10240, progress = 0.391737
slot update_slots: id  0 | task 0 | kv cache rm [20480, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 30720, n_tokens = 10240, progress = 0.587605
slot update_slots: id  0 | task 0 | kv cache rm [30720, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 40960, n_tokens = 10240, progress = 0.783474
slot update_slots: id  0 | task 0 | kv cache rm [40960, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 51200, n_tokens = 10240, progress = 0.979342
slot update_slots: id  0 | task 0 | kv cache rm [51200, end)
slot update_slots: id  0 | task 0 | prompt processing progress, n_past = 52280, n_tokens = 1080, progress = 1.000000
slot update_slots: id  0 | task 0 | prompt done, n_past = 52280, n_tokens = 1080
slot      release: id  0 | task 0 | stop processing: n_past = 53461, truncated = 0
slot print_timing: id  0 | task 0 | 
prompt eval time =  636323.17 ms / 52280 tokens (   12.17 ms per token,    82.16 tokens per second)
       eval time =  219954.50 ms /  1182 tokens (  186.09 ms per token,     5.37 tokens per second)
      total time =  856277.67 ms / 53462 tokens

I should add that I've found using between 32k and 64k context length is optimal for deepseek-r1 - it starts to get dumb quite quickly if you use more than about 80k (as found by Fiction.LiveBench also).

I think this is because llama.cpp uses static YaRN (to be able to store the RoPEed values in the KV-cache I think?).

Interestingly though, setting much closer to the original / pre-YaRN training context of 4k also makes the model become dumber too... I'm guessing that during fine-tuning it must have had most of it's examples in the 32k - 64k range?

@Panchovix
Copy link

I see, that's a lot of RAM! After this PR got merged, I tested with DeepSeek V3 0324, with a Consumer CPU (7800X3D, 192GB RAM), and 5090+4090x2+3090+A6000 at X8/X4/X4/X4/X4 PCIe, so huge bottleneck there. CPU RAM Bandiwdth is about 64 GB/s at 6000Mhz (Limited by 1 CCD)

On IQ3_XXS running with

./llama-server -m 'models_llm/DeepSeek-V3-0324-UD-IQ3_XXS-00001-of-00006.gguf' -c 65536 --no-mmap -ngl 999 -ot "blk.(0|1|2|3|4|5|6).ffn.=CUDA0" -ot "blk.(7|8|9|10).ffn.=CUDA1" -ot "blk.(11|12|13|14).ffn.=CUDA2" -ot "blk.(15|16|17).ffn.=CUDA3"  -ot "blk.(18|19|20|21|22|23|24|25|26).ffn.=CUDA4" -ot "ffn.*=CPU" -fa -mg 0 -ub 2048

I get

prompt eval time =  177615.10 ms / 21321 tokens (    8.33 ms per token,   120.04 tokens per second)
       eval time =  256254.45 ms /  1135 tokens (  225.77 ms per token,     4.43 tokens per second)

While on Q3_K_XL, running with

./llama-server -m '/models_llm/DeepSeek-V3-0324-UD-Q3_K_XL-00001-of-00007.gguf' -c 65536 --no-mmap -ngl 999 -ot "blk.(0|1|2|3|4|5|6).ffn.=CUDA0" -ot "blk.(7|8|9|10).ffn.=CUDA1" -ot "blk.(11|12|13|14).ffn.=CUDA2" -ot "blk.(15|16|17).ffn.=CUDA3"  -ot "blk.(18|19|20|21|22|23|24|25).ffn.=CUDA4" -ot "ffn.*=CPU" -fa -mg 0 -ub 2048

I get

prompt eval time =  164564.80 ms / 21321 tokens (    7.72 ms per token,   129.56 tokens per second)
eval time =  148085.76 ms /   967 tokens (  153.14 ms per token,     6.53 tokens per second)

So despite Q3_K_XL being bigger, and using 1 less layer in GPU, it is faster. Maybe that is related of what you mentioned a bit above about IQ quants.

@jukofyork
Copy link
Collaborator

So despite Q3_K_XL being bigger, and using 1 less layer in GPU, it is faster. Maybe that is related of what you mentioned a bit above about IQ quants.

Yeah, I think the other IQ quants are even worse on CPU than IQ4_XS.

Silver267 pushed a commit to Silver267/llama.cpp that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants