diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 00f9e5b..95f441c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -1,77 +1,77 @@ -name: CI -on: - push: - branches: - - main - tags: ['*'] - pull_request: - workflow_dispatch: -concurrency: - # Skip intermediate builds: always. - # Cancel intermediate builds: only if it is a pull request build. - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} -jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false - matrix: - version: - - '1.10' - - '1' - os: - - ubuntu-latest - arch: - - x64 - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - uses: julia-actions/julia-runtest@v1 - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false - # docs: - # name: Documentation - # runs-on: ubuntu-latest - # permissions: - # actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created - # contents: write - # statuses: write - # steps: - # - uses: actions/checkout@v4 - # - uses: julia-actions/setup-julia@v2 - # with: - # version: '1' - # - uses: julia-actions/cache@v2 - # - name: Configure doc environment - # shell: julia --project=docs --color=yes {0} - # run: | - # using Pkg - # Pkg.develop(PackageSpec(path=pwd())) - # Pkg.instantiate() - # - uses: julia-actions/julia-buildpkg@v1 - # - uses: julia-actions/julia-docdeploy@v1 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} - # - name: Run doctests - # shell: julia --project=docs --color=yes {0} - # run: | - # using Documenter: DocMeta, doctest - # using NMFMerge - # DocMeta.setdocmeta!(NMFMerge, :DocTestSetup, :(using NMFMerge); recursive=true) - # doctest(NMFMerge) +name: CI +on: + push: + branches: + - main + tags: ['*'] + pull_request: + workflow_dispatch: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read + strategy: + fail-fast: false + matrix: + version: + - '1.10' + - '1' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + # docs: + # name: Documentation + # runs-on: ubuntu-latest + # permissions: + # actions: write # needed to allow julia-actions/cache to proactively delete old caches that it has created + # contents: write + # statuses: write + # steps: + # - uses: actions/checkout@v4 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: '1' + # - uses: julia-actions/cache@v2 + # - name: Configure doc environment + # shell: julia --project=docs --color=yes {0} + # run: | + # using Pkg + # Pkg.develop(PackageSpec(path=pwd())) + # Pkg.instantiate() + # - uses: julia-actions/julia-buildpkg@v1 + # - uses: julia-actions/julia-docdeploy@v1 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} + # - name: Run doctests + # shell: julia --project=docs --color=yes {0} + # run: | + # using Documenter: DocMeta, doctest + # using NMFMerge + # DocMeta.setdocmeta!(NMFMerge, :DocTestSetup, :(using NMFMerge); recursive=true) + # doctest(NMFMerge) diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index 7886911..97d7bb5 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -1,16 +1,16 @@ -name: CompatHelper -on: - schedule: - - cron: '0 0 * * 0' - workflow_dispatch: -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main()' +name: CompatHelper +on: + schedule: + - cron: '0 0 * * 0' + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/Project.toml b/Project.toml index c02f679..285a8a4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,29 +1,29 @@ -name = "NMFMerge" -uuid = "9cc52eda-dfaf-4e21-aae3-9f26bed153a3" -authors = ["youdongguo <1010705897@qq.com> and contributors"] -version = "1.0.1" - -[deps] -DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -GsvdInitialization = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" -TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" - -[compat] -DataStructures = "0.18, 0.19" -ForwardDiff = "0.10" -GsvdInitialization = "1" -LinearAlgebra = "1" -NMF = "1" -TSVD = "0.4" -Test = "1" -julia = "1.10" - -[extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["ForwardDiff", "NMF", "Test"] +name = "NMFMerge" +uuid = "9cc52eda-dfaf-4e21-aae3-9f26bed153a3" +authors = ["youdongguo <1010705897@qq.com> and contributors"] +version = "1.1.0" + +[deps] +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +GsvdInitialization = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" +TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e" + +[compat] +DataStructures = "0.18, 0.19" +ForwardDiff = "0.10" +GsvdInitialization = "1" +LinearAlgebra = "1" +NMF = "1" +TSVD = "0.4" +Test = "1" +julia = "1.10" + +[extras] +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[targets] +test = ["ForwardDiff", "NMF", "Test"] \ No newline at end of file diff --git a/src/NMFMerge.jl b/src/NMFMerge.jl index 21eda04..1726029 100644 --- a/src/NMFMerge.jl +++ b/src/NMFMerge.jl @@ -8,12 +8,14 @@ export nmfmerge, mergecolumns """ - result = nmfmerge(X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) + result = nmfmerge([queuepenalty], X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) Performs "NMF-Merge" on data matrix `X`. Arguments: +-`queuepenalty`: a function of the form `f(E, h1sq, h2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, h1sq, h2sq)=E. h1sq and h2sq are the squared norms of the corresponding rows in H. + - `X::AbstractMatrix`: the data matrix to be factorized - `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, @@ -35,7 +37,7 @@ Keyword arguments: Other keywords arguments are passed to `NMF.nnmf`. """ -function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) +function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) n1, n2 = ncomponents f = tsvd(X, n2) Un, Sn, Vn = f @@ -50,11 +52,13 @@ function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediat result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) W_over, H_over = result_over.W, result_over.H W_over_normed, H_over_normed = colnormalize(W_over, H_over) - Wmerge, Hmerge, _ = colmerge2to1pq(W_over_normed, H_over_normed, n2) + Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) return result_renmf end -nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) +nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) +nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) function colnormalize!(W, H, p::Integer=2) nonzerocolids = Int[] @@ -79,17 +83,29 @@ Normalize the factorization so that each column satisfies `||W[:, i]||_p ≈ 1`. colnormalize(W, H, p::Integer=2) = colnormalize!(float(copy(W)), float(copy(H)), p) """ - Wmerge, Hmerge, mergeseq = colmerge2to1pq(W::AbstractArray, H::AbstractArray, n::Integer) + Wmerge, Hmerge, mergeseq = colmerge2to1pq([queuepenalty], W::AbstractArray, H::AbstractArray, n::Integer) Merge components in `W` and `H` (columns in `W` and rows in `H`) until only `n` components remain. +Arguments: + +-`queuepenalty`: The same as in `nmfmerge`. Default: f(E, h1sq, h2sq)=E. + +- `W::AbstractArray`: The basis matrix with normalized columns. + +- `H::AbstractArray`: The coefficient matrix. + +- `n::Integer`: The final number of components after merging. + +Outputs: + `Wmerge` and `Hmerge` are the merged results with `n` components. `mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the number of columns in `W` indicate the output of previous merge steps. """ -function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) +function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) mrgseq = Tuple{Int, Int}[] S = let S = S # julia #15276 [S[:, j] for j in axes(S, 2)] @@ -104,7 +120,10 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) Nt = length(S) Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) Nt >= n || throw(ArgumentError("Final solution more than original size")) - pq = initialize_pq_2to1(S, T) + pq = PriorityQueue{Tuple{Int,Int},Float64}() + for id0 in length(S):-1:2 + pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) + end m = Nt while m > n id0, id1 = dequeue!(pq) @@ -113,43 +132,36 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) end push!(mrgseq, (id0, id1)) S, T, id01, _ = mergecol2to1!(S, T, id0, id1); - pqupdate2to1!(pq, S, T, id01, 1:id01-1); + pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); m -= 1 end Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' return Smtx, Matrix(Tmtx), mrgseq end +colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) -function initialize_pq_2to1(S::AbstractVector, T::AbstractVector) - err_pq = PriorityQueue{Tuple{Int, Int},Float64}() - for id0 in length(S):-1:2 - err_pq = pqupdate2to1!(err_pq, S, T, id0, 1:id0-1) - end - return err_pq -end - -function pqupdate2to1!(pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To +function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To for id in overlapids if !isempty(S[id]) && !isempty(S[id01]) - loss = solve_remix(S, T, id, id01)[2] - enqueue!(pq, (id, id01), loss) + _, loss, _, t1sq, t2sq = solve_remix(S, T, id, id01) + enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) end end return pq end -function solve_remix(S, T, id1, id2) +function solve_remix(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) if iszero(h1h1) - return c, zero(c), (zero(c), one(c)) + return c, zero(c), (zero(c), one(c)), h1h1, h2h2 end if iszero(h2h2) - return c, zero(c), (one(c), zero(c)) + return c, zero(c), (one(c), zero(c)), h1h1, h2h2 end if iszero(c) # Check whether W1 or W2 is zero - iszero(sum(abs2, S[id1])) && return c, zero(h1h1), (zero(c), one(c)) - iszero(sum(abs2, S[id2])) && return c, zero(h2h2), (one(c), zero(c)) + iszero(sum(abs2, S[id1])) && return c, zero(h1h1), (zero(c), one(c)), h1h1, h2h2 + iszero(sum(abs2, S[id2])) && return c, zero(h2h2), (one(c), zero(c)), h1h1, h2h2 end b = sqrt(τ^2/4-δ) λ_max = τ/2+b @@ -161,7 +173,7 @@ function solve_remix(S, T, id1, id2) ξ = (h1h1-h2h2+2b)/den u = (ξ, 1)./sqrt(1+2ξ*c+ξ^2) end - return c, λ_min, u + return c, λ_min, u, h1h1, h2h2 end function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) @@ -184,7 +196,7 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1:: end function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) - c, loss, u, = solve_remix(S, T, id1, id2) + c, loss, u, _, _ = solve_remix(S, T, id1, id2) S12, T12 = remix_enact(S, T, id1, id2, c, u) return S12, T12, loss end @@ -227,4 +239,6 @@ function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArra return Smtx, Matrix(Tmtx), STstage, Err end +mergepenalty(λ_min, t1sq, t2sq) = λ_min + end diff --git a/test/runtests.jl b/test/runtests.jl index b44b266..de14b61 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,23 +51,22 @@ H_GT = [6 10 8 2 0 1 2 10; @test sum(abs2, W_standard - W_renmf) <= 1e-12 @test sum(abs2, H_standard - H_renmf) <= 1e-12 - X = rand(30, 20) result_1 = nmfmerge(X, 10; alg=:cd) result_2 = nmfmerge(X, 12 => 10; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) result_1 = nmfmerge(X, 4; alg=:cd) result_2 = nmfmerge(X, 5 => 4; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) result_1 = nmfmerge(X, 8; alg=:cd) result_2 = nmfmerge(X, 10 => 8; alg=:cd) - @test sum(abs2, result_1.W - result_2.W) <= 1e-12 - @test sum(abs2, result_1.H - result_2.H) <= 1e-12 + @test sum(abs2, result_1.W - result_2.W) <= 1e-12*sum(abs2, result_1.W) + @test sum(abs2, result_1.H - result_2.H) <= 1e-12*sum(abs2, result_1.H) end @testset "merge coefficients" begin @@ -87,7 +86,7 @@ end w = Fvecs[:,idx] τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W_v, H_v, 1, 2) - c, p, u = NMFMerge.solve_remix(W_v, H_v, 1, 2) + c, p, u, h1h1, h2h2 = NMFMerge.solve_remix(W_v, H_v, 1, 2) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b @@ -134,8 +133,9 @@ end idx = argmax(Fvals) w = Fvecs[:,idx] + τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W2, H2, 1, 2) - c, p, u = NMFMerge.solve_remix(W2, H2, 1, 2) + c, p, u, h1h1, h2h2 = NMFMerge.solve_remix(W2, H2, 1, 2) u = [u[1], u[2]] b = sqrt(τ^2/4-δ) λ_max = τ/2+b @@ -150,7 +150,7 @@ end @test norm(Q1*u - maximum(F.values)*Q2*u) <= 1e-10 @test norm(Q1*u - λ_max*Q2*u) <= 1e-10 - W12, H12, loss = NMFMerge.mergepair(W2, H2, 1, 2) + W12, H12, _ = NMFMerge.mergepair(W2, H2, 1, 2) Err(Hm) = sum(abs2, W12*Hm'-W1*H1) @test norm(ForwardDiff.gradient(Err, H12)) <= 1e-12 @@ -199,3 +199,23 @@ end @test H12 ≈ Hs[1] @test iszero(loss) end + +@testset "test customized merge function" begin + Ws = [rand(5) rand(5) rand(5)] + Hs = [rand(10) rand(10) rand(10)]' + Wsn, Hsn = colnormalize(Ws, Hs) + Wsn1 = [Wsn[:, j] for j in axes(Wsn, 2)] + Hsn1 = [Hsn[i, :] for i in axes(Hsn, 1)] + mergepenalty_custom(E, t1sq, t2sq) = -E + idpair_loss = [] + for id1 in 1:2, id2 in id1+1:3 + W12, H12, loss2 = NMFMerge.mergepair(Wsn1, Hsn1, id1, id2) + push!(idpair_loss, ((id1, id2), loss2)) + end + idpair_loss = sort(idpair_loss, by=x->x[2]) + merge_sequence = colmerge2to1pq(Wsn, Hsn, 1)[end] + merge_sequence_custom = colmerge2to1pq(mergepenalty_custom, Wsn, Hsn, 1)[end] + @test merge_sequence[1] == idpair_loss[1][1] + @test merge_sequence_custom[1] == idpair_loss[3][1] + +end