diff --git a/lib/mkl/fft.jl b/lib/mkl/fft.jl index 745ff65f..4429801b 100644 --- a/lib/mkl/fft.jl +++ b/lib/mkl/fft.jl @@ -105,11 +105,13 @@ function _create_descriptor(sz::NTuple{N,Int}, T::Type, complex::Bool) where {N} desc = desc_ref[] # Do not program descriptor scaling; we'll perform inverse normalization manually. # Set placement explicitly based on plan type later - # Construct a SYCL queue from current Level Zero context/device (reuse global queue) + # Use the task-local cached SYCL queue wrapping the global Level Zero queue, like the + # other oneMKL wrappers do. Creating fresh syclContext/syclQueue objects per plan is + # unsound: once they become garbage their finalizers (syclQueueDestroy etc.) tear down + # SYCL runtime state for the still-in-use underlying queue, corrupting later DFT + # commits and crashing at process exit. ze_ctx = oneAPI.context(); ze_dev = oneAPI.device() - sycl_dev = SYCL.syclDevice(SYCL.syclPlatform(oneAPI.driver()), ze_dev) - sycl_ctx = SYCL.syclContext([sycl_dev], ze_ctx) - q = SYCL.syclQueue(sycl_ctx, sycl_dev, oneAPI.global_queue(ze_ctx, ze_dev)) + q = oneAPI.sycl_queue(oneAPI.global_queue(ze_ctx, ze_dev)) return desc, q end diff --git a/test/fft.jl b/test/fft.jl index 1b148dfe..d4419462 100644 --- a/test/fft.jl +++ b/test/fft.jl @@ -79,4 +79,27 @@ end end end end + +@testset "shared queue lifetime across plans" begin + # Plans must share the single cached task-local SYCL queue rather than each owning a + # throwaway one (whose finalizer would tear down shared SYCL/oneMKL state). Assert the + # shared handle deterministically, independent of whether a stale queue would crash. + cached_handle = Base.unsafe_convert(oneAPI.oneMKL.syclQueue_t, + oneAPI.sycl_queue(oneAPI.global_queue(oneAPI.context(), oneAPI.device()))) + + dX1 = gpu(rand(ComplexF32, 8)) + p1 = AbstractFFTs.plan_fft(dX1) + @test p1.queue == cached_handle + dY1 = p1 * dX1 + p1i = AbstractFFTs.plan_ifft(dX1) + p1i * dY1 + + GC.gc(true) # run finalizers of any throwaway per-plan SYCL wrappers + + X2 = rand(ComplexF32, 8, 32) + dX2 = gpu(X2) + p2 = AbstractFFTs.plan_fft(dX2) + @test p2.queue == cached_handle + cmp(p2 * dX2, fft(X2)) +end end