diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b7180d5955c29..a4fbd823638fa 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -678,10 +678,14 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; + const bool is_mla = DV == 512; // TODO better parameterization + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + GGML_ASSERT(V || is_mla); + const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; @@ -689,6 +693,10 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); @@ -713,10 +721,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(K)); @@ -733,7 +741,7 @@ void launch_fattn( nb13 = nb13*bs*sizeof(half)/ts; } - if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V && need_f16_V && V->type != GGML_TYPE_F16) { GGML_ASSERT(ggml_is_contiguously_allocated(V)); V_f16.alloc(ggml_nelements(V)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b2f95fa3f00e6..9fe86c403635f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64, 64> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 32; - static constexpr int nbatch_V2 = 32; - static constexpr int nbatch_combine = 32; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } }; template <> @@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80, 80> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 40; - static constexpr int nbatch_V2 = 40; - static constexpr int nbatch_combine = 40; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 40; + } }; template <> @@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96, 96> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 48; - static constexpr int nbatch_V2 = 48; - static constexpr int nbatch_combine = 48; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 48; + } }; template <> @@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 56; - static constexpr int nbatch_V2 = 56; - static constexpr int nbatch_combine = 56; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 56; + } }; template <> @@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 64; - static constexpr int nbatch_V2 = 64; - static constexpr int nbatch_combine = 64; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } }; template <> @@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> { static constexpr int nwarps_max = 4; static constexpr bool Q_in_reg = true; static constexpr int nstages_target = 2; - static constexpr int nbatch_K2 = 128; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 128 : 64; + } + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 128 : 64; +#else + GGML_UNUSED(ncols); + return 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } }; template <> @@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> { static constexpr int nwarps_max = 8; static constexpr bool Q_in_reg = false; static constexpr int nstages_target = 1; - static constexpr int nbatch_K2 = 160; - static constexpr int nbatch_V2 = 128; - static constexpr int nbatch_combine = 128; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } }; // ------------------------------------------------------------------------------------------------------------------ @@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); - auto load = [&] __device__ (const int n) { + auto load = [&] __device__ (auto n) { const int stride_k = WARP_SIZE >> n; const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); @@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * c::nbatch_fa; tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; @@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V); + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) { - const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2; + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; if (nstages <= 1) { @@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } + + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) { - const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV; - const int i0_diff = i0_stop - i0_start; + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; - if (nstages <= 1) { + if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); @@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } __syncthreads(); } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; // Calculate VKQ tile: #pragma unroll @@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*tile_A::J; tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if (ntiles == 1) { mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); } else { @@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - constexpr int stride_tile_Q = DKQ/2 + 4; - constexpr int stride_tile_K = c::nbatch_K2 + 4; - constexpr int stride_tile_V = c::nbatch_V2 + 4; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Preload mask and K data for first iteration when using cp_async with multiple stages: if constexpr (nstages > 1) { - static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); constexpr bool use_cp_async = true; if (ncols2 > 1 || mask_h2) { flash_attn_ext_f16_load_mask (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K); + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } @@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. // So also write VKQ accumulators to shared memory in column-major format if np == 1. - constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4; + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); constexpr int tile_stride = nbatch_combine + 4; static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); @@ -1005,7 +1189,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // NEW_MMA_AVAILABLE } -template +template __launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1050,6 +1234,14 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING + + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); typedef fattn_mma_f16_config c; @@ -1060,9 +1252,10 @@ static __global__ void flash_attn_ext_f16( const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); const int stride_K = nb11 / sizeof(half2); - const int stride_V = nb21 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; @@ -1085,10 +1278,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1097,12 +1291,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } @@ -1123,10 +1317,11 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; const int kb0_start_kernel = kb0_start * kb_niter; @@ -1134,7 +1329,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else @@ -1160,10 +1355,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml typedef fattn_mma_f16_config c; - constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2; - constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2; - constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine; - const int nstages = cp_async_available(cc) ? c::nstages_target : 0; constexpr int ncols = ncols1 * ncols2; @@ -1173,15 +1364,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2); - const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); - const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; @@ -1195,7 +1392,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1206,7 +1403,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 9c5c803d02bc7..6bc0096cc65e6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -10,6 +10,7 @@ template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * Q = dst->src[0]; if constexpr (ncols2 <= 8) { @@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con return; } - if (Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7643c4b8bfa2b..7ea9d2ff06c53 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3216,7 +3216,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) { + if (!new_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];