Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/stimulus-artifact-ensemble.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ on:
nested_selection_metrics:
description: Comma-separated leave-subject-out selector metrics to compare.
required: true
default: balanced_accuracy,balanced_top2_top3_rank_lcb
default: balanced_accuracy,balanced_accuracy_delta_lcb,balanced_top2_top3_rank_lcb,balanced_top2_top3_rank_delta_lcb
type: string
nested_weight_selector_ensemble:
description: >-
Expand Down
118 changes: 109 additions & 9 deletions src/pymegdec/stimulus_artifact_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@
ARTIFACT_NESTED_SELECTION_METRIC_CHOICES = (
"balanced_accuracy",
"balanced_accuracy_lcb",
"balanced_accuracy_delta",
"balanced_accuracy_delta_lcb",
"balanced_top2_top3_rank",
"balanced_top2_top3_rank_lcb",
"balanced_top2_top3_rank_delta",
"balanced_top2_top3_rank_delta_lcb",
)


Expand Down Expand Up @@ -502,9 +506,21 @@ def _normalize_artifact_nested_selection_metric(selection_metric: str) -> str:
aliases = {
"balanced": "balanced_accuracy",
"balanced_lcb": "balanced_accuracy_lcb",
"balanced_delta": "balanced_accuracy_delta",
"balanced_delta_lcb": "balanced_accuracy_delta_lcb",
"paired_balanced": "balanced_accuracy_delta",
"paired_balanced_lcb": "balanced_accuracy_delta_lcb",
"paired_balanced_delta": "balanced_accuracy_delta",
"paired_balanced_delta_lcb": "balanced_accuracy_delta_lcb",
"balanced_rank": "balanced_top2_top3_rank",
"balanced_rank_lcb": "balanced_top2_top3_rank_lcb",
"rank_lcb": "balanced_top2_top3_rank_lcb",
"balanced_rank_delta": "balanced_top2_top3_rank_delta",
"balanced_rank_delta_lcb": "balanced_top2_top3_rank_delta_lcb",
"rank_delta": "balanced_top2_top3_rank_delta",
"rank_delta_lcb": "balanced_top2_top3_rank_delta_lcb",
"paired_rank": "balanced_top2_top3_rank_delta",
"paired_rank_lcb": "balanced_top2_top3_rank_delta_lcb",
}
normalized = aliases.get(normalized, normalized)
if normalized not in ARTIFACT_NESTED_SELECTION_METRIC_CHOICES:
Expand Down Expand Up @@ -1679,17 +1695,69 @@ def _artifact_recipe_rank_score(row: dict[str, object], *, n_classes: int) -> fl
return balanced + 0.25 * (top2 - top2_chance) + 0.125 * (top3 - top3_chance) + 0.10 * rank_gain


def _nested_selection_metric_value(rows: Sequence[dict[str, object]], *, selection_metric: str, n_classes: int) -> float:
def _nested_selection_metric_base(selection_metric: str) -> str:
"""Return the non-delta metric family used for one outer-row score."""

normalized = _normalize_artifact_nested_selection_metric(selection_metric)
if normalized == "balanced_accuracy":
return _metric_mean(rows, "balanced_accuracy")
if normalized == "balanced_accuracy_lcb":
values = _outer_metric_values(rows, "balanced_accuracy")
return _mean(values) - _sem(values)
if normalized.startswith("balanced_accuracy"):
return "balanced_accuracy"
return "balanced_top2_top3_rank"


def _nested_selection_uses_lcb(selection_metric: str) -> bool:
return _normalize_artifact_nested_selection_metric(selection_metric).endswith("_lcb")


def _nested_selection_is_delta(selection_metric: str) -> bool:
return "_delta" in _normalize_artifact_nested_selection_metric(selection_metric)


def _nested_selection_row_score(row: dict[str, object], *, selection_metric: str, n_classes: int) -> float:
base = _nested_selection_metric_base(selection_metric)
if base == "balanced_accuracy":
return _to_float(row["balanced_accuracy"])
return _artifact_recipe_rank_score(row, n_classes=n_classes)

values = [_artifact_recipe_rank_score(row, n_classes=n_classes) for row in rows]

def _outer_rows_by_participant(rows: Sequence[dict[str, object]]) -> dict[str, dict[str, object]]:
return {str(row.get("test_participant", "")): row for row in rows}


def _nested_selection_metric_value(
rows: Sequence[dict[str, object]],
*,
selection_metric: str,
n_classes: int,
reference_rows: Sequence[dict[str, object]] | None = None,
) -> float:
normalized = _normalize_artifact_nested_selection_metric(selection_metric)
if _nested_selection_is_delta(normalized):
if reference_rows is None:
raise ValueError(f"Nested selection metric {normalized!r} requires reference rows.")
reference_by_participant = _outer_rows_by_participant(reference_rows)
paired_deltas: list[float] = []
for row in rows:
participant = str(row.get("test_participant", ""))
reference = reference_by_participant.get(participant)
if reference is None:
continue
paired_deltas.append(
_nested_selection_row_score(row, selection_metric=normalized, n_classes=n_classes)
- _nested_selection_row_score(reference, selection_metric=normalized, n_classes=n_classes)
)
if not paired_deltas:
return float("-inf")
score = _mean(paired_deltas)
if _nested_selection_uses_lcb(normalized):
score -= _sem(paired_deltas)
return score

values = [
_nested_selection_row_score(row, selection_metric=normalized, n_classes=n_classes)
for row in rows
]
score = _mean(values)
if normalized.endswith("_lcb"):
if _nested_selection_uses_lcb(normalized):
score -= _sem(values)
return score

Expand Down Expand Up @@ -1897,6 +1965,7 @@ def _nested_source_weight_selector(
)
if not participants:
raise ValueError("Nested source-weight selector requires test_participant values.")
reference_candidate = min(candidates, key=lambda candidate: (candidate.uniform_distance, candidate.candidate_index))

selected_predictions: list[dict] = []
selection_rows: list[dict] = []
Expand All @@ -1910,10 +1979,16 @@ def _nested_source_weight_selector(
]
if not train_outer_rows:
raise ValueError(f"Cannot select artifact source weights for participant {participant}; no source subjects remain.")
reference_train_outer_rows = [
row
for other_participant, row in reference_candidate.outer_by_participant.items()
if other_participant != participant
]
selection_score = _nested_selection_metric_value(
train_outer_rows,
selection_metric=nested_selection_metric,
n_classes=n_classes,
reference_rows=reference_train_outer_rows,
)
balanced = _metric_mean(train_outer_rows, "balanced_accuracy")
scored_candidates.append(
Expand Down Expand Up @@ -1941,6 +2016,10 @@ def _nested_source_weight_selector(
"selected_artifact_ensemble_sources": ";".join(source_names),
"selected_source_weights": weight_text,
"selected_weight_grid_step": grid_step,
"reference_source_weights": reference_candidate.weight_text,
"reference_weight_grid_step": grid_step,
"reference_weight_candidate_index": reference_candidate.candidate_index,
"selected_weight_candidate_index": selected_candidate.candidate_index,
"selection_metric": _nested_selection_metric_label(nested_selection_metric),
"selection_metric_name": nested_selection_metric,
"selection_metric_value": selected_score,
Expand All @@ -1961,6 +2040,10 @@ def _nested_source_weight_selector(
selected_row["selected_artifact_ensemble_sources"] = ";".join(source_names)
selected_row["selected_source_weights"] = weight_text
selected_row["selected_weight_grid_step"] = grid_step
selected_row["reference_source_weights"] = reference_candidate.weight_text
selected_row["reference_weight_grid_step"] = grid_step
selected_row["reference_weight_candidate_index"] = reference_candidate.candidate_index
selected_row["selected_weight_candidate_index"] = selected_candidate.candidate_index
selected_row["selection_metric"] = _nested_selection_metric_label(nested_selection_metric)
selected_row["selection_metric_name"] = nested_selection_metric
selected_row["selection_metric_value"] = selected_score
Expand All @@ -1983,6 +2066,7 @@ def _nested_source_weight_selector(
summary["selected_source_weight_counts"] = _counts_text(str(row["selected_source_weights"]) for row in selection_rows)
summary["candidate_source_weight_count"] = len(candidates)
summary["selected_weight_grid_step"] = grid_step
summary["reference_source_weights"] = reference_candidate.weight_text
return selected_predictions, outer_rows, selection_rows, summary


Expand Down Expand Up @@ -2069,6 +2153,12 @@ def _nested_subject_selector(
)
if not participants:
raise ValueError("Nested subject selector requires test_participant values.")
reference_ensemble = ensemble_order[0]
if reference_ensemble not in outer_by_ensemble_participant:
raise ValueError(
f"Nested subject selector reference ensemble {reference_ensemble!r} has no outer rows."
)
reference_outer_by_participant = outer_by_ensemble_participant[reference_ensemble]

selected_predictions: list[dict] = []
selection_rows: list[dict] = []
Expand All @@ -2082,7 +2172,14 @@ def _nested_subject_selector(
]
if not train_outer_rows:
raise ValueError(f"Cannot select an artifact ensemble for participant {participant}; no source subjects remain.")
selection_score = _nested_selection_metric_value(train_outer_rows, selection_metric=nested_selection_metric, n_classes=n_classes)
reference_train_outer_rows = [
row
for other_participant, row in reference_outer_by_participant.items()
if other_participant != participant
]
selection_score = _nested_selection_metric_value(
train_outer_rows, selection_metric=nested_selection_metric, n_classes=n_classes, reference_rows=reference_train_outer_rows
)
candidates.append((selection_score, -ensemble_index, ensemble, train_outer_rows))

selected_score, _, selected_ensemble, train_outer_rows = max(candidates)
Expand All @@ -2097,6 +2194,7 @@ def _nested_subject_selector(
"artifact_ensemble": selector_name,
"selected_artifact_ensemble": selected_ensemble,
"selected_artifact_ensemble_sources": ";".join(selected_sources),
"reference_artifact_ensemble": reference_ensemble,
"selection_metric": _nested_selection_metric_label(nested_selection_metric),
"selection_metric_name": nested_selection_metric,
"selection_metric_value": selected_score,
Expand All @@ -2115,6 +2213,7 @@ def _nested_subject_selector(
selected_row["artifact_ensemble_recipe_selection"] = "leave_subject_out"
selected_row["selected_artifact_ensemble"] = selected_ensemble
selected_row["selected_artifact_ensemble_sources"] = ";".join(selected_sources)
selected_row["reference_artifact_ensemble"] = reference_ensemble
selected_row["selection_metric"] = _nested_selection_metric_label(nested_selection_metric)
selected_row["selection_metric_name"] = nested_selection_metric
selected_row["selection_metric_value"] = selected_score
Expand All @@ -2135,6 +2234,7 @@ def _nested_subject_selector(
summary["selected_artifact_ensemble_counts"] = _counts_text(
str(row["selected_artifact_ensemble"]) for row in selection_rows
)
summary["reference_artifact_ensemble"] = reference_ensemble
return selected_predictions, outer_rows, selection_rows, summary


Expand Down
Loading