Skip to content

Commit 6da34fa

Browse files
CUDA: faster Deepseek FA, add Turing support (#13435)
1 parent 5e7d95e commit 6da34fa

File tree

4 files changed

+276
-70
lines changed

4 files changed

+276
-70
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

+13-5
Original file line numberDiff line numberDiff line change
@@ -678,17 +678,25 @@ void launch_fattn(
678678
) {
679679
constexpr int ncols = ncols1 * ncols2;
680680

681+
const bool is_mla = DV == 512; // TODO better parameterization
682+
681683
const ggml_tensor * Q = dst->src[0];
682684
const ggml_tensor * K = dst->src[1];
683685
const ggml_tensor * V = dst->src[2];
684686

687+
GGML_ASSERT(V || is_mla);
688+
685689
const ggml_tensor * mask = dst->src[3];
686690

687691
ggml_tensor * KQV = dst;
688692

689693
GGML_ASSERT(Q->type == GGML_TYPE_F32);
690694
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
691695

696+
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
697+
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
698+
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
699+
692700
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
693701
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
694702
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -713,10 +721,10 @@ void launch_fattn(
713721
size_t nb12 = K->nb[2];
714722
size_t nb13 = K->nb[3];
715723

716-
const char * V_data = (const char *) V->data;
717-
size_t nb21 = V->nb[1];
718-
size_t nb22 = V->nb[2];
719-
size_t nb23 = V->nb[3];
724+
const char * V_data = V ? (const char *) V->data : nullptr;
725+
size_t nb21 = V ? V->nb[1] : nb11;
726+
size_t nb22 = V ? V->nb[2] : nb12;
727+
size_t nb23 = V ? V->nb[3] : nb13;
720728

721729
if (need_f16_K && K->type != GGML_TYPE_F16) {
722730
GGML_ASSERT(ggml_is_contiguously_allocated(K));
@@ -733,7 +741,7 @@ void launch_fattn(
733741
nb13 = nb13*bs*sizeof(half)/ts;
734742
}
735743

736-
if (need_f16_V && V->type != GGML_TYPE_F16) {
744+
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
737745
GGML_ASSERT(ggml_is_contiguously_allocated(V));
738746
V_f16.alloc(ggml_nelements(V));
739747
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);

0 commit comments

Comments
 (0)