diff --git a/SeQuant/core/utility/permutation.hpp b/SeQuant/core/utility/permutation.hpp index 0a3ad424b1..a440ad60a4 100644 --- a/SeQuant/core/utility/permutation.hpp +++ b/SeQuant/core/utility/permutation.hpp @@ -1,17 +1,14 @@ #ifndef SEQUANT_PERMUTATION_HPP #define SEQUANT_PERMUTATION_HPP +#include #include #include -#include - #include #include #include -#include -#include -#include +#include namespace sequant { @@ -21,63 +18,58 @@ namespace sequant { /// by stacking \p v0 and \p v1 on top of each other. /// @tparam Seq0 (reference to) a container type /// @tparam Seq1 (reference to) a container type -/// @param v0 first sequence; if passed as an rvalue reference, it is moved from -/// @param[in] v1 second sequence +/// @param v0 first range +/// @param v1 second range /// @pre \p v0 is a permutation of \p v1 /// @return the number of cycles template -std::size_t count_cycles(Seq0&& v0, const Seq1& v1) { - std::remove_reference_t v(std::forward(v0)); - using T = std::decay_t; - SEQUANT_ASSERT(ranges::is_permutation(v, v1)); +std::size_t count_cycles(Seq0&& v0, Seq1&& v1) { + using std::ranges::begin; + using std::ranges::end; + using std::ranges::size; + SEQUANT_ASSERT(std::ranges::is_permutation(v0, v1)); // This function can't deal with duplicate entries in v0 or v1 - SEQUANT_ASSERT(std::set(std::begin(v0), std::end(v0)).size() == v0.size()); - SEQUANT_ASSERT(std::set(std::begin(v1), std::end(v1)).size() == v1.size()); - - auto make_null = []() -> T { - if constexpr (std::is_arithmetic_v) { - return -1; - } else if constexpr (std::is_same_v) { - return L"p_50"; - } - - SEQUANT_UNREACHABLE; - }; + SEQUANT_ASSERT(std::set(begin(v0), end(v0)).size() == size(v0)); + SEQUANT_ASSERT(std::set(begin(v1), end(v1)).size() == size(v1)); - const auto null = make_null(); - SEQUANT_ASSERT(ranges::contains(v, null) == false); - SEQUANT_ASSERT(ranges::contains(v1, null) == false); + container::svector visited; + visited.resize(size(v0), false); std::size_t n_cycles = 0; - for (auto it = v.begin(); it != v.end(); ++it) { - if (*it != null) { - n_cycles++; - - auto idx = std::distance(v.begin(), it); - SEQUANT_ASSERT(idx >= 0); - - auto it0 = it; - - auto it1 = std::find(v1.begin(), v1.end(), *it0); - SEQUANT_ASSERT(it1 != v1.end()); - - auto idx1 = std::distance(v1.begin(), it1); - SEQUANT_ASSERT(idx1 >= 0); + std::size_t start_col = 0; + for (auto it = begin(v0); it != end(v0); ++it, ++start_col) { + if (visited[start_col]) { + // This column has already been part of a previous cycle + continue; + } - do { - it0 = std::find(v.begin(), v.end(), v[idx1]); - SEQUANT_ASSERT(it0 != v.end()); + n_cycles++; - it1 = std::find(v1.begin(), v1.end(), *it0); - SEQUANT_ASSERT(it1 != v1.end()); + std::size_t current_col = 0; + auto it0 = it; + do { + // Find corresponding element in v1 + auto it1 = std::ranges::find(v1, *it0); + SEQUANT_ASSERT(std::distance(begin(v1), it1) >= 0); + + // Determine column of the determined corresponding element + current_col = static_cast(std::distance(begin(v1), it1)); + SEQUANT_ASSERT(current_col < size(v0)); + + // Set it0 to the element in the determined column in v0 + it0 = begin(v0); + std::advance(it0, current_col); + + // Mark current_col as visited + SEQUANT_ASSERT(!visited[current_col]); + visited[current_col] = true; + } while (start_col != current_col); + } - idx1 = std::distance(v1.begin(), it1); - SEQUANT_ASSERT(idx1 >= 0); + // All columns must have been visited (otherwise, we'll have missed + // at least one cycle) + SEQUANT_ASSERT(std::ranges::all_of(visited, [](bool val) { return val; })); - *it0 = null; - } while (idx1 != idx); - } - } return n_cycles; }; @@ -105,7 +97,7 @@ int permutation_parity(std::span p, bool overwrite = false) { } if (overwrite) { - ranges::for_each(p, [N](auto& e) { e -= N; }); + std::ranges::for_each(p, [N](auto& e) { e -= N; }); } return parity;