diff --git a/docs/source/reference/modules_models.rst b/docs/source/reference/modules_models.rst index 68891fe4e67..0467e80dc61 100644 --- a/docs/source/reference/modules_models.rst +++ b/docs/source/reference/modules_models.rst @@ -14,8 +14,11 @@ Modules for model-based reinforcement learning, including world models and dynam ObsEncoder ObsDecoder RSSMPosterior + RSSMPosteriorV3 RSSMPrior + RSSMPriorV3 RSSMRollout + RSSMRolloutV3 PILCO ----- diff --git a/docs/source/reference/objectives_other.rst b/docs/source/reference/objectives_other.rst index b97d3efea50..c448eda9fc0 100644 --- a/docs/source/reference/objectives_other.rst +++ b/docs/source/reference/objectives_other.rst @@ -16,3 +16,28 @@ Additional loss modules for specialized algorithms. DreamerModelLoss DreamerValueLoss ExponentialQuadraticCost + +DreamerV3 +--------- + +Loss modules for DreamerV3 (`Mastering Diverse Domains in World Models, Hafner et al. 2023 `_). +Key differences from V1: discrete categorical latent state, KL balancing, symlog transforms, and two-hot value distributions. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + DreamerV3ActorLoss + DreamerV3ModelLoss + DreamerV3ValueLoss + +DreamerV3 Utilities +~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated/ + + symlog + symexp + two_hot_encode + two_hot_decode diff --git a/test/test_dreamer_v3.py b/test/test_dreamer_v3.py new file mode 100644 index 00000000000..585ea93a013 --- /dev/null +++ b/test/test_dreamer_v3.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for DreamerV3 loss modules and RSSM components. + +Reference: https://arxiv.org/abs/2301.04104 +""" +from __future__ import annotations + +import pytest +import torch +from tensordict import TensorDict +from tensordict.nn import ( + InteractionType, + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModule, +) +from test_objectives import LossModuleTestBase +from torch import nn + +from torchrl.data import Unbounded +from torchrl.envs.model_based.dreamer import DreamerEnv +from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv +from torchrl.modules import SafeSequential, WorldModelWrapper +from torchrl.modules.distributions.continuous import TanhNormal +from torchrl.modules.models.model_based import DreamerActor +from torchrl.modules.models.model_based_v3 import RSSMPriorV3 +from torchrl.modules.models.models import MLP +from torchrl.objectives import ( + DreamerV3ActorLoss, + DreamerV3ModelLoss, + DreamerV3ValueLoss, +) +from torchrl.objectives.dreamer_v3 import ( + _default_bins, + categorical_kl_balanced, + symexp, + symlog, + two_hot_decode, + two_hot_encode, +) +from torchrl.objectives.utils import ValueEstimators +from torchrl.testing import get_default_devices +from torchrl.testing.mocking_classes import ContinuousActionConvMockEnv + + +@pytest.mark.parametrize("device", get_default_devices()) +class TestDreamerV3(LossModuleTestBase): # type: ignore[misc] + img_size = (64, 64) + # Compact sizes to keep tests fast + num_cats = 4 + num_classes = 4 + state_dim = num_cats * num_classes # 16 + rnn_hidden_dim = 8 + action_dim = 3 + num_reward_bins = 16 # small for tests; paper uses 255 + + def _create_world_model_data(self): + B, T = 2, 3 + return TensorDict( + { + "state": torch.zeros(B, T, self.state_dim), + "belief": torch.zeros(B, T, self.rnn_hidden_dim), + "pixels": torch.rand(B, T, 3, *self.img_size), + "action": torch.randn(B, T, self.action_dim), + "next": { + "pixels": torch.rand(B, T, 3, *self.img_size), + "reward": torch.randn(B, T, 1), + "done": torch.zeros(B, T, dtype=torch.bool), + "terminated": torch.zeros(B, T, dtype=torch.bool), + }, + }, + [B, T], + ) + + def _create_actor_data(self): + B, T = 2, 3 + return TensorDict( + { + "state": torch.randn(B, T, self.state_dim), + "belief": torch.randn(B, T, self.rnn_hidden_dim), + "reward": torch.randn(B, T, 1), + }, + [B, T], + ) + + def _create_value_data(self): + N = 6 # 2 * 3 + return TensorDict( + { + "state": torch.randn(N, self.state_dim), + "belief": torch.randn(N, self.rnn_hidden_dim), + "lambda_target": torch.randn(N, 1), + }, + [N], + ) + + def _create_world_model(self, reward_two_hot=True): + """Minimal stub world model that produces all keys DreamerV3ModelLoss expects.""" + + class _StubWorldModel(nn.Module): + def __init__( + self_, + num_cats, + num_classes, + rnn_hidden_dim, + num_reward_bins, + reward_two_hot, + ): + super().__init__() + state_dim = num_cats * num_classes + # pixel encoder → reco + self_.encoder = nn.LazyConv2d(8, 4, stride=2) + self_.decoder = nn.LazyConvTranspose2d(3, 4, stride=2) + # prior / posterior MLP stubs + self_.prior_net = nn.Linear( + state_dim + rnn_hidden_dim, num_cats * num_classes + ) + self_.posterior_net = nn.LazyLinear(num_cats * num_classes) + # reward head + out_r = num_reward_bins if reward_two_hot else 1 + self_.reward_net = nn.LazyLinear(out_r) + self_.num_cats = num_cats + self_.num_classes = num_classes + self_.reward_two_hot = reward_two_hot + + def forward(self_, tensordict): + B, T = tensordict.shape + state = tensordict["state"] # [B, T, state_dim] + belief = tensordict["belief"] # [B, T, rnn_hidden] + + # prior logits + prior_in = torch.cat([state, belief], dim=-1) + prior_flat = self_.prior_net(prior_in) + prior_logits = prior_flat.view(B, T, self_.num_cats, self_.num_classes) + + # posterior logits (lazy — accepts anything) + post_flat = self_.posterior_net(prior_in) + posterior_logits = post_flat.view( + B, T, self_.num_cats, self_.num_classes + ) + + # reco pixels (tiny decode — just needs right shape) + next_pixels = tensordict["next", "pixels"] # [B, T, 3, H, W] + flat_pix = next_pixels.flatten(0, 1) # [B*T, 3, H, W] + enc = torch.relu(self_.encoder(flat_pix)) + reco_flat = torch.sigmoid(self_.decoder(enc)) + _, C, H, W = reco_flat.shape + reco_pixels = reco_flat.view(B, T, C, H, W) + + # reward prediction + reward_in = torch.cat([state, belief], dim=-1) + reward_pred = self_.reward_net(reward_in) # [B, T, out_r] + + tensordict.set(("next", "prior_logits"), prior_logits) + tensordict.set(("next", "posterior_logits"), posterior_logits) + tensordict.set(("next", "reco_pixels"), reco_pixels) + tensordict.set(("next", "reward"), reward_pred) + return tensordict + + stub = _StubWorldModel( + self.num_cats, + self.num_classes, + self.rnn_hidden_dim, + self.num_reward_bins, + reward_two_hot, + ) + # warm-up lazy layers + with torch.no_grad(): + stub(self._create_world_model_data()) + return stub + + def _create_mb_env(self): + mock_env = TransformedEnv( + ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) + ) + default_dict = { + "state": Unbounded(self.state_dim), + "belief": Unbounded(self.rnn_hidden_dim), + } + mock_env.append_transform( + TensorDictPrimer(random=False, default_value=0, **default_dict) + ) + rssm_prior = RSSMPriorV3( + action_spec=mock_env.action_spec, + hidden_dim=self.rnn_hidden_dim, + rnn_hidden_dim=self.rnn_hidden_dim, + num_categoricals=self.num_cats, + num_classes=self.num_classes, + action_dim=mock_env.action_spec.shape[0], + ) + transition_model = SafeSequential( + TensorDictModule( + rssm_prior, + in_keys=["state", "belief", "action"], + out_keys=["_", "state", "belief"], + ) + ) + reward_model = TensorDictModule( + MLP(out_features=1, depth=1, num_cells=8), + in_keys=["state", "belief"], + out_keys=["reward"], + ) + model_based_env = DreamerEnv( + world_model=WorldModelWrapper(transition_model, reward_model), + prior_shape=torch.Size([self.state_dim]), + belief_shape=torch.Size([self.rnn_hidden_dim]), + ) + model_based_env.set_specs_from_env(mock_env) + with torch.no_grad(): + model_based_env.rollout(3) + return model_based_env + + def _create_actor_model(self): + mock_env = TransformedEnv( + ContinuousActionConvMockEnv(pixel_shape=[3, *self.img_size]) + ) + actor_module = DreamerActor( + out_features=mock_env.action_spec.shape[0], + depth=1, + num_cells=8, + ) + actor_model = ProbabilisticTensorDictSequential( + TensorDictModule( + actor_module, + in_keys=["state", "belief"], + out_keys=["loc", "scale"], + ), + ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + out_keys=["action"], + default_interaction_type=InteractionType.RANDOM, + distribution_class=TanhNormal, + ), + ) + with torch.no_grad(): + td = TensorDict( + { + "state": torch.randn(1, 2, self.state_dim), + "belief": torch.randn(1, 2, self.rnn_hidden_dim), + }, + batch_size=[1], + ) + actor_model(td) + return actor_model + + def _create_value_model(self, out_features=1): + value_model = TensorDictModule( + MLP(out_features=out_features, depth=1, num_cells=8), + in_keys=["state", "belief"], + out_keys=["state_value"], + ) + with torch.no_grad(): + td = TensorDict( + { + "state": torch.randn(1, 2, self.state_dim), + "belief": torch.randn(1, 2, self.rnn_hidden_dim), + }, + batch_size=[1], + ) + value_model(td) + return value_model + + # ------------------------------------------------------------------ # + # Required by LossModuleTestBase + # ------------------------------------------------------------------ # + + def test_reset_parameters_recursive(self, device): + world_model = self._create_world_model(reward_two_hot=True).to(device) + loss_fn = DreamerV3ModelLoss(world_model, num_reward_bins=self.num_reward_bins) + self.reset_parameters_recursive_test(loss_fn) + + # ------------------------------------------------------------------ # + # Utility tests + # ------------------------------------------------------------------ # + + def test_dreamer_v3_symlog_invertibility(self, device): + x = torch.tensor([-1000.0, -10.0, -1.0, 0.0, 1.0, 10.0, 1000.0], device=device) + reconstructed = symexp(symlog(x)) + assert torch.allclose( + reconstructed, x, atol=1e-4 + ), f"symexp(symlog(x)) ≠ x: {reconstructed}" + + def test_dreamer_v3_two_hot_roundtrip(self, device): + bins = _default_bins(self.num_reward_bins).to(device) + vals = torch.linspace(-15.0, 15.0, 9, device=device) + encoded = two_hot_encode(vals, bins) + # Each row must be a valid probability distribution + assert torch.allclose(encoded.sum(-1), torch.ones(9, device=device), atol=1e-5) + decoded = two_hot_decode(torch.log(encoded + 1e-8), bins) + assert torch.allclose( + decoded, vals, atol=0.5 + ), f"two_hot round-trip error too large: {(decoded - vals).abs().max()}" + + # ------------------------------------------------------------------ # + # World model loss tests + # ------------------------------------------------------------------ # + + @pytest.mark.parametrize("reward_two_hot", [True, False]) + @pytest.mark.parametrize( + "lambda_kl,lambda_reco,lambda_reward", [(1.0, 1.0, 1.0), (0.0, 0.0, 0.0)] + ) + def test_dreamer_v3_model_loss_output_keys( + self, device, reward_two_hot, lambda_kl, lambda_reco, lambda_reward + ): + tensordict = self._create_world_model_data().to(device) + world_model = self._create_world_model(reward_two_hot=reward_two_hot).to(device) + loss_module = DreamerV3ModelLoss( + world_model, + lambda_kl=lambda_kl, + lambda_reco=lambda_reco, + lambda_reward=lambda_reward, + reward_two_hot=reward_two_hot, + num_reward_bins=self.num_reward_bins, + ) + loss_td, _ = loss_module(tensordict) + for key in ("loss_model_kl", "loss_model_reco", "loss_model_reward"): + assert key in loss_td.keys(), f"Missing {key}" + assert loss_td[key].shape == torch.Size([1]) + + def test_dreamer_v3_model_loss_backward(self, device): + tensordict = self._create_world_model_data().to(device) + world_model = self._create_world_model(reward_two_hot=True).to(device) + loss_module = DreamerV3ModelLoss( + world_model, + num_reward_bins=self.num_reward_bins, + ) + loss_td, _ = loss_module(tensordict) + total_loss = sum( + loss_td[k] + for k in ("loss_model_kl", "loss_model_reco", "loss_model_reward") + ) + total_loss.backward() + grad_total = sum( + p.grad.pow(2).sum().item() + for p in loss_module.parameters() + if p.grad is not None + ) + assert grad_total > 0, "All gradients are zero after backward" + for name, p in loss_module.named_parameters(): + if p.grad is not None: + assert not torch.isnan(p.grad).any(), f"NaN grad in {name}" + assert not torch.isinf(p.grad).any(), f"Inf grad in {name}" + + def test_dreamer_v3_kl_balanced_gradients(self, device): + """Both prior_logits and posterior_logits must receive gradients (KL balancing).""" + prior_logits = torch.randn( + 2, 3, self.num_cats, self.num_classes, requires_grad=True, device=device + ) + posterior_logits = torch.randn( + 2, 3, self.num_cats, self.num_classes, requires_grad=True, device=device + ) + kl = categorical_kl_balanced( + posterior_logits, prior_logits, alpha=0.8, free_bits=1.0 + ) + kl.backward() + assert ( + prior_logits.grad is not None and prior_logits.grad.norm() > 0 + ), "prior_logits has no gradient — KL balancing broken" + assert ( + posterior_logits.grad is not None and posterior_logits.grad.norm() > 0 + ), "posterior_logits has no gradient — KL balancing broken" + + def test_dreamer_v3_model_tensor_keys(self, device): + world_model = self._create_world_model() + loss_fn = DreamerV3ModelLoss(world_model, num_reward_bins=self.num_reward_bins) + default_keys = { + "reward": "reward", + "true_reward": "true_reward", + "prior_logits": "prior_logits", + "posterior_logits": "posterior_logits", + "pixels": "pixels", + "reco_pixels": "reco_pixels", + } + self.tensordict_keys_test(loss_fn, default_keys=default_keys) + + # ------------------------------------------------------------------ # + # Actor loss tests + # ------------------------------------------------------------------ # + + @pytest.mark.parametrize("imagination_horizon", [3, 5]) + @pytest.mark.parametrize("discount_loss", [True, False]) + @pytest.mark.parametrize( + "td_est", + [ValueEstimators.TD0, ValueEstimators.TD1, ValueEstimators.TDLambda, None], + ) + def test_dreamer_v3_actor_loss( + self, device, imagination_horizon, discount_loss, td_est + ): + tensordict = self._create_actor_data().to(device) + mb_env = self._create_mb_env().to(device) + actor_model = self._create_actor_model().to(device) + value_model = self._create_value_model().to(device) + loss_module = DreamerV3ActorLoss( + actor_model, + value_model, + mb_env, + imagination_horizon=imagination_horizon, + discount_loss=discount_loss, + ) + if td_est is not None: + loss_module.make_value_estimator(td_est) + loss_td, fake_data = loss_module(tensordict.reshape(-1)) + assert "loss_actor" in loss_td.keys() + assert loss_td["loss_actor"].ndim == 0 or loss_td["loss_actor"].numel() == 1 + loss_td["loss_actor"].backward() + grad_total = sum( + p.grad.pow(2).sum().item() + for p in loss_module.parameters() + if p.grad is not None + ) + assert grad_total > 0, "All gradients are zero after actor backward" + + # ------------------------------------------------------------------ # + # Value loss tests + # ------------------------------------------------------------------ # + + @pytest.mark.parametrize("discount_loss", [True, False]) + def test_dreamer_v3_value_loss_symlog_mse(self, device, discount_loss): + tensordict = self._create_value_data().to(device) + value_model = self._create_value_model(out_features=1).to(device) + loss_module = DreamerV3ValueLoss( + value_model, + value_loss="symlog_mse", + discount_loss=discount_loss, + ) + loss_td, _ = loss_module(tensordict) + assert "loss_value" in loss_td.keys() + loss_td["loss_value"].backward() + grad_total = sum( + p.grad.pow(2).sum().item() + for p in loss_module.parameters() + if p.grad is not None + ) + assert ( + grad_total > 0 + ), "All gradients are zero after value (symlog_mse) backward" + + @pytest.mark.parametrize("discount_loss", [True, False]) + def test_dreamer_v3_value_loss_two_hot(self, device, discount_loss): + tensordict = self._create_value_data().to(device) + # Value model must output logits over bins + value_model = self._create_value_model(out_features=self.num_reward_bins).to( + device + ) + loss_module = DreamerV3ValueLoss( + value_model, + value_loss="two_hot", + discount_loss=discount_loss, + num_value_bins=self.num_reward_bins, + ) + loss_td, _ = loss_module(tensordict) + assert "loss_value" in loss_td.keys() + loss_td["loss_value"].backward() + grad_total = sum( + p.grad.pow(2).sum().item() + for p in loss_module.parameters() + if p.grad is not None + ) + assert grad_total > 0, "All gradients are zero after value (two_hot) backward" + + def test_dreamer_v3_value_invalid_loss_type(self, device): + value_model = self._create_value_model() + with pytest.raises(ValueError, match="symlog_mse.*two_hot"): + DreamerV3ValueLoss(value_model, value_loss="bad_loss_type") diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index b8d85025a44..243e6ef53ce 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -26,6 +26,7 @@ RSSMPrior, RSSMRollout, ) +from .model_based_v3 import RSSMPosteriorV3, RSSMPriorV3, RSSMRolloutV3 from .models import ( Conv2dNet, Conv3dNet, @@ -81,8 +82,11 @@ "QMixer", "RBFController", "RSSMPosterior", + "RSSMPosteriorV3", "RSSMPrior", + "RSSMPriorV3", "RSSMRollout", + "RSSMRolloutV3", "Squeeze2dLayer", "SqueezeLayer", "VDNMixer", diff --git a/torchrl/modules/models/model_based_v3.py b/torchrl/modules/models/model_based_v3.py new file mode 100644 index 00000000000..98cd2c900c8 --- /dev/null +++ b/torchrl/modules/models/model_based_v3.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import torch +from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential +from torch import nn +from torch.nn import GRUCell + + +class RSSMPriorV3(nn.Module): + """DreamerV3 prior network with discrete categorical latent state. + + Implements the sequence model and dynamics predictor from DreamerV3. + The GRU updates the deterministic hidden state: + h_t = GRU(h_{t-1}, [z_{t-1}, a_{t-1}]) + Then the prior predicts a distribution over the stochastic latent: + ẑ_t ~ Cat(MLP(h_t)) + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + action_spec: Action spec (used only for shape validation). + hidden_dim (int, optional): Hidden dimension of the linear projector. + Defaults to 512. + rnn_hidden_dim (int, optional): GRU hidden state dimension (belief size). + Defaults to 512. + num_categoricals (int, optional): Number of categorical variables in the + discrete latent. Defaults to 32. + num_classes (int, optional): Number of classes per categorical variable. + Defaults to 32. + action_dim (int, optional): Action dimension. If provided (along with + ``num_categoricals * num_classes``), uses explicit ``nn.Linear`` instead + of ``nn.LazyLinear``. Defaults to None. + device (torch.device, optional): Device. Defaults to None. + """ + + def __init__( + self, + action_spec, + hidden_dim: int = 512, + rnn_hidden_dim: int = 512, + num_categoricals: int = 32, + num_classes: int = 32, + action_dim: int | None = None, + device=None, + ): + super().__init__() + self.num_categoricals = num_categoricals + self.num_classes = num_classes + self.rnn_hidden_dim = rnn_hidden_dim + state_dim = num_categoricals * num_classes + + self.rnn = GRUCell(hidden_dim, rnn_hidden_dim, device=device) + + if action_dim is not None: + projector_in = state_dim + action_dim + first_linear = nn.Linear(projector_in, hidden_dim, device=device) + else: + first_linear = nn.LazyLinear(hidden_dim, device=device) + self.action_state_projector = nn.Sequential(first_linear, nn.SiLU()) + + self.rnn_to_prior_projector = nn.Sequential( + nn.Linear(rnn_hidden_dim, hidden_dim, device=device), + nn.SiLU(), + nn.Linear(hidden_dim, num_categoricals * num_classes, device=device), + ) + + self.action_shape = action_spec.shape + + def forward( + self, + state: torch.Tensor, + belief: torch.Tensor, + action: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute prior distribution and update GRU belief. + + Args: + state: Previous stochastic state, shape ``[..., num_categoricals * num_classes]``. + belief: Previous GRU hidden state, shape ``[..., rnn_hidden_dim]``. + action: Current action, shape ``[..., action_dim]``. + + Returns: + prior_logits (torch.Tensor): Raw logits, shape + ``[..., num_categoricals, num_classes]``. + state (torch.Tensor): Sampled state (straight-through), shape + ``[..., num_categoricals * num_classes]``. + belief (torch.Tensor): Updated GRU hidden state, shape + ``[..., rnn_hidden_dim]``. + """ + projector_input = torch.cat([state, action], dim=-1) + action_state = self.action_state_projector(projector_input) + + # Run GRU in full precision to avoid cuBLAS issues under autocast + dtype = action_state.dtype + device_type = action_state.device.type + with torch.amp.autocast(device_type=device_type, enabled=False): + belief = self.rnn( + action_state.float(), + belief.float() if belief is not None else None, + ) + belief = belief.to(dtype) + + prior_logits_flat = self.rnn_to_prior_projector(belief) + prior_logits = prior_logits_flat.view( + *prior_logits_flat.shape[:-1], self.num_categoricals, self.num_classes + ) + + state = _straight_through_categorical(prior_logits) + state = state.view(*state.shape[:-2], self.num_categoricals * self.num_classes) + + return prior_logits, state, belief + + +class RSSMPosteriorV3(nn.Module): + """DreamerV3 posterior (representation model) with discrete categorical latent. + + Given the deterministic hidden state ``h_t`` and an observation embedding ``e_t``, + produces the posterior distribution over the stochastic latent: + z_t ~ Cat(MLP([h_t, e_t])) + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + hidden_dim (int, optional): Hidden dimension of the projector MLP. + Defaults to 512. + num_categoricals (int, optional): Number of categorical variables. + Defaults to 32. + num_classes (int, optional): Number of classes per categorical variable. + Defaults to 32. + rnn_hidden_dim (int, optional): Belief dimension. If provided along with + ``obs_embed_dim``, uses explicit ``nn.Linear``. Defaults to None. + obs_embed_dim (int, optional): Observation embedding dimension. If provided + along with ``rnn_hidden_dim``, uses explicit ``nn.Linear``. Defaults to None. + device (torch.device, optional): Device. Defaults to None. + """ + + def __init__( + self, + hidden_dim: int = 512, + num_categoricals: int = 32, + num_classes: int = 32, + rnn_hidden_dim: int | None = None, + obs_embed_dim: int | None = None, + device=None, + ): + super().__init__() + self.num_categoricals = num_categoricals + self.num_classes = num_classes + + if rnn_hidden_dim is not None and obs_embed_dim is not None: + projector_in = rnn_hidden_dim + obs_embed_dim + first_linear = nn.Linear(projector_in, hidden_dim, device=device) + else: + first_linear = nn.LazyLinear(hidden_dim, device=device) + + self.obs_rnn_to_post_projector = nn.Sequential( + first_linear, + nn.SiLU(), + nn.Linear(hidden_dim, num_categoricals * num_classes, device=device), + ) + + def forward( + self, + belief: torch.Tensor, + obs_embedding: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute posterior distribution given belief and observation embedding. + + Args: + belief: Deterministic GRU hidden state from prior, shape + ``[..., rnn_hidden_dim]``. + obs_embedding: Encoded observation, shape ``[..., obs_embed_dim]``. + + Returns: + posterior_logits (torch.Tensor): Raw logits, shape + ``[..., num_categoricals, num_classes]``. + state (torch.Tensor): Sampled state (straight-through), shape + ``[..., num_categoricals * num_classes]``. + """ + post_logits_flat = self.obs_rnn_to_post_projector( + torch.cat([belief, obs_embedding], dim=-1) + ) + posterior_logits = post_logits_flat.view( + *post_logits_flat.shape[:-1], self.num_categoricals, self.num_classes + ) + state = _straight_through_categorical(posterior_logits) + state = state.view(*state.shape[:-2], self.num_categoricals * self.num_classes) + return posterior_logits, state + + +class RSSMRolloutV3(TensorDictModuleBase): + """Roll out the DreamerV3 RSSM over a sequence. + + Given encoded observations and actions for ``T`` time steps, this module + runs the prior (GRU + categorical) then the posterior (categorical) at each + step and returns a stacked TensorDict of all intermediate states. + + The previous posterior state ``z_t`` is used as the prior input for step + ``t+1``, matching the recurrent structure of DreamerV3. + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + rssm_prior (TensorDictModule): Prior module wrapping :class:`RSSMPriorV3`. + rssm_posterior (TensorDictModule): Posterior module wrapping + :class:`RSSMPosteriorV3`. + """ + + def __init__( + self, + rssm_prior: TensorDictModule, + rssm_posterior: TensorDictModule, + ): + super().__init__() + _module = TensorDictSequential(rssm_prior, rssm_posterior) + self.in_keys = _module.in_keys + self.out_keys = _module.out_keys + self.rssm_prior = rssm_prior + self.rssm_posterior = rssm_posterior + + def forward(self, tensordict): + """Roll out the RSSM for one episode chunk. + + Args: + tensordict (TensorDictBase): Input with shape ``[*batch, T]`` containing + actions, encoded observations, and initial state/belief. + + Returns: + TensorDictBase: Stacked outputs with shape ``[*batch, T]``. + """ + tensordict_out = [] + *batch, time_steps = tensordict.shape + + update_values = tensordict.exclude(*self.out_keys).unbind(-1) + _tensordict = update_values[0] + + output_keys = list( + update_values[0].keys(include_nested=True, leaves_only=True) + ) + list(self.out_keys) + + for t in range(time_steps): + self.rssm_prior(_tensordict) + self.rssm_posterior(_tensordict) + + tensordict_out.append(_tensordict.select(*output_keys, strict=False)) + if t < time_steps - 1: + next_state = _tensordict.get(("next", "state")) + next_belief = _tensordict.get(("next", "belief")) + _tensordict = update_values[t + 1] + _tensordict.set("state", next_state) + _tensordict.set("belief", next_belief) + + return torch.stack(tensordict_out, tensordict.ndim - 1) + + +def _straight_through_categorical(logits: torch.Tensor) -> torch.Tensor: + """Sample from categorical with straight-through gradient estimator. + + Forward: hard one-hot sample. + Backward: gradients flow through the soft probabilities. + + Args: + logits: ``[..., num_categoricals, num_classes]`` + + Returns: + one_hot tensor with same shape, gradients through softmax. + """ + probs = torch.softmax(logits, dim=-1) + indices = torch.distributions.Categorical(probs=probs).sample() + one_hot = torch.zeros_like(probs) + one_hot.scatter_(-1, indices.unsqueeze(-1), 1.0) + # straight-through: forward uses hard, backward uses soft + return one_hot + (probs - probs.detach()) diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8e47d73519..7873354ec0c 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -15,6 +15,15 @@ DreamerModelLoss, DreamerValueLoss, ) +from torchrl.objectives.dreamer_v3 import ( + DreamerV3ActorLoss, + DreamerV3ModelLoss, + DreamerV3ValueLoss, + symexp, + symlog, + two_hot_decode, + two_hot_encode, +) from torchrl.objectives.gail import GAILLoss from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss from torchrl.objectives.multiagent import QMixerLoss @@ -52,6 +61,9 @@ "DistributionalDQNLoss", "DreamerActorLoss", "DreamerModelLoss", + "DreamerV3ActorLoss", + "DreamerV3ModelLoss", + "DreamerV3ValueLoss", "DreamerValueLoss", "ExponentialQuadraticCost", "GAILLoss", @@ -77,4 +89,8 @@ "hold_out_net", "hold_out_params", "next_state_value", + "symexp", + "symlog", + "two_hot_decode", + "two_hot_encode", ] diff --git a/torchrl/objectives/dreamer_v3.py b/torchrl/objectives/dreamer_v3.py new file mode 100644 index 00000000000..d952d2c2859 --- /dev/null +++ b/torchrl/objectives/dreamer_v3.py @@ -0,0 +1,677 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""DreamerV3 loss modules. + +Implements the three loss modules from DreamerV3 (Mastering Diverse Domains in +World Models, Hafner et al. 2023): + +- :class:`DreamerV3ModelLoss` — world model (KL balancing + symlog reconstruction) +- :class:`DreamerV3ActorLoss` — actor (REINFORCE + entropy bonus) +- :class:`DreamerV3ValueLoss` — value function (symlog MSE or two-hot CE) + +Utility functions :func:`symlog`, :func:`symexp`, :func:`two_hot_encode` and +:func:`two_hot_decode` are also exported for use in custom models. + +Reference: https://arxiv.org/abs/2301.04104 +""" +from __future__ import annotations + +from dataclasses import dataclass + +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModule +from tensordict.utils import NestedKey + +from torchrl._utils import _maybe_record_function_decorator +from torchrl.envs.model_based.dreamer import DreamerEnv +from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp +from torchrl.objectives.common import LossModule +from torchrl.objectives.utils import ( + _GAMMA_LMBDA_DEPREC_ERROR, + default_value_kwargs, + hold_out_net, + ValueEstimators, +) +from torchrl.objectives.value import ( + TD0Estimator, + TD1Estimator, + TDLambdaEstimator, + ValueEstimatorBase, +) + +# --------------------------------------------------------------------------- +# Symlog / symexp transforms +# --------------------------------------------------------------------------- + + +def symlog(x: torch.Tensor) -> torch.Tensor: + """Symmetric logarithm: ``sign(x) * log(|x| + 1)``. + + Used by DreamerV3 to compress the dynamic range of targets and + predictions before computing reconstruction losses. + """ + return x.sign() * (x.abs() + 1).log() + + +def symexp(x: torch.Tensor) -> torch.Tensor: + """Symmetric exponential: ``sign(x) * (exp(|x|) - 1)``. + + Inverse of :func:`symlog`. + """ + return x.sign() * (x.abs().exp() - 1) + + +# --------------------------------------------------------------------------- +# Two-hot encoding (for reward / value distributions) +# --------------------------------------------------------------------------- + +# Default 255-bin linspace in symlog space: roughly covers [-20, 20] raw scale +_DEFAULT_NUM_BINS: int = 255 +_DEFAULT_BIN_RANGE: float = 20.0 + + +def _default_bins(num_bins: int = _DEFAULT_NUM_BINS, device=None) -> torch.Tensor: + return torch.linspace( + -_DEFAULT_BIN_RANGE, _DEFAULT_BIN_RANGE, num_bins, device=device + ) + + +def two_hot_encode( + x: torch.Tensor, + bins: torch.Tensor, +) -> torch.Tensor: + """Encode a scalar tensor as a two-hot distribution over ``bins``. + + The scalar is split between the two nearest bin centers proportionally so + that ``E[bins] = x``. + + Args: + x: Values to encode, shape ``[...]``. + bins: Sorted bin centers, shape ``[num_bins]``. + + Returns: + Two-hot vectors, shape ``[..., num_bins]``. + """ + bins = bins.to(x.device) + x_clamped = x.clamp(bins[0], bins[-1]) + + # Index of the lower bin + lower_idx = (x_clamped.unsqueeze(-1) >= bins).sum(-1) - 1 + lower_idx = lower_idx.clamp(0, bins.shape[0] - 2) + upper_idx = lower_idx + 1 + + lower_val = bins[lower_idx] + upper_val = bins[upper_idx] + span = upper_val - lower_val + upper_weight = torch.where( + span > 0, (x_clamped - lower_val) / span, torch.zeros_like(x_clamped) + ) + lower_weight = 1.0 - upper_weight + + two_hot = torch.zeros(*x.shape, bins.shape[0], device=x.device, dtype=x.dtype) + two_hot.scatter_(-1, lower_idx.unsqueeze(-1), lower_weight.unsqueeze(-1)) + two_hot.scatter_(-1, upper_idx.unsqueeze(-1), upper_weight.unsqueeze(-1)) + return two_hot + + +def two_hot_decode(logits: torch.Tensor, bins: torch.Tensor) -> torch.Tensor: + """Decode a distribution over ``bins`` to a scalar expectation. + + Args: + logits: Raw logits, shape ``[..., num_bins]``. + bins: Sorted bin centers, shape ``[num_bins]``. + + Returns: + Scalar expected values, shape ``[...]``. + """ + bins = bins.to(logits.device) + probs = torch.softmax(logits, dim=-1) + return (probs * bins).sum(-1) + + +# --------------------------------------------------------------------------- +# KL balancing for categorical distributions (DreamerV3 §3) +# --------------------------------------------------------------------------- + + +def categorical_kl_balanced( + posterior_logits: torch.Tensor, + prior_logits: torch.Tensor, + alpha: float = 0.8, + free_bits: float = 1.0, +) -> torch.Tensor: + """KL divergence with balancing between posterior and prior. + + Computes: + loss = α * KL(sg(posterior) ‖ prior) + (1 - α) * KL(posterior ‖ sg(prior)) + + The first term trains only the *prior*; the second trains only the + *posterior*. Free bits are applied per categorical before averaging. + + Args: + posterior_logits: Shape ``[..., num_categoricals, num_classes]``. + prior_logits: Shape ``[..., num_categoricals, num_classes]``. + alpha (float): Balancing weight (0.8 in the paper). Default: 0.8. + free_bits (float): Minimum KL per categorical in nats. Default: 1.0. + + Returns: + Scalar KL loss. + """ + posterior = torch.softmax(posterior_logits, dim=-1) + prior = torch.softmax(prior_logits, dim=-1) + + # Numerical stability: add small epsilon + eps = 1e-8 + posterior = posterior.clamp(min=eps) + prior = prior.clamp(min=eps) + + # KL(sg(posterior) || prior): only prior gets gradients + post_sg = posterior.detach() + kl_term1 = (post_sg * (post_sg.log() - prior.log())).sum(-1) # [..., num_cats] + + # KL(posterior || sg(prior)): only posterior gets gradients + prior_sg = prior.detach() + kl_term2 = (posterior * (posterior.log() - prior_sg.log())).sum( + -1 + ) # [..., num_cats] + + # Free bits per categorical + kl_term1 = kl_term1.clamp_min(free_bits) + kl_term2 = kl_term2.clamp_min(free_bits) + + # Average over categoricals and batch + return (alpha * kl_term1 + (1.0 - alpha) * kl_term2).mean() + + +# --------------------------------------------------------------------------- +# DreamerV3ModelLoss +# --------------------------------------------------------------------------- + + +class DreamerV3ModelLoss(LossModule): + """DreamerV3 World Model Loss. + + Computes three terms: + + 1. **KL loss** — balanced KL between prior and posterior categorical + distributions (see :func:`categorical_kl_balanced`). + 2. **Reconstruction loss** — symlog MSE between predicted and true + observations. + 3. **Reward loss** — two-hot cross-entropy or symlog MSE for the predicted + reward. + + Optionally a **continue loss** (binary cross-entropy) can be enabled + when the world model outputs a continue predictor. + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + world_model (TensorDictModule): World model that takes a tensordict with + observations/actions and writes predicted observations, rewards, and + RSSM prior/posterior logits. + lambda_kl (float, optional): KL loss weight. Default: 1.0. + lambda_reco (float, optional): Reconstruction loss weight. Default: 1.0. + lambda_reward (float, optional): Reward prediction loss weight. Default: 1.0. + lambda_continue (float, optional): Continue prediction loss weight. + Default: 0.0 (disabled). + kl_alpha (float, optional): KL balancing factor (α in the paper). + Default: 0.8. + free_bits (float, optional): Minimum KL per categorical in nats. + Default: 1.0. + reco_loss (str, optional): Reconstruction loss type (``"l2"`` or + ``"l1"``). Default: ``"l2"``. + reward_two_hot (bool, optional): If ``True``, uses two-hot cross-entropy + for the reward loss; otherwise uses symlog MSE. Default: ``True``. + num_reward_bins (int, optional): Number of bins for the two-hot reward + distribution. Default: 255. + global_average (bool, optional): If ``True``, averages losses over all + dimensions. Otherwise sums over non-batch/time dims first. Default: + ``False``. + """ + + @dataclass + class _AcceptedKeys: + """Configurable tensordict keys. + + Attributes: + reward (NestedKey): Predicted reward. Defaults to ``"reward"``. + true_reward (NestedKey): Ground-truth reward (stored temporarily). + Defaults to ``"true_reward"``. + prior_logits (NestedKey): Prior categorical logits from the prior + RSSM. Defaults to ``"prior_logits"``. + posterior_logits (NestedKey): Posterior categorical logits. + Defaults to ``"posterior_logits"``. + pixels (NestedKey): Ground-truth pixel observation. + Defaults to ``"pixels"``. + reco_pixels (NestedKey): Predicted pixel observation. + Defaults to ``"reco_pixels"``. + continue_pred (NestedKey): Predicted continue logit (optional). + Defaults to ``"continue_pred"``. + done (NestedKey): Ground-truth done flag (optional). + Defaults to ``"done"``. + """ + + reward: NestedKey = "reward" + true_reward: NestedKey = "true_reward" + prior_logits: NestedKey = "prior_logits" + posterior_logits: NestedKey = "posterior_logits" + pixels: NestedKey = "pixels" + reco_pixels: NestedKey = "reco_pixels" + continue_pred: NestedKey = "continue_pred" + done: NestedKey = "done" + + tensor_keys: _AcceptedKeys + default_keys = _AcceptedKeys + + def __init__( + self, + world_model: TensorDictModule, + *, + lambda_kl: float = 1.0, + lambda_reco: float = 1.0, + lambda_reward: float = 1.0, + lambda_continue: float = 0.0, + kl_alpha: float = 0.8, + free_bits: float = 1.0, + reco_loss: str = "l2", + reward_two_hot: bool = True, + num_reward_bins: int = _DEFAULT_NUM_BINS, + global_average: bool = False, + ): + super().__init__() + self.world_model = world_model + self.lambda_kl = lambda_kl + self.lambda_reco = lambda_reco + self.lambda_reward = lambda_reward + self.lambda_continue = lambda_continue + self.kl_alpha = kl_alpha + self.free_bits = free_bits + self.reco_loss = reco_loss + self.reward_two_hot = reward_two_hot + self.global_average = global_average + self.register_buffer( + "reward_bins", + _default_bins(num_reward_bins), + ) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @_maybe_record_function_decorator("dreamer_v3/world_model_loss") + def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: + tensordict = tensordict.copy() + tensordict.rename_key_( + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.true_reward), + ) + + tensordict = self.world_model(tensordict) + + # ---- KL loss ---- + prior_logits = tensordict.get(("next", self.tensor_keys.prior_logits)) + posterior_logits = tensordict.get(("next", self.tensor_keys.posterior_logits)) + kl_loss = categorical_kl_balanced( + posterior_logits, + prior_logits, + alpha=self.kl_alpha, + free_bits=self.free_bits, + ).unsqueeze(-1) + + # ---- Reconstruction loss ---- + pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous() + reco_pixels = tensordict.get( + ("next", self.tensor_keys.reco_pixels) + ).contiguous() + # Apply symlog before computing distance + if self.reco_loss == "l2": + reco_loss = (symlog(pixels) - symlog(reco_pixels)).pow(2) + else: + reco_loss = (symlog(pixels) - symlog(reco_pixels)).abs() + if not self.global_average: + reco_loss = reco_loss.sum((-3, -2, -1)) + reco_loss = reco_loss.mean().unsqueeze(-1) + + # ---- Reward loss ---- + true_reward = tensordict.get(("next", self.tensor_keys.true_reward)) + pred_reward = tensordict.get(("next", self.tensor_keys.reward)) + + if self.reward_two_hot: + # pred_reward should be logits over reward_bins + targets = two_hot_encode(symlog(true_reward.squeeze(-1)), self.reward_bins) + reward_loss = -(targets * torch.log_softmax(pred_reward, dim=-1)).sum(-1) + else: + reward_loss = (symlog(true_reward) - symlog(pred_reward)).pow(2).squeeze(-1) + reward_loss = reward_loss.mean().unsqueeze(-1) + + td_out = TensorDict( + loss_model_kl=self.lambda_kl * kl_loss, + loss_model_reco=self.lambda_reco * reco_loss, + loss_model_reward=self.lambda_reward * reward_loss, + ) + + # ---- Optional continue loss ---- + if self.lambda_continue > 0: + continue_pred = tensordict.get( + ("next", self.tensor_keys.continue_pred), None + ) + done = tensordict.get(("next", self.tensor_keys.done), None) + if continue_pred is not None and done is not None: + # continue = 1 - done; BCE with logits + continue_target = (~done).float() + continue_loss = torch.nn.functional.binary_cross_entropy_with_logits( + continue_pred.squeeze(-1), continue_target.squeeze(-1) + ).unsqueeze(-1) + td_out.set("loss_model_continue", self.lambda_continue * continue_loss) + + self._clear_weakrefs(tensordict, td_out) + return td_out, tensordict.data + + +# --------------------------------------------------------------------------- +# DreamerV3ActorLoss +# --------------------------------------------------------------------------- + + +class DreamerV3ActorLoss(LossModule): + """DreamerV3 Actor Loss. + + Rolls out imagined trajectories in latent space using the world model + environment, then computes: + + .. code-block:: text + + loss_actor = -E[log π(a_t | z_t) * sg(A_t)] - η * H[π(· | z_t)] + + where ``A_t = V_λ(z_t) - v(z_t)`` is the advantage (normalized lambda + return minus baseline) and ``η`` is the entropy bonus weight. + + When the actor is a reparameterizable (continuous) policy the + reparameterization gradient is used directly instead of REINFORCE. + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + actor_model (TensorDictModule): The actor / policy network. + value_model (TensorDictModule): The value network. + model_based_env (DreamerEnv): The imagination environment. + imagination_horizon (int, optional): Rollout length inside imagination. + Default: 15. + discount_loss (bool, optional): If ``True``, discount the actor loss + with a cumulative gamma factor. Default: ``True``. + entropy_bonus (float, optional): Weight for the entropy regularisation + term ``η``. Default: ``3e-4``. + use_reinforce (bool, optional): If ``True``, uses REINFORCE (log-prob + * stop-gradient advantage). If ``False``, uses the straight + reparameterization gradient (suitable for continuous Gaussian + actors). Default: ``False``. + """ + + @dataclass + class _AcceptedKeys: + """Configurable tensordict keys. + + Attributes: + state (NestedKey): Stochastic latent state. Defaults to ``"state"``. + belief (NestedKey): Deterministic GRU hidden state. Defaults to ``"belief"``. + reward (NestedKey): Imagined reward. Defaults to ``"reward"``. + value (NestedKey): State value. Defaults to ``"state_value"``. + action_log_prob (NestedKey): Log-prob of the taken action. + Defaults to ``"action_log_prob"``. + done (NestedKey): Done flag. Defaults to ``"done"``. + terminated (NestedKey): Terminated flag. Defaults to ``"terminated"``. + """ + + state: NestedKey = "state" + belief: NestedKey = "belief" + reward: NestedKey = "reward" + value: NestedKey = "state_value" + action_log_prob: NestedKey = "action_log_prob" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + tensor_keys: _AcceptedKeys + default_keys = _AcceptedKeys + default_value_estimator = ValueEstimators.TDLambda + + value_model: TensorDictModule + actor_model: TensorDictModule + + def __init__( + self, + actor_model: TensorDictModule, + value_model: TensorDictModule, + model_based_env: DreamerEnv, + *, + imagination_horizon: int = 15, + discount_loss: bool = True, + entropy_bonus: float = 3e-4, + use_reinforce: bool = False, + gamma: int | None = None, + lmbda: int | None = None, + ): + super().__init__() + self.actor_model = actor_model + self.__dict__["value_model"] = value_model + self.model_based_env = model_based_env + self.imagination_horizon = imagination_horizon + self.discount_loss = discount_loss + self.entropy_bonus = entropy_bonus + self.use_reinforce = use_reinforce + if gamma is not None: + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + if lmbda is not None: + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys(value=self._tensor_keys.value) + + @_maybe_record_function_decorator("dreamer_v3/actor_loss") + def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: + tensordict = tensordict.select( + self.tensor_keys.state, self.tensor_keys.belief + ).data + + with hold_out_net(self.model_based_env), set_exploration_type( + ExplorationType.RANDOM + ): + tensordict = self.model_based_env.reset(tensordict.copy()) + fake_data = self.model_based_env.rollout( + max_steps=self.imagination_horizon, + policy=self.actor_model, + auto_reset=False, + tensordict=tensordict, + ) + next_tensordict = step_mdp(fake_data, keep_other=True) + with hold_out_net(self.value_model): + next_tensordict = self.value_model(next_tensordict) + + reward = fake_data.get(("next", self.tensor_keys.reward)) + next_value = next_tensordict.get(self.tensor_keys.value) + lambda_target = self.lambda_target(reward, next_value) + fake_data.set("lambda_target", lambda_target) + + if self.discount_loss: + gamma = self.value_estimator.gamma.to(tensordict.device) + discount = gamma.expand(lambda_target.shape).clone() + discount[..., 0, :] = 1 + discount = discount.cumprod(dim=-2) + else: + discount = torch.ones_like(lambda_target) + + if self.use_reinforce: + # REINFORCE: log π(a|z) * sg(A_t) + log_prob = fake_data.get(self.tensor_keys.action_log_prob) + with hold_out_net(self.value_model): + baseline_td = fake_data.select(*self.value_model.in_keys, strict=False) + self.value_model(baseline_td) + baseline = baseline_td.get(self.tensor_keys.value) + advantage = (lambda_target - baseline).detach() + actor_loss = -(discount * log_prob * advantage).sum((-2, -1)).mean() + else: + # Reparameterization gradient + actor_loss = -(discount * lambda_target).sum((-2, -1)).mean() + + # Entropy bonus (if actor provides log_prob) + log_prob_for_entropy = fake_data.get(self.tensor_keys.action_log_prob, None) + if log_prob_for_entropy is not None and self.entropy_bonus > 0: + entropy = -(discount * log_prob_for_entropy).sum((-2, -1)).mean() + actor_loss = actor_loss - self.entropy_bonus * entropy + + loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) + self._clear_weakrefs(tensordict, loss_tensordict) + return loss_tensordict, fake_data.data + + def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) + terminated = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) + input_tensordict = TensorDict( + { + ("next", self.tensor_keys.reward): reward, + ("next", self.tensor_keys.value): value, + ("next", self.tensor_keys.done): done, + ("next", self.tensor_keys.terminated): terminated, + }, + [], + ) + return self.value_estimator.value_estimate(input_tensordict) + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + + if isinstance(value_type, ValueEstimatorBase) or ( + isinstance(value_type, type) and issubclass(value_type, ValueEstimatorBase) + ): + return LossModule.make_value_estimator(self, value_type, **hyperparams) + + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + value_net = None + if value_type is ValueEstimators.TD1: + self._value_estimator = TD1Estimator(**hp, value_network=value_net) + elif value_type is ValueEstimators.TD0: + self._value_estimator = TD0Estimator(**hp, value_network=value_net) + elif value_type is ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} is not implemented for {type(self)}." + ) + elif value_type is ValueEstimators.TDLambda: + if hasattr(self, "lmbda"): + hp["lmbda"] = self.lmbda + self._value_estimator = TDLambdaEstimator( + **hp, value_network=value_net, vectorized=True + ) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + self._value_estimator.set_keys( + value=self.tensor_keys.value, + value_target="value_target", + ) + + +# --------------------------------------------------------------------------- +# DreamerV3ValueLoss +# --------------------------------------------------------------------------- + + +class DreamerV3ValueLoss(LossModule): + """DreamerV3 Value Loss. + + Trains the value network to predict the lambda-target computed by + :class:`DreamerV3ActorLoss`. Supports two loss modes: + + - ``"symlog_mse"`` (default): ``(symlog(v_pred) - symlog(target))^2`` + - ``"two_hot"``: Two-hot cross-entropy over a fixed bin grid (matches the + full DreamerV3 distribution-valued critic). + + Reference: https://arxiv.org/abs/2301.04104 + + Args: + value_model (TensorDictModule): The value network. + value_loss (str, optional): Loss type — ``"symlog_mse"`` or ``"two_hot"``. + Default: ``"symlog_mse"``. + discount_loss (bool, optional): If ``True``, discounts the loss with + a cumulative gamma factor. Default: ``True``. + gamma (float, optional): Discount factor used when ``discount_loss=True``. + Default: ``0.99``. + num_value_bins (int, optional): Number of bins for ``"two_hot"`` loss. + Default: 255. + """ + + @dataclass + class _AcceptedKeys: + """Configurable tensordict keys. + + Attributes: + value (NestedKey): Predicted value key. Defaults to ``"state_value"``. + """ + + value: NestedKey = "state_value" + + tensor_keys: _AcceptedKeys + default_keys = _AcceptedKeys + + value_model: TensorDictModule + + def __init__( + self, + value_model: TensorDictModule, + value_loss: str = "symlog_mse", + discount_loss: bool = True, + gamma: float = 0.99, + num_value_bins: int = _DEFAULT_NUM_BINS, + ): + super().__init__() + self.value_model = value_model + self.value_loss = value_loss + self.gamma = gamma + self.discount_loss = discount_loss + if value_loss not in ("symlog_mse", "two_hot"): + raise ValueError( + f"value_loss must be 'symlog_mse' or 'two_hot', got '{value_loss}'" + ) + self.register_buffer("value_bins", _default_bins(num_value_bins)) + + def _forward_value_estimator_keys(self, **kwargs) -> None: + pass + + @_maybe_record_function_decorator("dreamer_v3/value_loss") + def forward(self, fake_data) -> tuple[TensorDict, TensorDict]: + lambda_target = fake_data.get("lambda_target") + + tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) + self.value_model(tensordict_select) + value_pred = tensordict_select.get(self.tensor_keys.value) + + # lambda_target shape: [N, 1] (flat) or [B, T, 1] (batch × time) + # Squeeze the trailing 1 for loss computation + target_sq = lambda_target.squeeze(-1) # [N] or [B, T] + + if self.discount_loss and target_sq.ndim >= 2: + discount = self.gamma * torch.ones_like(target_sq) + discount[..., 0] = 1 + discount = discount.cumprod(dim=-1) + else: + discount = torch.ones_like(target_sq) + + if self.value_loss == "two_hot": + # value_pred: logits over value_bins [..., num_bins] + targets = two_hot_encode(symlog(target_sq), self.value_bins) + loss = -(targets * torch.log_softmax(value_pred, dim=-1)).sum(-1) + else: + # symlog MSE + loss = (symlog(value_pred.squeeze(-1)) - symlog(target_sq)).pow(2) + + value_loss = (discount * loss).mean() + + loss_tensordict = TensorDict({"loss_value": value_loss}) + self._clear_weakrefs(fake_data, loss_tensordict) + return loss_tensordict, fake_data