Skip to content

Fix WandB model config logging for JAX agents#312

Open
andregonz wants to merge 1 commit into
Toni-SM:mainfrom
andregonz:fix-jax-wandb-model-logging
Open

Fix WandB model config logging for JAX agents#312
andregonz wants to merge 1 commit into
Toni-SM:mainfrom
andregonz:fix-jax-wandb-model-logging

Conversation

@andregonz

Copy link
Copy Markdown

When using the JAX backend, enabling wandb: True causes a crash due to PyTorch-specific model inspection (.net._modules). This PR adds a fallback for JAX/Flax models that logs their type instead:

models_cfg = {k: {"type": type(v).__name__} for k, v in self.models.items()}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant