Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 44 additions & 52 deletions SeQuant/core/utility/permutation.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
#ifndef SEQUANT_PERMUTATION_HPP
#define SEQUANT_PERMUTATION_HPP

#include <SeQuant/core/container.hpp>
#include <SeQuant/core/index.hpp>
#include <SeQuant/core/utility/macros.hpp>

#include <range/v3/algorithm.hpp>

#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <set>
#include <type_traits>
#include <utility>
#include <ranges>

namespace sequant {

Expand All @@ -21,63 +18,58 @@ namespace sequant {
/// by stacking \p v0 and \p v1 on top of each other.
/// @tparam Seq0 (reference to) a container type
/// @tparam Seq1 (reference to) a container type
/// @param v0 first sequence; if passed as an rvalue reference, it is moved from
/// @param[in] v1 second sequence
/// @param v0 first range
/// @param v1 second range
/// @pre \p v0 is a permutation of \p v1
/// @return the number of cycles
template <typename Seq0, typename Seq1>
std::size_t count_cycles(Seq0&& v0, const Seq1& v1) {
std::remove_reference_t<Seq0> v(std::forward<Seq0>(v0));
using T = std::decay_t<decltype(v[0])>;
SEQUANT_ASSERT(ranges::is_permutation(v, v1));
std::size_t count_cycles(Seq0&& v0, Seq1&& v1) {
using std::ranges::begin;
using std::ranges::end;
using std::ranges::size;
SEQUANT_ASSERT(std::ranges::is_permutation(v0, v1));
// This function can't deal with duplicate entries in v0 or v1
SEQUANT_ASSERT(std::set(std::begin(v0), std::end(v0)).size() == v0.size());
SEQUANT_ASSERT(std::set(std::begin(v1), std::end(v1)).size() == v1.size());

auto make_null = []() -> T {
if constexpr (std::is_arithmetic_v<T>) {
return -1;
} else if constexpr (std::is_same_v<T, Index>) {
return L"p_50";
}

SEQUANT_UNREACHABLE;
};
SEQUANT_ASSERT(std::set(begin(v0), end(v0)).size() == size(v0));
SEQUANT_ASSERT(std::set(begin(v1), end(v1)).size() == size(v1));

const auto null = make_null();
SEQUANT_ASSERT(ranges::contains(v, null) == false);
SEQUANT_ASSERT(ranges::contains(v1, null) == false);
container::svector<bool> visited;
visited.resize(size(v0), false);

std::size_t n_cycles = 0;
for (auto it = v.begin(); it != v.end(); ++it) {
if (*it != null) {
n_cycles++;

auto idx = std::distance(v.begin(), it);
SEQUANT_ASSERT(idx >= 0);

auto it0 = it;

auto it1 = std::find(v1.begin(), v1.end(), *it0);
SEQUANT_ASSERT(it1 != v1.end());

auto idx1 = std::distance(v1.begin(), it1);
SEQUANT_ASSERT(idx1 >= 0);
std::size_t start_col = 0;
for (auto it = begin(v0); it != end(v0); ++it, ++start_col) {
if (visited[start_col]) {
// This column has already been part of a previous cycle
continue;
}

do {
it0 = std::find(v.begin(), v.end(), v[idx1]);
SEQUANT_ASSERT(it0 != v.end());
n_cycles++;

it1 = std::find(v1.begin(), v1.end(), *it0);
SEQUANT_ASSERT(it1 != v1.end());
std::size_t current_col = 0;
auto it0 = it;
do {
// Find corresponding element in v1
auto it1 = std::ranges::find(v1, *it0);
SEQUANT_ASSERT(std::distance(begin(v1), it1) >= 0);

// Determine column of the determined corresponding element
current_col = static_cast<std::size_t>(std::distance(begin(v1), it1));
SEQUANT_ASSERT(current_col < size(v0));

// Set it0 to the element in the determined column in v0
it0 = begin(v0);
std::advance(it0, current_col);

// Mark current_col as visited
SEQUANT_ASSERT(!visited[current_col]);
visited[current_col] = true;
} while (start_col != current_col);
}

idx1 = std::distance(v1.begin(), it1);
SEQUANT_ASSERT(idx1 >= 0);
// All columns must have been visited (otherwise, we'll have missed
// at least one cycle)
SEQUANT_ASSERT(std::ranges::all_of(visited, [](bool val) { return val; }));

*it0 = null;
} while (idx1 != idx);
}
}
return n_cycles;
};

Expand Down Expand Up @@ -105,7 +97,7 @@ int permutation_parity(std::span<T> p, bool overwrite = false) {
}

if (overwrite) {
ranges::for_each(p, [N](auto& e) { e -= N; });
std::ranges::for_each(p, [N](auto& e) { e -= N; });
}

return parity;
Expand Down