From c6f81498f9493db4faac2c7f9fb476aa3cbd09cc Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Mon, 4 Aug 2025 22:17:23 -0500 Subject: [PATCH 1/8] add queuepenalty --- src/NMFMerge.jl | 56 +++++++++++++++++++++++++----------------------- test/runtests.jl | 29 ++++++++++++++----------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 5ad9bee..df252f9 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -14,6 +14,8 @@ Performs "NMF-Merge" on data matrix `X`. Arguments: +-`queuepenalty`: a function of the form `f(λ_min, t1sq, t2sq)` that computes the penalty for merging two components, where `λ_min` is the smaller eigenvalue of the generalized eigenvalue problem. + - `X::AbstractMatrix`: the data matrix to be factorized - `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, @@ -35,7 +37,7 @@ Keyword arguments: Other keywords arguments are passed to `NMF.nnmf`. """ -function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) +function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) n1, n2 = ncomponents f = tsvd(X, n2) Un, Sn, Vn = f @@ -50,11 +52,13 @@ function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediat result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) W_over, H_over = result_over.W, result_over.H W_over_normed, H_over_normed = colnormalize(W_over, H_over) - Wmerge, Hmerge, _ = colmerge2to1pq(W_over_normed, H_over_normed, n2) + Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) return result_renmf end -nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) +nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) function colnormalize!(W, H, p::Integer=2) nonzerocolids = Int[] @@ -89,7 +93,7 @@ components remain. `mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the number of columns in `W` indicate the output of previous merge steps. """ -function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) +function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) mrgseq = Tuple{Int, Int}[] S = let S = S # julia #15276 [S[:, j] for j in axes(S, 2)] @@ -103,7 +107,10 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) Nt = length(S) Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) Nt >= n || throw(ArgumentError("Final solution more than original size")) - pq = initialize_pq_2to1(S, T) + pq = PriorityQueue{Tuple{Int,Int},Float64}() + for id0 in length(S):-1:2 + pq = pqupdate2to1!(pq, queuepenalty, S, T, id0, 1:id0-1) + end m = Nt while m > n id0, id1 = dequeue!(pq) @@ -112,33 +119,28 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) end push!(mrgseq, (id0, id1)) S, T, id01, _ = mergecol2to1!(S, T, id0, id1); - pqupdate2to1!(pq, S, T, id01, 1:id01-1); + pqupdate2to1!(pq, queuepenalty, S, T, id01, 1:id01-1); m -= 1 end Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' return Smtx, Matrix(Tmtx), mrgseq end +colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) -function initialize_pq_2to1(S::AbstractVector, T::AbstractVector) - err_pq = PriorityQueue{Tuple{Int, Int},Float64}() - for id0 in length(S):-1:2 - err_pq = pqupdate2to1!(err_pq, S, T, id0, 1:id0-1) - end - return err_pq -end - -function pqupdate2to1!(pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To +function pqupdate2to1!(pq, queuepenalty::Function, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To for id in overlapids if !isempty(S[id]) && !isempty(S[id01]) - loss = solve_remix(S, T, id, id01)[2] - enqueue!(pq, (id, id01), loss) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) + loss = solve_remix(t1sq, t1t2, t2sq, c)[2] + enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) end end return pq end -function solve_remix(S, T, id1, id2) - τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) +function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) if h1h1 == 0 return c, zero(c), (zero(c), one(c)) end @@ -159,13 +161,9 @@ function solve_remix(S, T, id1, id2) end function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) - c = W[id1]'*W[id2] - h1h1 = H[id1]'*H[id1] - h1h2 = H[id1]'*H[id2] - h2h2 = H[id2]'*H[id2] - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) - return τ, δ, c, h1h1, h1h2, h2h2 + h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] + c = W[id1]'*W[id2] # assumes normalization + return h1sq, h1h2, h2sq, c end function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) @@ -178,7 +176,8 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1:: end function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) - c, loss, u, = solve_remix(S, T, id1, id2) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) + c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) S12, T12 = remix_enact(S, T, id1, id2, c, u) return S12, T12, loss end @@ -221,4 +220,7 @@ function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArra return Smtx, Matrix(Tmtx), STstage, Err end +mergepenalty(λ_min, t1sq, t2sq) = λ_min +shotpenalty(λ_min, t1sq, t2sq) = λ_min / sqrt(min(t1sq, t2sq)) + end diff --git a/test/runtests.jl b/test/runtests.jl index 5f6f88f..598d887 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -32,7 +32,7 @@ H_GT = [6 10 8 2 0 1 2 10; 4 9 10 7 7 0 0 0 ] -@testset "test top wrapper" begin +@testset "test top wrapper" begin W = W_GT[:, 3:4] H = H_GT[3:4, :] X = W*H @@ -51,7 +51,6 @@ H_GT = [6 10 8 2 0 1 2 10; @test sum(abs2, W_standard - W_renmf) <= 1e-12 @test sum(abs2, H_standard - H_renmf) <= 1e-12 - X = rand(30, 20) result_1 = nmfmerge(X, 10; alg=:cd) result_2 = nmfmerge(X, 12 => 10; alg=:cd) @@ -67,7 +66,7 @@ H_GT = [6 10 8 2 0 1 2 10; result_2 = nmfmerge(X, 10 => 8; alg=:cd) @test sum(abs2, result_1.W - result_2.W) <= 1e-12 @test sum(abs2, result_1.H - result_2.H) <= 1e-12 - + end @testset "merge coefficients" begin @@ -86,8 +85,10 @@ end idx = argmax(Fvals) w = Fvecs[:,idx] - τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W_v, H_v, 1, 2) - c, p, u = NMFMerge.solve_remix(W_v, H_v, 1, 2) + h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W_v, H_v, 1, 2) + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) + c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b @@ -101,8 +102,8 @@ end @test norm(u[1].*W_v[1].+u[2].*W_v[2]) ≈ 1 @test norm(Q1*u - maximum(F.values)*Q2*u) <= 1e-10 @test norm(Q1*u - λ_max*Q2*u) <= 1e-10 - - W12, H12, loss = NMFMerge.mergepair(W_v, H_v, 1, 2) + + W12, H12, _ = NMFMerge.mergepair(W_v, H_v, 1, 2) Err(Hm) = sum(abs2, W12 * Hm' - W * H) @test norm(ForwardDiff.gradient(Err, H12)) <= 1e-10 end @@ -121,7 +122,7 @@ end imgnf = NMF.solve!(NMF.CoordinateDescent{Float64}(), img, W0, H0) W1, H1 = imgnf.W, imgnf.H W1n, H1n = colnormalize(W1, H1) - [@test abs(norm(W1n[:,j], 2)-1) <= 1e-12 for j in axes(W1n, 2)] + [@test abs(norm(W1n[:,j], 2)-1) <= 1e-12 for j in axes(W1n, 2)] W2 = [W1n[:, j] for j in axes(W1n, 2)]; H2 = [H1n[i, :] for i in axes(H1n, 1)]; @@ -134,13 +135,15 @@ end idx = argmax(Fvals) w = Fvecs[:,idx] - τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W2, H2, 1, 2) - c, p, u = NMFMerge.solve_remix(W2, H2, 1, 2) + h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W2, H2, 1, 2) + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) + c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b λ_min = δ/λ_max - + @test abs(λ_max - maximum(F.values))<=1e-12 @test abs(λ_min - minimum(F.values))<=1e-10 @@ -150,7 +153,7 @@ end @test norm(Q1*u - maximum(F.values)*Q2*u) <= 1e-10 @test norm(Q1*u - λ_max*Q2*u) <= 1e-10 - W12, H12, loss = NMFMerge.mergepair(W2, H2, 1, 2) + W12, H12, _ = NMFMerge.mergepair(W2, H2, 1, 2) Err(Hm) = sum(abs2, W12*Hm'-W1*H1) @test norm(ForwardDiff.gradient(Err, H12)) <= 1e-12 @@ -169,7 +172,7 @@ end N1b = randn(length(S1)); N1b = N1b / norm(N1b) * coef T2 = zero(T1) T2[15] = 0.25 * sqrt(min(sum(abs2, N1a) * sum(abs2, T1a), sum(abs2, N1b) * sum(abs2, T1b))) - + W, H = [S1 S1 S2], [T1a'; T1b'; T2'] W0, H0 = [S1 S2], [T1'; T2'] Wn, Hn = colnormalize(W, H) From b895242204965a93e12fbf8e0edafe732e83afd5 Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Tue, 5 Aug 2025 12:37:59 -0500 Subject: [PATCH 2/8] fix some api --- src/NMFMerge.jl | 10 +++++----- test/runtests.jl | 13 ++++++------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index df252f9..7cf7c2d 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -8,13 +8,13 @@ export nmfmerge, mergecolumns """ - result = nmfmerge(X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + result = nmfmerge(queuepenalty, X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) Performs "NMF-Merge" on data matrix `X`. Arguments: --`queuepenalty`: a function of the form `f(λ_min, t1sq, t2sq)` that computes the penalty for merging two components, where `λ_min` is the smaller eigenvalue of the generalized eigenvalue problem. +-`queuepenalty`: a function of the form `f(E, t1sq, t2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, t1sq, t2sq)=E. - `X::AbstractMatrix`: the data matrix to be factorized @@ -109,7 +109,7 @@ function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Int Nt >= n || throw(ArgumentError("Final solution more than original size")) pq = PriorityQueue{Tuple{Int,Int},Float64}() for id0 in length(S):-1:2 - pq = pqupdate2to1!(pq, queuepenalty, S, T, id0, 1:id0-1) + pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) end m = Nt while m > n @@ -119,7 +119,7 @@ function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Int end push!(mrgseq, (id0, id1)) S, T, id01, _ = mergecol2to1!(S, T, id0, id1); - pqupdate2to1!(pq, queuepenalty, S, T, id01, 1:id01-1); + pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); m -= 1 end Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' @@ -127,7 +127,7 @@ function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Int end colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) -function pqupdate2to1!(pq, queuepenalty::Function, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To +function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To for id in overlapids if !isempty(S[id]) && !isempty(S[id01]) t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) diff --git a/test/runtests.jl b/test/runtests.jl index 598d887..d40d4c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,19 +54,18 @@ H_GT = [6 10 8 2 0 1 2 10; X = rand(30, 20) result_1 = nmfmerge(X, 10; alg=:cd) result_2 = nmfmerge(X, 12 => 10; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) result_1 = nmfmerge(X, 4; alg=:cd) result_2 = nmfmerge(X, 5 => 4; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) result_1 = nmfmerge(X, 8; alg=:cd) result_2 = nmfmerge(X, 10 => 8; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 - + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) end @testset "merge coefficients" begin From 04bfa3179b63b5032766d8ab705abccb49266b32 Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Tue, 5 Aug 2025 12:45:44 -0500 Subject: [PATCH 3/8] delete shotpenalty --- src/NMFMerge.jl | 451 ++++++++++++++++++++++++------------------------ 1 file changed, 225 insertions(+), 226 deletions(-) diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 7cf7c2d..9b2a4eb 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -1,226 +1,225 @@ -module NMFMerge - -using LinearAlgebra, DataStructures, NMF, GsvdInitialization, TSVD - -export nmfmerge, - colnormalize, - colmerge2to1pq, - mergecolumns - -""" - result = nmfmerge(queuepenalty, X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) - -Performs "NMF-Merge" on data matrix `X`. - -Arguments: - --`queuepenalty`: a function of the form `f(E, t1sq, t2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, t1sq, t2sq)=E. - -- `X::AbstractMatrix`: the data matrix to be factorized - -- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, - where `n1` is the number of components for overcomplete NMF, and `n2` is the number of components for the final NMF. - We require `n1 >= n2`. - -Alternatively, `ncomponents` can be an integer denoting the final number of components. In this case, `nmfmerge` -defaults to an approximate 20% component excess before merging. - - -Keyword arguments: - -- `tol_final`: The tolerance of final NMF - -- `tol_intermediate`: The tolerence of initial and overcomplete NMF - -`W0`, `H0`: initialization for the initial NMF. If at least one of `W0` and `H0` is `nothing`, NNDSVD is used for initialization. - - -Other keywords arguments are passed to `NMF.nnmf`. -""" -function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) - n1, n2 = ncomponents - f = tsvd(X, n2) - Un, Sn, Vn = f - if W0 === nothing || H0 === nothing - W0, H0 = NMF.nndsvd(X, n2, initdata=(U = Un, S = Sn, V = Vn)) - end - result_initial = nnmf(X, n2; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) - W_initial, H_initial = result_initial.W, result_initial.H - kadd = n1 - n2 - kadd >= 0 || throw(ArgumentError("Cannot merge to more components than original")) - W_over_init, H_over_init = gsvdrecover(X, W_initial, H_initial, kadd, f) - result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) - W_over, H_over = result_over.W, result_over.H - W_over_normed, H_over_normed = colnormalize(W_over, H_over) - Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) - result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) - return result_renmf -end -nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) -nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) -nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) - -function colnormalize!(W, H, p::Integer=2) - nonzerocolids = Int[] - for (j, w) in pairs(eachcol(W)) - normw = norm(w, p) - if !iszero(normw) - W[:, j] = w/normw - H[j, :] = H[j, :]*normw - push!(nonzerocolids, j) - end - end - W, H = W[:, nonzerocolids], H[nonzerocolids, :] - return W, H -end - -""" - Wnormalized, Hnormalized = colnormalize(W, H, p=2) - -Normalize the factorization so that each column satisfies `||W[:, i]||_p ≈ 1`. - -""" -colnormalize(W, H, p::Integer=2) = colnormalize!(float(copy(W)), float(copy(H)), p) - -""" - Wmerge, Hmerge, mergeseq = colmerge2to1pq(W::AbstractArray, H::AbstractArray, n::Integer) - -Merge components in `W` and `H` (columns in `W` and rows in `H`) until only `n` -components remain. - -`Wmerge` and `Hmerge` are the merged results with `n` components. - -`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the -number of columns in `W` indicate the output of previous merge steps. -""" -function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) - mrgseq = Tuple{Int, Int}[] - S = let S = S # julia #15276 - [S[:, j] for j in axes(S, 2)] - end - T = let T = T - [T[i, :] for i in axes(T, 1)] - end - for s in S - abs(norm(s)-1)<1e-12 || throw(ArgumentError("W columns must be normalized")) - end - Nt = length(S) - Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) - Nt >= n || throw(ArgumentError("Final solution more than original size")) - pq = PriorityQueue{Tuple{Int,Int},Float64}() - for id0 in length(S):-1:2 - pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) - end - m = Nt - while m > n - id0, id1 = dequeue!(pq) - if isempty(S[id0])||isempty(S[id1]) - continue - end - push!(mrgseq, (id0, id1)) - S, T, id01, _ = mergecol2to1!(S, T, id0, id1); - pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); - m -= 1 - end - Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' - return Smtx, Matrix(Tmtx), mrgseq -end -colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) - -function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To - for id in overlapids - if !isempty(S[id]) && !isempty(S[id01]) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) - loss = solve_remix(t1sq, t1t2, t2sq, c)[2] - enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) - end - end - return pq -end - -function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) - if h1h1 == 0 - return c, zero(c), (zero(c), one(c)) - end - if h2h2 == 0 - return c, zero(c), (one(c), zero(c)) - end - b = sqrt(τ^2/4-δ) - λ_max = τ/2+b - λ_min = δ/λ_max - den = (h1h2+c*h2h2)*2 - if iszero(den) - u = h1h1 >= h2h2 ? (one(c), zero(c)) : (zero(c), one(c)) - else - ξ = (h1h1-h2h2+2b)/den - u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) - end - return c, λ_min, u -end - -function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) - h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] - c = W[id1]'*W[id2] # assumes normalization - return h1sq, h1h2, h2sq, c -end - -function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) - S01, T01, loss = mergepair(S, T, id0, id1) - S[id0] = S[id1] = T[id0] = T[id1] = eltype(S[1])[] - id01 = length(S)+1 - push!(S, S01) - push!(T, T01) - return S, T, id01, loss -end - -function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) - c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) - S12, T12 = remix_enact(S, T, id1, id2, c, u) - return S12, T12, loss -end - -function remix_enact(S::AbstractVector{TS}, T::AbstractVector, id1::Integer, id2::Integer, c::AbstractFloat, w::Tuple{Tw, Tw}) where {Tw, TS} - S12 = zeros(eltype(TS), length(S[id1])) - S12 += w[1]*S[id1] - S12 += w[2]*S[id2] - T1, T2 = (w[1]+w[2]*c)*T[id1], (w[1]*c+w[2])*T[id2] - T12 = T1+T2 - return S12, T12 -end - -""" - Wmerge, Hmerge, WHstage, Err = mergecolumns(W, H, mergeseq; tracemerge=false) - -Merge components in `W` and `H` (columns in `W` and rows in `H`) according to the sequence of merge pair ids `mergeseq`. - -`Wmerge` and `Hmerge` are the merged results. - -`WHstage::Vector{Tuple{Matrix, Matrix}}` includes the results of each merge stage. `WHstage` is empty if `tracemerge=false`. - -`Err::Vector` includes merge penalty of each merge stage. -""" -function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArray; tracemerge::Bool = false) - Err = Float64[] - S = [W[:, j] for j in axes(W, 2)] - T = [H[i, :] for i in axes(H, 1)] - STstage = [] - for mergeids in mergeseq - id0, id1 = mergeids - if tracemerge - push!(STstage, (copy(S), copy(T))) - end - S, T, _, loss = mergecol2to1!(S, T, id0, id1) - err = loss - push!(Err, err) - end - Smtx, Tmtx = hcat(filter(x -> x != [], S)...), hcat(filter(x -> x != [], T)...)' - return Smtx, Matrix(Tmtx), STstage, Err -end - -mergepenalty(λ_min, t1sq, t2sq) = λ_min -shotpenalty(λ_min, t1sq, t2sq) = λ_min / sqrt(min(t1sq, t2sq)) - -end +module NMFMerge + +using LinearAlgebra, DataStructures, NMF, GsvdInitialization, TSVD + +export nmfmerge, + colnormalize, + colmerge2to1pq, + mergecolumns + +""" + result = nmfmerge(queuepenalty, X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + +Performs "NMF-Merge" on data matrix `X`. + +Arguments: + +-`queuepenalty`: a function of the form `f(E, t1sq, t2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, t1sq, t2sq)=E. + +- `X::AbstractMatrix`: the data matrix to be factorized + +- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, + where `n1` is the number of components for overcomplete NMF, and `n2` is the number of components for the final NMF. + We require `n1 >= n2`. + +Alternatively, `ncomponents` can be an integer denoting the final number of components. In this case, `nmfmerge` +defaults to an approximate 20% component excess before merging. + + +Keyword arguments: + +- `tol_final`: The tolerance of final NMF + +- `tol_intermediate`: The tolerence of initial and overcomplete NMF + +`W0`, `H0`: initialization for the initial NMF. If at least one of `W0` and `H0` is `nothing`, NNDSVD is used for initialization. + + +Other keywords arguments are passed to `NMF.nnmf`. +""" +function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + n1, n2 = ncomponents + f = tsvd(X, n2) + Un, Sn, Vn = f + if W0 === nothing || H0 === nothing + W0, H0 = NMF.nndsvd(X, n2, initdata=(U = Un, S = Sn, V = Vn)) + end + result_initial = nnmf(X, n2; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) + W_initial, H_initial = result_initial.W, result_initial.H + kadd = n1 - n2 + kadd >= 0 || throw(ArgumentError("Cannot merge to more components than original")) + W_over_init, H_over_init = gsvdrecover(X, W_initial, H_initial, kadd, f) + result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) + W_over, H_over = result_over.W, result_over.H + W_over_normed, H_over_normed = colnormalize(W_over, H_over) + Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) + result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) + return result_renmf +end +nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) +nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) + +function colnormalize!(W, H, p::Integer=2) + nonzerocolids = Int[] + for (j, w) in pairs(eachcol(W)) + normw = norm(w, p) + if !iszero(normw) + W[:, j] = w/normw + H[j, :] = H[j, :]*normw + push!(nonzerocolids, j) + end + end + W, H = W[:, nonzerocolids], H[nonzerocolids, :] + return W, H +end + +""" + Wnormalized, Hnormalized = colnormalize(W, H, p=2) + +Normalize the factorization so that each column satisfies `||W[:, i]||_p ≈ 1`. + +""" +colnormalize(W, H, p::Integer=2) = colnormalize!(float(copy(W)), float(copy(H)), p) + +""" + Wmerge, Hmerge, mergeseq = colmerge2to1pq(W::AbstractArray, H::AbstractArray, n::Integer) + +Merge components in `W` and `H` (columns in `W` and rows in `H`) until only `n` +components remain. + +`Wmerge` and `Hmerge` are the merged results with `n` components. + +`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the +number of columns in `W` indicate the output of previous merge steps. +""" +function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) + mrgseq = Tuple{Int, Int}[] + S = let S = S # julia #15276 + [S[:, j] for j in axes(S, 2)] + end + T = let T = T + [T[i, :] for i in axes(T, 1)] + end + for s in S + abs(norm(s)-1)<1e-12 || throw(ArgumentError("W columns must be normalized")) + end + Nt = length(S) + Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) + Nt >= n || throw(ArgumentError("Final solution more than original size")) + pq = PriorityQueue{Tuple{Int,Int},Float64}() + for id0 in length(S):-1:2 + pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) + end + m = Nt + while m > n + id0, id1 = dequeue!(pq) + if isempty(S[id0])||isempty(S[id1]) + continue + end + push!(mrgseq, (id0, id1)) + S, T, id01, _ = mergecol2to1!(S, T, id0, id1); + pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); + m -= 1 + end + Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' + return Smtx, Matrix(Tmtx), mrgseq +end +colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) + +function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To + for id in overlapids + if !isempty(S[id]) && !isempty(S[id01]) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) + loss = solve_remix(t1sq, t1t2, t2sq, c)[2] + enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) + end + end + return pq +end + +function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) + if h1h1 == 0 + return c, zero(c), (zero(c), one(c)) + end + if h2h2 == 0 + return c, zero(c), (one(c), zero(c)) + end + b = sqrt(τ^2/4-δ) + λ_max = τ/2+b + λ_min = δ/λ_max + den = (h1h2+c*h2h2)*2 + if iszero(den) + u = h1h1 >= h2h2 ? (one(c), zero(c)) : (zero(c), one(c)) + else + ξ = (h1h1-h2h2+2b)/den + u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) + end + return c, λ_min, u +end + +function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) + h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] + c = W[id1]'*W[id2] # assumes normalization + return h1sq, h1h2, h2sq, c +end + +function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) + S01, T01, loss = mergepair(S, T, id0, id1) + S[id0] = S[id1] = T[id0] = T[id1] = eltype(S[1])[] + id01 = length(S)+1 + push!(S, S01) + push!(T, T01) + return S, T, id01, loss +end + +function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) + c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) + S12, T12 = remix_enact(S, T, id1, id2, c, u) + return S12, T12, loss +end + +function remix_enact(S::AbstractVector{TS}, T::AbstractVector, id1::Integer, id2::Integer, c::AbstractFloat, w::Tuple{Tw, Tw}) where {Tw, TS} + S12 = zeros(eltype(TS), length(S[id1])) + S12 += w[1]*S[id1] + S12 += w[2]*S[id2] + T1, T2 = (w[1]+w[2]*c)*T[id1], (w[1]*c+w[2])*T[id2] + T12 = T1+T2 + return S12, T12 +end + +""" + Wmerge, Hmerge, WHstage, Err = mergecolumns(W, H, mergeseq; tracemerge=false) + +Merge components in `W` and `H` (columns in `W` and rows in `H`) according to the sequence of merge pair ids `mergeseq`. + +`Wmerge` and `Hmerge` are the merged results. + +`WHstage::Vector{Tuple{Matrix, Matrix}}` includes the results of each merge stage. `WHstage` is empty if `tracemerge=false`. + +`Err::Vector` includes merge penalty of each merge stage. +""" +function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArray; tracemerge::Bool = false) + Err = Float64[] + S = [W[:, j] for j in axes(W, 2)] + T = [H[i, :] for i in axes(H, 1)] + STstage = [] + for mergeids in mergeseq + id0, id1 = mergeids + if tracemerge + push!(STstage, (copy(S), copy(T))) + end + S, T, _, loss = mergecol2to1!(S, T, id0, id1) + err = loss + push!(Err, err) + end + Smtx, Tmtx = hcat(filter(x -> x != [], S)...), hcat(filter(x -> x != [], T)...)' + return Smtx, Matrix(Tmtx), STstage, Err +end + +mergepenalty(λ_min, t1sq, t2sq) = λ_min + +end From 84df5e895346589e0a51962b3a345ce895f94b5a Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Wed, 3 Sep 2025 15:10:14 -0500 Subject: [PATCH 4/8] update the docstring of queuepenalty --- .github/workflows/CI.yml | 154 +++++----- .github/workflows/CompatHelper.yml | 32 +- Project.toml | 58 ++-- src/NMFMerge.jl | 462 +++++++++++++++-------------- 4 files changed, 359 insertions(+), 347 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 00f9e5b..95f441c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,77 +1,77 @@ -name: CI -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false - matrix: - version: - - '1.10' - - '1' - os: - - ubuntu-latest - arch: - - x64 - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false - # docs: - # name: Documentation - # runs-on: ubuntu-latest - # permissions: - # actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created - # contents: write - # statuses: write - # steps: - # - uses: actions/checkout@v4 - # - uses: julia-actions/setup-julia@v2 - # with: - # version: '1' - # - uses: julia-actions/cache@v2 - # - name: Configure doc environment - # shell: julia --project=docs --color=yes {0} - # run: | - # using Pkg - # Pkg.develop(PackageSpec(path=pwd())) - # Pkg.instantiate() - # - uses: julia-actions/julia-buildpkg@v1 - # - uses: julia-actions/julia-docdeploy@v1 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} - # - name: Run doctests - # shell: julia --project=docs --color=yes {0} - # run: | - # using Documenter: DocMeta, doctest - # using NMFMerge - # DocMeta.setdocmeta!(NMFMerge, :DocTestSetup, :(using NMFMerge); recursive=true) - # doctest(NMFMerge) +name: CI +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read + strategy: + fail-fast: false + matrix: + version: + - '1.10' + - '1' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + # docs: + # name: Documentation + # runs-on: ubuntu-latest + # permissions: + # actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created + # contents: write + # statuses: write + # steps: + # - uses: actions/checkout@v4 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: '1' + # - uses: julia-actions/cache@v2 + # - name: Configure doc environment + # shell: julia --project=docs --color=yes {0} + # run: | + # using Pkg + # Pkg.develop(PackageSpec(path=pwd())) + # Pkg.instantiate() + # - uses: julia-actions/julia-buildpkg@v1 + # - uses: julia-actions/julia-docdeploy@v1 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} + # - name: Run doctests + # shell: julia --project=docs --color=yes {0} + # run: | + # using Documenter: DocMeta, doctest + # using NMFMerge + # DocMeta.setdocmeta!(NMFMerge, :DocTestSetup, :(using NMFMerge); recursive=true) + # doctest(NMFMerge) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 7886911..97d7bb5 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -1,16 +1,16 @@ -name: CompatHelper -on: - schedule: - - cron: '0 0 * * 0' - workflow_dispatch: -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main()' +name: CompatHelper +on: + schedule: + - cron: '0 0 * * 0' + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/Project.toml b/Project.toml index d3af602..cc00902 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,29 @@ -name = "NMFMerge" -uuid = "9cc52eda-dfaf-4e21-aae3-9f26bed153a3" -authors = ["youdongguo <1010705897@qq.com> and contributors"] -version = "1.0.1" - -[deps] -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -GsvdInitialization = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" -TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" - -[compat] -DataStructures = "0.18" -ForwardDiff = "0.10" -GsvdInitialization = "1" -LinearAlgebra = "1" -NMF = "1" -TSVD = "0.4" -Test = "1" -julia = "1.10" - -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["ForwardDiff", "NMF", "Test"] +name = "NMFMerge" +uuid = "9cc52eda-dfaf-4e21-aae3-9f26bed153a3" +authors = ["youdongguo <1010705897@qq.com> and contributors"] +version = "1.1.0" + +[deps] +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +GsvdInitialization = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" +TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" + +[compat] +DataStructures = "0.18" +ForwardDiff = "0.10" +GsvdInitialization = "1" +LinearAlgebra = "1" +NMF = "1" +TSVD = "0.4" +Test = "1" +julia = "1.10" + +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["ForwardDiff", "NMF", "Test"] diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 9b2a4eb..7cd3c1f 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -1,225 +1,237 @@ -module NMFMerge - -using LinearAlgebra, DataStructures, NMF, GsvdInitialization, TSVD - -export nmfmerge, - colnormalize, - colmerge2to1pq, - mergecolumns - -""" - result = nmfmerge(queuepenalty, X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) - -Performs "NMF-Merge" on data matrix `X`. - -Arguments: - --`queuepenalty`: a function of the form `f(E, t1sq, t2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, t1sq, t2sq)=E. - -- `X::AbstractMatrix`: the data matrix to be factorized - -- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, - where `n1` is the number of components for overcomplete NMF, and `n2` is the number of components for the final NMF. - We require `n1 >= n2`. - -Alternatively, `ncomponents` can be an integer denoting the final number of components. In this case, `nmfmerge` -defaults to an approximate 20% component excess before merging. - - -Keyword arguments: - -- `tol_final`: The tolerance of final NMF - -- `tol_intermediate`: The tolerence of initial and overcomplete NMF - -`W0`, `H0`: initialization for the initial NMF. If at least one of `W0` and `H0` is `nothing`, NNDSVD is used for initialization. - - -Other keywords arguments are passed to `NMF.nnmf`. -""" -function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) - n1, n2 = ncomponents - f = tsvd(X, n2) - Un, Sn, Vn = f - if W0 === nothing || H0 === nothing - W0, H0 = NMF.nndsvd(X, n2, initdata=(U = Un, S = Sn, V = Vn)) - end - result_initial = nnmf(X, n2; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) - W_initial, H_initial = result_initial.W, result_initial.H - kadd = n1 - n2 - kadd >= 0 || throw(ArgumentError("Cannot merge to more components than original")) - W_over_init, H_over_init = gsvdrecover(X, W_initial, H_initial, kadd, f) - result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) - W_over, H_over = result_over.W, result_over.H - W_over_normed, H_over_normed = colnormalize(W_over, H_over) - Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) - result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) - return result_renmf -end -nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) -nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) -nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) - -function colnormalize!(W, H, p::Integer=2) - nonzerocolids = Int[] - for (j, w) in pairs(eachcol(W)) - normw = norm(w, p) - if !iszero(normw) - W[:, j] = w/normw - H[j, :] = H[j, :]*normw - push!(nonzerocolids, j) - end - end - W, H = W[:, nonzerocolids], H[nonzerocolids, :] - return W, H -end - -""" - Wnormalized, Hnormalized = colnormalize(W, H, p=2) - -Normalize the factorization so that each column satisfies `||W[:, i]||_p ≈ 1`. - -""" -colnormalize(W, H, p::Integer=2) = colnormalize!(float(copy(W)), float(copy(H)), p) - -""" - Wmerge, Hmerge, mergeseq = colmerge2to1pq(W::AbstractArray, H::AbstractArray, n::Integer) - -Merge components in `W` and `H` (columns in `W` and rows in `H`) until only `n` -components remain. - -`Wmerge` and `Hmerge` are the merged results with `n` components. - -`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the -number of columns in `W` indicate the output of previous merge steps. -""" -function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) - mrgseq = Tuple{Int, Int}[] - S = let S = S # julia #15276 - [S[:, j] for j in axes(S, 2)] - end - T = let T = T - [T[i, :] for i in axes(T, 1)] - end - for s in S - abs(norm(s)-1)<1e-12 || throw(ArgumentError("W columns must be normalized")) - end - Nt = length(S) - Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) - Nt >= n || throw(ArgumentError("Final solution more than original size")) - pq = PriorityQueue{Tuple{Int,Int},Float64}() - for id0 in length(S):-1:2 - pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) - end - m = Nt - while m > n - id0, id1 = dequeue!(pq) - if isempty(S[id0])||isempty(S[id1]) - continue - end - push!(mrgseq, (id0, id1)) - S, T, id01, _ = mergecol2to1!(S, T, id0, id1); - pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); - m -= 1 - end - Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' - return Smtx, Matrix(Tmtx), mrgseq -end -colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) - -function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To - for id in overlapids - if !isempty(S[id]) && !isempty(S[id01]) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) - loss = solve_remix(t1sq, t1t2, t2sq, c)[2] - enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) - end - end - return pq -end - -function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) - if h1h1 == 0 - return c, zero(c), (zero(c), one(c)) - end - if h2h2 == 0 - return c, zero(c), (one(c), zero(c)) - end - b = sqrt(τ^2/4-δ) - λ_max = τ/2+b - λ_min = δ/λ_max - den = (h1h2+c*h2h2)*2 - if iszero(den) - u = h1h1 >= h2h2 ? (one(c), zero(c)) : (zero(c), one(c)) - else - ξ = (h1h1-h2h2+2b)/den - u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) - end - return c, λ_min, u -end - -function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) - h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] - c = W[id1]'*W[id2] # assumes normalization - return h1sq, h1h2, h2sq, c -end - -function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) - S01, T01, loss = mergepair(S, T, id0, id1) - S[id0] = S[id1] = T[id0] = T[id1] = eltype(S[1])[] - id01 = length(S)+1 - push!(S, S01) - push!(T, T01) - return S, T, id01, loss -end - -function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) - c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) - S12, T12 = remix_enact(S, T, id1, id2, c, u) - return S12, T12, loss -end - -function remix_enact(S::AbstractVector{TS}, T::AbstractVector, id1::Integer, id2::Integer, c::AbstractFloat, w::Tuple{Tw, Tw}) where {Tw, TS} - S12 = zeros(eltype(TS), length(S[id1])) - S12 += w[1]*S[id1] - S12 += w[2]*S[id2] - T1, T2 = (w[1]+w[2]*c)*T[id1], (w[1]*c+w[2])*T[id2] - T12 = T1+T2 - return S12, T12 -end - -""" - Wmerge, Hmerge, WHstage, Err = mergecolumns(W, H, mergeseq; tracemerge=false) - -Merge components in `W` and `H` (columns in `W` and rows in `H`) according to the sequence of merge pair ids `mergeseq`. - -`Wmerge` and `Hmerge` are the merged results. - -`WHstage::Vector{Tuple{Matrix, Matrix}}` includes the results of each merge stage. `WHstage` is empty if `tracemerge=false`. - -`Err::Vector` includes merge penalty of each merge stage. -""" -function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArray; tracemerge::Bool = false) - Err = Float64[] - S = [W[:, j] for j in axes(W, 2)] - T = [H[i, :] for i in axes(H, 1)] - STstage = [] - for mergeids in mergeseq - id0, id1 = mergeids - if tracemerge - push!(STstage, (copy(S), copy(T))) - end - S, T, _, loss = mergecol2to1!(S, T, id0, id1) - err = loss - push!(Err, err) - end - Smtx, Tmtx = hcat(filter(x -> x != [], S)...), hcat(filter(x -> x != [], T)...)' - return Smtx, Matrix(Tmtx), STstage, Err -end - -mergepenalty(λ_min, t1sq, t2sq) = λ_min - -end +module NMFMerge + +using LinearAlgebra, DataStructures, NMF, GsvdInitialization, TSVD + +export nmfmerge, + colnormalize, + colmerge2to1pq, + mergecolumns + +""" + result = nmfmerge([queuepenalty], X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + +Performs "NMF-Merge" on data matrix `X`. + +Arguments: + +-`queuepenalty`: a function of the form `f(E, h1sq, h2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, h1sq, h2sq)=E. h1sq and h2sq are the squared norms of the corresponding rows in H. + +- `X::AbstractMatrix`: the data matrix to be factorized + +- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, + where `n1` is the number of components for overcomplete NMF, and `n2` is the number of components for the final NMF. + We require `n1 >= n2`. + +Alternatively, `ncomponents` can be an integer denoting the final number of components. In this case, `nmfmerge` +defaults to an approximate 20% component excess before merging. + + +Keyword arguments: + +- `tol_final`: The tolerance of final NMF + +- `tol_intermediate`: The tolerence of initial and overcomplete NMF + +`W0`, `H0`: initialization for the initial NMF. If at least one of `W0` and `H0` is `nothing`, NNDSVD is used for initialization. + + +Other keywords arguments are passed to `NMF.nnmf`. +""" +function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + n1, n2 = ncomponents + f = tsvd(X, n2) + Un, Sn, Vn = f + if W0 === nothing || H0 === nothing + W0, H0 = NMF.nndsvd(X, n2, initdata=(U = Un, S = Sn, V = Vn)) + end + result_initial = nnmf(X, n2; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0)) + W_initial, H_initial = result_initial.W, result_initial.H + kadd = n1 - n2 + kadd >= 0 || throw(ArgumentError("Cannot merge to more components than original")) + W_over_init, H_over_init = gsvdrecover(X, W_initial, H_initial, kadd, f) + result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) + W_over, H_over = result_over.W, result_over.H + W_over_normed, H_over_normed = colnormalize(W_over, H_over) + Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) + result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) + return result_renmf +end +nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) +nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) + +function colnormalize!(W, H, p::Integer=2) + nonzerocolids = Int[] + for (j, w) in pairs(eachcol(W)) + normw = norm(w, p) + if !iszero(normw) + W[:, j] = w/normw + H[j, :] = H[j, :]*normw + push!(nonzerocolids, j) + end + end + W, H = W[:, nonzerocolids], H[nonzerocolids, :] + return W, H +end + +""" + Wnormalized, Hnormalized = colnormalize(W, H, p=2) + +Normalize the factorization so that each column satisfies `||W[:, i]||_p ≈ 1`. + +""" +colnormalize(W, H, p::Integer=2) = colnormalize!(float(copy(W)), float(copy(H)), p) + +""" + Wmerge, Hmerge, mergeseq = colmerge2to1pq([queuepenalty], W::AbstractArray, H::AbstractArray, n::Integer) + +Merge components in `W` and `H` (columns in `W` and rows in `H`) until only `n` +components remain. + +Arguments: + +-`queuepenalty`: The same as in `nmfmerge`. Default: f(E, h1sq, h2sq)=E. + +- `W::AbstractArray`: The basis matrix with normalized columns. + +- `H::AbstractArray`: The coefficient matrix. + +- `n::Integer`: The final number of components after merging. + +Outputs: + +`Wmerge` and `Hmerge` are the merged results with `n` components. + +`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the +number of columns in `W` indicate the output of previous merge steps. +""" +function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) + mrgseq = Tuple{Int, Int}[] + S = let S = S # julia #15276 + [S[:, j] for j in axes(S, 2)] + end + T = let T = T + [T[i, :] for i in axes(T, 1)] + end + for s in S + abs(norm(s)-1)<1e-12 || throw(ArgumentError("W columns must be normalized")) + end + Nt = length(S) + Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) + Nt >= n || throw(ArgumentError("Final solution more than original size")) + pq = PriorityQueue{Tuple{Int,Int},Float64}() + for id0 in length(S):-1:2 + pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) + end + m = Nt + while m > n + id0, id1 = dequeue!(pq) + if isempty(S[id0])||isempty(S[id1]) + continue + end + push!(mrgseq, (id0, id1)) + S, T, id01, _ = mergecol2to1!(S, T, id0, id1); + pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); + m -= 1 + end + Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' + return Smtx, Matrix(Tmtx), mrgseq +end +colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) + +function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To + for id in overlapids + if !isempty(S[id]) && !isempty(S[id01]) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) + loss = solve_remix(t1sq, t1t2, t2sq, c)[2] + enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) + end + end + return pq +end + +function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) + if h1h1 == 0 + return c, zero(c), (zero(c), one(c)) + end + if h2h2 == 0 + return c, zero(c), (one(c), zero(c)) + end + b = sqrt(τ^2/4-δ) + λ_max = τ/2+b + λ_min = δ/λ_max + den = (h1h2+c*h2h2)*2 + if iszero(den) + u = h1h1 >= h2h2 ? (one(c), zero(c)) : (zero(c), one(c)) + else + ξ = (h1h1-h2h2+2b)/den + u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) + end + return c, λ_min, u +end + +function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) + h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] + c = W[id1]'*W[id2] # assumes normalization + return h1sq, h1h2, h2sq, c +end + +function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) + S01, T01, loss = mergepair(S, T, id0, id1) + S[id0] = S[id1] = T[id0] = T[id1] = eltype(S[1])[] + id01 = length(S)+1 + push!(S, S01) + push!(T, T01) + return S, T, id01, loss +end + +function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) + t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) + c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) + S12, T12 = remix_enact(S, T, id1, id2, c, u) + return S12, T12, loss +end + +function remix_enact(S::AbstractVector{TS}, T::AbstractVector, id1::Integer, id2::Integer, c::AbstractFloat, w::Tuple{Tw, Tw}) where {Tw, TS} + S12 = zeros(eltype(TS), length(S[id1])) + S12 += w[1]*S[id1] + S12 += w[2]*S[id2] + T1, T2 = (w[1]+w[2]*c)*T[id1], (w[1]*c+w[2])*T[id2] + T12 = T1+T2 + return S12, T12 +end + +""" + Wmerge, Hmerge, WHstage, Err = mergecolumns(W, H, mergeseq; tracemerge=false) + +Merge components in `W` and `H` (columns in `W` and rows in `H`) according to the sequence of merge pair ids `mergeseq`. + +`Wmerge` and `Hmerge` are the merged results. + +`WHstage::Vector{Tuple{Matrix, Matrix}}` includes the results of each merge stage. `WHstage` is empty if `tracemerge=false`. + +`Err::Vector` includes merge penalty of each merge stage. +""" +function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArray; tracemerge::Bool = false) + Err = Float64[] + S = [W[:, j] for j in axes(W, 2)] + T = [H[i, :] for i in axes(H, 1)] + STstage = [] + for mergeids in mergeseq + id0, id1 = mergeids + if tracemerge + push!(STstage, (copy(S), copy(T))) + end + S, T, _, loss = mergecol2to1!(S, T, id0, id1) + err = loss + push!(Err, err) + end + Smtx, Tmtx = hcat(filter(x -> x != [], S)...), hcat(filter(x -> x != [], T)...)' + return Smtx, Matrix(Tmtx), STstage, Err +end + +mergepenalty(λ_min, t1sq, t2sq) = λ_min + +end From 6350807e339c93a83a4cb09ae6623580e251fbaf Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Thu, 4 Sep 2025 18:46:46 -0500 Subject: [PATCH 5/8] consistent_with_dealing with zero --- src/NMFMerge.jl | 23 ++++++++++++----------- test/runtests.jl | 13 +++++-------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 1e9d186..5bd317a 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -143,17 +143,15 @@ colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq( function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To for id in overlapids if !isempty(S[id]) && !isempty(S[id01]) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) - loss = solve_remix(t1sq, t1t2, t2sq, c)[2] + _, loss, _, t1sq, t2sq = solve_remix(S, T, id, id01) enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) end end return pq end -function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) +function solve_remix(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) + τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) if iszero(h1h1) return c, zero(c), (zero(c), one(c)) end @@ -175,13 +173,17 @@ function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFlo ξ = (h1h1-h2h2+2b)/den u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) end - return c, λ_min, u + return c, λ_min, u, h1h1, h2h2 end function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) - h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] - c = W[id1]'*W[id2] # assumes normalization - return h1sq, h1h2, h2sq, c + c = W[id1]'*W[id2] + h1h1 = H[id1]'*H[id1] + h1h2 = H[id1]'*H[id2] + h2h2 = H[id2]'*H[id2] + τ = h1h1+2c*h1h2+h2h2 + δ = (1-c^2)*(h1h1*h2h2-h1h2^2) + return τ, δ, c, h1h1, h1h2, h2h2 end function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) @@ -194,8 +196,7 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1:: end function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) - t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) - c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) + c, loss, u, _, _ = solve_remix(S, T, id1, id2) S12, T12 = remix_enact(S, T, id1, id2, c, u) return S12, T12, loss end diff --git a/test/runtests.jl b/test/runtests.jl index 36c273e..39cb174 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,10 +85,8 @@ end idx = argmax(Fvals) w = Fvecs[:,idx] - h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W_v, H_v, 1, 2) - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) - c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c) + τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W_v, H_v, 1, 2) + c, p, u, h1h1, h2h2 = NMFMerge.solve_remix(W_v, H_v, 1, 2) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b @@ -135,10 +133,9 @@ end idx = argmax(Fvals) w = Fvecs[:,idx] - h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W2, H2, 1, 2) - τ = h1h1+2c*h1h2+h2h2 - δ = (1-c^2)*(h1h1*h2h2-h1h2^2) - c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c) + + τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W2, H2, 1, 2) + c, p, u, h1h1, h2h2 = NMFMerge.solve_remix(W2, H2, 1, 2) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b From 10c241eea07fe69bdc6b2fb648045be679bcba58 Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Thu, 4 Sep 2025 18:54:41 -0500 Subject: [PATCH 6/8] fix bugs --- src/NMFMerge.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 5bd317a..1726029 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -153,15 +153,15 @@ end function solve_remix(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) if iszero(h1h1) - return c, zero(c), (zero(c), one(c)) + return c, zero(c), (zero(c), one(c)), h1h1, h2h2 end if iszero(h2h2) - return c, zero(c), (one(c), zero(c)) + return c, zero(c), (one(c), zero(c)), h1h1, h2h2 end if iszero(c) # Check whether W1 or W2 is zero - iszero(sum(abs2, S[id1])) && return c, zero(h1h1), (zero(c), one(c)) - iszero(sum(abs2, S[id2])) && return c, zero(h2h2), (one(c), zero(c)) + iszero(sum(abs2, S[id1])) && return c, zero(h1h1), (zero(c), one(c)), h1h1, h2h2 + iszero(sum(abs2, S[id2])) && return c, zero(h2h2), (one(c), zero(c)), h1h1, h2h2 end b = sqrt(τ^2/4-δ) λ_max = τ/2+b From 74856b6e1b43275056107481557583eb681555d9 Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Fri, 5 Sep 2025 11:48:32 -0500 Subject: [PATCH 7/8] test_customized merge function --- test/runtests.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 39cb174..de14b61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -199,3 +199,23 @@ end @test H12 ≈ Hs[1] @test iszero(loss) end + +@testset "test customized merge function" begin + Ws = [rand(5) rand(5) rand(5)] + Hs = [rand(10) rand(10) rand(10)]' + Wsn, Hsn = colnormalize(Ws, Hs) + Wsn1 = [Wsn[:, j] for j in axes(Wsn, 2)] + Hsn1 = [Hsn[i, :] for i in axes(Hsn, 1)] + mergepenalty_custom(E, t1sq, t2sq) = -E + idpair_loss = [] + for id1 in 1:2, id2 in id1+1:3 + W12, H12, loss2 = NMFMerge.mergepair(Wsn1, Hsn1, id1, id2) + push!(idpair_loss, ((id1, id2), loss2)) + end + idpair_loss = sort(idpair_loss, by=x->x[2]) + merge_sequence = colmerge2to1pq(Wsn, Hsn, 1)[end] + merge_sequence_custom = colmerge2to1pq(mergepenalty_custom, Wsn, Hsn, 1)[end] + @test merge_sequence[1] == idpair_loss[1][1] + @test merge_sequence_custom[1] == idpair_loss[3][1] + +end From ef9418a19a97c82d009f77eaf9352644c0fa8e8a Mon Sep 17 00:00:00 2001 From: youdongguo <1010705897@qq.com> Date: Wed, 10 Sep 2025 15:09:27 -0500 Subject: [PATCH 8/8] revert the datastructure version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9b3652b..285a8a4 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" [compat] -DataStructures = "0.18" +DataStructures = "0.18, 0.19" ForwardDiff = "0.10" GsvdInitialization = "1" LinearAlgebra = "1"