Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/stimulus-artifact-ensemble.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,12 @@ on:
default: balanced_accuracy,balanced_top2_top3_rank_lcb
type: string
nested_weight_selector_ensemble:
description: Multi-source ensemble to optimize with leakage-safe leave-subject-out source-weight selection.
description: >-
Multi-source ensemble(s) to optimize with leakage-safe
leave-subject-out source-weight selection. Use a comma-separated
list or 'all'.
required: true
default: compact_small_finetune_w150_pca128
default: all
type: string
nested_weight_grid_step:
description: Simplex grid step for nested source-weight selection.
Expand Down
99 changes: 77 additions & 22 deletions src/pymegdec/stimulus_artifact_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,6 +1699,52 @@ def _counts_text(values: Iterable[str]) -> str:
return ";".join(f"{value}:{counts[value]}" for value in sorted(counts, key=_participant_sort_key))


def _resolve_nested_weight_selector_ensembles(
ensemble_sources: dict[str, Sequence[str]],
raw_ensemble_spec: str | None,
) -> list[str]:
"""Return multi-source ensembles requested for nested weight selection.

The historical behavior selected the first multi-source recipe when no
explicit ensemble was passed. The new ``all``/comma-list syntax lets the
artifact workflow evaluate leakage-safe source-weight grids for every
promising recipe in one invocation, without changing the old default API
behavior for programmatic callers.
"""

multi_source_ensembles = [
name
for name, source_names in ensemble_sources.items()
if len(source_names) > 1
]
if not multi_source_ensembles:
raise ValueError("Nested weight selection requires at least one multi-source ensemble.")

raw = "" if raw_ensemble_spec is None else str(raw_ensemble_spec).strip()
if raw == "":
return [multi_source_ensembles[0]]

requested = [token.strip() for token in raw.split(",") if token.strip()]
if any(token.lower() in {"all", "*"} for token in requested):
if len(requested) != 1:
raise ValueError(
"Use 'all' by itself for nested source-weight selection, "
"not mixed with explicit ensemble names."
)
return multi_source_ensembles

missing = [name for name in requested if name not in ensemble_sources]
if missing:
raise ValueError(f"Unknown nested weight-selector ensemble(s): {', '.join(missing)}")
single_source = [name for name in requested if len(ensemble_sources[name]) < 2]
if single_source:
raise ValueError(
"Nested weight selection requires multi-source ensembles; "
f"got single-source {', '.join(single_source)}"
)
return list(dict.fromkeys(requested))


@dataclass(frozen=True)
class WeightCandidate:
candidate_index: int
Expand Down Expand Up @@ -2206,28 +2252,33 @@ def ensemble_prediction_sources(
artifacts["group_summary"].append(nested_summary)
artifacts["nested_selection"] = nested_selection
if nested_weight_selector_name:
if nested_weight_selector_ensemble is None:
multi_source_ensembles = [name for name in ensemble_sources if len(ensemble_sources[name]) > 1]
if not multi_source_ensembles:
raise ValueError("Nested weight selection requires at least one multi-source ensemble.")
nested_weight_selector_ensemble = multi_source_ensembles[0]
weight_predictions, weight_outer, weight_selection, weight_summary = _nested_source_weight_selector(
selector_name=nested_weight_selector_name,
selector_ensemble=nested_weight_selector_ensemble,
ensemble_sources=ensemble_sources,
indexed_sources=indexed_sources,
key_columns=key_columns,
class_labels=class_labels,
n_classes=len(class_labels),
score_normalization=score_normalization,
aggregation_mode=aggregation_mode,
grid_step=nested_weight_grid_step,
nested_selection_metric=nested_selection_metric,
selector_ensembles = _resolve_nested_weight_selector_ensembles(
ensemble_sources,
nested_weight_selector_ensemble,
)
artifacts["predictions"].extend(weight_predictions)
artifacts["outer"].extend(weight_outer)
artifacts["group_summary"].append(weight_summary)
artifacts["nested_weight_selection"] = weight_selection
nested_weight_selection_rows: list[dict] = []
for selector_ensemble in selector_ensembles:
selector_name = nested_weight_selector_name
if len(selector_ensembles) > 1:
selector_name = f"{nested_weight_selector_name}_{selector_ensemble}"
weight_predictions, weight_outer, weight_selection, weight_summary = _nested_source_weight_selector(
selector_name=selector_name,
selector_ensemble=selector_ensemble,
ensemble_sources=ensemble_sources,
indexed_sources=indexed_sources,
key_columns=key_columns,
class_labels=class_labels,
n_classes=len(class_labels),
score_normalization=score_normalization,
aggregation_mode=aggregation_mode,
grid_step=nested_weight_grid_step,
nested_selection_metric=nested_selection_metric,
)
artifacts["predictions"].extend(weight_predictions)
artifacts["outer"].extend(weight_outer)
artifacts["group_summary"].append(weight_summary)
nested_weight_selection_rows.extend(weight_selection)
artifacts["nested_weight_selection"] = nested_weight_selection_rows
return artifacts


Expand Down Expand Up @@ -2310,7 +2361,11 @@ def main(argv: list[str] | None = None) -> int:
)
parser.add_argument(
"--nested-weight-selector-ensemble",
help="Ensemble whose source weights should be grid-selected by --nested-weight-selector-name. Defaults to the first multi-source ensemble.",
help=(
"Ensemble(s) whose source weights should be grid-selected by "
"--nested-weight-selector-name. Use a comma-separated list or 'all'. "
"Defaults to the first multi-source ensemble."
),
)
parser.add_argument(
"--nested-weight-grid-step",
Expand Down
50 changes: 50 additions & 0 deletions tests/test_stimulus_artifact_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,56 @@ def test_nested_weight_selector_can_use_rank_aware_metric(self) -> None:
participant_1 = next(row for row in nested_predictions if row["test_participant"] == "1")
self.assertEqual(participant_1["predicted_label"], 1)

def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None:
source_a = _source(
"source_a",
[
_participant_scored_row(1, 0, 1, 0.10, 0.90),
_participant_scored_row(2, 0, 0, 0.90, 0.10),
_participant_scored_row(3, 0, 0, 0.90, 0.10),
],
)
source_b = _source(
"source_b",
[
_participant_scored_row(1, 0, 0, 0.90, 0.10),
_participant_scored_row(2, 0, 1, 0.10, 0.90),
_participant_scored_row(3, 0, 1, 0.10, 0.90),
],
)
source_c = _source(
"source_c",
[
_participant_scored_row(1, 0, 0, 0.85, 0.15),
_participant_scored_row(2, 0, 1, 0.15, 0.85),
_participant_scored_row(3, 0, 1, 0.15, 0.85),
],
)

artifacts = ensemble_prediction_sources(
[source_a, source_b, source_c],
[
("source_a_b", ("source_a", "source_b")),
("source_a_c", ("source_a", "source_c")),
],
aggregation_mode="mean_score",
nested_weight_selector_name="nested_weight_selector",
nested_weight_selector_ensemble="all",
nested_weight_grid_step=1.0,
)

selector_names = {
row["artifact_ensemble"]
for row in artifacts["group_summary"]
if str(row["artifact_ensemble"]).startswith("nested_weight_selector")
}
self.assertEqual(
selector_names,
{"nested_weight_selector_source_a_b", "nested_weight_selector_source_a_c"},
)
selection_names = {row["artifact_ensemble"] for row in artifacts["nested_weight_selection"]}
self.assertEqual(selection_names, selector_names)


if __name__ == "__main__":
unittest.main()