diff --git a/samples/FFT.py b/samples/FFT.py index 8d2815c..dfc5807 100644 --- a/samples/FFT.py +++ b/samples/FFT.py @@ -33,7 +33,9 @@ def fft_kernel(x_packed_in, y_packed_out, N (ConstInt): Total FFT size (e.g., 256, 1024). F0, F1, F2 (ConstInt): Factors of N, such that N = F0 * F1 * F2. These define the logical 3D shape for the FFT decomposition. - BS (ConstInt): Batch size of the input data. + BS (ConstInt): Per-block minibatch size — the number of batch items + processed by a single thread block. The full input batch + is split across blocks as grid_x = Batch // BS. D (ConstInt): Atom packing dimension. This parameter controls how the real and imaginary data are interleaved and packed into memory for optimal coalesced access on the GPU. @@ -43,14 +45,13 @@ def fft_kernel(x_packed_in, y_packed_out, F1F2 = F1 * F2 F0F2 = F0 * F2 - bid = ct.bid(0) # Get the Batch ID for the current block. - # In this kernel, each block processes one item from the batch. + bid = ct.bid(0) # Block index along the batch dimension. + # Each block processes a contiguous minibatch of BS items from the full batch. # --- Load Input Data --- - # Load input data for the current batch from `x_packed_in`. - # `x_packed_in` is initially (BS, N * 2 // D, D) due to the packing scheme. - # `ct.load` reads the specified tile from global memory. - # Then, `ct.reshape` transforms it to (BS, N, 2) to logically separate + # Load this block's minibatch from `x_packed_in` (shape (Batch, N*2//D, D)). + # `ct.load` reads a (BS, N*2//D, D) tile at minibatch index `bid`. + # Then `ct.reshape` transforms it to (BS, N, 2) to logically separate # the real and imaginary components for each of the N elements. X_ri = ct.reshape(ct.load(x_packed_in, index=(bid, 0, 0), shape=(BS, N * 2 // D, D)), (BS, N, 2)) @@ -233,7 +234,8 @@ def make_twiddles(decomp: tuple, precision: torch.dtype, device: torch.device): def cutile_fft( x: torch.Tensor, factors: tuple, # (F0, F1, F2) - factors of N - atom_packing_dim: int = 64 # The 'D' parameter for data packing/unpacking + atom_packing_dim: int = 64, # The 'D' parameter for data packing/unpacking + minibatch: int = 1, # The 'BS' kernel constant: items processed per block ) -> torch.Tensor: """ Performs a Batched 1D Fast Fourier Transform (FFT) using a cuTile kernel @@ -253,6 +255,11 @@ def cutile_fft( in the kernel. This value affects memory access patterns. The total number of real/imaginary elements (N*2) must be divisible by this dimension. Default is 64. + minibatch (int): Number of batch items processed by each thread block (the + kernel's `BS` constant). The grid is sized as Batch // minibatch, + so `Batch` must be divisible by `minibatch`. Larger values let + a block reuse the loaded W/T matrices across more items, but + increase register/shared-memory pressure. Default is 1. Returns: torch.Tensor: Output tensor of shape (Batch, N) containing the FFT results. @@ -260,8 +267,9 @@ def cutile_fft( Raises: ValueError: If input tensor dimensions, device, or data type are incorrect, - if the provided factors do not multiply to N, or if N*2 is not - divisible by atom_packing_dim. + if the provided factors do not multiply to N, if N*2 is not + divisible by atom_packing_dim, or if Batch is not divisible + by minibatch. """ # --- Input Validation --- if x.ndim != 2: @@ -271,8 +279,8 @@ def cutile_fft( if x.dtype != torch.complex64: raise ValueError("Input tensor dtype must be torch.complex64.") - BS = x.shape[0] # Extract Batch Size from the input tensor's shape. - N = x.shape[1] # Extract Total FFT size from the input tensor's shape. + Batch = x.shape[0] # Total batch size from the input tensor's shape. + N = x.shape[1] # Total FFT size from the input tensor's shape. F0, F1, F2 = factors # Validate that the provided factors correctly decompose the total FFT size N. @@ -285,18 +293,20 @@ def cutile_fft( PRECISION_DTYPE = x.real.dtype # --- Prepare Input Data for Kernel (Split real/imag, pack) --- - # Convert the complex input tensor (BS, N) to a real tensor (BS, N, 2) + # Convert the complex input tensor (Batch, N) to a real tensor (Batch, N, 2) # where the last dimension explicitly separates real and imaginary parts. x_ri = torch.view_as_real(x) - # Reshape the real/imaginary tensor to the packed format (BS, N*2 // D, D) + # Reshape the real/imaginary tensor to the packed format (Batch, N*2 // D, D) # that the kernel expects for efficient memory access. # This step assumes that the total number of real/imaginary elements (N*2) # is perfectly divisible by the `atom_packing_dim` (D). if (N * 2) % atom_packing_dim != 0: raise ValueError(f"Total real/imag elements (N*2 = {N*2}) must be divisible by " f"atom_packing_dim ({atom_packing_dim}) for kernel packing.") - x_packed_in = x_ri.reshape(BS, N * 2 // atom_packing_dim, atom_packing_dim).contiguous() + if Batch % minibatch != 0: + raise ValueError(f"Batch ({Batch}) must be divisible by minibatch ({minibatch}).") + x_packed_in = x_ri.reshape(Batch, N * 2 // atom_packing_dim, atom_packing_dim).contiguous() # --- Generate W (Rotation) and T (Twiddle) Matrices --- # These matrices are pre-computed mathematically based on the FFT decomposition. @@ -310,23 +320,23 @@ def cutile_fft( y_packed_out = torch.empty_like(x_packed_in) # --- Calculate Grid Dimensions --- - # For this FFT kernel, one thread block is launched for each item in the batch. + # One thread block is launched per minibatch of size `minibatch`. # The grid is a 3-tuple (grid_x, grid_y, grid_z). - grid = (BS, 1, 1) + grid = (Batch // minibatch, 1, 1) # --- Launch the cuTile Kernel --- # The `fft_kernel` is launched on the GPU with the calculated grid dimensions. # All necessary input tensors (packed data, W and T matrices) and constant parameters - # (N, F0, F1, F2, BS, D) are passed to the kernel. + # (N, F0, F1, F2, BS=minibatch, D) are passed to the kernel. ct.launch(torch.cuda.current_stream(), grid, fft_kernel, (x_packed_in, y_packed_out, W0_gmem, W1_gmem, W2_gmem, T0_gmem, T1_gmem, - N, F0, F1, F2, BS, atom_packing_dim)) + N, F0, F1, F2, minibatch, atom_packing_dim)) # --- Unpack Output from Kernel (Reshape, combine real/imag) --- - # Reshape the packed output tensor back to (BS, N, 2) to separate real/imaginary parts. - y_ri = y_packed_out.reshape(BS, N, 2) + # Reshape the packed output tensor back to (Batch, N, 2) to separate real/imaginary parts. + y_ri = y_packed_out.reshape(Batch, N, 2) # Convert the real/imaginary pair tensor back to a complex tensor (torch.complex64). y_complex = torch.view_as_complex(y_ri)