diff --git a/lib/cunumeric_jl_wrapper/include/types.h b/lib/cunumeric_jl_wrapper/include/types.h index c75754db..88a18dc9 100644 --- a/lib/cunumeric_jl_wrapper/include/types.h +++ b/lib/cunumeric_jl_wrapper/include/types.h @@ -65,3 +65,6 @@ void wrap_unary_reds(jlcxx::Module&); // Binary op codes void wrap_binary_ops(jlcxx::Module&); + +// Linear algebra op codes +void wrap_linalg_ops(jlcxx::Module& mod); diff --git a/lib/cunumeric_jl_wrapper/include/ufi.h b/lib/cunumeric_jl_wrapper/include/ufi.h index f5558c16..d189c8fa 100644 --- a/lib/cunumeric_jl_wrapper/include/ufi.h +++ b/lib/cunumeric_jl_wrapper/include/ufi.h @@ -1,6 +1,6 @@ /* Copyright 2026 Northwestern University, * Carnegie Mellon University University - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/lib/cunumeric_jl_wrapper/src/types.cpp b/lib/cunumeric_jl_wrapper/src/types.cpp index 2e73ebd5..f181dc2e 100644 --- a/lib/cunumeric_jl_wrapper/src/types.cpp +++ b/lib/cunumeric_jl_wrapper/src/types.cpp @@ -162,3 +162,8 @@ void wrap_binary_ops(jlcxx::Module& mod) { mod.set_const("SUBTRACT", CuPyNumericBinaryOpCode::CUPYNUMERIC_BINOP_SUBTRACT); } + +void wrap_linalg_ops(jlcxx::Module& mod) { + mod.set_const("SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_SOLVE}); + mod.set_const("MP_SOLVE", legate::LocalTaskID{CuPyNumericOpCode::CUPYNUMERIC_MP_SOLVE}); +} \ No newline at end of file diff --git a/lib/cunumeric_jl_wrapper/src/wrapper.cpp b/lib/cunumeric_jl_wrapper/src/wrapper.cpp index 29334333..e4b19331 100644 --- a/lib/cunumeric_jl_wrapper/src/wrapper.cpp +++ b/lib/cunumeric_jl_wrapper/src/wrapper.cpp @@ -68,6 +68,7 @@ JLCXX_MODULE define_julia_module(jlcxx::Module& mod) { wrap_unary_ops(mod); wrap_binary_ops(mod); wrap_unary_reds(mod); + wrap_linalg_ops(mod); using jlcxx::ParameterList; using jlcxx::Parametric; diff --git a/src/ndarray/detail/ndarray.jl b/src/ndarray/detail/ndarray.jl index b11c4101..282dd568 100644 --- a/src/ndarray/detail/ndarray.jl +++ b/src/ndarray/detail/ndarray.jl @@ -494,3 +494,9 @@ function compare(arr::NDArray{T,N}, arr2::NDArray{T,N}, atol::Real, rtol::Real) # successful completion return true end + +function nda_to_logical_store(arr::NDArray{T,N}) where {T,N} + la_handle = cuNumeric.get_store(arr) + st_handle = Legate.data(Legate.LogicalArray{T,N}(la_handle[], size(arr))) + return Legate.LogicalStore{T,N}(st_handle, size(arr)) +end diff --git a/src/ndarray/linalg.jl b/src/ndarray/linalg.jl new file mode 100644 index 00000000..29bab673 --- /dev/null +++ b/src/ndarray/linalg.jl @@ -0,0 +1,126 @@ +function choose_nd_color_shape(shape::NTuple{N,Int}) where N + color_shape = Base.ones(Int, N) + if N > 2 + color_shape[1] = Legate.num_procs() + done = false + while !done && color_shape[1] % 2 == 0 + weight_per_dim = [shape[i] / color_shape[i] for i in 1:N-2] + max_weight, idx = findmax(weight_per_dim) + if weight_per_dim[idx] > 2 * weight_per_dim[1] + color_shape[1] ÷= 2 + color_shape[idx] *= 2 + else + done = true + end + end + end + return Tuple(color_shape) +end + +function prepare_manual_task_for_batched_matrices(full_shape::NTuple{N,Int}) where N + initial_color_shape = choose_nd_color_shape(full_shape) + tilesize = Tuple((full_shape[i] + initial_color_shape[i] - 1) ÷ initial_color_shape[i] for i in 1:N) + color_shape = Tuple((full_shape[i] + tilesize[i] - 1) ÷ tilesize[i] for i in 1:N) + return tilesize, color_shape +end + +function solve_batched(a::NDArray{T,N}, b::NDArray, x::NDArray) where {T,N} + nrhs = size(b)[end] + full_shape = size(a) + tilesize_a, color_shape = prepare_manual_task_for_batched_matrices(full_shape) + tilesize_b = (tilesize_a[1:end-1]..., nrhs) + + store_a = nda_to_logical_store(a) + store_b = nda_to_logical_store(b) + store_x = nda_to_logical_store(x) + + tiled_a = Legate.partition_by_tiling(store_a, collect(tilesize_a)) + tiled_b = Legate.partition_by_tiling(store_b, collect(tilesize_b)) + tiled_x = Legate.partition_by_tiling(store_x, collect(tilesize_b)) + + rt = Legate.get_runtime() + domain = Legate.domain_from_shape(Legate.Shape(Legate.to_cxx_vector(color_shape))) + lib = cuNumeric.get_lib() + task = Legate.create_manual_task(rt, lib, cuNumeric.SOLVE, domain) + + Legate.add_input(task, tiled_a) + Legate.add_input(task, tiled_b) + Legate.add_output(task, tiled_x) + + Legate.submit_manual_task(rt, task) +end + +# Dimension guards +function solve(a::NDArray{T,1}, b::NDArray{S,M}) where {T,S,M} + throw(ArgumentError("1-dimensional array given. Array must be at least two-dimensional")) +end + +function solve(a::NDArray{T,0}, b::NDArray{S,M}) where {T,S,M} + throw(ArgumentError("0-dimensional array given. Array must be at least two-dimensional")) +end + +function solve(a::NDArray{T,N}, b::NDArray{S,0}) where {T,N,S} + throw(ArgumentError("0-dimensional array given. Array must be at least one-dimensional")) +end + +# Float16 guards +function solve(a::NDArray{Float16,N}, b::NDArray{S,M}) where {N,S,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) +end + +function solve(a::NDArray{T,N}, b::NDArray{Float16,M}) where {T,N,M} + throw(ArgumentError("array type float16 is unsupported in linalg")) +end + +# 2D case: (m,m),(m)->( m) +function solve(a::NDArray{T,2}, b::NDArray{S,1}) where {T,S} + size(a)[end-1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[2] != size(b)[1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m)->(m) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# 2D case: (m,m),(m,n)->(m,n) +function solve(a::NDArray{T,2}, b::NDArray{S,2}) where {T,S} + size(a)[end-1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[2] != size(b)[1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension 0, " * + "with signature (m,m),(m,n)->(m,n) (size $(size(b)[1]) " * + "is different from $(size(a)[2]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# Batched case: (...,m,m),(...,m,n)->(...,m,n) +function solve(a::NDArray{T,N}, b::NDArray{S,N}) where {T,S,N} + size(a)[end-1] != size(a)[end] && + throw(ArgumentError("Last 2 dimensions of the array must be square")) + size(a)[end] != size(b)[end-1] && + throw(ArgumentError( + "Input operand 1 has a mismatch in its dimension " * + "$(N-2), with signature (...,m,m),(...,m,n)->(...,m,n)" * + " (size $(size(b)[end-1]) is different from $(size(a)[end]))" + )) + prod(size(a)) == 0 || prod(size(b)) == 0 && return zeros(T, size(b)...) + x = zeros(T, size(b)...) + solve_batched(a, b, x) + return x +end + +# Mismatched batch dimensions +function solve(a::NDArray{T,N}, b::NDArray{S,M}) where {T,N,S,M} + throw(ArgumentError("Batched matrices require signature (...,m,m),(...,m,n)->(...,m,n)")) +end \ No newline at end of file diff --git a/src/ndarray/ndarray.jl b/src/ndarray/ndarray.jl index d58f6572..df20f2ce 100644 --- a/src/ndarray/ndarray.jl +++ b/src/ndarray/ndarray.jl @@ -762,4 +762,4 @@ end function Base.isapprox(arr::NDArray{T}, arr2::NDArray{T}; atol=0, rtol=0) where {T} return compare(arr, arr2, atol, rtol) -end +end \ No newline at end of file diff --git a/test/tests/linalg.jl b/test/tests/linalg.jl index 32a18100..e2e66d03 100644 --- a/test/tests/linalg.jl +++ b/test/tests/linalg.jl @@ -102,3 +102,40 @@ end @test sort(Array(out)) == sort(ref) end + +@testset "solve diagonal" begin + n = 4 + A = cuNumeric.zeros(Float64, n, n) + b = cuNumeric.zeros(Float64, n, 1) + cuNumeric.@allowscalar for i in 1:n + A[i, i] = 4.0 + b[i, 1] = 1.0 + end + x = cuNumeric.solve(A, b) + allowscalar() do + @test cuNumeric.compare(fill(0.25, n, 1), x, atol(Float64), rtol(Float64)) + end +end + +@testset "solve identity" begin + n = 4 + A = cuNumeric.NDArray(Matrix{Float64}(I, n, n)) + b = cuNumeric.NDArray(reshape(collect(1.0:n), n, 1)) + x = cuNumeric.solve(A, b) + ref = reshape(collect(1.0:n), n, 1) + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end +end + +@testset "solve general" begin + A_ref = [2.0 1.0; 5.0 7.0] + b_ref = [11.0; 13.0;;] + A = cuNumeric.NDArray(A_ref) + b = cuNumeric.NDArray(b_ref) + x = cuNumeric.solve(A, b) + ref = A_ref \ b_ref + allowscalar() do + @test cuNumeric.compare(ref, x, atol(Float64), rtol(Float64)) + end +end