From 89d2fb0c844df2efd75e1457a4e88ecc99a618db Mon Sep 17 00:00:00 2001 From: Antoine Richard Date: Tue, 2 Jun 2026 14:39:16 +0200 Subject: [PATCH 1/2] Fix pre-trained policy action grad leak Run low-level policy inference without autograd and detach the output before copying it into the persistent action buffer. This keeps Warp-facing action tensors usable when environments are stepped outside inference_mode. --- .../antoine-fix-pretrained-policy-action-grad.rst | 5 +++++ .../contrib/navigation/mdp/pre_trained_policy_action.py | 3 ++- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst diff --git a/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst b/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst new file mode 100644 index 000000000000..e5ddee1be1ff --- /dev/null +++ b/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst @@ -0,0 +1,5 @@ +Fixed +^^^^^ + +* Fixed :class:`~isaaclab_tasks.contrib.navigation.mdp.PreTrainedPolicyAction` + low-level policy inference to keep generated actions detached from autograd. diff --git a/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py b/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py index 4857d63711e1..19d95bc3daf4 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py +++ b/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py @@ -96,7 +96,8 @@ def process_actions(self, actions: torch.Tensor): def apply_actions(self): if self._counter % self.cfg.low_level_decimation == 0: low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") - self.low_level_actions[:] = self.policy(low_level_obs) + with torch.no_grad(): + self.low_level_actions[:] = self.policy(low_level_obs).detach() self._low_level_action_term.process_actions(self.low_level_actions) self._counter = 0 self._low_level_action_term.apply_actions() From 62cb4b34c701e85c0ae53972f9299296cfeec79e Mon Sep 17 00:00:00 2001 From: Antoine Richard Date: Thu, 4 Jun 2026 09:40:10 +0200 Subject: [PATCH 2/2] Fix startup benchmark inference mode --- scripts/benchmarks/benchmark_startup.py | 3 ++- .../antoine-fix-pretrained-policy-action-grad.rst | 5 ----- .../contrib/navigation/mdp/pre_trained_policy_action.py | 3 +-- 3 files changed, 3 insertions(+), 8 deletions(-) delete mode 100644 source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst diff --git a/scripts/benchmarks/benchmark_startup.py b/scripts/benchmarks/benchmark_startup.py index 93d92257ca11..aaf5df6e6075 100644 --- a/scripts/benchmarks/benchmark_startup.py +++ b/scripts/benchmarks/benchmark_startup.py @@ -233,7 +233,8 @@ def main( first_step_time_begin = time.perf_counter_ns() first_step_profile.enable() try: - env.step(actions) + with torch.inference_mode(): + env.step(actions) finally: first_step_profile.disable() diff --git a/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst b/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst deleted file mode 100644 index e5ddee1be1ff..000000000000 --- a/source/isaaclab_tasks/changelog.d/antoine-fix-pretrained-policy-action-grad.rst +++ /dev/null @@ -1,5 +0,0 @@ -Fixed -^^^^^ - -* Fixed :class:`~isaaclab_tasks.contrib.navigation.mdp.PreTrainedPolicyAction` - low-level policy inference to keep generated actions detached from autograd. diff --git a/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py b/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py index 19d95bc3daf4..4857d63711e1 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py +++ b/source/isaaclab_tasks/isaaclab_tasks/contrib/navigation/mdp/pre_trained_policy_action.py @@ -96,8 +96,7 @@ def process_actions(self, actions: torch.Tensor): def apply_actions(self): if self._counter % self.cfg.low_level_decimation == 0: low_level_obs = self._low_level_obs_manager.compute_group("ll_policy") - with torch.no_grad(): - self.low_level_actions[:] = self.policy(low_level_obs).detach() + self.low_level_actions[:] = self.policy(low_level_obs) self._low_level_action_term.process_actions(self.low_level_actions) self._counter = 0 self._low_level_action_term.apply_actions()