From 6f4807188c5ac36fc309f64b0705773d616565c8 Mon Sep 17 00:00:00 2001 From: liangkaiz Date: Fri, 26 Jun 2026 01:15:15 -0700 Subject: [PATCH 1/3] Add umi --- .../data/vfm/action/datasets/__init__.py | 2 + .../datasets/stats/umi_lerobot_stats.json | 4 + .../action/datasets/umi_lerobot_dataset.py | 134 ++++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 cosmos_framework/data/vfm/action/datasets/stats/umi_lerobot_stats.json create mode 100644 cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py diff --git a/cosmos_framework/data/vfm/action/datasets/__init__.py b/cosmos_framework/data/vfm/action/datasets/__init__.py index 0b01e6b3..64b1b278 100644 --- a/cosmos_framework/data/vfm/action/datasets/__init__.py +++ b/cosmos_framework/data/vfm/action/datasets/__init__.py @@ -13,6 +13,7 @@ from cosmos_framework.data.vfm.action.datasets.bridge_orig_lerobot_dataset import BridgeOrigLeRobotDataset from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset from cosmos_framework.data.vfm.action.datasets.robomind_franka_dataset import RoboMINDFrankaDataset +from cosmos_framework.data.vfm.action.datasets.umi_lerobot_dataset import UMILeRobotDataset __all__ = [ "ActionBaseDataset", @@ -20,4 +21,5 @@ "BridgeOrigLeRobotDataset", "DROIDLeRobotDataset", "RoboMINDFrankaDataset", + "UMILeRobotDataset", ] diff --git a/cosmos_framework/data/vfm/action/datasets/stats/umi_lerobot_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/umi_lerobot_stats.json new file mode 100644 index 00000000..44b9c8ce --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/umi_lerobot_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.035246, -0.037122, -0.035762, 0.984050, -0.108706, -0.065188, -0.110908, 0.982889, -0.085106, 0.000000, -0.027468, -0.036971, -0.029396, 0.993522, -0.076207, -0.061227, -0.074231, 0.992173, -0.076929, 0.000000], + "q99": [ 0.038095, 0.033082, 0.032447, 1.000000, 0.110573, 0.068087, 0.108972, 1.000000, 0.089037, 0.096749, 0.033588, 0.032840, 0.027391, 1.000000, 0.073697, 0.057541, 0.075834, 1.000000, 0.086257, 0.085000] +} diff --git a/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py new file mode 100644 index 00000000..52e3fc71 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""UMI LeRobot dataset.""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import ( + build_abs_pose_from_components, + pose_abs_to_rel, +) + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["wrist_view"] + +# Default image key for wrist camera in UMI LeRobot datasets. +_IMAGE_FEATURE = "observation.images.image" +_STATE_FEATURE = "observation.state" +_ACTION_FEATURE = "action" + +_NORMALIZER_PATH = Path(__file__).parent / "stats/umi_lerobot_stats.json" + + +class UMILeRobotDataset(ActionBaseDataset): + """UMI dataset converted to LeRobot format with 10D cartesian actions: + + [pos_delta(3), rot6d_delta(6), gripper_width(1)] + + Expects a LeRobot v2 dataset with: + * ``observation.images.image``: wrist-mounted RGB video (configurable via + ``image_key``). + * ``observation.state``: 7D EEF state ``[pos(3), rot_axisangle(3), + gripper_width(1)]``. + * ``action``: 7D commanded state in the same format. + + Absolute axis-angle EEF poses are converted to backward-framewise rot6d + relative poses, and the gripper width is taken from the commanded action. + """ + + def __init__( + self, + root: str, + fps: float = 10.0, + chunk_length: int = 16, + mode: str = "joint", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 2e-4, + viewpoint: Viewpoint = "wrist_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + image_key: str = _IMAGE_FEATURE, + ) -> None: + if viewpoint != "wrist_view": + raise NotImplementedError("This UMI dataset only supports wrist_view.") + super().__init__( + root=root, + domain_name="umi_lerobot", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + self._image_key = image_key + + @property + def action_dim(self) -> int: + return 10 + + def _action_spec(self) -> ActionSpec: + return build_action_spec(Pos(), Rot("rot6d"), Gripper()) + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + row_idx = idx * self._sample_stride + observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] + action_rows = observation_rows[: self._chunk_length] + + episode = self._episodes[int(observation_rows[0]["episode_index"])] + video = self._load_video(episode, observation_rows) + raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) + + return self._build_result( + mode=mode, + video=video, + action=raw_action, + ai_caption=ai_caption, + initial_pose=initial_pose, + ) + + def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]]) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + return decode_video_frames( + self._video_path(episode, self._image_key), + [float(episode.get(f"videos/{self._image_key}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + + def _build_raw_action( + self, + observation_rows: list[dict[str, Any]], + action_rows: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor]: + # State is 7D: [pos(3), rot_axisangle(3), gripper_width(1)] + state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) + poses_abs = build_abs_pose_from_components(state[:, 0:3], state[:, 3:6], "axisangle") + + initial_pose = torch.from_numpy(poses_abs[0].copy()).float() + poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) + + # Gripper width from commanded action (7th column) + gripper = np.asarray([row[_ACTION_FEATURE][6] for row in action_rows], dtype=np.float32).reshape(-1, 1) + action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) + return torch.from_numpy(action).float(), initial_pose From d4475403ca2265495f9421df833b3fe7b3072b1d Mon Sep 17 00:00:00 2001 From: liangkaiz Date: Fri, 26 Jun 2026 01:37:09 -0700 Subject: [PATCH 2/3] domain name --- .../data/vfm/action/datasets/umi_lerobot_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py index 52e3fc71..292b9bc3 100644 --- a/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py @@ -24,7 +24,7 @@ Viewpoint = Literal["wrist_view"] # Default image key for wrist camera in UMI LeRobot datasets. -_IMAGE_FEATURE = "observation.images.image" +_IMAGE_FEATURE = "observation.images.camera0" _STATE_FEATURE = "observation.state" _ACTION_FEATURE = "action" @@ -64,7 +64,7 @@ def __init__( raise NotImplementedError("This UMI dataset only supports wrist_view.") super().__init__( root=root, - domain_name="umi_lerobot", + domain_name="umi", fps=fps, chunk_length=chunk_length, mode=mode, From 43e3179ecf8d6b2c0c77038cd511b699653472b5 Mon Sep 17 00:00:00 2001 From: liangkaiz Date: Fri, 26 Jun 2026 18:53:43 -0700 Subject: [PATCH 3/3] Modify UMI dataset --- .../action/datasets/umi_lerobot_dataset.py | 89 ++++++++++++++----- 1 file changed, 65 insertions(+), 24 deletions(-) diff --git a/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py index 292b9bc3..cec95876 100644 --- a/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/umi_lerobot_dataset.py @@ -13,6 +13,7 @@ import torch from lerobot.datasets.video_utils import decode_video_frames +from cosmos_framework.data.vfm.action.action_normalization import load_action_stats from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset from cosmos_framework.data.vfm.action.pose_utils import ( @@ -23,13 +24,24 @@ PoseConvention = Literal["backward_framewise"] Viewpoint = Literal["wrist_view"] -# Default image key for wrist camera in UMI LeRobot datasets. -_IMAGE_FEATURE = "observation.images.camera0" -_STATE_FEATURE = "observation.state" -_ACTION_FEATURE = "action" +# Feature keys matching UMI LeRobot parquet columns. +# Trajectory: 7D [pos(3), quat_wxyz(4)] — the main-camera TCP pose. +_TRAJ_KEY = "observation.state.right_main_camera_trajectory_xyz_wxyz" +_GRIPPER_KEY = "observation.state.right_gripper_width_m" +_IMAGE_FEATURE = "observation.image.right_main_camera_rgb" + +# Default EEF-in-camera-frame offset (most UMI rigs). +# touch_in_the_wild / FastUMI use FORWARD_EEF_IN_CAMERA_FRAME_XYZ_WXYZ with z=0.056. +_DEFAULT_EEF_IN_CAMERA_FRAME_XYZ_WXYZ: tuple[float, ...] = (0.0, 0.086, 0.09, 1.0, 0.0, 0.0, 0.0) +FORWARD_EEF_IN_CAMERA_FRAME_XYZ_WXYZ: tuple[float, ...] = (0.0, 0.086, 0.056, 1.0, 0.0, 0.0, 0.0) +"""EEF offset for touch_in_the_wild / FastUMI rigs (camera mounted slightly forward).""" _NORMALIZER_PATH = Path(__file__).parent / "stats/umi_lerobot_stats.json" +# Action layout: single-arm is the first 10D of the 20D bimanual stats file +# (right_eef_poses(9) + right_eef_commands(1)). +_SINGLE_ARM_ACTION_DIM = 10 + class UMILeRobotDataset(ActionBaseDataset): """UMI dataset converted to LeRobot format with 10D cartesian actions: @@ -37,14 +49,17 @@ class UMILeRobotDataset(ActionBaseDataset): [pos_delta(3), rot6d_delta(6), gripper_width(1)] Expects a LeRobot v2 dataset with: - * ``observation.images.image``: wrist-mounted RGB video (configurable via - ``image_key``). - * ``observation.state``: 7D EEF state ``[pos(3), rot_axisangle(3), - gripper_width(1)]``. - * ``action``: 7D commanded state in the same format. - - Absolute axis-angle EEF poses are converted to backward-framewise rot6d - relative poses, and the gripper width is taken from the commanded action. + * ``observation.images.camera0``: wrist-mounted RGB video (configurable + via ``image_key``). + * ``observation.state.right_main_camera_trajectory_xyz_wxyz``: 7D camera + TCP pose ``[pos(3), quat_wxyz(4)]`` for frames [0 .. chunk_length]. + * ``observation.state.right_gripper_width_m``: scalar gripper width for + frames [1 .. chunk_length] (commanded future widths). + + Poses are transformed from the camera TCP frame to the EEF frame via + ``eef_in_camera_frame_xyz_wxyz``, then converted to backward-framewise + rot6d relative poses. The stats file stores 20D bimanual stats (right + + left arm); single-arm normalization uses only the first 10D (right arm). """ def __init__( @@ -54,11 +69,12 @@ def __init__( chunk_length: int = 16, mode: str = "joint", pose_convention: PoseConvention = "backward_framewise", - tolerance_s: float = 2e-4, + tolerance_s: float = 1e-4, viewpoint: Viewpoint = "wrist_view", action_normalization: str | None = "quantile", sample_stride: int = 1, image_key: str = _IMAGE_FEATURE, + eef_in_camera_frame_xyz_wxyz: tuple[float, ...] = _DEFAULT_EEF_IN_CAMERA_FRAME_XYZ_WXYZ, ) -> None: if viewpoint != "wrist_view": raise NotImplementedError("This UMI dataset only supports wrist_view.") @@ -76,9 +92,14 @@ def __init__( ) self._image_key = image_key + xyz_wxyz = np.asarray(eef_in_camera_frame_xyz_wxyz, dtype=np.float32).reshape(1, 7) + self._eef_in_camera_frame_mat: np.ndarray = build_abs_pose_from_components( + xyz_wxyz[:, :3], xyz_wxyz[:, 3:], "quat_wxyz" + )[0] # [4, 4] + @property def action_dim(self) -> int: - return 10 + return _SINGLE_ARM_ACTION_DIM def _action_spec(self) -> ActionSpec: return build_action_spec(Pos(), Rot("rot6d"), Gripper()) @@ -87,16 +108,26 @@ def _action_spec(self) -> ActionSpec: def _stats_path(cls) -> Path: return _NORMALIZER_PATH + @classmethod + def load_action_stats(cls) -> dict[str, torch.Tensor]: + # Stats file stores 20D bimanual layout (right + left arm). + # Single-arm normalization uses only the first 10D (right arm). + raw = { + key: torch.from_numpy(value).float() + for key, value in load_action_stats(str(cls._stats_path())).items() + } + return {key: tensor[:_SINGLE_ARM_ACTION_DIM] for key, tensor in raw.items()} + def __getitem__(self, idx: int) -> dict[str, Any]: mode = self._choose_mode() idx = int(idx) row_idx = idx * self._sample_stride + # T+1 rows: current frame + T future frames observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] - action_rows = observation_rows[: self._chunk_length] episode = self._episodes[int(observation_rows[0]["episode_index"])] video = self._load_video(episode, observation_rows) - raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + raw_action, initial_pose = self._build_raw_action(observation_rows) task = self._tasks[int(observation_rows[0]["task_index"])] ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) @@ -119,16 +150,26 @@ def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, def _build_raw_action( self, observation_rows: list[dict[str, Any]], - action_rows: list[dict[str, Any]], ) -> tuple[torch.Tensor, torch.Tensor]: - # State is 7D: [pos(3), rot_axisangle(3), gripper_width(1)] - state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) - poses_abs = build_abs_pose_from_components(state[:, 0:3], state[:, 3:6], "axisangle") + # Trajectory: T+1 poses, [pos(3), quat_wxyz(4)] per frame. + traj = np.asarray([row[_TRAJ_KEY] for row in observation_rows], dtype=np.float32) # [T+1, 7] + poses_abs = build_abs_pose_from_components(traj[:, :3], traj[:, 3:], "quat_wxyz") # [T+1, 4, 4] initial_pose = torch.from_numpy(poses_abs[0].copy()).float() - poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) - # Gripper width from commanded action (7th column) - gripper = np.asarray([row[_ACTION_FEATURE][6] for row in action_rows], dtype=np.float32).reshape(-1, 1) - action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) + # Transform from camera TCP frame to EEF frame, then compute relative poses. + eef_poses_abs = poses_abs @ self._eef_in_camera_frame_mat # [T+1, 4, 4] + eef_poses_rel = pose_abs_to_rel( + eef_poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention + ) # [T, 9] + + # Gripper command: future frames only (rows[1:]), matching gripper_indices=[1..T]. + gripper_rows = observation_rows[1:] + gripper_vals = [row[_GRIPPER_KEY] for row in gripper_rows] + gripper = np.asarray( + [float(v) if np.isscalar(v) else float(v[0]) for v in gripper_vals], + dtype=np.float32, + ).reshape(-1, 1) # [T, 1] + + action = np.concatenate([eef_poses_rel, gripper], axis=-1) # [T, 10] return torch.from_numpy(action).float(), initial_pose