diff --git a/source/isaaclab_contrib/isaaclab_contrib/rl/rlinf/extension.py b/source/isaaclab_contrib/isaaclab_contrib/rl/rlinf/extension.py index d38e577811c0..5cf851ce63eb 100644 --- a/source/isaaclab_contrib/isaaclab_contrib/rl/rlinf/extension.py +++ b/source/isaaclab_contrib/isaaclab_contrib/rl/rlinf/extension.py @@ -276,17 +276,38 @@ def _register_gr00t_converters(cfg: dict) -> None: Args: cfg: The IsaacLab-specific configuration dictionary (``env.train.isaaclab``). """ - from rlinf.models.embodiment.gr00t import simulation_io - obs_converter_type = cfg.get("obs_converter_type", "dex3") - if obs_converter_type not in simulation_io.OBS_CONVERSION: - simulation_io.OBS_CONVERSION[obs_converter_type] = _convert_isaaclab_obs_to_gr00t - logger.info(f"Registered obs converter: {obs_converter_type}") + simulation_modules = [] + try: + from rlinf.models.embodiment.gr00t import simulation_io as gr00t_simulation_io + + simulation_modules.append(("gr00t", gr00t_simulation_io)) + except Exception as exc: + logger.debug(f"Could not import GR00T N1.5 simulation_io: {exc}") + + try: + from rlinf.models.embodiment.gr00t_n1d6 import simulation_io as gr00t_n1d6_simulation_io + + simulation_modules.append(("gr00t_n1d6", gr00t_n1d6_simulation_io)) + except Exception as exc: + logger.debug(f"Could not import GR00T N1.6 simulation_io: {exc}") + + try: + from rlinf.models.embodiment.gr00t_n1d7 import simulation_io as gr00t_n1d7_simulation_io + + simulation_modules.append(("gr00t_n1d7", gr00t_n1d7_simulation_io)) + except Exception as exc: + logger.debug(f"Could not import GR00T N1.7 simulation_io: {exc}") - if obs_converter_type not in simulation_io.ACTION_CONVERSION: - simulation_io.ACTION_CONVERSION[obs_converter_type] = _convert_gr00t_to_isaaclab_action - logger.info(f"Registered action converter: {obs_converter_type}") + for module_name, simulation_io in simulation_modules: + if obs_converter_type not in simulation_io.OBS_CONVERSION: + simulation_io.OBS_CONVERSION[obs_converter_type] = _convert_isaaclab_obs_to_gr00t + logger.info(f"Registered {module_name} obs converter: {obs_converter_type}") + + if obs_converter_type not in simulation_io.ACTION_CONVERSION: + simulation_io.ACTION_CONVERSION[obs_converter_type] = _convert_gr00t_to_isaaclab_action + logger.info(f"Registered {module_name} action converter: {obs_converter_type}") def _convert_isaaclab_obs_to_gr00t(env_obs: dict) -> dict: @@ -338,10 +359,18 @@ def _convert_isaaclab_obs_to_gr00t(env_obs: dict) -> dict: gr00t_key = spec.get("gr00t_key") slice_range = spec.get("slice", [0, states_np.shape[-1]]) if gr00t_key: - groot_obs[gr00t_key] = states_np[:, :, slice_range[0] : slice_range[1]] - - # Pass through task descriptions - groot_obs["annotation.human.action.task_description"] = env_obs.get("task_descriptions", []) + state_part = states_np[:, :, slice_range[0] : slice_range[1]] + if "scale" in spec: + state_part = state_part * np.asarray(spec["scale"], dtype=state_part.dtype) + if "offset" in spec: + state_part = state_part + np.asarray(spec["offset"], dtype=state_part.dtype) + groot_obs[gr00t_key] = state_part + + # Pass through task descriptions. SO-101 N1.6 checkpoints use + # annotation.human.task_description, while older LIBERO-style configs use + # annotation.human.action.task_description. + language_key = gr00t_mapping.get("language_key", "annotation.human.action.task_description") + groot_obs[language_key] = env_obs.get("task_descriptions", []) return groot_obs @@ -367,8 +396,24 @@ def _convert_gr00t_to_isaaclab_action(action_chunk: dict, chunk_size: int = 1) - prefix_pad = action_mapping.get("prefix_pad", 0) suffix_pad = action_mapping.get("suffix_pad", 0) - # Concatenate all action parts - action_parts = [v[:, :chunk_size, :] for v in action_chunk.values()] + # Concatenate action parts in the configured order when provided. + action_keys = action_mapping.get("gr00t_action_keys") or list(action_chunk.keys()) + action_parts = [] + for key in action_keys: + if key in action_chunk: + action_parts.append(action_chunk[key][:, :chunk_size, :]) + continue + short_key = key.split(".", 1)[1] if key.startswith("action.") else f"action.{key}" + if short_key in action_chunk: + action_parts.append(action_chunk[short_key][:, :chunk_size, :]) + continue + logger.warning( + f"GR00T action key '{key}' (also tried '{short_key}') not found in action chunk " + f"(available: {list(action_chunk)}); this entry will be skipped and the action tensor " + "will be narrower than expected." + ) + if not action_parts: + raise KeyError(f"No configured GR00T action keys found in action chunk: keys={list(action_chunk)}") action_concat = np.concatenate(action_parts, axis=-1) # Apply padding @@ -379,6 +424,10 @@ def _convert_gr00t_to_isaaclab_action(action_chunk: dict, chunk_size: int = 1) - mode="constant", constant_values=0, ) + if "scale" in action_mapping: + action_concat = action_concat * np.asarray(action_mapping["scale"], dtype=action_concat.dtype) + if "offset" in action_mapping: + action_concat = action_concat + np.asarray(action_mapping["offset"], dtype=action_concat.dtype) return action_concat