Skip to content

Speedups to gemm_conv_forward and gemm_conv_backward kernels#522

Open
jonahsamost wants to merge 9 commits intoPufferAI:4.0from
jonahsamost:jonah_4_8_nmmo3
Open

Speedups to gemm_conv_forward and gemm_conv_backward kernels#522
jonahsamost wants to merge 9 commits intoPufferAI:4.0from
jonahsamost:jonah_4_8_nmmo3

Conversation

@jonahsamost
Copy link
Copy Markdown
Contributor

@jonahsamost jonahsamost commented Apr 9, 2026

Note that the old kernels are kept in the code, the new kernels are just renamed to <old_kernel_name>_fast. A lot of the speedup comes from using the fast divmod operator and caching repeatable operations.

nvcc native /root/PufferLib/tests/bench_gemm_conv_end2end  (fp32)
bench_gemm_conv_end2end  precision=fp32  warmup=50 iters=200

layer 1  B=1024  IC=59 OC=128  11x15 K=5 S=3 -> 3x4  relu=1
  --- correctness (reference = gemm_conv forward/backward) ---
  forward  gemm_fast vs gemm:  max|diff| 0  mean|diff| 0  max rel err 0
  forward  cudnn vs gemm:      max|diff| 2.74181e-06  mean|diff| 8.2099e-08  max rel err 0.0325513
  backward wgrad gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward d_input gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward wgrad cudnn vs gemm:     max|diff| 0.0477219  mean|diff| 0.00805478  max rel err 39.7875
  backward d_input cudnn vs gemm: max|diff| 0.000272185  mean|diff| 2.73545e-05  max rel err 13.4153
  --- timing (50 warmup / 200 iters): forward and backward measured separately ---
  forward:
    gemm (slow): 202.8187 us/iter
    gemm_fast:   192.4102 us/iter  (1.05x vs gemm)
    cudnn:       198.3834 us/iter  (1.02x vs gemm, 0.97x vs gemm_fast)
  backward:
    gemm (slow): 681.7363 us/iter
    gemm_fast:   541.7361 us/iter  (1.26x vs gemm)
    cudnn:       391.7222 us/iter  (1.74x vs gemm, 1.38x vs gemm_fast)

...

layer 2  B=1024  IC=128 OC=128  3x4 K=3 S=1 -> 1x2  relu=0
  --- correctness (reference = gemm_conv forward/backward) ---
  forward  gemm_fast vs gemm:  max|diff| 0  mean|diff| 0  max rel err 0
  forward  cudnn vs gemm:      max|diff| 0.000462592  mean|diff| 7.1466e-05  max rel err 21.2912
  backward wgrad gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward d_input gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward wgrad cudnn vs gemm:     max|diff| 0.0218544  mean|diff| 0.00328068  max rel err 21.8746
  backward d_input cudnn vs gemm: max|diff| 9.43989e-06  mean|diff| 2.11144e-07  max rel err 0.197254
  --- timing (50 warmup / 200 iters): forward and backward measured separately ---
  forward:
    gemm (slow):  62.6403 us/iter
    gemm_fast:    56.8698 us/iter  (1.10x vs gemm)
    cudnn:        67.7363 us/iter  (0.92x vs gemm, 0.84x vs gemm_fast)
  backward:
    gemm (slow): 116.6816 us/iter
    gemm_fast:   100.2875 us/iter  (1.16x vs gemm)
    cudnn:       131.9123 us/iter  (0.88x vs gemm, 0.76x vs gemm_fast)

...

nvcc native /root/PufferLib/tests/bench_gemm_conv_end2end  (bf16)
bench_gemm_conv_end2end  precision=bf16  warmup=50 iters=200

layer 1  B=1024  IC=59 OC=128  11x15 K=5 S=3 -> 3x4  relu=1
  --- correctness (reference = gemm_conv forward/backward) ---
  forward  gemm_fast vs gemm:  max|diff| 0  mean|diff| 0  max rel err 0
  forward  cudnn vs gemm:      max|diff| 0.0078125  mean|diff| 0.00019535  max rel err 167.084
  backward wgrad gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward d_input gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward wgrad cudnn vs gemm:     max|diff| 2  mean|diff| 0.159601  max rel err 2574.97
  backward d_input cudnn vs gemm: max|diff| 0.00585938  mean|diff| 0.000110408  max rel err 139.618
  --- timing (50 warmup / 200 iters): forward and backward measured separately ---
  forward:
    gemm (slow): 139.1566 us/iter
    gemm_fast:   106.4277 us/iter  (1.31x vs gemm)
    cudnn:       239.0912 us/iter  (0.58x vs gemm, 0.45x vs gemm_fast)
  backward:
    gemm (slow): 572.0896 us/iter
    gemm_fast:   403.5356 us/iter  (1.42x vs gemm)
    cudnn:       1495.2573 us/iter  (0.38x vs gemm, 0.27x vs gemm_fast)

...

layer 2  B=1024  IC=128 OC=128  3x4 K=3 S=1 -> 1x2  relu=0
  --- correctness (reference = gemm_conv forward/backward) ---
  forward  gemm_fast vs gemm:  max|diff| 0  mean|diff| 0  max rel err 0
  forward  cudnn vs gemm:      max|diff| 0.0078125  mean|diff| 2.17161e-07  max rel err 0.0193715
  backward wgrad gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward d_input gemm_fast vs gemm: max|diff| 0  mean|diff| 0  max rel err 0
  backward wgrad cudnn vs gemm:     max|diff| 0.125  mean|diff| 5.90477e-06  max rel err 0.786781
  backward d_input cudnn vs gemm: max|diff| 0.00390625  mean|diff| 0.000104174  max rel err 93.8416
  --- timing (50 warmup / 200 iters): forward and backward measured separately ---
  forward:
    gemm (slow):  24.5946 us/iter
    gemm_fast:    18.4427 us/iter  (1.33x vs gemm)
    cudnn:       117.7077 us/iter  (0.21x vs gemm, 0.16x vs gemm_fast)
  backward:
    gemm (slow):  65.5693 us/iter
    gemm_fast:    49.1440 us/iter  (1.33x vs gemm)
    cudnn:       198.4363 us/iter  (0.33x vs gemm, 0.25x vs gemm_fast)

embedding and multihot kernels

in float
n3_multihot  B=32768  obs_size=1707  multihot_elems=318996480
  reference=n3_multihot_kernel  fast=n3_multihot_kernel_fast (dm_n3_hwf+dm_n3_hw+dm_n3_w, n3_hwf=1650)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only (correctness above uses cudaMemsetAsync like encoder):
    n3_multihot_kernel (ref): 2092.5452 us/iter
    n3_multihot_kernel_fast:  1810.9296 us/iter  (1.16x vs ref)
...
in bf16
n3_multihot  B=32768  obs_size=1707  multihot_elems=318996480
  reference=n3_multihot_kernel  fast=n3_multihot_kernel_fast (dm_n3_hwf+dm_n3_hw+dm_n3_w, n3_hwf=1650)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only (correctness above uses cudaMemsetAsync like encoder):
    n3_multihot_kernel (ref): 1092.8813 us/iter
    n3_multihot_kernel_fast:  1002.1514 us/iter  (1.09x vs ref)


...
in float
n3_embedding  B=32768  obs_size=1707  out_elems=49283072  embed_elems=4096
  reference=n3_embedding_kernel  fast=n3_embedding_kernel_fast (kDmN3Player + vec copy)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only:
    n3_embedding_kernel (ref): 473.7963 us/iter
    n3_embedding_kernel_fast: 145.2106 us/iter  (3.26x vs ref)

in bf16
n3_embedding  B=32768  obs_size=1707  out_elems=49283072  embed_elems=4096
  reference=n3_embedding_kernel  fast=n3_embedding_kernel_fast (kDmN3Player + vec copy)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only:
    n3_embedding_kernel (ref): 695.4757 us/iter
    n3_embedding_kernel_fast:  73.6362 us/iter  (9.44x vs ref)

n3_conv_bias_grad  conv2  B=32768  OC=128  spatial=2  grad_elems=8388608
  reference=n3_conv_bias_grad_nchw  fast=n3_conv_bias_grad_nchw_fast (FastDivMod)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only:
    n3_conv_bias_grad_nchw (ref):  51.2400 us/iter
    n3_conv_bias_grad_nchw_fast:   40.9008 us/iter  (1.25x vs ref)

n3_conv_bias_grad  conv1  B=32768  OC=128  spatial=12  grad_elems=50331648
  reference=n3_conv_bias_grad_nchw  fast=n3_conv_bias_grad_nchw_fast (FastDivMod)
  correctness  fast vs reference: max|diff| 0  mean|diff| 0  max rel err 0
  timing (50 warmup / 200 iters), kernel only:
    n3_conv_bias_grad_nchw (ref): 581.6385 us/iter
    n3_conv_bias_grad_nchw_fast:  136.5968 us/iter  (4.26x vs ref)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant