From 89ad47ea455fe9f3674997c74355a6abc0202ad1 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 24 Oct 2025 15:34:41 +0100 Subject: [PATCH 1/2] Allow another plan first in MulPlan --- Project.toml | 2 +- src/plans.jl | 46 ++++++++++++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index dd6da321..57bd01b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ContinuumArrays" uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c" -version = "0.20.0" +version = "0.20.1" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/plans.jl b/src/plans.jl index 497041c4..cc06d42e 100644 --- a/src/plans.jl +++ b/src/plans.jl @@ -31,47 +31,55 @@ end Takes a factorization and supports it applied to different dimensions. """ -struct InvPlan{T, Facts<:Tuple, Dims} <: Plan{T} +struct InvPlan{T, Facts<:Tuple, Pln, Dims} <: Plan{T} factorizations::Facts + plan::Pln dims::Dims end -InvPlan(fact::Tuple, dims) = InvPlan{eltype(fact), typeof(fact), typeof(dims)}(fact, dims) -InvPlan(fact, dims) = InvPlan((fact,), dims) +InvPlan(fact::Tuple, plan, dims) = InvPlan{eltype(fact), typeof(fact), typeof(plan), typeof(dims)}(fact, dims) +InvPlan(fact::Tuple, dims) = InvPlan(fact, nothing, dims) +InvPlan(fact, dims...) = InvPlan((fact,), dims...) size(F::InvPlan) = size.(F.factorizations, 1) """ - MulPlan(matrix, dims) + MulPlan(matrix, [plan], dims) -Takes a matrix and supports it applied to different dimensions. +Takes a matrix and supports it applied to different dimensions, after applying a plan. """ -struct MulPlan{T, Fact<:Tuple, Dims} <: Plan{T} +struct MulPlan{T, Fact<:Tuple, Pln, Dims} <: Plan{T} matrices::Fact + plan::Pln dims::Dims end -MulPlan(mats::Tuple, dims) = MulPlan{eltype(mats), typeof(mats), typeof(dims)}(mats, dims) -MulPlan(mats::AbstractMatrix, dims) = MulPlan((mats,), dims) +MulPlan(mats::Tuple, plan, dims) = MulPlan{eltype(mats), typeof(mats), typeof(plan), typeof(dims)}(mats, plan, dims) +MulPlan(mats::Tuple, dims) = MulPlan(mats, nothing, dims) +MulPlan(mats::AbstractMatrix, dims...) = MulPlan((mats,), dims...) + +_transformifnotnothing(::Nothing, x) = x +_transformifnotnothing(P, x) = P*x for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizations))) @eval begin - function *(P::$Pln{<:Any,<:Tuple,Int}, x::AbstractVector) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, x::AbstractVector) @assert P.dims == 1 - $op(only(getfield(P, $fld)), x) # Only a single factorization when dims isa Int + $op(only(getfield(P, $fld)), _transformifnotnothing(P.plan, x)) # Only a single factorization when dims isa Int end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractMatrix) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, X::AbstractMatrix) if P.dims == 1 $op(only(getfield(P, $fld)), X) # Only a single factorization when dims isa Int else @assert P.dims == 2 - permutedims($op(only(getfield(P, $fld)), permutedims(X))) + permutedims($op(only(getfield(P, $fld)), permutedims(_transformifnotnothing(P.plan, X)))) end end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,3}) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, Xin::AbstractArray{<:Any,3}) + X = _transformifnotnothing(P.plan, Xin) Y = similar(X) if P.dims == 1 for j in axes(X,3) @@ -90,7 +98,8 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati Y end - function *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray{<:Any,4}) + function *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, Xin::AbstractArray{<:Any,4}) + X = _transformifnotnothing(P.plan, Xin) Y = similar(X) if P.dims == 1 for j in axes(X,3), l in axes(X,4) @@ -116,7 +125,8 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray) = error("Overload") - function *(P::$Pln, X::AbstractArray) + function *(P::$Pln, Xin::AbstractArray) + X = _transformifnotnothing(P.plan, Xin) for (fac,dim) in zip(getfield(P, $fld), P.dims) X = $Pln(fac, dim) * X end @@ -125,7 +135,7 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati end end -*(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.dims) +*(A::AbstractMatrix, P::MulPlan) = MulPlan(Ref(A) .* P.matrices, P.plan, P.dims) -inv(P::MulPlan) = InvPlan(map(factorize,P.matrices), P.dims) -inv(P::InvPlan) = MulPlan(convert.(Matrix,P.factorizations), P.dims) \ No newline at end of file +inv(P::MulPlan{<:Any,<:Any,Nothing}) = InvPlan(map(factorize,P.matrices), P.dims) +inv(P::InvPlan{<:Any,<:Any,Nothing}) = MulPlan(convert.(Matrix,P.factorizations), P.dims) \ No newline at end of file From f39b0670bdd71a07d6c0f8f6c3e6cee281306c54 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Fri, 24 Oct 2025 17:28:10 +0100 Subject: [PATCH 2/2] Update plans.jl --- src/plans.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plans.jl b/src/plans.jl index cc06d42e..882ee544 100644 --- a/src/plans.jl +++ b/src/plans.jl @@ -37,7 +37,7 @@ struct InvPlan{T, Facts<:Tuple, Pln, Dims} <: Plan{T} dims::Dims end -InvPlan(fact::Tuple, plan, dims) = InvPlan{eltype(fact), typeof(fact), typeof(plan), typeof(dims)}(fact, dims) +InvPlan(fact::Tuple, plan, dims) = InvPlan{mapreduce(eltype,promote_type,fact), typeof(fact), typeof(plan), typeof(dims)}(fact, plan, dims) InvPlan(fact::Tuple, dims) = InvPlan(fact, nothing, dims) InvPlan(fact, dims...) = InvPlan((fact,), dims...) @@ -55,7 +55,7 @@ struct MulPlan{T, Fact<:Tuple, Pln, Dims} <: Plan{T} dims::Dims end -MulPlan(mats::Tuple, plan, dims) = MulPlan{eltype(mats), typeof(mats), typeof(plan), typeof(dims)}(mats, plan, dims) +MulPlan(mats::Tuple, plan, dims) = MulPlan{mapreduce(eltype,promote_type,mats), typeof(mats), typeof(plan), typeof(dims)}(mats, plan, dims) MulPlan(mats::Tuple, dims) = MulPlan(mats, nothing, dims) MulPlan(mats::AbstractMatrix, dims...) = MulPlan((mats,), dims...) @@ -123,7 +123,7 @@ for (Pln,op,fld) in ((:MulPlan, :*, :(:matrices)), (:InvPlan, :\, :(:factorizati - *(P::$Pln{<:Any,<:Tuple,Int}, X::AbstractArray) = error("Overload") + *(P::$Pln{<:Any,<:Tuple,<:Any,Int}, X::AbstractArray) = error("Overload") function *(P::$Pln, Xin::AbstractArray) X = _transformifnotnothing(P.plan, Xin)