@@ -138,3 +138,86 @@ def test_psdcone():
138138
139139 assert np .abs (np .trace (sol ) - 1.0 ) < 1e-6
140140 assert (np .linalg .eigvals (sol ) >= - 1e-6 ).all ()
141+
142+
143+ def test_solve_only_batch ():
144+ """Test solve_only_batch with Ps=None (default)."""
145+ np .random .seed (0 )
146+ m = 20
147+ n = 10
148+ batch_size = 5
149+
150+ As , bs , cs , cone_dicts = [], [], [], []
151+ for _ in range (batch_size ):
152+ A , b , c , cone_dims = utils .least_squares_eq_scs_data (m , n )
153+ As .append (A )
154+ bs .append (b )
155+ cs .append (c )
156+ cone_dicts .append (cone_dims )
157+
158+ # Test serial path (n_jobs_forward=1) with Ps=None
159+ xs , ys , ss = cone_prog .solve_only_batch (
160+ As , bs , cs , cone_dicts , n_jobs_forward = 1 , solve_method = 'Clarabel' )
161+ assert len (xs ) == batch_size
162+ assert len (ys ) == batch_size
163+ assert len (ss ) == batch_size
164+
165+ # Verify solutions satisfy optimality conditions
166+ for i in range (batch_size ):
167+ np .testing .assert_allclose (As [i ] @ xs [i ] + ss [i ], bs [i ], atol = 1e-7 )
168+
169+ # Test parallel path (n_jobs_forward=-1) with Ps=None
170+ xs_par , ys_par , ss_par = cone_prog .solve_only_batch (
171+ As , bs , cs , cone_dicts , n_jobs_forward = - 1 , solve_method = 'Clarabel' )
172+ assert len (xs_par ) == batch_size
173+
174+ # Verify parallel solutions also satisfy optimality conditions
175+ for i in range (batch_size ):
176+ np .testing .assert_allclose (As [i ] @ xs_par [i ] + ss_par [i ], bs [i ], atol = 1e-7 )
177+
178+
179+ def test_derivative_batch_parallel ():
180+ """Test that parallel D_batch works correctly."""
181+ np .random .seed (0 )
182+ m = 20
183+ n = 10
184+ batch_size = 5
185+
186+ As , bs , cs , cone_dicts = [], [], [], []
187+ for _ in range (batch_size ):
188+ A , b , c , cone_dims = utils .least_squares_eq_scs_data (m , n )
189+ As .append (A )
190+ bs .append (b )
191+ cs .append (c )
192+ cone_dicts .append (cone_dims )
193+
194+ # Solve with serial backward pass
195+ xs_ser , ys_ser , ss_ser , D_ser , DT_ser = cone_prog .solve_and_derivative_batch (
196+ As , bs , cs , cone_dicts , n_jobs_forward = 1 , n_jobs_backward = 1 , solve_method = 'Clarabel' )
197+
198+ # Solve with parallel backward pass
199+ xs_par , ys_par , ss_par , D_par , DT_par = cone_prog .solve_and_derivative_batch (
200+ As , bs , cs , cone_dicts , n_jobs_forward = - 1 , n_jobs_backward = - 1 , solve_method = 'Clarabel' )
201+
202+ # Create perturbations
203+ dAs = [utils .get_random_like (A , lambda n : np .random .normal (0 , 1e-6 , size = n )) for A in As ]
204+ dbs = [np .random .normal (0 , 1e-6 , size = b .size ) for b in bs ]
205+ dcs = [np .random .normal (0 , 1e-6 , size = c .size ) for c in cs ]
206+
207+ # Test D_batch (forward derivative)
208+ dxs_ser , dys_ser , dss_ser = D_ser (dAs , dbs , dcs )
209+ dxs_par , dys_par , dss_par = D_par (dAs , dbs , dcs )
210+
211+ for i in range (batch_size ):
212+ np .testing .assert_allclose (dxs_ser [i ], dxs_par [i ], rtol = 1e-5 , atol = 1e-10 )
213+ np .testing .assert_allclose (dys_ser [i ], dys_par [i ], rtol = 1e-5 , atol = 1e-10 )
214+ np .testing .assert_allclose (dss_ser [i ], dss_par [i ], rtol = 1e-5 , atol = 1e-10 )
215+
216+ # Test DT_batch (adjoint derivative)
217+ dAs_ser , dbs_ser , dcs_ser = DT_ser (xs_ser , ys_ser , ss_ser )
218+ dAs_par , dbs_par , dcs_par = DT_par (xs_par , ys_par , ss_par )
219+
220+ for i in range (batch_size ):
221+ np .testing .assert_allclose (dAs_ser [i ].todense (), dAs_par [i ].todense (), rtol = 1e-5 , atol = 1e-10 )
222+ np .testing .assert_allclose (dbs_ser [i ], dbs_par [i ], rtol = 1e-5 , atol = 1e-10 )
223+ np .testing .assert_allclose (dcs_ser [i ], dcs_par [i ], rtol = 1e-5 , atol = 1e-10 )
0 commit comments