diff --git a/stable_audio_tools/training/arc.py b/stable_audio_tools/training/arc.py index e3e71fd3..e7ca5af1 100644 --- a/stable_audio_tools/training/arc.py +++ b/stable_audio_tools/training/arc.py @@ -1319,15 +1319,10 @@ def training_step(self, batch, batch_idx): logit_center_penalty = 0 if self.do_contrastive_disc: - - rolled_metadata = [] - - for i in range(reals.shape[0]): - rolled_keys = ["prompt"] - rolled_metadata.append(metadata[i]) - for rolled_key in rolled_keys: - rolled_metadata[i][rolled_key] = metadata[(i + 1) % reals.shape[0]][rolled_key] - + + n = reals.shape[0] + rolled_metadata = [{**md, "prompt": metadata[(i + 1) % n]["prompt"]} for i, md in enumerate(metadata)] + rolled_conditioning = self.discriminator.conditioner(rolled_metadata, self.device) if self.inpainting_config is not None: