Skip to content

Performance: vmap over keypoints in smoother. add profiling#77

Closed
ksikka wants to merge 2 commits into
mainfrom
eks-performance
Closed

Performance: vmap over keypoints in smoother. add profiling#77
ksikka wants to merge 2 commits into
mainfrom
eks-performance

Conversation

@ksikka
Copy link
Copy Markdown
Contributor

@ksikka ksikka commented Apr 10, 2026

Fixes #76

Performance impact

On short fly-anipose vid, 30 keypoints, ensemble of 3, 12 core CPU:

● 3.2s → 0.43s on the smoother pass. 7.5× speedup on the bottleneck, total runtime down from 3.67s to 0.59s (~6× end-to-end).

Example profiling output

WARNING:2026-04-10 15:45:13,488:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Running EKS on 18 input files across 6 cameras
Saving to: /media/ksikka/data/untar_datasets/fly_anipose_subset/models/pleasant_ensemble/video_preds
[profile] format_data: 0.110s
[profile] input_dfs_to_markerArray: 0.010s
[profile] ensemble + centering: 0.094s
[profile] variance inflation (skipped): 0.000s
[EKS] Linear path: PCA subspace + linear emissions
[profile] PCA: 0.006s
[profile] KF init (PCA): 0.013s
[profile] build observations (linear): 0.001s
Correlated keypoint blocks: [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29]]
[profile]    build_R: 0.002s
[profile]    final smoother pass (30 keypoints): 0.424s
[profile] run_kalman_smoother (total): 0.428s
[profile] reprojection + packaging: 0.040s
[profile] ensemble_kalman_smoother_multicam total: 0.585s

Diff review: eks/core.py final smoother pass

f_fn

Original used a default-argument capture to avoid Python's loop closure bug:

f_fn = (lambda x, A=A_k: A @ x)                                                                                                                                                                     

In _smooth_one, A_k is a function parameter, not a loop variable — there is no closure-over-loop-variable issue, so the simpler form is correct and equivalent:

f_fn = lambda x: A_k @ x                                                                                                                                                                            

s_final

Original converted to a Python float before passing:

s_final = float(s_finals[k])

New code passes a JAX scalar s_k. Inside params_nlgssm_for_keypoint it is immediately wrapped:

dynamics_covariance=jnp.asarray(s) * jnp.asarray(Q)

The type difference has no effect.

h_fn_k

Same reasoning as f_fn — the default-argument pattern was only needed to guard against the loop closure bug, which does not apply inside _smooth_one.

Everything else

m0_k, S0_k, Q_k, R_k are direct per-keypoint slices in both versions. No difference.

Conclusion: no logic changes.

@ksikka ksikka closed this Apr 10, 2026
@ksikka
Copy link
Copy Markdown
Contributor Author

ksikka commented Apr 10, 2026

Closing to use my own fork branch instead.

@ksikka ksikka deleted the eks-performance branch April 10, 2026 20:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EKS performance optimization by using vmap over keypoints

1 participant