Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/stimulus-artifact-ensemble.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ on:
aggregation_modes:
description: Comma-separated artifact aggregation modes to compare.
required: true
default: auto,hard_vote,mean_score,confidence_weighted_mean_score,entropy_weighted_mean_score,agreement_weighted_mean_score,log_score_mean,score_rank_fusion,reciprocal_rank_fusion,mean_rank,borda,score_tiebreak_first_source,balanced_assignment,balanced_assignment_shrink25,balanced_assignment_shrink50,balanced_assignment_shrink75,uniform_prior_shift,uniform_prior_shift_shrink25,uniform_prior_shift_shrink50,uniform_prior_shift_shrink75
default: auto,hard_vote,mean_score,confidence_weighted_mean_score,entropy_weighted_mean_score,agreement_weighted_mean_score,log_score_mean,score_rank_fusion,reciprocal_rank_fusion,mean_rank,borda,score_tiebreak_first_source,balanced_assignment,balanced_assignment_shrink25,balanced_assignment_shrink50,balanced_assignment_shrink75,uniform_prior_shift,uniform_prior_shift_shrink25,uniform_prior_shift_shrink50,uniform_prior_shift_shrink75,prior_corrected_mean_score,prior_corrected_mean_score_shrink25,prior_corrected_mean_score_shrink50,prior_corrected_mean_score_shrink75
type: string
score_normalizations:
description: Comma-separated per-source score normalizations to compare before score aggregation.
Expand Down
158 changes: 157 additions & 1 deletion src/pymegdec/stimulus_artifact_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@
"uniform_prior_shift_shrink25",
"uniform_prior_shift_shrink50",
"uniform_prior_shift_shrink75",
"prior_corrected_mean_score",
"prior_corrected_mean_score_shrink25",
"prior_corrected_mean_score_shrink50",
"prior_corrected_mean_score_shrink75",
)
ARTIFACT_NESTED_SELECTION_METRIC_CHOICES = (
"balanced_accuracy",
Expand Down Expand Up @@ -487,6 +491,18 @@ def _normalize_artifact_aggregation_mode(aggregation_mode: str) -> str:
"uniform_prior_shift_50": "uniform_prior_shift_shrink50",
"uniform_prior_shift_75": "uniform_prior_shift_shrink75",
"prior_shift_shrink50": "uniform_prior_shift_shrink50",
"prior_corrected": "prior_corrected_mean_score",
"prior_correction": "prior_corrected_mean_score",
"prior_corrected_score": "prior_corrected_mean_score",
"prior_corrected_score_mean": "prior_corrected_mean_score",
"prior_corrected_mean": "prior_corrected_mean_score",
"prior_corrected_mean_score_shrinkage": "prior_corrected_mean_score_shrink50",
"prior_corrected_mean_score_shrinkage_25": "prior_corrected_mean_score_shrink25",
"prior_corrected_mean_score_shrinkage_50": "prior_corrected_mean_score_shrink50",
"prior_corrected_mean_score_shrinkage_75": "prior_corrected_mean_score_shrink75",
"prior_corrected_score_shrink25": "prior_corrected_mean_score_shrink25",
"prior_corrected_score_shrink50": "prior_corrected_mean_score_shrink50",
"prior_corrected_score_shrink75": "prior_corrected_mean_score_shrink75",
}
normalized = aliases.get(normalized, normalized)
if normalized not in ARTIFACT_AGGREGATION_MODE_CHOICES:
Expand Down Expand Up @@ -527,6 +543,12 @@ def _is_uniform_prior_shift_mode(aggregation_mode: str) -> bool:
return _normalize_artifact_aggregation_mode(aggregation_mode).startswith("uniform_prior_shift")


def _is_prior_corrected_mode(aggregation_mode: str) -> bool:
"""Return whether an aggregation mode applies unlabeled class-prior correction."""

return _normalize_artifact_aggregation_mode(aggregation_mode).startswith("prior_corrected_mean_score")


def _balanced_assignment_uniform_alpha(aggregation_mode: str) -> float:
"""Return the assignment quota shrinkage toward a uniform class prior."""

Expand Down Expand Up @@ -568,6 +590,19 @@ def _uniform_prior_shift_alpha(aggregation_mode: str) -> float:
return 1.0


def _prior_corrected_alpha(aggregation_mode: str) -> float:
"""Return the shrinkage strength for prior-corrected score aggregation."""

normalized = _normalize_artifact_aggregation_mode(aggregation_mode)
if normalized == "prior_corrected_mean_score_shrink25":
return 0.25
if normalized == "prior_corrected_mean_score_shrink50":
return 0.50
if normalized == "prior_corrected_mean_score_shrink75":
return 0.75
return 1.0


def _common_value_columns(rows: Sequence[dict[str, str]], pattern: re.Pattern[str], *, offset: int = 0) -> dict[int, str]:
"""Return value columns common to all rows for one regex pattern."""

Expand Down Expand Up @@ -857,7 +892,9 @@ def _rank_labels_by_scores(
*(
mode
for mode in ARTIFACT_AGGREGATION_MODE_CHOICES
if _is_balanced_assignment_mode(mode) or _is_uniform_prior_shift_mode(mode)
if _is_balanced_assignment_mode(mode)
or _is_uniform_prior_shift_mode(mode)
or _is_prior_corrected_mode(mode)
),
}
aggregated = None
Expand Down Expand Up @@ -1612,6 +1649,113 @@ def _apply_uniform_prior_shift_rows(
return rows


def _probability_vector_from_row_scores(row: dict[str, object], class_labels: Sequence[int]) -> list[float]:
"""Return a non-negative, normalized class vector for prior correction."""

values: list[float] = []
missing: list[str] = []
for label in class_labels:
column = f"artifact_score_class_{label}"
raw_value = str(row.get(column, "")).strip()
if raw_value == "":
missing.append(column)
continue
value = _to_float(raw_value)
if not math.isfinite(value):
raise ValueError(f"prior_corrected_mean_score requires finite class scores; got {column}={raw_value!r}.")
values.append(value)
if missing:
raise ValueError(
"prior_corrected_mean_score requires class score/probability columns for every class; "
f"missing examples={missing[:5]}."
)

if all(value >= 0.0 for value in values):
total = sum(values)
if total > 0.0 and math.isfinite(total):
return [value / total for value in values]
return _softmax(values)


def _apply_prior_corrected_rows(
prediction_rows: Sequence[dict[str, object]],
class_labels: Sequence[int],
*,
correction_alpha: float = 1.0,
) -> list[dict[str, object]]:
"""Apply unlabeled per-participant class-prior correction to score rows."""

rows = [dict(row) for row in prediction_rows]
label_list = [int(label) for label in class_labels]
if not label_list:
return rows
correction_alpha = min(max(float(correction_alpha), 0.0), 1.0)
display_labels = _display_label_map(label_list)
by_participant: dict[str, list[int]] = defaultdict(list)
for row_index, row in enumerate(rows):
by_participant[str(row.get("test_participant", ""))].append(row_index)

for indices in by_participant.values():
probability_rows = [
_probability_vector_from_row_scores(rows[index], label_list)
for index in indices
]
if not probability_rows:
continue
uniform_prior = 1.0 / len(label_list)
predicted_mass = [
sum(probability_row[class_index] for probability_row in probability_rows) / len(probability_rows)
for class_index in range(len(label_list))
]
correction_factors = [
(1.0 - correction_alpha) + correction_alpha * uniform_prior / max(mass, 1e-12)
for mass in predicted_mass
]
mass_text = ";".join(
f"{label}:{mass:.6g}"
for label, mass in zip(label_list, predicted_mass, strict=True)
)
mode_suffix = "" if correction_alpha >= 1.0 - 1e-12 else f"_shrink{int(round(100.0 * correction_alpha)):02d}"
for row_index, probabilities in zip(indices, probability_rows, strict=True):
row = rows[row_index]
corrected_values = [
max(0.0, probability) * factor
for probability, factor in zip(probabilities, correction_factors, strict=True)
]
corrected_total = sum(corrected_values)
if corrected_total <= 0.0 or not math.isfinite(corrected_total):
corrected_values = [uniform_prior for _ in label_list]
corrected_total = 1.0
corrected_scores = {
label: corrected_values[class_index] / corrected_total
for class_index, label in enumerate(label_list)
}
ranked_labels = sorted(label_list, key=lambda label: (-corrected_scores[label], label))
predicted_label = int(ranked_labels[0])
true_label = _to_int(row["true_label"], field="true_label")
row["predicted_label"] = predicted_label
row["predicted_stimulus"] = display_labels.get(predicted_label, predicted_label)
row["correct"] = predicted_label == true_label
row["artifact_ensemble_mode"] = f"class_score_prior_corrected_mean{mode_suffix}"
row["artifact_ensemble_prior_correction_alpha"] = f"{correction_alpha:.6g}"
row["artifact_ensemble_prior_correction_class_mass"] = mass_text
row["artifact_ensemble_rank_source"] = row["artifact_ensemble_mode"]
_add_score_alias_columns(
row,
scores=corrected_scores,
class_labels=label_list,
display_labels=display_labels,
)
_update_rank_metrics_from_labels(
row,
ranked_labels=ranked_labels,
true_label=true_label,
class_labels=label_list,
display_labels=display_labels,
)
return rows


def _outer_rows(ensemble_name: str, prediction_rows: Sequence[dict[str, object]], *, n_classes: int) -> list[dict[str, object]]:
rows: list[dict[str, object]] = []
by_participant: dict[str, list[dict[str, object]]] = defaultdict(list)
Expand Down Expand Up @@ -1876,6 +2020,12 @@ def _nested_source_weight_selector(
class_labels,
alpha=_uniform_prior_shift_alpha(aggregation_mode),
)
if _is_prior_corrected_mode(aggregation_mode):
prediction_rows = _apply_prior_corrected_rows(
prediction_rows,
class_labels,
correction_alpha=_prior_corrected_alpha(aggregation_mode),
)
outer_rows = _outer_rows(f"{selector_name}__candidate_{candidate_index}", prediction_rows, n_classes=n_classes)
by_participant: dict[str, list[dict]] = defaultdict(list)
for row in prediction_rows:
Expand Down Expand Up @@ -2219,6 +2369,12 @@ def ensemble_prediction_sources(
class_labels,
alpha=_uniform_prior_shift_alpha(aggregation_mode),
)
if _is_prior_corrected_mode(aggregation_mode):
prediction_rows = _apply_prior_corrected_rows(
prediction_rows,
class_labels,
correction_alpha=_prior_corrected_alpha(aggregation_mode),
)
outer_rows = _outer_rows(ensemble_name, prediction_rows, n_classes=len(class_labels))
summary = _group_summary(
ensemble_name,
Expand Down
84 changes: 52 additions & 32 deletions tests/test_stimulus_artifact_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,26 @@ def _participant_three_score_row(
return row


def _nested_weight_pair_sources() -> tuple[PredictionSource, PredictionSource]:
source_a = _source(
"source_a",
[
_participant_scored_row(1, 0, 1, 0.10, 0.90),
_participant_scored_row(2, 0, 0, 0.90, 0.10),
_participant_scored_row(3, 0, 0, 0.90, 0.10),
],
)
source_b = _source(
"source_b",
[
_participant_scored_row(1, 0, 0, 0.90, 0.10),
_participant_scored_row(2, 0, 1, 0.10, 0.90),
_participant_scored_row(3, 0, 1, 0.10, 0.90),
],
)
return source_a, source_b


def _ranked_row(true_label: int, predicted_label: int, class_0_rank: float, class_1_rank: float) -> dict[str, str]:
return {
**_row(1, 1, true_label, predicted_label, true_label_rank=class_0_rank if true_label == 0 else class_1_rank),
Expand Down Expand Up @@ -647,6 +667,36 @@ def test_uniform_prior_shift_debiases_participant_score_distribution(self) -> No
self.assertEqual(predictions[1]["rank_class_1"], 1)
self.assertEqual(shifted["group_summary"][0]["balanced_accuracy_mean"], 1.0)

def test_prior_corrected_mean_score_softly_rebalances_subject_score_mass(self) -> None:
latent = _source(
"latent",
[
_multi_two_class_scored_row(1, 0, 0, 0.90, 0.10),
_multi_two_class_scored_row(2, 0, 0, 0.80, 0.20),
_multi_two_class_scored_row(3, 1, 0, 0.51, 0.49),
_multi_two_class_scored_row(4, 1, 0, 0.50, 0.50),
],
)

mean_artifacts = ensemble_prediction_sources(
[latent],
[("mean", ("latent",))],
aggregation_mode="mean_score",
)
prior_artifacts = ensemble_prediction_sources(
[latent],
[("prior", ("latent",))],
aggregation_mode="prior_corrected_mean_score",
)

self.assertEqual([row["predicted_label"] for row in mean_artifacts["predictions"]], [0, 0, 0, 0])
predictions = prior_artifacts["predictions"]
self.assertEqual([row["predicted_label"] for row in predictions], [0, 0, 1, 1])
self.assertEqual({row["artifact_ensemble_mode"] for row in predictions}, {"class_score_prior_corrected_mean"})
self.assertEqual({row["artifact_ensemble_prior_correction_alpha"] for row in predictions}, {"1"})
self.assertAlmostEqual(float(predictions[0]["score_class_0"]) + float(predictions[0]["score_class_1"]), 1.0)
self.assertEqual(prior_artifacts["group_summary"][0]["balanced_accuracy_mean"], 1.0)

def test_rejects_misaligned_source_prediction_keys(self) -> None:
compact = _source("compact", [_row(1, 1, 0, 0)])
finetune = _source("finetune", [_row(1, 2, 0, 0)])
Expand Down Expand Up @@ -752,22 +802,7 @@ def test_nested_subject_selector_can_use_rank_aware_metric(self) -> None:
self.assertEqual(nested_summary["selection_metric_name"], "balanced_top2_top3_rank")

def test_nested_weight_selector_uses_other_subjects_only(self) -> None:
source_a = _source(
"source_a",
[
_participant_scored_row(1, 0, 1, 0.10, 0.90),
_participant_scored_row(2, 0, 0, 0.90, 0.10),
_participant_scored_row(3, 0, 0, 0.90, 0.10),
],
)
source_b = _source(
"source_b",
[
_participant_scored_row(1, 0, 0, 0.90, 0.10),
_participant_scored_row(2, 0, 1, 0.10, 0.90),
_participant_scored_row(3, 0, 1, 0.10, 0.90),
],
)
source_a, source_b = _nested_weight_pair_sources()

artifacts = ensemble_prediction_sources(
[source_a, source_b],
Expand Down Expand Up @@ -836,22 +871,7 @@ def test_nested_weight_selector_can_use_rank_aware_metric(self) -> None:
self.assertEqual(participant_1["predicted_label"], 1)

def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None:
source_a = _source(
"source_a",
[
_participant_scored_row(1, 0, 1, 0.10, 0.90),
_participant_scored_row(2, 0, 0, 0.90, 0.10),
_participant_scored_row(3, 0, 0, 0.90, 0.10),
],
)
source_b = _source(
"source_b",
[
_participant_scored_row(1, 0, 0, 0.90, 0.10),
_participant_scored_row(2, 0, 1, 0.10, 0.90),
_participant_scored_row(3, 0, 1, 0.10, 0.90),
],
)
source_a, source_b = _nested_weight_pair_sources()
source_c = _source(
"source_c",
[
Expand Down
21 changes: 10 additions & 11 deletions tests/test_stimulus_source_inner_stacking_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _stack_config(**kwargs):
)


def test_stacker_postprocessing_none_returns_argmax_predictions():
def _postprocessing_inputs():
classes = np.asarray([1, 2, 3])
scores = np.asarray(
[
Expand All @@ -24,11 +24,17 @@ def test_stacker_postprocessing_none_returns_argmax_predictions():
[2.0, 0.0, 4.0],
]
)
source_labels = np.repeat(classes, 4)
return scores, classes, source_labels


def test_stacker_postprocessing_none_returns_argmax_predictions():
scores, classes, source_labels = _postprocessing_inputs()

predictions, metadata = _postprocess_stacked_predictions(
scores,
classes,
np.repeat(classes, 4),
source_labels,
_stack_config(),
)

Expand All @@ -37,19 +43,12 @@ def test_stacker_postprocessing_none_returns_argmax_predictions():


def test_stacker_postprocessing_can_reuse_source_prior_balanced_assignment():
classes = np.asarray([1, 2, 3])
scores = np.asarray(
[
[3.0, 2.0, 0.0],
[2.9, 2.8, 0.0],
[2.0, 0.0, 4.0],
]
)
scores, classes, source_labels = _postprocessing_inputs()

predictions, metadata = _postprocess_stacked_predictions(
scores,
classes,
np.repeat(classes, 4),
source_labels,
_stack_config(prediction_postprocessing="source_prior_balanced_assignment"),
)

Expand Down