diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 410b88a9..536f345b 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -459,9 +459,10 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac @inbounds if i <= size(A,1) && j <= size(B,2) z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Cij = convert(promote_type(R, typeof(z2)), z2) + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) for k in 1:size(A,2) - Cij += A[i, k]*B[k, j] + Cij += convert(Tacc, A[i, k]) * convert(Tacc, B[k, j]) end C[i,j] = add(Cij, C[i,j]) end @@ -511,10 +512,11 @@ function _triangular_matmatmul!(C::AbstractGPUVecOrMat{R}, A::AbstractTriangular @inbounds if i <= l && j <= n z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,i] * B[i,j] + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += convert(Tacc, A[i,i]) * convert(Tacc, B[i,j]) for k in (upperA ? (i + 1) : 1):(upperA ? m : (i - 1)) - Cij += A[i,k] * B[k,j] + Cij += convert(Tacc, A[i,k]) * convert(Tacc, B[k,j]) end # treat C as write-only when beta is zero (it may hold NaN/Inf) C[i,j] = iszero(beta) ? alpha * Cij : alpha * Cij + beta * C[i,j] @@ -596,10 +598,11 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += (unit ? one(Cij) : A[i,i]) * B[i,j] + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += (unit ? one(Cij) : convert(Tacc, A[i,i])) * convert(Tacc, B[i,j]) for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) - Cij += A[i,k] * B[k,j] + Cij += convert(Tacc, A[i,k]) * convert(Tacc, B[k,j]) end C[i,j] += Cij end @@ -613,10 +616,11 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += (unit ? one(Cij) : transpose(A[i,i])) * B[i,j] + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += (unit ? one(Cij) : transpose(convert(Tacc, A[i,i]))) * convert(Tacc, B[i,j]) for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) - Cij += transpose(A[k,i]) * B[k,j] + Cij += transpose(convert(Tacc, A[k,i])) * convert(Tacc, B[k,j]) end C[i,j] += Cij end @@ -630,10 +634,11 @@ function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += (unit ? one(Cij) : adjoint(A[i,i])) * B[i,j] + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += (unit ? one(Cij) : adjoint(convert(Tacc, A[i,i]))) * convert(Tacc, B[i,j]) for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) - Cij += adjoint(A[k,i]) * B[k,j] + Cij += adjoint(convert(Tacc, A[k,i])) * convert(Tacc, B[k,j]) end C[i,j] += Cij end @@ -674,10 +679,11 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,j] * (unit ? one(Cij) : B[j,j]) + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += convert(Tacc, A[i,j]) * (unit ? one(Cij) : convert(Tacc, B[j,j])) for k in (upper ? 1 : (j + 1)):(upper ? (j - 1) : m) - Cij += A[i,k] * B[k,j] + Cij += convert(Tacc, A[i,k]) * convert(Tacc, B[k,j]) end C[i,j] += Cij end @@ -691,10 +697,11 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,j] * (unit ? one(Cij) : transpose(B[j,j])) + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += convert(Tacc, A[i,j]) * (unit ? one(Cij) : transpose(convert(Tacc, B[j,j]))) for k in (upper ? 1 : (j + 1) ):(upper ? (j - 1) : m) - Cij += A[i,k] * transpose(B[j,k]) + Cij += convert(Tacc, A[i,k]) * transpose(convert(Tacc, B[j,k])) end C[i,j] += Cij end @@ -708,10 +715,11 @@ function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Fun @inbounds if i <= l && j <= n z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) - Cij = convert(promote_type(R, typeof(z2)), z2) - Cij += A[i,j] * (unit ? one(Cij) : adjoint(B[j,j])) + Tacc = promote_type(R, typeof(z2)) + Cij = convert(Tacc, z2) + Cij += convert(Tacc, A[i,j]) * (unit ? one(Cij) : adjoint(convert(Tacc, B[j,j]))) for k in (upper ? 1 : (j + 1)):(upper ? (j - 1) : m) - Cij += A[i,k] * adjoint(B[j,k]) + Cij += convert(Tacc, A[i,k]) * adjoint(convert(Tacc, B[j,k])) end C[i,j] += Cij end diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index d8e587e0..a7ba91e6 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -193,7 +193,7 @@ A = AT(rand(T, n, n)) B = AT(rand(T, n, n)) Ct = AT(rand(T, n, n)) - C = collect(Ct) + C = collect(Ct) mul!(Ct, TR1(A), TR2(B), 1, -1) mul!(C, TR1(collect(A)), TR2(collect(B)), 1, -1) @test collect(Ct) ≈ C @@ -576,6 +576,31 @@ end end end +@testsuite "linalg/mul!/integer-accumulate" (AT, eltypes)->begin + # products must be formed in the wide accumulator type, not the narrow input type, else + # narrow-integer products overflow. + gpu = AT <: AbstractGPUArray + @testset "$Tin -> $Tout" for (Tin, Tout) in ((Int16, Int32), (Int16, Int64), (Int32, Int64)) + Tin in eltypes || continue + n = 16 + hi = Tin(4) * isqrt(typemax(Tin)) # |a*b| reliably exceeds typemax(Tin) + rng = (-hi):hi + A, B, x = rand(rng, n, n), rand(rng, n, n), rand(rng, n) + dA, dB, dx = AT(A), AT(B), AT(x) + + if gpu || VERSION >= v"1.11" + C = AT(zeros(Tout, n, n)) + mul!(C, dA, dB) + @test Array(C) == Tout.(Int64.(A) * Int64.(B)) + end + if gpu # JuliaLang/LinearAlgebra.jl#1659 + c = AT(zeros(Tout, n)) + mul!(c, dA, dx) + @test Array(c) == Tout.(Int64.(A) * Int64.(x)) + end + end +end + @testsuite "linalg/norm" (AT, eltypes)->begin @testset "$p-norm($sz x $T)" for sz in [(2,), (2,0), (2,2,2)], p in Any[0, 0.5, 1, 1.5, 2, Inf, -Inf],