Performance: vmap over keypoints in smoother. add profiling#77
Closed
ksikka wants to merge 2 commits into
Closed
Conversation
Contributor
Author
|
Closing to use my own fork branch instead. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #76
Performance impact
On short fly-anipose vid, 30 keypoints, ensemble of 3, 12 core CPU:
● 3.2s → 0.43s on the smoother pass. 7.5× speedup on the bottleneck, total runtime down from 3.67s to 0.59s (~6× end-to-end).
Example profiling output
Diff review:
eks/core.pyfinal smoother passf_fnOriginal used a default-argument capture to avoid Python's loop closure bug:
In
_smooth_one,A_kis a function parameter, not a loop variable — there is no closure-over-loop-variable issue, so the simpler form is correct and equivalent:s_finalOriginal converted to a Python float before passing:
New code passes a JAX scalar
s_k. Insideparams_nlgssm_for_keypointit is immediately wrapped:The type difference has no effect.
h_fn_kSame reasoning as
f_fn— the default-argument pattern was only needed to guard against the loop closure bug, which does not apply inside_smooth_one.Everything else
m0_k,S0_k,Q_k,R_kare direct per-keypoint slices in both versions. No difference.Conclusion: no logic changes.