From eee920fc3f2e9860df37c95ae18e2675655f6f26 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Sun, 7 Jun 2026 00:38:01 +0200 Subject: [PATCH 1/2] Add prior-corrected artifact aggregation --- .../workflows/stimulus-artifact-ensemble.yml | 2 +- src/pymegdec/stimulus_artifact_ensemble.py | 158 +++++++++++++++++- tests/test_stimulus_artifact_ensemble.py | 30 ++++ 3 files changed, 188 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stimulus-artifact-ensemble.yml b/.github/workflows/stimulus-artifact-ensemble.yml index 0761655..45153a0 100644 --- a/.github/workflows/stimulus-artifact-ensemble.yml +++ b/.github/workflows/stimulus-artifact-ensemble.yml @@ -31,7 +31,7 @@ on: aggregation_modes: description: Comma-separated artifact aggregation modes to compare. required: true - default: auto,hard_vote,mean_score,confidence_weighted_mean_score,entropy_weighted_mean_score,agreement_weighted_mean_score,log_score_mean,score_rank_fusion,reciprocal_rank_fusion,mean_rank,borda,score_tiebreak_first_source,balanced_assignment,balanced_assignment_shrink25,balanced_assignment_shrink50,balanced_assignment_shrink75,uniform_prior_shift,uniform_prior_shift_shrink25,uniform_prior_shift_shrink50,uniform_prior_shift_shrink75 + default: auto,hard_vote,mean_score,confidence_weighted_mean_score,entropy_weighted_mean_score,agreement_weighted_mean_score,log_score_mean,score_rank_fusion,reciprocal_rank_fusion,mean_rank,borda,score_tiebreak_first_source,balanced_assignment,balanced_assignment_shrink25,balanced_assignment_shrink50,balanced_assignment_shrink75,uniform_prior_shift,uniform_prior_shift_shrink25,uniform_prior_shift_shrink50,uniform_prior_shift_shrink75,prior_corrected_mean_score,prior_corrected_mean_score_shrink25,prior_corrected_mean_score_shrink50,prior_corrected_mean_score_shrink75 type: string score_normalizations: description: Comma-separated per-source score normalizations to compare before score aggregation. diff --git a/src/pymegdec/stimulus_artifact_ensemble.py b/src/pymegdec/stimulus_artifact_ensemble.py index 7a0e15c..a436723 100644 --- a/src/pymegdec/stimulus_artifact_ensemble.py +++ b/src/pymegdec/stimulus_artifact_ensemble.py @@ -47,6 +47,10 @@ "uniform_prior_shift_shrink25", "uniform_prior_shift_shrink50", "uniform_prior_shift_shrink75", + "prior_corrected_mean_score", + "prior_corrected_mean_score_shrink25", + "prior_corrected_mean_score_shrink50", + "prior_corrected_mean_score_shrink75", ) ARTIFACT_NESTED_SELECTION_METRIC_CHOICES = ( "balanced_accuracy", @@ -487,6 +491,18 @@ def _normalize_artifact_aggregation_mode(aggregation_mode: str) -> str: "uniform_prior_shift_50": "uniform_prior_shift_shrink50", "uniform_prior_shift_75": "uniform_prior_shift_shrink75", "prior_shift_shrink50": "uniform_prior_shift_shrink50", + "prior_corrected": "prior_corrected_mean_score", + "prior_correction": "prior_corrected_mean_score", + "prior_corrected_score": "prior_corrected_mean_score", + "prior_corrected_score_mean": "prior_corrected_mean_score", + "prior_corrected_mean": "prior_corrected_mean_score", + "prior_corrected_mean_score_shrinkage": "prior_corrected_mean_score_shrink50", + "prior_corrected_mean_score_shrinkage_25": "prior_corrected_mean_score_shrink25", + "prior_corrected_mean_score_shrinkage_50": "prior_corrected_mean_score_shrink50", + "prior_corrected_mean_score_shrinkage_75": "prior_corrected_mean_score_shrink75", + "prior_corrected_score_shrink25": "prior_corrected_mean_score_shrink25", + "prior_corrected_score_shrink50": "prior_corrected_mean_score_shrink50", + "prior_corrected_score_shrink75": "prior_corrected_mean_score_shrink75", } normalized = aliases.get(normalized, normalized) if normalized not in ARTIFACT_AGGREGATION_MODE_CHOICES: @@ -527,6 +543,12 @@ def _is_uniform_prior_shift_mode(aggregation_mode: str) -> bool: return _normalize_artifact_aggregation_mode(aggregation_mode).startswith("uniform_prior_shift") +def _is_prior_corrected_mode(aggregation_mode: str) -> bool: + """Return whether an aggregation mode applies unlabeled class-prior correction.""" + + return _normalize_artifact_aggregation_mode(aggregation_mode).startswith("prior_corrected_mean_score") + + def _balanced_assignment_uniform_alpha(aggregation_mode: str) -> float: """Return the assignment quota shrinkage toward a uniform class prior.""" @@ -568,6 +590,19 @@ def _uniform_prior_shift_alpha(aggregation_mode: str) -> float: return 1.0 +def _prior_corrected_alpha(aggregation_mode: str) -> float: + """Return the shrinkage strength for prior-corrected score aggregation.""" + + normalized = _normalize_artifact_aggregation_mode(aggregation_mode) + if normalized == "prior_corrected_mean_score_shrink25": + return 0.25 + if normalized == "prior_corrected_mean_score_shrink50": + return 0.50 + if normalized == "prior_corrected_mean_score_shrink75": + return 0.75 + return 1.0 + + def _class_value_columns(rows: Sequence[dict[str, str]], patterns: Sequence[tuple[re.Pattern[str], int]]) -> dict[int, str]: common: dict[int, str] | None = None for row in rows: @@ -792,7 +827,9 @@ def _rank_labels_by_scores( *( mode for mode in ARTIFACT_AGGREGATION_MODE_CHOICES - if _is_balanced_assignment_mode(mode) or _is_uniform_prior_shift_mode(mode) + if _is_balanced_assignment_mode(mode) + or _is_uniform_prior_shift_mode(mode) + or _is_prior_corrected_mode(mode) ), } aggregated = None @@ -1546,6 +1583,113 @@ def _apply_uniform_prior_shift_rows( return rows +def _probability_vector_from_row_scores(row: dict[str, object], class_labels: Sequence[int]) -> list[float]: + """Return a non-negative, normalized class vector for prior correction.""" + + values: list[float] = [] + missing: list[str] = [] + for label in class_labels: + column = f"artifact_score_class_{label}" + raw_value = str(row.get(column, "")).strip() + if raw_value == "": + missing.append(column) + continue + value = _to_float(raw_value) + if not math.isfinite(value): + raise ValueError(f"prior_corrected_mean_score requires finite class scores; got {column}={raw_value!r}.") + values.append(value) + if missing: + raise ValueError( + "prior_corrected_mean_score requires class score/probability columns for every class; " + f"missing examples={missing[:5]}." + ) + + if all(value >= 0.0 for value in values): + total = sum(values) + if total > 0.0 and math.isfinite(total): + return [value / total for value in values] + return _softmax(values) + + +def _apply_prior_corrected_rows( + prediction_rows: Sequence[dict[str, object]], + class_labels: Sequence[int], + *, + correction_alpha: float = 1.0, +) -> list[dict[str, object]]: + """Apply unlabeled per-participant class-prior correction to score rows.""" + + rows = [dict(row) for row in prediction_rows] + label_list = [int(label) for label in class_labels] + if not label_list: + return rows + correction_alpha = min(max(float(correction_alpha), 0.0), 1.0) + display_labels = _display_label_map(label_list) + by_participant: dict[str, list[int]] = defaultdict(list) + for row_index, row in enumerate(rows): + by_participant[str(row.get("test_participant", ""))].append(row_index) + + for indices in by_participant.values(): + probability_rows = [ + _probability_vector_from_row_scores(rows[index], label_list) + for index in indices + ] + if not probability_rows: + continue + uniform_prior = 1.0 / len(label_list) + predicted_mass = [ + sum(probability_row[class_index] for probability_row in probability_rows) / len(probability_rows) + for class_index in range(len(label_list)) + ] + correction_factors = [ + (1.0 - correction_alpha) + correction_alpha * uniform_prior / max(mass, 1e-12) + for mass in predicted_mass + ] + mass_text = ";".join( + f"{label}:{mass:.6g}" + for label, mass in zip(label_list, predicted_mass, strict=True) + ) + mode_suffix = "" if correction_alpha >= 1.0 - 1e-12 else f"_shrink{int(round(100.0 * correction_alpha)):02d}" + for row_index, probabilities in zip(indices, probability_rows, strict=True): + row = rows[row_index] + corrected_values = [ + max(0.0, probability) * factor + for probability, factor in zip(probabilities, correction_factors, strict=True) + ] + corrected_total = sum(corrected_values) + if corrected_total <= 0.0 or not math.isfinite(corrected_total): + corrected_values = [uniform_prior for _ in label_list] + corrected_total = 1.0 + corrected_scores = { + label: corrected_values[class_index] / corrected_total + for class_index, label in enumerate(label_list) + } + ranked_labels = sorted(label_list, key=lambda label: (-corrected_scores[label], label)) + predicted_label = int(ranked_labels[0]) + true_label = _to_int(row["true_label"], field="true_label") + row["predicted_label"] = predicted_label + row["predicted_stimulus"] = display_labels.get(predicted_label, predicted_label) + row["correct"] = predicted_label == true_label + row["artifact_ensemble_mode"] = f"class_score_prior_corrected_mean{mode_suffix}" + row["artifact_ensemble_prior_correction_alpha"] = f"{correction_alpha:.6g}" + row["artifact_ensemble_prior_correction_class_mass"] = mass_text + row["artifact_ensemble_rank_source"] = row["artifact_ensemble_mode"] + _add_score_alias_columns( + row, + scores=corrected_scores, + class_labels=label_list, + display_labels=display_labels, + ) + _update_rank_metrics_from_labels( + row, + ranked_labels=ranked_labels, + true_label=true_label, + class_labels=label_list, + display_labels=display_labels, + ) + return rows + + def _outer_rows(ensemble_name: str, prediction_rows: Sequence[dict[str, object]], *, n_classes: int) -> list[dict[str, object]]: rows: list[dict[str, object]] = [] by_participant: dict[str, list[dict[str, object]]] = defaultdict(list) @@ -1764,6 +1908,12 @@ def _nested_source_weight_selector( class_labels, alpha=_uniform_prior_shift_alpha(aggregation_mode), ) + if _is_prior_corrected_mode(aggregation_mode): + prediction_rows = _apply_prior_corrected_rows( + prediction_rows, + class_labels, + correction_alpha=_prior_corrected_alpha(aggregation_mode), + ) outer_rows = _outer_rows(f"{selector_name}__candidate_{candidate_index}", prediction_rows, n_classes=n_classes) by_participant: dict[str, list[dict]] = defaultdict(list) for row in prediction_rows: @@ -2107,6 +2257,12 @@ def ensemble_prediction_sources( class_labels, alpha=_uniform_prior_shift_alpha(aggregation_mode), ) + if _is_prior_corrected_mode(aggregation_mode): + prediction_rows = _apply_prior_corrected_rows( + prediction_rows, + class_labels, + correction_alpha=_prior_corrected_alpha(aggregation_mode), + ) outer_rows = _outer_rows(ensemble_name, prediction_rows, n_classes=len(class_labels)) summary = _group_summary( ensemble_name, diff --git a/tests/test_stimulus_artifact_ensemble.py b/tests/test_stimulus_artifact_ensemble.py index ddcd173..d143d64 100644 --- a/tests/test_stimulus_artifact_ensemble.py +++ b/tests/test_stimulus_artifact_ensemble.py @@ -597,6 +597,36 @@ def test_uniform_prior_shift_debiases_participant_score_distribution(self) -> No self.assertEqual(predictions[1]["rank_class_1"], 1) self.assertEqual(shifted["group_summary"][0]["balanced_accuracy_mean"], 1.0) + def test_prior_corrected_mean_score_softly_rebalances_subject_score_mass(self) -> None: + latent = _source( + "latent", + [ + _multi_two_class_scored_row(1, 0, 0, 0.90, 0.10), + _multi_two_class_scored_row(2, 0, 0, 0.80, 0.20), + _multi_two_class_scored_row(3, 1, 0, 0.51, 0.49), + _multi_two_class_scored_row(4, 1, 0, 0.50, 0.50), + ], + ) + + mean_artifacts = ensemble_prediction_sources( + [latent], + [("mean", ("latent",))], + aggregation_mode="mean_score", + ) + prior_artifacts = ensemble_prediction_sources( + [latent], + [("prior", ("latent",))], + aggregation_mode="prior_corrected_mean_score", + ) + + self.assertEqual([row["predicted_label"] for row in mean_artifacts["predictions"]], [0, 0, 0, 0]) + predictions = prior_artifacts["predictions"] + self.assertEqual([row["predicted_label"] for row in predictions], [0, 0, 1, 1]) + self.assertEqual({row["artifact_ensemble_mode"] for row in predictions}, {"class_score_prior_corrected_mean"}) + self.assertEqual({row["artifact_ensemble_prior_correction_alpha"] for row in predictions}, {"1"}) + self.assertAlmostEqual(float(predictions[0]["score_class_0"]) + float(predictions[0]["score_class_1"]), 1.0) + self.assertEqual(prior_artifacts["group_summary"][0]["balanced_accuracy_mean"], 1.0) + def test_rejects_misaligned_source_prediction_keys(self) -> None: compact = _source("compact", [_row(1, 1, 0, 0)]) finetune = _source("finetune", [_row(1, 2, 0, 0)]) From a0205838f2314c6aa1391fc0f1484b34e88c7a25 Mon Sep 17 00:00:00 2001 From: Florian Pfaff Date: Tue, 9 Jun 2026 14:35:10 +0200 Subject: [PATCH 2/2] Deduplicate paired-delta test fixtures --- tests/test_stimulus_artifact_ensemble.py | 54 ++++++++----------- ...us_source_inner_stacking_postprocessing.py | 21 ++++---- 2 files changed, 32 insertions(+), 43 deletions(-) diff --git a/tests/test_stimulus_artifact_ensemble.py b/tests/test_stimulus_artifact_ensemble.py index 391dc5c..55a9dc4 100644 --- a/tests/test_stimulus_artifact_ensemble.py +++ b/tests/test_stimulus_artifact_ensemble.py @@ -121,6 +121,26 @@ def _participant_three_score_row( return row +def _nested_weight_pair_sources() -> tuple[PredictionSource, PredictionSource]: + source_a = _source( + "source_a", + [ + _participant_scored_row(1, 0, 1, 0.10, 0.90), + _participant_scored_row(2, 0, 0, 0.90, 0.10), + _participant_scored_row(3, 0, 0, 0.90, 0.10), + ], + ) + source_b = _source( + "source_b", + [ + _participant_scored_row(1, 0, 0, 0.90, 0.10), + _participant_scored_row(2, 0, 1, 0.10, 0.90), + _participant_scored_row(3, 0, 1, 0.10, 0.90), + ], + ) + return source_a, source_b + + def _ranked_row(true_label: int, predicted_label: int, class_0_rank: float, class_1_rank: float) -> dict[str, str]: return { **_row(1, 1, true_label, predicted_label, true_label_rank=class_0_rank if true_label == 0 else class_1_rank), @@ -782,22 +802,7 @@ def test_nested_subject_selector_can_use_rank_aware_metric(self) -> None: self.assertEqual(nested_summary["selection_metric_name"], "balanced_top2_top3_rank") def test_nested_weight_selector_uses_other_subjects_only(self) -> None: - source_a = _source( - "source_a", - [ - _participant_scored_row(1, 0, 1, 0.10, 0.90), - _participant_scored_row(2, 0, 0, 0.90, 0.10), - _participant_scored_row(3, 0, 0, 0.90, 0.10), - ], - ) - source_b = _source( - "source_b", - [ - _participant_scored_row(1, 0, 0, 0.90, 0.10), - _participant_scored_row(2, 0, 1, 0.10, 0.90), - _participant_scored_row(3, 0, 1, 0.10, 0.90), - ], - ) + source_a, source_b = _nested_weight_pair_sources() artifacts = ensemble_prediction_sources( [source_a, source_b], @@ -866,22 +871,7 @@ def test_nested_weight_selector_can_use_rank_aware_metric(self) -> None: self.assertEqual(participant_1["predicted_label"], 1) def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None: - source_a = _source( - "source_a", - [ - _participant_scored_row(1, 0, 1, 0.10, 0.90), - _participant_scored_row(2, 0, 0, 0.90, 0.10), - _participant_scored_row(3, 0, 0, 0.90, 0.10), - ], - ) - source_b = _source( - "source_b", - [ - _participant_scored_row(1, 0, 0, 0.90, 0.10), - _participant_scored_row(2, 0, 1, 0.10, 0.90), - _participant_scored_row(3, 0, 1, 0.10, 0.90), - ], - ) + source_a, source_b = _nested_weight_pair_sources() source_c = _source( "source_c", [ diff --git a/tests/test_stimulus_source_inner_stacking_postprocessing.py b/tests/test_stimulus_source_inner_stacking_postprocessing.py index 96ecbd3..4211545 100644 --- a/tests/test_stimulus_source_inner_stacking_postprocessing.py +++ b/tests/test_stimulus_source_inner_stacking_postprocessing.py @@ -15,7 +15,7 @@ def _stack_config(**kwargs): ) -def test_stacker_postprocessing_none_returns_argmax_predictions(): +def _postprocessing_inputs(): classes = np.asarray([1, 2, 3]) scores = np.asarray( [ @@ -24,11 +24,17 @@ def test_stacker_postprocessing_none_returns_argmax_predictions(): [2.0, 0.0, 4.0], ] ) + source_labels = np.repeat(classes, 4) + return scores, classes, source_labels + + +def test_stacker_postprocessing_none_returns_argmax_predictions(): + scores, classes, source_labels = _postprocessing_inputs() predictions, metadata = _postprocess_stacked_predictions( scores, classes, - np.repeat(classes, 4), + source_labels, _stack_config(), ) @@ -37,19 +43,12 @@ def test_stacker_postprocessing_none_returns_argmax_predictions(): def test_stacker_postprocessing_can_reuse_source_prior_balanced_assignment(): - classes = np.asarray([1, 2, 3]) - scores = np.asarray( - [ - [3.0, 2.0, 0.0], - [2.9, 2.8, 0.0], - [2.0, 0.0, 4.0], - ] - ) + scores, classes, source_labels = _postprocessing_inputs() predictions, metadata = _postprocess_stacked_predictions( scores, classes, - np.repeat(classes, 4), + source_labels, _stack_config(prediction_postprocessing="source_prior_balanced_assignment"), )