@@ -652,3 +652,25 @@ def test_singular_array_init_or_prior(estimator, build_dataset, w0):
652652 with pytest .raises (LinAlgError ) as raised_err :
653653 model .fit (input_data , labels )
654654 assert str (raised_err .value ) == msg
655+
656+
657+ @pytest .mark .integration
658+ @pytest .mark .parametrize ('estimator, build_dataset' , metric_learners ,
659+ ids = ids_metric_learners )
660+ def test_deterministic_initialization (estimator , build_dataset ):
661+ """Test that estimators that have a prior or an init are deterministic
662+ when it is set to to random and when the random_state is fixed."""
663+ input_data , labels , _ , X = build_dataset ()
664+ model = clone (estimator )
665+ if hasattr (estimator , 'init' ):
666+ model .set_params (init = 'random' )
667+ if hasattr (estimator , 'prior' ):
668+ model .set_params (prior = 'random' )
669+ model1 = clone (model )
670+ set_random_state (model1 , 42 )
671+ model1 = model1 .fit (input_data , labels )
672+ model2 = clone (model )
673+ set_random_state (model2 , 42 )
674+ model2 = model2 .fit (input_data , labels )
675+ np .testing .assert_allclose (model1 .get_mahalanobis_matrix (),
676+ model2 .get_mahalanobis_matrix ())
0 commit comments