diff --git a/README.md b/README.md index 8ae4b3f0..40a5924c 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras` | Layer Norm bwd | 4096² f32 | 246 GB/s | 251 GB/s | OK (-2%) | | Matrix Multiplication | 4096³ f32 | 47.4 TFLOPS | 43.5 TFLOPS | +9% | | Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 34.2 TFLOPS | 30.9 TFLOPS | +11% | -| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 545 μs | 550 μs | OK (+1%) | +| FFT (3-stage Cooley-Tukey) | 4096-pt ×256 c64 | 209 μs | 204 μs | OK (-2%) | | Mixture of Experts | 256tok 1024h 32e 2048i f16 | 27.7 TFLOPS | 20.3 TFLOPS | +36% | | Attention (FMHA) | 8×16×1024² ×64 f16 causal | 102.7 TFLOPS | 63.3 TFLOPS | +62% | | Softmax (TMA) | 4096² f32 | 838 GB/s | 843 GB/s | OK (-1%) | diff --git a/examples/fft.jl b/examples/fft.jl index af74cd70..612b317c 100644 --- a/examples/fft.jl +++ b/examples/fft.jl @@ -21,40 +21,29 @@ using FFTW # Python left-multiply W @ X ↔ Julia right-multiply X * W (batch dims trailing). # Python ct.permute(x, (0,2,3,1)) ↔ Julia permutedims(x, (3,1,2,4)). function fft_kernel( - x_packed_in::ct.TileArray{Float32, 3}, # Input (D, N2D, BS) - y_packed_out::ct.TileArray{Float32, 3}, # Output (D, N2D, BS) + x_packed_in::ct.TileArray{Float32, 3}, # Input (D, 2N ÷ D, BS) + y_packed_out::ct.TileArray{Float32, 3}, # Output (D, 2N ÷ D, BS) W0::ct.TileArray{Float32, 3}, # W0 (2, F0, F0) DFT matrix W1::ct.TileArray{Float32, 3}, # W1 (2, F1, F1) W2::ct.TileArray{Float32, 3}, # W2 (2, F2, F2) T0::ct.TileArray{Float32, 3}, # T0 (2, F1F2, F0) twiddle factors T1::ct.TileArray{Float32, 3}, # T1 (2, F2, F1) twiddle factors - n_const::Int, - f0_const::Int, - f1_const::Int, - f2_const::Int, - f0f1_const::Int, - f1f2_const::Int, - f0f2_const::Int, - bs_const::Int, - d_const::Int, - n2d_const::Int + N::Int, + F0::Int, + F1::Int, + F2::Int, + BS::Int, + D::Int, ) - N = n_const - F0 = f0_const - F1 = f1_const - F2 = f2_const - F0F1 = f0f1_const - F1F2 = f1f2_const - F0F2 = f0f2_const - BS = bs_const - D = d_const - N2D = n2d_const + F0F1 = F0 * F1 + F1F2 = F1 * F2 + F0F2 = F0 * F2 bid = ct.bid(1) # --- Load Input Data --- - # Input is (D, N2D, BS). Load and reshape to (2, N, BS). - X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS)) + # Input is (D, 2N ÷ D, BS). Load and reshape to (2, N, BS). + X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, 2N ÷ D, BS)), (2, N, BS)) # Split real and imaginary parts, reshape to 4D factored form X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS)) @@ -131,7 +120,7 @@ function fft_kernel( # --- Concatenate and Store --- X_r_final = reshape(X_r10, (1, N, BS)) X_i_final = reshape(X_i10, (1, N, BS)) - Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS)) + Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, 2N ÷ D, BS)) ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri) return @@ -193,14 +182,14 @@ end =============================================================================# function prepare(; benchmark::Bool=false, - batch::Int=benchmark ? 64 : 2, - factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2), + batch::Int=benchmark ? 256 : 2, + factors::NTuple{3,Int}=benchmark ? (16, 16, 16) : (2, 2, 2), atom_packing_dim::Int=min(64, 2 * prod(factors))) - n = prod(factors) - @assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim" + N = prod(factors) + @assert 2N % atom_packing_dim == 0 "2 * N must be divisible by atom_packing_dim" cuRAND.seed!(42) - input = cuRAND.randn(ComplexF32, n, batch) + input = cuRAND.randn(ComplexF32, N, batch) W0, W1, W2, T0, T1 = make_twiddles(factors) W0_gpu = CuArray(W0) @@ -210,46 +199,43 @@ function prepare(; benchmark::Bool=false, T1_gpu = CuArray(T1) D = atom_packing_dim - N2D = n * 2 ÷ D - # Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major. - # When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout. - x_ri = reinterpret(reshape, Float32, input) # (2, n, batch) - x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch) - y_packed = CuArray{Float32}(undef, D, N2D, batch) + # Pack complex input as (D, 2N ÷ D, batch) Float32 — matches Python's (batch, 2N ÷ D, D) row-major. + # When D=2, reinterpret gives (2, N, batch) directly. For D>2, reshape the flat layout. + x_ri = reinterpret(reshape, Float32, input) # (2, N, batch) + x_packed = D == 2 ? x_ri : reshape(x_ri, D, 2N ÷ D, batch) + y_packed = CuArray{Float32}(undef, D, 2N ÷ D, batch) return (; input, x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - factors, batch, n, D, N2D + factors, batch, N, D ) end function run(data; nruns::Int=1, warmup::Int=0) (; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, - factors, batch, n, D, N2D) = data + factors, batch, N, D) = data F0, F1, F2 = factors - F0F1 = F0 * F1 - F1F2 = F1 * F2 - F0F2 = F0 * F2 - grid = (batch, 1, 1) + BS = 1 + grid = (batch ÷ BS, 1, 1) CUDACore.@sync for _ in 1:warmup - @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D)) end times = Float64[] NVTX.@range "cuTile" begin for i in 1:nruns NVTX.@range "run $i" begin - t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D)) + t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D)) push!(times, t * 1000) # ms end end end - # Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch) - y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch) + # Unpack output: (D, 2n ÷ D, batch) → (2, N, batch) → ComplexF32(n, batch) + y_ri = D == 2 ? y_packed : reshape(y_packed, 2, N, batch) y_complex = reinterpret(reshape, ComplexF32, y_ri) output = copy(y_complex) @@ -272,18 +258,19 @@ end =============================================================================# function run_others(data; nruns::Int=1, warmup::Int=0) - (; input, batch, n) = data + (; input, batch, N) = data results = Dict{String, Vector{Float64}}() + plan = cuFFT.plan_fft!(input, 1) CUDACore.@sync for _ in 1:warmup - cuFFT.fft!(copy(input), 1) + plan * copy(input) end times_cufft = Float64[] NVTX.@range "cuFFT" begin for i in 1:nruns NVTX.@range "run $i" begin input_copy = copy(input) - t = CUDACore.@elapsed cuFFT.fft!(input_copy, 1) + t = CUDACore.@elapsed plan * input_copy push!(times_cufft, t * 1000) end end diff --git a/examples/fft.py b/examples/fft.py index d190ca43..e4dd7c1d 100644 --- a/examples/fft.py +++ b/examples/fft.py @@ -115,9 +115,9 @@ def fft_make_twiddles(factors, precision, device): def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None): """Allocate and initialize data for FFT.""" if batch is None: - batch = 64 if benchmark else 2 + batch = 256 if benchmark else 2 if factors is None: - factors = (8, 8, 8) if benchmark else (2, 2, 2) + factors = (16, 16, 16) if benchmark else (2, 2, 2) F0, F1, F2 = factors N = F0 * F1 * F2 D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim @@ -152,12 +152,13 @@ def run(data, *, nruns: int = 1, warmup: int = 0): F0, F1, F2 = data["factors"] batch, N, D = data["batch"], data["N"], data["D"] - grid = (batch, 1, 1) + BS = 1 + grid = (batch // BS, 1, 1) # Warmup for _ in range(warmup): ct.launch(torch.cuda.current_stream(), grid, fft_kernel, - (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D)) torch.cuda.synchronize() # Timed runs @@ -169,7 +170,7 @@ def run(data, *, nruns: int = 1, warmup: int = 0): end = torch.cuda.Event(enable_timing=True) start.record() ct.launch(torch.cuda.current_stream(), grid, fft_kernel, - (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D)) + (x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D)) end.record() torch.cuda.synchronize() times.append(start.elapsed_time(end)) # ms