From 2267236f5485afde28ef153d3a20ca3f330cca9e Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Sat, 6 Jun 2026 22:57:05 +0200 Subject: [PATCH] Add all-source nested weight selector --- .../workflows/stimulus-artifact-ensemble.yml | 7 +- src/pymegdec/stimulus_artifact_ensemble.py | 99 ++++++++++++++----- tests/test_stimulus_artifact_ensemble.py | 50 ++++++++++ 3 files changed, 132 insertions(+), 24 deletions(-) diff --git a/.github/workflows/stimulus-artifact-ensemble.yml b/.github/workflows/stimulus-artifact-ensemble.yml index bf32b58..044c6a6 100644 --- a/.github/workflows/stimulus-artifact-ensemble.yml +++ b/.github/workflows/stimulus-artifact-ensemble.yml @@ -44,9 +44,12 @@ on: default: balanced_accuracy,balanced_top2_top3_rank_lcb type: string nested_weight_selector_ensemble: - description: Multi-source ensemble to optimize with leakage-safe leave-subject-out source-weight selection. + description: >- + Multi-source ensemble(s) to optimize with leakage-safe + leave-subject-out source-weight selection. Use a comma-separated + list or 'all'. required: true - default: compact_small_finetune_w150_pca128 + default: all type: string nested_weight_grid_step: description: Simplex grid step for nested source-weight selection. diff --git a/src/pymegdec/stimulus_artifact_ensemble.py b/src/pymegdec/stimulus_artifact_ensemble.py index 9d97f1b..828f44c 100644 --- a/src/pymegdec/stimulus_artifact_ensemble.py +++ b/src/pymegdec/stimulus_artifact_ensemble.py @@ -1494,6 +1494,52 @@ def _counts_text(values: Iterable[str]) -> str: return ";".join(f"{value}:{counts[value]}" for value in sorted(counts, key=_participant_sort_key)) +def _resolve_nested_weight_selector_ensembles( + ensemble_sources: dict[str, Sequence[str]], + raw_ensemble_spec: str | None, +) -> list[str]: + """Return multi-source ensembles requested for nested weight selection. + + The historical behavior selected the first multi-source recipe when no + explicit ensemble was passed. The new ``all``/comma-list syntax lets the + artifact workflow evaluate leakage-safe source-weight grids for every + promising recipe in one invocation, without changing the old default API + behavior for programmatic callers. + """ + + multi_source_ensembles = [ + name + for name, source_names in ensemble_sources.items() + if len(source_names) > 1 + ] + if not multi_source_ensembles: + raise ValueError("Nested weight selection requires at least one multi-source ensemble.") + + raw = "" if raw_ensemble_spec is None else str(raw_ensemble_spec).strip() + if raw == "": + return [multi_source_ensembles[0]] + + requested = [token.strip() for token in raw.split(",") if token.strip()] + if any(token.lower() in {"all", "*"} for token in requested): + if len(requested) != 1: + raise ValueError( + "Use 'all' by itself for nested source-weight selection, " + "not mixed with explicit ensemble names." + ) + return multi_source_ensembles + + missing = [name for name in requested if name not in ensemble_sources] + if missing: + raise ValueError(f"Unknown nested weight-selector ensemble(s): {', '.join(missing)}") + single_source = [name for name in requested if len(ensemble_sources[name]) < 2] + if single_source: + raise ValueError( + "Nested weight selection requires multi-source ensembles; " + f"got single-source {', '.join(single_source)}" + ) + return list(dict.fromkeys(requested)) + + @dataclass(frozen=True) class WeightCandidate: candidate_index: int @@ -1989,28 +2035,33 @@ def ensemble_prediction_sources( artifacts["group_summary"].append(nested_summary) artifacts["nested_selection"] = nested_selection if nested_weight_selector_name: - if nested_weight_selector_ensemble is None: - multi_source_ensembles = [name for name in ensemble_sources if len(ensemble_sources[name]) > 1] - if not multi_source_ensembles: - raise ValueError("Nested weight selection requires at least one multi-source ensemble.") - nested_weight_selector_ensemble = multi_source_ensembles[0] - weight_predictions, weight_outer, weight_selection, weight_summary = _nested_source_weight_selector( - selector_name=nested_weight_selector_name, - selector_ensemble=nested_weight_selector_ensemble, - ensemble_sources=ensemble_sources, - indexed_sources=indexed_sources, - key_columns=key_columns, - class_labels=class_labels, - n_classes=len(class_labels), - score_normalization=score_normalization, - aggregation_mode=aggregation_mode, - grid_step=nested_weight_grid_step, - nested_selection_metric=nested_selection_metric, + selector_ensembles = _resolve_nested_weight_selector_ensembles( + ensemble_sources, + nested_weight_selector_ensemble, ) - artifacts["predictions"].extend(weight_predictions) - artifacts["outer"].extend(weight_outer) - artifacts["group_summary"].append(weight_summary) - artifacts["nested_weight_selection"] = weight_selection + nested_weight_selection_rows: list[dict] = [] + for selector_ensemble in selector_ensembles: + selector_name = nested_weight_selector_name + if len(selector_ensembles) > 1: + selector_name = f"{nested_weight_selector_name}_{selector_ensemble}" + weight_predictions, weight_outer, weight_selection, weight_summary = _nested_source_weight_selector( + selector_name=selector_name, + selector_ensemble=selector_ensemble, + ensemble_sources=ensemble_sources, + indexed_sources=indexed_sources, + key_columns=key_columns, + class_labels=class_labels, + n_classes=len(class_labels), + score_normalization=score_normalization, + aggregation_mode=aggregation_mode, + grid_step=nested_weight_grid_step, + nested_selection_metric=nested_selection_metric, + ) + artifacts["predictions"].extend(weight_predictions) + artifacts["outer"].extend(weight_outer) + artifacts["group_summary"].append(weight_summary) + nested_weight_selection_rows.extend(weight_selection) + artifacts["nested_weight_selection"] = nested_weight_selection_rows return artifacts @@ -2093,7 +2144,11 @@ def main(argv: list[str] | None = None) -> int: ) parser.add_argument( "--nested-weight-selector-ensemble", - help="Ensemble whose source weights should be grid-selected by --nested-weight-selector-name. Defaults to the first multi-source ensemble.", + help=( + "Ensemble(s) whose source weights should be grid-selected by " + "--nested-weight-selector-name. Use a comma-separated list or 'all'. " + "Defaults to the first multi-source ensemble." + ), ) parser.add_argument( "--nested-weight-grid-step", diff --git a/tests/test_stimulus_artifact_ensemble.py b/tests/test_stimulus_artifact_ensemble.py index 1e1b810..3d90607 100644 --- a/tests/test_stimulus_artifact_ensemble.py +++ b/tests/test_stimulus_artifact_ensemble.py @@ -751,6 +751,56 @@ def test_nested_weight_selector_can_use_rank_aware_metric(self) -> None: participant_1 = next(row for row in nested_predictions if row["test_participant"] == "1") self.assertEqual(participant_1["predicted_label"], 1) + def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None: + source_a = _source( + "source_a", + [ + _participant_scored_row(1, 0, 1, 0.10, 0.90), + _participant_scored_row(2, 0, 0, 0.90, 0.10), + _participant_scored_row(3, 0, 0, 0.90, 0.10), + ], + ) + source_b = _source( + "source_b", + [ + _participant_scored_row(1, 0, 0, 0.90, 0.10), + _participant_scored_row(2, 0, 1, 0.10, 0.90), + _participant_scored_row(3, 0, 1, 0.10, 0.90), + ], + ) + source_c = _source( + "source_c", + [ + _participant_scored_row(1, 0, 0, 0.85, 0.15), + _participant_scored_row(2, 0, 1, 0.15, 0.85), + _participant_scored_row(3, 0, 1, 0.15, 0.85), + ], + ) + + artifacts = ensemble_prediction_sources( + [source_a, source_b, source_c], + [ + ("source_a_b", ("source_a", "source_b")), + ("source_a_c", ("source_a", "source_c")), + ], + aggregation_mode="mean_score", + nested_weight_selector_name="nested_weight_selector", + nested_weight_selector_ensemble="all", + nested_weight_grid_step=1.0, + ) + + selector_names = { + row["artifact_ensemble"] + for row in artifacts["group_summary"] + if str(row["artifact_ensemble"]).startswith("nested_weight_selector") + } + self.assertEqual( + selector_names, + {"nested_weight_selector_source_a_b", "nested_weight_selector_source_a_c"}, + ) + selection_names = {row["artifact_ensemble"] for row in artifacts["nested_weight_selection"]} + self.assertEqual(selection_names, selector_names) + if __name__ == "__main__": unittest.main()