Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lib/mkl/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,13 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N}
desc = desc_ref[]
# Do not program descriptor scaling; we'll perform inverse normalization manually.
# Set placement explicitly based on plan type later
# Construct a SYCL queue from current Level Zero context/device (reuse global queue)
# Use the task-local cached SYCL queue wrapping the global Level Zero queue, like the
# other oneMKL wrappers do. Creating fresh syclContext/syclQueue objects per plan is
# unsound: once they become garbage their finalizers (syclQueueDestroy etc.) tear down
# SYCL runtime state for the still-in-use underlying queue, corrupting later DFT
# commits and crashing at process exit.
ze_ctx = oneAPI.context(); ze_dev = oneAPI.device()
sycl_dev = SYCL.syclDevice(SYCL.syclPlatform(oneAPI.driver()), ze_dev)
sycl_ctx = SYCL.syclContext([sycl_dev], ze_ctx)
q = SYCL.syclQueue(sycl_ctx, sycl_dev, oneAPI.global_queue(ze_ctx, ze_dev))
q = oneAPI.sycl_queue(oneAPI.global_queue(ze_ctx, ze_dev))
return desc, q
end

Expand Down
23 changes: 23 additions & 0 deletions test/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,27 @@ end
end
end
end

@testset "shared queue lifetime across plans" begin
# Plans must share the single cached task-local SYCL queue rather than each owning a
# throwaway one (whose finalizer would tear down shared SYCL/oneMKL state). Assert the
# shared handle deterministically, independent of whether a stale queue would crash.
cached_handle = Base.unsafe_convert(oneAPI.oneMKL.syclQueue_t,
oneAPI.sycl_queue(oneAPI.global_queue(oneAPI.context(), oneAPI.device())))

dX1 = gpu(rand(ComplexF32, 8))
p1 = AbstractFFTs.plan_fft(dX1)
@test p1.queue == cached_handle
dY1 = p1 * dX1
p1i = AbstractFFTs.plan_ifft(dX1)
p1i * dY1

GC.gc(true) # run finalizers of any throwaway per-plan SYCL wrappers

X2 = rand(ComplexF32, 8, 32)
dX2 = gpu(X2)
p2 = AbstractFFTs.plan_fft(dX2)
@test p2.queue == cached_handle
cmp(p2 * dX2, fft(X2))
end
end
Loading