@@ -678,17 +678,25 @@ void launch_fattn(
678
678
) {
679
679
constexpr int ncols = ncols1 * ncols2;
680
680
681
+ const bool is_mla = DV == 512 ; // TODO better parameterization
682
+
681
683
const ggml_tensor * Q = dst->src [0 ];
682
684
const ggml_tensor * K = dst->src [1 ];
683
685
const ggml_tensor * V = dst->src [2 ];
684
686
687
+ GGML_ASSERT (V || is_mla);
688
+
685
689
const ggml_tensor * mask = dst->src [3 ];
686
690
687
691
ggml_tensor * KQV = dst;
688
692
689
693
GGML_ASSERT (Q->type == GGML_TYPE_F32);
690
694
GGML_ASSERT (KQV->type == GGML_TYPE_F32);
691
695
696
+ GGML_ASSERT ( Q->nb [0 ] == ggml_element_size (Q));
697
+ GGML_ASSERT ( K->nb [0 ] == ggml_element_size (K));
698
+ GGML_ASSERT (!V || V->nb [0 ] == ggml_element_size (V));
699
+
692
700
GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
693
701
GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
694
702
" the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
@@ -713,10 +721,10 @@ void launch_fattn(
713
721
size_t nb12 = K->nb [2 ];
714
722
size_t nb13 = K->nb [3 ];
715
723
716
- const char * V_data = (const char *) V->data ;
717
- size_t nb21 = V->nb [1 ];
718
- size_t nb22 = V->nb [2 ];
719
- size_t nb23 = V->nb [3 ];
724
+ const char * V_data = V ? (const char *) V->data : nullptr ;
725
+ size_t nb21 = V ? V ->nb [1 ] : nb11 ;
726
+ size_t nb22 = V ? V ->nb [2 ] : nb12 ;
727
+ size_t nb23 = V ? V ->nb [3 ] : nb13 ;
720
728
721
729
if (need_f16_K && K->type != GGML_TYPE_F16) {
722
730
GGML_ASSERT (ggml_is_contiguously_allocated (K));
@@ -733,7 +741,7 @@ void launch_fattn(
733
741
nb13 = nb13*bs*sizeof (half)/ts;
734
742
}
735
743
736
- if (need_f16_V && V->type != GGML_TYPE_F16) {
744
+ if (V && need_f16_V && V->type != GGML_TYPE_F16) {
737
745
GGML_ASSERT (ggml_is_contiguously_allocated (V));
738
746
V_f16.alloc (ggml_nelements (V));
739
747
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
0 commit comments