Skip to content

Eks performance#78

Merged
themattinthehatt merged 2 commits into
paninski-lab:mainfrom
ksikka:eks-performance
Apr 10, 2026
Merged

Eks performance#78
themattinthehatt merged 2 commits into
paninski-lab:mainfrom
ksikka:eks-performance

Conversation

@ksikka
Copy link
Copy Markdown
Contributor

@ksikka ksikka commented Apr 10, 2026

Fixes #76

Performance impact

On short fly-anipose vid, 30 keypoints, ensemble of 3, 12 core CPU:

● 3.2s → 0.43s on the smoother pass. 7.5× speedup on the bottleneck, total runtime down from 3.67s to 0.59s (~6× end-to-end).

Example profiling output

WARNING:2026-04-10 15:45:13,488:jax._src.xla_bridge:969: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Running EKS on 18 input files across 6 cameras
Saving to: /media/ksikka/data/untar_datasets/fly_anipose_subset/models/pleasant_ensemble/video_preds
[profile] format_data: 0.110s
[profile] input_dfs_to_markerArray: 0.010s
[profile] ensemble + centering: 0.094s
[profile] variance inflation (skipped): 0.000s
[EKS] Linear path: PCA subspace + linear emissions
[profile] PCA: 0.006s
[profile] KF init (PCA): 0.013s
[profile] build observations (linear): 0.001s
Correlated keypoint blocks: [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20], [21], [22], [23], [24], [25], [26], [27], [28], [29]]
[profile]    build_R: 0.002s
[profile]    final smoother pass (30 keypoints): 0.424s
[profile] run_kalman_smoother (total): 0.428s
[profile] reprojection + packaging: 0.040s
[profile] ensemble_kalman_smoother_multicam total: 0.585s

Diff review: eks/core.py final smoother pass

f_fn

Original used a default-argument capture to avoid Python's loop closure bug:

f_fn = (lambda x, A=A_k: A @ x)                                                                                                                                                                     

In _smooth_one, A_k is a function parameter, not a loop variable — there is no closure-over-loop-variable issue, so the simpler form is correct and equivalent:

f_fn = lambda x: A_k @ x                                                                                                                                                                            

s_final

Original converted to a Python float before passing:

s_final = float(s_finals[k])

New code passes a JAX scalar s_k. Inside params_nlgssm_for_keypoint it is immediately wrapped:

dynamics_covariance=jnp.asarray(s) * jnp.asarray(Q)

The type difference has no effect.

h_fn_k

Same reasoning as f_fn — the default-argument pattern was only needed to guard against the loop closure bug, which does not apply inside _smooth_one.

Everything else

m0_k, S0_k, Q_k, R_k are direct per-keypoint slices in both versions. No difference.

Conclusion: no logic changes.

@themattinthehatt themattinthehatt merged commit f83656a into paninski-lab:main Apr 10, 2026
4 checks passed
@ksikka
Copy link
Copy Markdown
Contributor Author

ksikka commented Apr 10, 2026

BTW: I compared the PDF plots before and after and they look identical

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

EKS performance optimization by using vmap over keypoints

2 participants