diff --git a/Project.toml b/Project.toml index 707ee08d..032844e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "GPUArrays" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "11.5.8" +version = "11.5.9" [workspace] projects = ["lib/GPUArraysCore", "lib/JLArrays", "test", "docs"] diff --git a/src/host/random.jl b/src/host/random.jl index 5beb7d31..389ac790 100644 --- a/src/host/random.jl +++ b/src/host/random.jl @@ -447,7 +447,7 @@ end end end -function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T <: Number +function Random.rand!(rng::RNG, A::AnyGPUArray{T}) where T isempty(A) && return A rand_generic_kernel!(get_backend(A))(rng.seed, rng.counter, A; ndrange=length(A)) advance_counter!(rng) @@ -586,7 +586,8 @@ function Random.rand!(rng::RNG{AT}, A::AbstractArray{T}) where {AT, T} Random.rand!(rng, B) copyto!(A, B) end -function Random.randn!(rng::RNG{AT}, A::AbstractArray{T}) where {AT, T} +function Random.randn!(rng::RNG{AT}, A::AbstractArray{T}) where {AT, T<:Union{AbstractFloat, + Complex{<:AbstractFloat}}} isempty(A) && return A B = similar(AT{T}, size(A)) Random.randn!(rng, B) diff --git a/test/testsuite/random.jl b/test/testsuite/random.jl index 2cd65a3c..bb028a02 100644 --- a/test/testsuite/random.jl +++ b/test/testsuite/random.jl @@ -1,3 +1,14 @@ +# A non-`Number` isbits eltype with a custom sampler, mirroring how ColorTypes +# hooks colorants into the Random API. Exercises the generic element-wise rand! +# path; regression cover for JuliaGPU/CUDA.jl#3179. +struct RGBTriplet + r::Float32 + g::Float32 + b::Float32 +end +Random.rand(rng::AbstractRNG, ::Random.SamplerType{RGBTriplet}) = + RGBTriplet(rand(rng, Float32), rand(rng, Float32), rand(rng, Float32)) + @testsuite "random" (AT, eltypes)->begin rng = if AT <: AbstractGPUArray GPUArrays.RNG{AT}() @@ -88,4 +99,21 @@ @test (randn!(cpu_rng, A); true) end end + + # non-`Number` eltypes with a custom sampler must generate element-wise + # rather than hit an ambiguous fallback (JuliaGPU/CUDA.jl#3179) + if AT <: AbstractGPUArray + @testset "non-Number eltype" begin + A = AT{RGBTriplet}(undef, 1024) + rand!(rng, A) + h = Array(A) + @test all(t -> 0f0 <= t.r <= 1f0 && 0f0 <= t.g <= 1f0 && 0f0 <= t.b <= 1f0, h) + @test any(t -> t.r != h[1].r, h) + end + + # randn! on a non-float eltype must error cleanly, not recurse forever + @testset "randn! rejects non-float" begin + @test_throws MethodError randn!(rng, AT{Int32}(undef, 4)) + end + end end