diff --git a/src/infiniop/ops/paged_caching/bang/paged_caching_bang.mlu b/src/infiniop/ops/paged_caching/bang/paged_caching_bang.mlu index 0f7597fa5..bbe57a216 100644 --- a/src/infiniop/ops/paged_caching/bang/paged_caching_bang.mlu +++ b/src/infiniop/ops/paged_caching/bang/paged_caching_bang.mlu @@ -5,6 +5,8 @@ namespace { +__nram__ char paged_caching_nram_buffer[NRAM_MAX_SIZE]; + template __mlu_global__ void pagedCachingKernel( Tdata *k_cache, @@ -48,8 +50,12 @@ __mlu_global__ void pagedCachingKernel( + head_idx * v_cache_head_stride + block_offset * v_cache_slot_stride; - __memcpy(k_dst, k_src, head_size * sizeof(Tdata), GDRAM2GDRAM); - __memcpy(v_dst, v_src, head_size * sizeof(Tdata), GDRAM2GDRAM); + char *nram_base = reinterpret_cast(((reinterpret_cast(paged_caching_nram_buffer) + ALIGN_SIZE - 1) / ALIGN_SIZE) * ALIGN_SIZE); + Tdata *tmp = reinterpret_cast(nram_base); + __memcpy(tmp, k_src, head_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(tmp + head_size, v_src, head_size * sizeof(Tdata), GDRAM2NRAM); + __memcpy(k_dst, tmp, head_size * sizeof(Tdata), NRAM2GDRAM); + __memcpy(v_dst, tmp + head_size, head_size * sizeof(Tdata), NRAM2GDRAM); } }