diff --git a/.github/workflows/stimulus-artifact-ensemble.yml b/.github/workflows/stimulus-artifact-ensemble.yml index f3a582d..cf11c2e 100644 --- a/.github/workflows/stimulus-artifact-ensemble.yml +++ b/.github/workflows/stimulus-artifact-ensemble.yml @@ -41,7 +41,7 @@ on: nested_selection_metrics: description: Comma-separated leave-subject-out selector metrics to compare. required: true - default: balanced_accuracy,balanced_top2_top3_rank_lcb + default: balanced_accuracy,balanced_accuracy_delta_lcb,balanced_top2_top3_rank_lcb,balanced_top2_top3_rank_delta_lcb type: string nested_weight_selector_ensemble: description: >- diff --git a/src/pymegdec/stimulus_artifact_ensemble.py b/src/pymegdec/stimulus_artifact_ensemble.py index 8fc782f..afb19e8 100644 --- a/src/pymegdec/stimulus_artifact_ensemble.py +++ b/src/pymegdec/stimulus_artifact_ensemble.py @@ -51,8 +51,12 @@ ARTIFACT_NESTED_SELECTION_METRIC_CHOICES = ( "balanced_accuracy", "balanced_accuracy_lcb", + "balanced_accuracy_delta", + "balanced_accuracy_delta_lcb", "balanced_top2_top3_rank", "balanced_top2_top3_rank_lcb", + "balanced_top2_top3_rank_delta", + "balanced_top2_top3_rank_delta_lcb", ) @@ -502,9 +506,21 @@ def _normalize_artifact_nested_selection_metric(selection_metric: str) -> str: aliases = { "balanced": "balanced_accuracy", "balanced_lcb": "balanced_accuracy_lcb", + "balanced_delta": "balanced_accuracy_delta", + "balanced_delta_lcb": "balanced_accuracy_delta_lcb", + "paired_balanced": "balanced_accuracy_delta", + "paired_balanced_lcb": "balanced_accuracy_delta_lcb", + "paired_balanced_delta": "balanced_accuracy_delta", + "paired_balanced_delta_lcb": "balanced_accuracy_delta_lcb", "balanced_rank": "balanced_top2_top3_rank", "balanced_rank_lcb": "balanced_top2_top3_rank_lcb", "rank_lcb": "balanced_top2_top3_rank_lcb", + "balanced_rank_delta": "balanced_top2_top3_rank_delta", + "balanced_rank_delta_lcb": "balanced_top2_top3_rank_delta_lcb", + "rank_delta": "balanced_top2_top3_rank_delta", + "rank_delta_lcb": "balanced_top2_top3_rank_delta_lcb", + "paired_rank": "balanced_top2_top3_rank_delta", + "paired_rank_lcb": "balanced_top2_top3_rank_delta_lcb", } normalized = aliases.get(normalized, normalized) if normalized not in ARTIFACT_NESTED_SELECTION_METRIC_CHOICES: @@ -1679,17 +1695,69 @@ def _artifact_recipe_rank_score(row: dict[str, object], *, n_classes: int) -> fl return balanced + 0.25 * (top2 - top2_chance) + 0.125 * (top3 - top3_chance) + 0.10 * rank_gain -def _nested_selection_metric_value(rows: Sequence[dict[str, object]], *, selection_metric: str, n_classes: int) -> float: +def _nested_selection_metric_base(selection_metric: str) -> str: + """Return the non-delta metric family used for one outer-row score.""" + normalized = _normalize_artifact_nested_selection_metric(selection_metric) - if normalized == "balanced_accuracy": - return _metric_mean(rows, "balanced_accuracy") - if normalized == "balanced_accuracy_lcb": - values = _outer_metric_values(rows, "balanced_accuracy") - return _mean(values) - _sem(values) + if normalized.startswith("balanced_accuracy"): + return "balanced_accuracy" + return "balanced_top2_top3_rank" + + +def _nested_selection_uses_lcb(selection_metric: str) -> bool: + return _normalize_artifact_nested_selection_metric(selection_metric).endswith("_lcb") + + +def _nested_selection_is_delta(selection_metric: str) -> bool: + return "_delta" in _normalize_artifact_nested_selection_metric(selection_metric) + + +def _nested_selection_row_score(row: dict[str, object], *, selection_metric: str, n_classes: int) -> float: + base = _nested_selection_metric_base(selection_metric) + if base == "balanced_accuracy": + return _to_float(row["balanced_accuracy"]) + return _artifact_recipe_rank_score(row, n_classes=n_classes) - values = [_artifact_recipe_rank_score(row, n_classes=n_classes) for row in rows] + +def _outer_rows_by_participant(rows: Sequence[dict[str, object]]) -> dict[str, dict[str, object]]: + return {str(row.get("test_participant", "")): row for row in rows} + + +def _nested_selection_metric_value( + rows: Sequence[dict[str, object]], + *, + selection_metric: str, + n_classes: int, + reference_rows: Sequence[dict[str, object]] | None = None, +) -> float: + normalized = _normalize_artifact_nested_selection_metric(selection_metric) + if _nested_selection_is_delta(normalized): + if reference_rows is None: + raise ValueError(f"Nested selection metric {normalized!r} requires reference rows.") + reference_by_participant = _outer_rows_by_participant(reference_rows) + paired_deltas: list[float] = [] + for row in rows: + participant = str(row.get("test_participant", "")) + reference = reference_by_participant.get(participant) + if reference is None: + continue + paired_deltas.append( + _nested_selection_row_score(row, selection_metric=normalized, n_classes=n_classes) + - _nested_selection_row_score(reference, selection_metric=normalized, n_classes=n_classes) + ) + if not paired_deltas: + return float("-inf") + score = _mean(paired_deltas) + if _nested_selection_uses_lcb(normalized): + score -= _sem(paired_deltas) + return score + + values = [ + _nested_selection_row_score(row, selection_metric=normalized, n_classes=n_classes) + for row in rows + ] score = _mean(values) - if normalized.endswith("_lcb"): + if _nested_selection_uses_lcb(normalized): score -= _sem(values) return score @@ -1897,6 +1965,7 @@ def _nested_source_weight_selector( ) if not participants: raise ValueError("Nested source-weight selector requires test_participant values.") + reference_candidate = min(candidates, key=lambda candidate: (candidate.uniform_distance, candidate.candidate_index)) selected_predictions: list[dict] = [] selection_rows: list[dict] = [] @@ -1910,10 +1979,16 @@ def _nested_source_weight_selector( ] if not train_outer_rows: raise ValueError(f"Cannot select artifact source weights for participant {participant}; no source subjects remain.") + reference_train_outer_rows = [ + row + for other_participant, row in reference_candidate.outer_by_participant.items() + if other_participant != participant + ] selection_score = _nested_selection_metric_value( train_outer_rows, selection_metric=nested_selection_metric, n_classes=n_classes, + reference_rows=reference_train_outer_rows, ) balanced = _metric_mean(train_outer_rows, "balanced_accuracy") scored_candidates.append( @@ -1941,6 +2016,10 @@ def _nested_source_weight_selector( "selected_artifact_ensemble_sources": ";".join(source_names), "selected_source_weights": weight_text, "selected_weight_grid_step": grid_step, + "reference_source_weights": reference_candidate.weight_text, + "reference_weight_grid_step": grid_step, + "reference_weight_candidate_index": reference_candidate.candidate_index, + "selected_weight_candidate_index": selected_candidate.candidate_index, "selection_metric": _nested_selection_metric_label(nested_selection_metric), "selection_metric_name": nested_selection_metric, "selection_metric_value": selected_score, @@ -1961,6 +2040,10 @@ def _nested_source_weight_selector( selected_row["selected_artifact_ensemble_sources"] = ";".join(source_names) selected_row["selected_source_weights"] = weight_text selected_row["selected_weight_grid_step"] = grid_step + selected_row["reference_source_weights"] = reference_candidate.weight_text + selected_row["reference_weight_grid_step"] = grid_step + selected_row["reference_weight_candidate_index"] = reference_candidate.candidate_index + selected_row["selected_weight_candidate_index"] = selected_candidate.candidate_index selected_row["selection_metric"] = _nested_selection_metric_label(nested_selection_metric) selected_row["selection_metric_name"] = nested_selection_metric selected_row["selection_metric_value"] = selected_score @@ -1983,6 +2066,7 @@ def _nested_source_weight_selector( summary["selected_source_weight_counts"] = _counts_text(str(row["selected_source_weights"]) for row in selection_rows) summary["candidate_source_weight_count"] = len(candidates) summary["selected_weight_grid_step"] = grid_step + summary["reference_source_weights"] = reference_candidate.weight_text return selected_predictions, outer_rows, selection_rows, summary @@ -2069,6 +2153,12 @@ def _nested_subject_selector( ) if not participants: raise ValueError("Nested subject selector requires test_participant values.") + reference_ensemble = ensemble_order[0] + if reference_ensemble not in outer_by_ensemble_participant: + raise ValueError( + f"Nested subject selector reference ensemble {reference_ensemble!r} has no outer rows." + ) + reference_outer_by_participant = outer_by_ensemble_participant[reference_ensemble] selected_predictions: list[dict] = [] selection_rows: list[dict] = [] @@ -2082,7 +2172,14 @@ def _nested_subject_selector( ] if not train_outer_rows: raise ValueError(f"Cannot select an artifact ensemble for participant {participant}; no source subjects remain.") - selection_score = _nested_selection_metric_value(train_outer_rows, selection_metric=nested_selection_metric, n_classes=n_classes) + reference_train_outer_rows = [ + row + for other_participant, row in reference_outer_by_participant.items() + if other_participant != participant + ] + selection_score = _nested_selection_metric_value( + train_outer_rows, selection_metric=nested_selection_metric, n_classes=n_classes, reference_rows=reference_train_outer_rows + ) candidates.append((selection_score, -ensemble_index, ensemble, train_outer_rows)) selected_score, _, selected_ensemble, train_outer_rows = max(candidates) @@ -2097,6 +2194,7 @@ def _nested_subject_selector( "artifact_ensemble": selector_name, "selected_artifact_ensemble": selected_ensemble, "selected_artifact_ensemble_sources": ";".join(selected_sources), + "reference_artifact_ensemble": reference_ensemble, "selection_metric": _nested_selection_metric_label(nested_selection_metric), "selection_metric_name": nested_selection_metric, "selection_metric_value": selected_score, @@ -2115,6 +2213,7 @@ def _nested_subject_selector( selected_row["artifact_ensemble_recipe_selection"] = "leave_subject_out" selected_row["selected_artifact_ensemble"] = selected_ensemble selected_row["selected_artifact_ensemble_sources"] = ";".join(selected_sources) + selected_row["reference_artifact_ensemble"] = reference_ensemble selected_row["selection_metric"] = _nested_selection_metric_label(nested_selection_metric) selected_row["selection_metric_name"] = nested_selection_metric selected_row["selection_metric_value"] = selected_score @@ -2135,6 +2234,7 @@ def _nested_subject_selector( summary["selected_artifact_ensemble_counts"] = _counts_text( str(row["selected_artifact_ensemble"]) for row in selection_rows ) + summary["reference_artifact_ensemble"] = reference_ensemble return selected_predictions, outer_rows, selection_rows, summary diff --git a/src/pymegdec/stimulus_latent_autoencoder.py b/src/pymegdec/stimulus_latent_autoencoder.py index b334ea7..436e6c5 100644 --- a/src/pymegdec/stimulus_latent_autoencoder.py +++ b/src/pymegdec/stimulus_latent_autoencoder.py @@ -126,6 +126,7 @@ class LatentAutoencoderConfig: # pylint: disable=too-many-instance-attributes validation_source_count: int = 2 validation_source_strategy: str = DEFAULT_LATENT_VALIDATION_SOURCE_STRATEGY validation_selection_metric: str = "balanced_accuracy" + validation_min_epochs: int = 0 patience: int = 12 refit_all_sources: bool = True final_epoch_multiplier: float = 1.0 @@ -294,6 +295,11 @@ def _apply_latent_training_preset(config: LatentAutoencoderConfig, preset: str) 0.03, ) final_min_epochs = max(int(config.final_min_epochs), 8) + # The initial smoke run peaked at epoch 3, which can be an unstable choice for + # a jointly trained encoder/reconstruction model. Anti-collapse presets wait + # for a few epochs before allowing source-validation model selection/early + # stopping; this remains source-only and never inspects the held-out subject. + validation_min_epochs = max(int(config.validation_min_epochs), 6) soft_worst_class_recall_weight = max(float(config.soft_worst_class_recall_weight), 0.01) margin_loss_weight = max(float(config.margin_loss_weight), 0.005) confidence_penalty_weight = max(float(config.confidence_penalty_weight), 0.002) @@ -314,6 +320,7 @@ def _apply_latent_training_preset(config: LatentAutoencoderConfig, preset: str) validation_source_count=validation_source_count, validation_prediction_balance_weight=validation_prediction_balance_weight, validation_selection_metric="balanced_top2_top3_rank_balance", + validation_min_epochs=validation_min_epochs, final_min_epochs=final_min_epochs, ) @@ -333,6 +340,7 @@ def _apply_latent_training_preset(config: LatentAutoencoderConfig, preset: str) validation_source_count=validation_source_count, validation_prediction_balance_weight=validation_prediction_balance_weight, validation_selection_metric="balanced_top2_top3_rank_balance", + validation_min_epochs=validation_min_epochs, final_min_epochs=final_min_epochs, # The neural classifier head is trained jointly with reconstruction. # In the smoke run, the latent space still carried useful rank signal @@ -370,6 +378,7 @@ def _apply_latent_training_preset(config: LatentAutoencoderConfig, preset: str) validation_source_count=validation_source_count, validation_prediction_balance_weight=validation_prediction_balance_weight, validation_selection_metric="balanced_top2_top3_rank_balance", + validation_min_epochs=validation_min_epochs, final_min_epochs=final_min_epochs, supervised_contrastive_weight=max(float(config.supervised_contrastive_weight), 0.02), supervised_contrastive_temperature=_min_positive_temperature( @@ -423,6 +432,7 @@ def _apply_latent_training_preset(config: LatentAutoencoderConfig, preset: str) validation_source_count=validation_source_count, validation_prediction_balance_weight=validation_prediction_balance_weight, validation_selection_metric="balanced_top2_top3_rank_balance", + validation_min_epochs=validation_min_epochs, final_min_epochs=final_min_epochs, supervised_contrastive_temperature=_min_positive_temperature(config.supervised_contrastive_temperature, 0.20), score_calibration="validation_selected_guarded", @@ -974,6 +984,8 @@ def _train_model( # pylint: disable=too-many-arguments,too-many-locals epochs_since_improvement = 0 history = [] rng = np.random.default_rng(config.seed) + configured_validation_min_epochs = max(1, int(config.validation_min_epochs)) + effective_validation_min_epochs = min(max(1, int(max_epochs)), configured_validation_min_epochs) for epoch in range(1, max_epochs + 1): model.train() @@ -1073,7 +1085,8 @@ def _train_model( # pylint: disable=too-many-arguments,too-many-locals validation_selection_score = float(validation_metrics["selection_score"]) - float( config.validation_prediction_balance_weight ) * validation_prediction_balance_penalty - if validation_selection_score > best_validation_selection_score + 1e-8: + can_select_epoch = epoch >= effective_validation_min_epochs + if can_select_epoch and validation_selection_score > best_validation_selection_score + 1e-8: best_validation_balanced = validation_balanced best_validation_selection_score = validation_selection_score best_validation_prediction_balance_penalty = validation_prediction_balance_penalty @@ -1081,9 +1094,12 @@ def _train_model( # pylint: disable=too-many-arguments,too-many-locals best_epoch = epoch best_state = copy.deepcopy(model.state_dict()) epochs_since_improvement = 0 - else: + elif can_select_epoch: epochs_since_improvement += 1 - if config.patience > 0 and epochs_since_improvement >= config.patience: + else: + # Do not spend patience before the model is eligible for selection. + epochs_since_improvement = 0 + if can_select_epoch and config.patience > 0 and epochs_since_improvement >= config.patience: break else: best_epoch = epoch @@ -1106,6 +1122,8 @@ def _train_model( # pylint: disable=too-many-arguments,too-many-locals model.load_state_dict(best_state) return model, { "best_epoch": int(best_epoch), + "validation_min_epochs": int(config.validation_min_epochs), + "effective_validation_min_epochs": int(effective_validation_min_epochs), "best_validation_balanced_accuracy": float(best_validation_balanced), "best_validation_selection_score": float(best_validation_selection_score), "best_validation_prediction_balance_penalty": float(best_validation_prediction_balance_penalty), @@ -2887,6 +2905,8 @@ def _outer_row( # pylint: disable=too-many-arguments "prediction_balance_penalty": _prediction_balance_penalty(predicted_labels, classes), "validation_source_count": config.validation_source_count, "validation_source_strategy": config.validation_source_strategy, + "validation_min_epochs": config.validation_min_epochs, + "effective_validation_min_epochs": fit_metadata.get("effective_validation_min_epochs", np.nan), "refit_all_sources": config.refit_all_sources, "final_epoch_multiplier": config.final_epoch_multiplier, "final_min_epochs": config.final_min_epochs, @@ -3019,6 +3039,7 @@ def _group_summary(outer_rows: list[dict], config: LatentAutoencoderConfig) -> l "validation_selection_metric": config.validation_selection_metric, "validation_source_count": config.validation_source_count, "validation_source_strategy": config.validation_source_strategy, + "validation_min_epochs": config.validation_min_epochs, "refit_all_sources": config.refit_all_sources, "final_epoch_multiplier": config.final_epoch_multiplier, "final_min_epochs": config.final_min_epochs, @@ -4227,6 +4248,16 @@ def _build_parser(prog: str | None = None) -> argparse.ArgumentParser: default="balanced_accuracy", help="Source-validation metric used for epoch selection and early stopping.", ) + parser.add_argument( + "--validation-min-epochs", + type=int, + default=0, + help=( + "Minimum epoch eligible for source-validation model selection and early stopping. " + "Use this to avoid selecting very early unstable epochs; anti-collapse presets set " + "a conservative floor automatically." + ), + ) parser.add_argument("--patience", type=int, default=12) parser.add_argument("--refit-all-sources", action=argparse.BooleanOptionalAction, default=True) parser.add_argument( @@ -4494,6 +4525,7 @@ def main(argv: Sequence[str] | None = None, prog: str | None = None) -> int: validation_source_count=args.validation_source_count, validation_source_strategy=args.validation_source_strategy, validation_selection_metric=args.validation_selection_metric, + validation_min_epochs=args.validation_min_epochs, patience=args.patience, refit_all_sources=bool(args.refit_all_sources), final_epoch_multiplier=args.final_epoch_multiplier, diff --git a/tests/test_stimulus_artifact_ensemble.py b/tests/test_stimulus_artifact_ensemble.py index c7726e9..7fa8a66 100644 --- a/tests/test_stimulus_artifact_ensemble.py +++ b/tests/test_stimulus_artifact_ensemble.py @@ -121,6 +121,26 @@ def _participant_three_score_row( return row +def _nested_weight_pair_sources() -> tuple[PredictionSource, PredictionSource]: + source_a = _source( + "source_a", + [ + _participant_scored_row(1, 0, 1, 0.10, 0.90), + _participant_scored_row(2, 0, 0, 0.90, 0.10), + _participant_scored_row(3, 0, 0, 0.90, 0.10), + ], + ) + source_b = _source( + "source_b", + [ + _participant_scored_row(1, 0, 0, 0.90, 0.10), + _participant_scored_row(2, 0, 1, 0.10, 0.90), + _participant_scored_row(3, 0, 1, 0.10, 0.90), + ], + ) + return source_a, source_b + + def _ranked_row(true_label: int, predicted_label: int, class_0_rank: float, class_1_rank: float) -> dict[str, str]: return { **_row(1, 1, true_label, predicted_label, true_label_rank=class_0_rank if true_label == 0 else class_1_rank), @@ -751,24 +771,53 @@ def test_nested_subject_selector_can_use_rank_aware_metric(self) -> None: ) self.assertEqual(nested_summary["selection_metric_name"], "balanced_top2_top3_rank") - def test_nested_weight_selector_uses_other_subjects_only(self) -> None: - source_a = _source( - "source_a", + def test_nested_subject_selector_can_use_paired_delta_lcb_metric(self) -> None: + compact = _source( + "compact", [ - _participant_scored_row(1, 0, 1, 0.10, 0.90), - _participant_scored_row(2, 0, 0, 0.90, 0.10), - _participant_scored_row(3, 0, 0, 0.90, 0.10), + _row(1, 1, 0, 0), + _row(2, 1, 0, 1), + _row(3, 1, 0, 1), ], ) - source_b = _source( - "source_b", + robust_delta = _source( + "robust_delta", [ - _participant_scored_row(1, 0, 0, 0.90, 0.10), - _participant_scored_row(2, 0, 1, 0.10, 0.90), - _participant_scored_row(3, 0, 1, 0.10, 0.90), + _row(1, 1, 0, 1), + _row(2, 1, 0, 0), + _row(3, 1, 0, 0), ], ) + artifacts = ensemble_prediction_sources( + [compact, robust_delta], + [ + ("compact", ("compact",)), + ("robust_delta", ("robust_delta",)), + ], + nested_selector_name="nested_subject_selector", + nested_selection_metric="balanced_accuracy_delta_lcb", + ) + + selections = { + row["test_participant"]: row["selected_artifact_ensemble"] + for row in artifacts["nested_selection"] + } + self.assertEqual(selections["1"], "robust_delta") + first_selection = next(row for row in artifacts["nested_selection"] if row["test_participant"] == "1") + self.assertEqual(first_selection["selection_metric"], "other_subjects_balanced_accuracy_delta_lcb") + self.assertEqual(first_selection["selection_metric_name"], "balanced_accuracy_delta_lcb") + self.assertEqual(first_selection["reference_artifact_ensemble"], "compact") + + nested_summary = next( + row for row in artifacts["group_summary"] if row["artifact_ensemble"] == "nested_subject_selector" + ) + self.assertEqual(nested_summary["selection_metric_name"], "balanced_accuracy_delta_lcb") + self.assertEqual(nested_summary["reference_artifact_ensemble"], "compact") + + def test_nested_weight_selector_uses_other_subjects_only(self) -> None: + source_a, source_b = _nested_weight_pair_sources() + artifacts = ensemble_prediction_sources( [source_a, source_b], [("source_a_b", ("source_a", "source_b"))], @@ -835,23 +884,33 @@ def test_nested_weight_selector_can_use_rank_aware_metric(self) -> None: participant_1 = next(row for row in nested_predictions if row["test_participant"] == "1") self.assertEqual(participant_1["predicted_label"], 1) - def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None: - source_a = _source( - "source_a", - [ - _participant_scored_row(1, 0, 1, 0.10, 0.90), - _participant_scored_row(2, 0, 0, 0.90, 0.10), - _participant_scored_row(3, 0, 0, 0.90, 0.10), - ], - ) - source_b = _source( - "source_b", - [ - _participant_scored_row(1, 0, 0, 0.90, 0.10), - _participant_scored_row(2, 0, 1, 0.10, 0.90), - _participant_scored_row(3, 0, 1, 0.10, 0.90), - ], + def test_nested_weight_selector_can_use_paired_delta_lcb_metric(self) -> None: + source_a, source_b = _nested_weight_pair_sources() + + artifacts = ensemble_prediction_sources( + [source_a, source_b], + [("source_a_b", ("source_a", "source_b"))], + aggregation_mode="mean_score", + nested_weight_selector_name="nested_weight_selector", + nested_weight_selector_ensemble="source_a_b", + nested_weight_grid_step=1.0, + nested_selection_metric="balanced_accuracy_delta_lcb", ) + + selections = { + row["test_participant"]: row["selected_source_weights"] + for row in artifacts["nested_weight_selection"] + } + self.assertEqual(selections["1"], "source_a:1;source_b:0") + first_selection = next(row for row in artifacts["nested_weight_selection"] if row["test_participant"] == "1") + self.assertEqual(first_selection["selection_metric"], "other_subjects_balanced_accuracy_delta_lcb") + self.assertEqual(first_selection["selection_metric_name"], "balanced_accuracy_delta_lcb") + self.assertIn("reference_source_weights", first_selection) + nested_summary = next(row for row in artifacts["group_summary"] if row["artifact_ensemble"] == "nested_weight_selector") + self.assertEqual(nested_summary["selection_metric_name"], "balanced_accuracy_delta_lcb") + + def test_nested_weight_selector_can_expand_all_multi_source_ensembles(self) -> None: + source_a, source_b = _nested_weight_pair_sources() source_c = _source( "source_c", [ diff --git a/tests/test_stimulus_latent_autoencoder_controls.py b/tests/test_stimulus_latent_autoencoder_controls.py index eda1f4e..3bece1e 100644 --- a/tests/test_stimulus_latent_autoencoder_controls.py +++ b/tests/test_stimulus_latent_autoencoder_controls.py @@ -392,6 +392,7 @@ def test_latent_model_maps_sparse_participant_ids_for_subject_adversary_when_tor hidden_dim=8, latent_dim=5, dropout=0.0, + input_dropout=0.0, ) targets = model.subject_targets(torch.tensor([8, 2, 4, 8])) diff --git a/tests/test_stimulus_latent_autoencoder_presets.py b/tests/test_stimulus_latent_autoencoder_presets.py index 3f04471..65cd014 100644 --- a/tests/test_stimulus_latent_autoencoder_presets.py +++ b/tests/test_stimulus_latent_autoencoder_presets.py @@ -25,6 +25,7 @@ def test_anti_collapse_train_preset_enables_source_only_regularizers(): assert config.validation_source_count >= 4 assert config.validation_prediction_balance_weight >= 0.03 assert config.validation_selection_metric == "balanced_top2_top3_rank_balance" + assert config.validation_min_epochs >= 6 assert config.final_min_epochs >= 8 assert config.score_calibration == "none" assert config.prediction_postprocessing == "none" @@ -53,6 +54,7 @@ def test_anti_collapse_refit_preset_adds_source_only_latent_logistic_probe(): assert config.balanced_batch_sampling is True assert config.validation_source_count >= 4 assert config.validation_selection_metric == "balanced_top2_top3_rank_balance" + assert config.validation_min_epochs >= 6 assert config.latent_head_refit == "validation_selected_source_logistic" assert config.latent_head_refit_selection_metric == "balanced_top2_top3_rank_balance" assert config.score_calibration == "validation_selected_guarded" @@ -112,6 +114,7 @@ def test_anti_collapse_contrastive_preset_adds_source_only_latent_clustering(): assert config.soft_macro_recall_weight >= 0.02 assert config.validation_source_count >= 4 assert config.validation_selection_metric == "balanced_top2_top3_rank_balance" + assert config.validation_min_epochs >= 6 assert config.final_min_epochs >= 8 assert config.supervised_contrastive_weight >= 0.02 assert config.supervised_contrastive_temperature <= 0.20 @@ -141,6 +144,7 @@ def test_none_preset_preserves_explicit_config_values(): original = LatentAutoencoderConfig( label_smoothing=0.2, validation_source_count=6, + validation_min_epochs=5, input_dropout=0.07, ) diff --git a/tests/test_stimulus_source_inner_stacking_postprocessing.py b/tests/test_stimulus_source_inner_stacking_postprocessing.py index 96ecbd3..4211545 100644 --- a/tests/test_stimulus_source_inner_stacking_postprocessing.py +++ b/tests/test_stimulus_source_inner_stacking_postprocessing.py @@ -15,7 +15,7 @@ def _stack_config(**kwargs): ) -def test_stacker_postprocessing_none_returns_argmax_predictions(): +def _postprocessing_inputs(): classes = np.asarray([1, 2, 3]) scores = np.asarray( [ @@ -24,11 +24,17 @@ def test_stacker_postprocessing_none_returns_argmax_predictions(): [2.0, 0.0, 4.0], ] ) + source_labels = np.repeat(classes, 4) + return scores, classes, source_labels + + +def test_stacker_postprocessing_none_returns_argmax_predictions(): + scores, classes, source_labels = _postprocessing_inputs() predictions, metadata = _postprocess_stacked_predictions( scores, classes, - np.repeat(classes, 4), + source_labels, _stack_config(), ) @@ -37,19 +43,12 @@ def test_stacker_postprocessing_none_returns_argmax_predictions(): def test_stacker_postprocessing_can_reuse_source_prior_balanced_assignment(): - classes = np.asarray([1, 2, 3]) - scores = np.asarray( - [ - [3.0, 2.0, 0.0], - [2.9, 2.8, 0.0], - [2.0, 0.0, 4.0], - ] - ) + scores, classes, source_labels = _postprocessing_inputs() predictions, metadata = _postprocess_stacked_predictions( scores, classes, - np.repeat(classes, 4), + source_labels, _stack_config(prediction_postprocessing="source_prior_balanced_assignment"), )