Skip to content

feat(compute): fused PatchTST encoder CUDA kernels#86

Merged
dndungu merged 2 commits intomainfrom
feat/fused-encoder-kernel
Apr 10, 2026
Merged

feat(compute): fused PatchTST encoder CUDA kernels#86
dndungu merged 2 commits intomainfrom
feat/fused-encoder-kernel

Conversation

@dndungu
Copy link
Copy Markdown
Contributor

@dndungu dndungu commented Apr 10, 2026

Summary

  • Add fused encoder forward/backward CUDA kernel orchestrators that replace ~78 discrete Engine operations per encoder layer with a single C function call
  • Forward kernel uses 8 cuBLAS GEMMs + 7 custom sub-kernels (LayerNorm, head transpose, GELU, softmax, residual)
  • Backward kernel uses 14 cuBLAS GEMMs + 6 custom sub-kernels for gradient computation
  • Full Go binding stack: purego + CGo wrappers, KernelRunner interface, FusedEncoderProvider optional interface on GPUEngine
  • Makefile updated to link libkernels.so against libcublas

Files

CUDA kernels (new):

  • internal/cuda/kernels/fused_encoder_fwd.{cu,h} — forward orchestrator + sub-kernels
  • internal/cuda/kernels/fused_encoder_bwd.{cu,h} — backward orchestrator + sub-kernels

Go bindings (new):

  • internal/cuda/kernels/fused_encoder_{fwd,bwd}_purego.go — purego wrappers + buffer index constants
  • internal/cuda/kernels/fused_encoder_{fwd,bwd}.go — CGo wrappers
  • compute/fused_encoder.go — FusedEncoderProvider optional interface
  • compute/gpu_fused_encoder.go — GPUEngine implementation

Modified:

  • internal/cuda/kernels/purego.go — KernelLib symbol registration (optional)
  • internal/cuda/kernels/Makefile — add .cu files + -lcublas link
  • internal/cublas/cublas_purego.goHandle.Ptr() for raw pointer access
  • internal/gpuapi/kernels.go — KernelRunner interface additions
  • internal/gpuapi/{cuda,opencl,rocm,sycl,fpga,metal}_kernels.go — adapter implementations/stubs

Test plan

  • go build ./... passes
  • go test ./compute/ passes (CPU tests)
  • Compile kernels on DGX: cd internal/cuda/kernels && make CUDA_ARCH=sm_121 shared
  • GPU unit test: fused forward matches per-op forward within 1e-4
  • PatchTST training benchmark: 28K×20×10 epochs with fused encoder

dndungu added 2 commits April 10, 2026 12:57
Implement fused encoder forward and backward pass orchestrators that
replace ~78 discrete Engine operations per layer with a single C
function call. The orchestrator launches cuBLAS GEMMs for matrix
multiplications and custom CUDA sub-kernels for LayerNorm, head
transpose, GELU, softmax, and residual operations.

Forward kernel (fused_encoder_fwd.cu):
- 7 sub-kernels: layernorm_fwd, bias_add, head_split, head_merge,
  softmax_fwd, bias_gelu_fwd, bias_residual_add
- 8 cuBLAS Sgemm/SgemmStridedBatched calls
- Caches 16 intermediate buffers (FEB_*) for backward use

Backward kernel (fused_encoder_bwd.cu):
- 6 sub-kernels: layernorm_bwd, gelu_bwd, softmax_bwd,
  bias_grad_reduce, add, head_split/merge
- 14 cuBLAS calls for weight/input gradient computation
- Reads forward cache; accumulates into gradient buffers

Go bindings:
- Purego and CGo wrappers for both kernels
- KernelLib symbol registration (optional, non-fatal if absent)
- KernelRunner interface methods + all backend stubs
- FusedEncoderProvider optional interface on GPUEngine
- cublas.Handle.Ptr() for passing raw handle to C

Build: Makefile adds -lcublas to libkernels.so link step.

Closes zerfoo/zerfoo E55 tasks T55.1.1, T55.1.2, T55.2.1, T55.2.2.
The cublas package uses the same uintptr-to-unsafe.Pointer pattern as
other GPU runtime binding packages for storing C library handles via
purego/dlopen. Add it to the vet exclusion list alongside internal/cuda,
internal/hip, etc.

Triggered by the new Handle.Ptr() method added for the fused encoder
kernel orchestrator.
@dndungu dndungu merged commit d7bb08e into main Apr 10, 2026
1 check passed
@dndungu dndungu deleted the feat/fused-encoder-kernel branch April 10, 2026 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant