Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080 (`tileiras`
| Layer Norm bwd | 4096² f32 | 246 GB/s | 251 GB/s | OK (-2%) |
| Matrix Multiplication | 4096³ f32 | 47.4 TFLOPS | 43.5 TFLOPS | +9% |
| Batch Matrix Multiply | 1024×512×2048 ×8 f32 | 34.2 TFLOPS | 30.9 TFLOPS | +11% |
| FFT (3-stage Cooley-Tukey) | 512-pt ×64 c64 | 545 μs | 550 μs | OK (+1%) |
| FFT (3-stage Cooley-Tukey) | 4096-pt ×256 c64 | 209 μs | 204 μs | OK (-2%) |
| Mixture of Experts | 256tok 1024h 32e 2048i f16 | 27.7 TFLOPS | 20.3 TFLOPS | +36% |
| Attention (FMHA) | 8×16×1024² ×64 f16 causal | 102.7 TFLOPS | 63.3 TFLOPS | +62% |
| Softmax (TMA) | 4096² f32 | 838 GB/s | 843 GB/s | OK (-1%) |
Expand Down
85 changes: 36 additions & 49 deletions examples/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,29 @@ using FFTW
# Python left-multiply W @ X ↔ Julia right-multiply X * W (batch dims trailing).
# Python ct.permute(x, (0,2,3,1)) ↔ Julia permutedims(x, (3,1,2,4)).
function fft_kernel(
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, N2D, BS)
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, N2D, BS)
x_packed_in::ct.TileArray{Float32, 3}, # Input (D, 2N ÷ D, BS)
y_packed_out::ct.TileArray{Float32, 3}, # Output (D, 2N ÷ D, BS)
W0::ct.TileArray{Float32, 3}, # W0 (2, F0, F0) DFT matrix
W1::ct.TileArray{Float32, 3}, # W1 (2, F1, F1)
W2::ct.TileArray{Float32, 3}, # W2 (2, F2, F2)
T0::ct.TileArray{Float32, 3}, # T0 (2, F1F2, F0) twiddle factors
T1::ct.TileArray{Float32, 3}, # T1 (2, F2, F1) twiddle factors
n_const::Int,
f0_const::Int,
f1_const::Int,
f2_const::Int,
f0f1_const::Int,
f1f2_const::Int,
f0f2_const::Int,
bs_const::Int,
d_const::Int,
n2d_const::Int
N::Int,
F0::Int,
F1::Int,
F2::Int,
BS::Int,
D::Int,
)
N = n_const
F0 = f0_const
F1 = f1_const
F2 = f2_const
F0F1 = f0f1_const
F1F2 = f1f2_const
F0F2 = f0f2_const
BS = bs_const
D = d_const
N2D = n2d_const
F0F1 = F0 * F1
F1F2 = F1 * F2
F0F2 = F0 * F2

bid = ct.bid(1)

# --- Load Input Data ---
# Input is (D, N2D, BS). Load and reshape to (2, N, BS).
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, N2D, BS)), (2, N, BS))
# Input is (D, 2N ÷ D, BS). Load and reshape to (2, N, BS).
X_ri = reshape(ct.load(x_packed_in; index=(Int32(1), Int32(1), bid), shape=(D, 2N ÷ D, BS)), (2, N, BS))

# Split real and imaginary parts, reshape to 4D factored form
X_r = reshape(ct.extract(X_ri, (1, 1, 1), (1, N, BS)), (F2, F1, F0, BS))
Expand Down Expand Up @@ -131,7 +120,7 @@ function fft_kernel(
# --- Concatenate and Store ---
X_r_final = reshape(X_r10, (1, N, BS))
X_i_final = reshape(X_i10, (1, N, BS))
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, N2D, BS))
Y_ri = reshape(ct.cat((X_r_final, X_i_final), 1), (D, 2N ÷ D, BS))
ct.store(y_packed_out; index=(Int32(1), Int32(1), bid), tile=Y_ri)

return
Expand Down Expand Up @@ -193,14 +182,14 @@ end
=============================================================================#

function prepare(; benchmark::Bool=false,
batch::Int=benchmark ? 64 : 2,
factors::NTuple{3,Int}=benchmark ? (8, 8, 8) : (2, 2, 2),
batch::Int=benchmark ? 256 : 2,
factors::NTuple{3,Int}=benchmark ? (16, 16, 16) : (2, 2, 2),
atom_packing_dim::Int=min(64, 2 * prod(factors)))
n = prod(factors)
@assert (n * 2) % atom_packing_dim == 0 "N*2 must be divisible by atom_packing_dim"
N = prod(factors)
@assert 2N % atom_packing_dim == 0 "2 * N must be divisible by atom_packing_dim"

cuRAND.seed!(42)
input = cuRAND.randn(ComplexF32, n, batch)
input = cuRAND.randn(ComplexF32, N, batch)

W0, W1, W2, T0, T1 = make_twiddles(factors)
W0_gpu = CuArray(W0)
Expand All @@ -210,46 +199,43 @@ function prepare(; benchmark::Bool=false,
T1_gpu = CuArray(T1)

D = atom_packing_dim
N2D = n * 2 ÷ D
# Pack complex input as (D, N2D, batch) Float32 — matches Python's (batch, N2D, D) row-major.
# When D=2, reinterpret gives (2, n, batch) directly. For D>2, reshape the flat layout.
x_ri = reinterpret(reshape, Float32, input) # (2, n, batch)
x_packed = D == 2 ? x_ri : reshape(x_ri, D, N2D, batch)
y_packed = CuArray{Float32}(undef, D, N2D, batch)
# Pack complex input as (D, 2N ÷ D, batch) Float32 — matches Python's (batch, 2N ÷ D, D) row-major.
# When D=2, reinterpret gives (2, N, batch) directly. For D>2, reshape the flat layout.
x_ri = reinterpret(reshape, Float32, input) # (2, N, batch)
x_packed = D == 2 ? x_ri : reshape(x_ri, D, 2N ÷ D, batch)
y_packed = CuArray{Float32}(undef, D, 2N ÷ D, batch)

return (;
input, x_packed, y_packed,
W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
factors, batch, n, D, N2D
factors, batch, N, D
)
end

function run(data; nruns::Int=1, warmup::Int=0)
(; x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
factors, batch, n, D, N2D) = data
factors, batch, N, D) = data

F0, F1, F2 = factors
F0F1 = F0 * F1
F1F2 = F1 * F2
F0F2 = F0 * F2
grid = (batch, 1, 1)
BS = 1
grid = (batch ÷ BS, 1, 1)

CUDACore.@sync for _ in 1:warmup
@cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D))
@cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D))
end

times = Float64[]
NVTX.@range "cuTile" begin
for i in 1:nruns
NVTX.@range "run $i" begin
t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(n), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2), ct.Constant(batch), ct.Constant(D), ct.Constant(N2D))
t = CUDACore.@elapsed @cuda backend=cuTile blocks=grid fft_kernel(x_packed, y_packed, W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu, ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2), ct.Constant(BS), ct.Constant(D))
push!(times, t * 1000) # ms
end
end
end

# Unpack output: (D, N2D, batch) → (2, n, batch) → ComplexF32(n, batch)
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, n, batch)
# Unpack output: (D, 2n ÷ D, batch) → (2, N, batch) → ComplexF32(n, batch)
y_ri = D == 2 ? y_packed : reshape(y_packed, 2, N, batch)
y_complex = reinterpret(reshape, ComplexF32, y_ri)
output = copy(y_complex)

Expand All @@ -272,18 +258,19 @@ end
=============================================================================#

function run_others(data; nruns::Int=1, warmup::Int=0)
(; input, batch, n) = data
(; input, batch, N) = data
results = Dict{String, Vector{Float64}}()

plan = cuFFT.plan_fft!(input, 1)
CUDACore.@sync for _ in 1:warmup
cuFFT.fft!(copy(input), 1)
plan * copy(input)
end
times_cufft = Float64[]
NVTX.@range "cuFFT" begin
for i in 1:nruns
NVTX.@range "run $i" begin
input_copy = copy(input)
t = CUDACore.@elapsed cuFFT.fft!(input_copy, 1)
t = CUDACore.@elapsed plan * input_copy
push!(times_cufft, t * 1000)
end
end
Expand Down
11 changes: 6 additions & 5 deletions examples/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def fft_make_twiddles(factors, precision, device):
def prepare(*, benchmark: bool = False, batch: int = None, factors: tuple = None, atom_packing_dim: int = None):
"""Allocate and initialize data for FFT."""
if batch is None:
batch = 64 if benchmark else 2
batch = 256 if benchmark else 2
if factors is None:
factors = (8, 8, 8) if benchmark else (2, 2, 2)
factors = (16, 16, 16) if benchmark else (2, 2, 2)
F0, F1, F2 = factors
N = F0 * F1 * F2
D = min(64, N * 2) if atom_packing_dim is None else atom_packing_dim
Expand Down Expand Up @@ -152,12 +152,13 @@ def run(data, *, nruns: int = 1, warmup: int = 0):
F0, F1, F2 = data["factors"]
batch, N, D = data["batch"], data["N"], data["D"]

grid = (batch, 1, 1)
BS = 1
grid = (batch // BS, 1, 1)

# Warmup
for _ in range(warmup):
ct.launch(torch.cuda.current_stream(), grid, fft_kernel,
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D))
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D))
torch.cuda.synchronize()

# Timed runs
Expand All @@ -169,7 +170,7 @@ def run(data, *, nruns: int = 1, warmup: int = 0):
end = torch.cuda.Event(enable_timing=True)
start.record()
ct.launch(torch.cuda.current_stream(), grid, fft_kernel,
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, batch, D))
(x_packed, y_packed, W0, W1, W2, T0, T1, N, F0, F1, F2, BS, D))
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)) # ms
Expand Down