Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions samples/FFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def fft_kernel(x_packed_in, y_packed_out,
N (ConstInt): Total FFT size (e.g., 256, 1024).
F0, F1, F2 (ConstInt): Factors of N, such that N = F0 * F1 * F2. These define
the logical 3D shape for the FFT decomposition.
BS (ConstInt): Batch size of the input data.
BS (ConstInt): Per-block minibatch size — the number of batch items
processed by a single thread block. The full input batch
is split across blocks as grid_x = Batch // BS.
D (ConstInt): Atom packing dimension. This parameter controls how the real and
imaginary data are interleaved and packed into memory for optimal
coalesced access on the GPU.
Expand All @@ -43,14 +45,13 @@ def fft_kernel(x_packed_in, y_packed_out,
F1F2 = F1 * F2
F0F2 = F0 * F2

bid = ct.bid(0) # Get the Batch ID for the current block.
# In this kernel, each block processes one item from the batch.
bid = ct.bid(0) # Block index along the batch dimension.
# Each block processes a contiguous minibatch of BS items from the full batch.

# --- Load Input Data ---
# Load input data for the current batch from `x_packed_in`.
# `x_packed_in` is initially (BS, N * 2 // D, D) due to the packing scheme.
# `ct.load` reads the specified tile from global memory.
# Then, `ct.reshape` transforms it to (BS, N, 2) to logically separate
# Load this block's minibatch from `x_packed_in` (shape (Batch, N*2//D, D)).
# `ct.load` reads a (BS, N*2//D, D) tile at minibatch index `bid`.
# Then `ct.reshape` transforms it to (BS, N, 2) to logically separate
# the real and imaginary components for each of the N elements.
X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0),
shape=(BS, N * 2 // D, D)), (BS, N, 2))
Expand Down Expand Up @@ -233,7 +234,8 @@ def make_twiddles(decomp: tuple, precision: torch.dtype, device: torch.device):
def cutile_fft(
x: torch.Tensor,
factors: tuple, # (F0, F1, F2) - factors of N
atom_packing_dim: int = 64 # The 'D' parameter for data packing/unpacking
atom_packing_dim: int = 64, # The 'D' parameter for data packing/unpacking
minibatch: int = 1, # The 'BS' kernel constant: items processed per block
) -> torch.Tensor:
"""
Performs a Batched 1D Fast Fourier Transform (FFT) using a cuTile kernel
Expand All @@ -253,15 +255,21 @@ def cutile_fft(
in the kernel. This value affects memory access patterns.
The total number of real/imaginary elements (N*2) must be
divisible by this dimension. Default is 64.
minibatch (int): Number of batch items processed by each thread block (the
kernel's `BS` constant). The grid is sized as Batch // minibatch,
so `Batch` must be divisible by `minibatch`. Larger values let
a block reuse the loaded W/T matrices across more items, but
increase register/shared-memory pressure. Default is 1.

Returns:
torch.Tensor: Output tensor of shape (Batch, N) containing the FFT results.
The output data type will be torch.complex64.

Raises:
ValueError: If input tensor dimensions, device, or data type are incorrect,
if the provided factors do not multiply to N, or if N*2 is not
divisible by atom_packing_dim.
if the provided factors do not multiply to N, if N*2 is not
divisible by atom_packing_dim, or if Batch is not divisible
by minibatch.
"""
# --- Input Validation ---
if x.ndim != 2:
Expand All @@ -271,8 +279,8 @@ def cutile_fft(
if x.dtype != torch.complex64:
raise ValueError("Input tensor dtype must be torch.complex64.")

BS = x.shape[0] # Extract Batch Size from the input tensor's shape.
N = x.shape[1] # Extract Total FFT size from the input tensor's shape.
Batch = x.shape[0] # Total batch size from the input tensor's shape.
N = x.shape[1] # Total FFT size from the input tensor's shape.

F0, F1, F2 = factors
# Validate that the provided factors correctly decompose the total FFT size N.
Expand All @@ -285,18 +293,20 @@ def cutile_fft(
PRECISION_DTYPE = x.real.dtype

# --- Prepare Input Data for Kernel (Split real/imag, pack) ---
# Convert the complex input tensor (BS, N) to a real tensor (BS, N, 2)
# Convert the complex input tensor (Batch, N) to a real tensor (Batch, N, 2)
# where the last dimension explicitly separates real and imaginary parts.
x_ri = torch.view_as_real(x)

# Reshape the real/imaginary tensor to the packed format (BS, N*2 // D, D)
# Reshape the real/imaginary tensor to the packed format (Batch, N*2 // D, D)
# that the kernel expects for efficient memory access.
# This step assumes that the total number of real/imaginary elements (N*2)
# is perfectly divisible by the `atom_packing_dim` (D).
if (N * 2) % atom_packing_dim != 0:
raise ValueError(f"Total real/imag elements (N*2 = {N*2}) must be divisible by "
f"atom_packing_dim ({atom_packing_dim}) for kernel packing.")
x_packed_in = x_ri.reshape(BS, N * 2 // atom_packing_dim, atom_packing_dim).contiguous()
if Batch % minibatch != 0:
raise ValueError(f"Batch ({Batch}) must be divisible by minibatch ({minibatch}).")
x_packed_in = x_ri.reshape(Batch, N * 2 // atom_packing_dim, atom_packing_dim).contiguous()

# --- Generate W (Rotation) and T (Twiddle) Matrices ---
# These matrices are pre-computed mathematically based on the FFT decomposition.
Expand All @@ -310,23 +320,23 @@ def cutile_fft(
y_packed_out = torch.empty_like(x_packed_in)

# --- Calculate Grid Dimensions ---
# For this FFT kernel, one thread block is launched for each item in the batch.
# One thread block is launched per minibatch of size `minibatch`.
# The grid is a 3-tuple (grid_x, grid_y, grid_z).
grid = (BS, 1, 1)
grid = (Batch // minibatch, 1, 1)

# --- Launch the cuTile Kernel ---
# The `fft_kernel` is launched on the GPU with the calculated grid dimensions.
# All necessary input tensors (packed data, W and T matrices) and constant parameters
# (N, F0, F1, F2, BS, D) are passed to the kernel.
# (N, F0, F1, F2, BS=minibatch, D) are passed to the kernel.
ct.launch(torch.cuda.current_stream(), grid, fft_kernel,
(x_packed_in, y_packed_out,
W0_gmem, W1_gmem, W2_gmem,
T0_gmem, T1_gmem,
N, F0, F1, F2, BS, atom_packing_dim))
N, F0, F1, F2, minibatch, atom_packing_dim))

# --- Unpack Output from Kernel (Reshape, combine real/imag) ---
# Reshape the packed output tensor back to (BS, N, 2) to separate real/imaginary parts.
y_ri = y_packed_out.reshape(BS, N, 2)
# Reshape the packed output tensor back to (Batch, N, 2) to separate real/imaginary parts.
y_ri = y_packed_out.reshape(Batch, N, 2)
# Convert the real/imaginary pair tensor back to a complex tensor (torch.complex64).
y_complex = torch.view_as_complex(y_ri)

Expand Down