Skip to content

Vectorize triangulation and Jacobian projection for nonlinear EKS#79

Merged
themattinthehatt merged 2 commits intopaninski-lab:mainfrom
LennyAharon:perf/vectorize-triangulation-and-jacobians
Apr 20, 2026
Merged

Vectorize triangulation and Jacobian projection for nonlinear EKS#79
themattinthehatt merged 2 commits intopaninski-lab:mainfrom
LennyAharon:perf/vectorize-triangulation-and-jacobians

Conversation

@LennyAharon
Copy link
Copy Markdown
Contributor

Summary

  • triangulate_3d_models: replace nested for m, for k loop with joblib.Parallel(prefer='threads') over all M×K pairs — 72s → 7s for K=28 keypoints
  • project_3d_covariance_to_2d: replace per-frame jax.jacfwd loop with vmap(jax.jacfwd(h_cam)) and batched numpy covariance projection (J @ V @ J^T) — ~13,659s → 7s for T=30k frames (was firing 5M individual JAX dispatches)

Benchmark (30k frames, 28 kps, 6 views, nonlinear, fixed smooth_param)

Phase PR#78 alone + this PR
Triangulation 72s 7s
Reprojection + packaging ~13,659s 7s
Final smoother 12s 16s
Total ~3.8 hrs 34s
fps ~2 873

Root cause

Both bottlenecks were Python loops dispatching individual JAX calls over large dimensions:

  • Triangulation: for m in range(M): for k in range(K) — replaced with joblib.Parallel
  • Jacobians: for t in range(T): jax.jacfwd(h_cam)(ms_k[t]) — 30k individual dispatches replaced with vmap(jax.jacfwd(h_cam))(ms_k)

- triangulate_3d_models: replace nested for-loop with joblib.Parallel(prefer='threads')
  over all M*K (model, keypoint) pairs; 72s -> 7s for K=28 keypoints

- project_3d_covariance_to_2d: replace per-frame jax.jacfwd loop with
  vmap(jax.jacfwd(h_cam)) and batched numpy covariance projection (J @ V @ J^T);
  ~13,659s -> 7s for T=30k frames (was firing 5M individual JAX dispatches)

Benchmark (30k frames, 28 kps, 6 views, nonlinear, smooth_param=10000):
  PR#78 alone:       ~3.8 hrs  (reprojection dominated by per-frame dispatch)
  + these changes:     34s     (873 fps)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@themattinthehatt themattinthehatt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! a minor request to update the test tests/test_multicam_smoother.py::test_project_3d_covariance_to_2d_matches_fd_linearization (at the very bottom of the file)

after this line:

var_x, var_y = project_3d_covariance_to_2d(ms_k, Vs_k, h_cam, inflated)

add this:

assert var_x.shape == (T,)
assert var_y.shape == (T,)

and then in the loop below, check all instead of a subset of timestamps:

for t in range(T):

…riance_to_2d

Per reviewer request:
- Assert var_x.shape == (T,) and var_y.shape == (T,)
- Check all T timesteps instead of a subset [0, 3, 5, 7]

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@themattinthehatt themattinthehatt merged commit 101cb40 into paninski-lab:main Apr 20, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants