diff --git a/examples/retrieval/bi_encoder/llama3_2_1b.yaml b/examples/retrieval/bi_encoder/llama3_2_1b.yaml index c51841ce70..7da8e0f57c 100644 --- a/examples/retrieval/bi_encoder/llama3_2_1b.yaml +++ b/examples/retrieval/bi_encoder/llama3_2_1b.yaml @@ -41,6 +41,7 @@ model: use_liger_kernel: true use_sdpa_patching: true torch_dtype: bfloat16 + do_distributed_inbatch_negative: false tokenizer: _target_: nemo_automodel.NeMoAutoTokenizer.from_pretrained diff --git a/nemo_automodel/_transformers/auto_model.py b/nemo_automodel/_transformers/auto_model.py index d23c8acdd3..7c561ef776 100644 --- a/nemo_automodel/_transformers/auto_model.py +++ b/nemo_automodel/_transformers/auto_model.py @@ -1074,6 +1074,8 @@ def from_pretrained( pretrained_model_name_or_path: str, pooling: str = "avg", l2_normalize: bool = True, + do_distributed_inbatch_negative: bool = False, + detach_distributed_inbatch_negatives: bool = True, **kwargs, ) -> PreTrainedModel: """Load a bi-encoder model with infrastructure. @@ -1085,6 +1087,10 @@ def from_pretrained( pretrained_model_name_or_path: Path to pretrained model or model identifier. pooling: Pooling strategy (``'avg'``, ``'cls'``, ``'last'``, etc.). l2_normalize: Whether to L2-normalize embeddings. + do_distributed_inbatch_negative: Whether to gather passages across ranks for distributed in-batch + negatives during training. + detach_distributed_inbatch_negatives: Whether to detach remote passage embeddings in distributed + in-batch-negative losses. Set to false for full cross-rank gradient flow. **kwargs: Forwarded to ``_NeMoAutoModelForRetrievalBase.from_pretrained``. Returns: @@ -1094,6 +1100,8 @@ def from_pretrained( pretrained_model_name_or_path, pooling=pooling, l2_normalize=l2_normalize, + do_distributed_inbatch_negative=do_distributed_inbatch_negative, + detach_distributed_inbatch_negatives=detach_distributed_inbatch_negatives, **kwargs, ) diff --git a/nemo_automodel/_transformers/retrieval.py b/nemo_automodel/_transformers/retrieval.py index 7287fed41d..22b622d104 100644 --- a/nemo_automodel/_transformers/retrieval.py +++ b/nemo_automodel/_transformers/retrieval.py @@ -167,7 +167,7 @@ def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_ty sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden.shape[0] emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths] - elif pool_type == "colbert": + elif pool_type in {"colbert", "multi_vector"}: emb = last_hidden else: raise ValueError(f"pool_type {pool_type} not supported") @@ -359,12 +359,14 @@ def __init__( pooling: str = "avg", l2_normalize: bool = True, do_distributed_inbatch_negative: bool = False, + detach_distributed_inbatch_negatives: bool = True, ): super().__init__() _init_encoder_common(self, model) self.pooling = pooling self.l2_normalize = l2_normalize self.do_distributed_inbatch_negative = do_distributed_inbatch_negative + self.detach_distributed_inbatch_negatives = detach_distributed_inbatch_negatives @classmethod def build( @@ -374,6 +376,7 @@ def build( pooling: str = "avg", l2_normalize: bool = True, do_distributed_inbatch_negative: bool = False, + detach_distributed_inbatch_negatives: bool = True, trust_remote_code: bool = False, **hf_kwargs, ): @@ -393,6 +396,7 @@ def build( pooling=pooling, l2_normalize=l2_normalize, do_distributed_inbatch_negative=do_distributed_inbatch_negative, + detach_distributed_inbatch_negatives=detach_distributed_inbatch_negatives, ) def save_pretrained(self, save_directory: str, **kwargs): diff --git a/nemo_automodel/components/models/common/inbatch_neg_utils.py b/nemo_automodel/components/models/common/inbatch_neg_utils.py index 92129c85e5..f2912a557c 100644 --- a/nemo_automodel/components/models/common/inbatch_neg_utils.py +++ b/nemo_automodel/components/models/common/inbatch_neg_utils.py @@ -24,25 +24,61 @@ import torch import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn_func -def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]: +def _all_gather_tensor(t: torch.Tensor, preserve_grad: bool = False) -> torch.Tensor: + """All-gather ``t`` along dim 0, preserving autograd only when needed.""" + if preserve_grad and t.requires_grad: + return torch.cat(dist_nn_func.all_gather(t), dim=0) + + gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] + dist.all_gather(gathered, t) + if t.requires_grad: + gathered[dist.get_rank()] = t + return torch.cat(gathered, dim=0) + + +def dist_gather_tensor(t: Optional[torch.Tensor], preserve_grad: bool = False) -> Optional[torch.Tensor]: """All-gather ``t`` along dim 0 across the default process group. - The local-rank slice is replaced with the original ``t`` so that gradients - flow back only to the local portion of the gathered tensor (other ranks' - slices are detached). Returns ``t`` unchanged when distributed is not - available, not initialized, or world size is 1. + When ``preserve_grad`` is true, tensors that require gradients use an + autograd-aware gather so distributed in-batch-negative losses can send + passage gradients back to the owning rank. Otherwise, remote slices are + detached and only the local slice keeps gradient flow. Non-gradient tensors, + such as masks or IDs, always use a regular detached gather. + Returns ``t`` unchanged when distributed is not available, not initialized, + or world size is 1. """ if t is None: return None if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() <= 1: return t t = t.contiguous() - gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] - dist.all_gather(gathered, t) - gathered[dist.get_rank()] = t - return torch.cat(gathered, dim=0) + return _all_gather_tensor(t, preserve_grad=preserve_grad) + + +def dist_gather_tensor_with_dim1_padding( + t: Optional[torch.Tensor], + padding_value: int | float | bool = 0, + preserve_grad: bool = False, +) -> Optional[torch.Tensor]: + """All-gather ``t`` after padding dim 1 to the maximum length across ranks.""" + if t is None: + return None + if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() <= 1: + return t + local_shape = torch.tensor(t.shape, device=t.device, dtype=torch.long) + shapes = [torch.empty_like(local_shape) for _ in range(dist.get_world_size())] + dist.all_gather(shapes, local_shape) + max_dim1 = max(int(shape[1].item()) for shape in shapes) + if t.shape[1] < max_dim1: + pad_shape = list(t.shape) + pad_shape[1] = max_dim1 - t.shape[1] + padding = t.new_full(pad_shape, padding_value) + t = torch.cat([t, padding], dim=1) + t = t.contiguous() + return _all_gather_tensor(t, preserve_grad=preserve_grad) def mask_gathered_passages_same_doc_as_positive( diff --git a/nemo_automodel/recipes/retrieval/train_bi_encoder.py b/nemo_automodel/recipes/retrieval/train_bi_encoder.py index 5d5181e6a5..c43918c0be 100644 --- a/nemo_automodel/recipes/retrieval/train_bi_encoder.py +++ b/nemo_automodel/recipes/retrieval/train_bi_encoder.py @@ -48,6 +48,11 @@ logger = logging.getLogger(__name__) +def _uses_multi_vector_scoring(model) -> bool: + """Return whether the model emits token-level embeddings for MaxSim scoring.""" + return getattr(model, "pooling", None) in {"colbert", "multi_vector"} + + def contrastive_scores_and_labels( query: torch.Tensor, key: torch.Tensor, current_train_n_passages: int ) -> tuple[torch.Tensor, torch.Tensor]: @@ -72,6 +77,61 @@ def contrastive_scores_and_labels( return qk, labels +def maxsim_scores_and_labels( + query: torch.Tensor, + key: torch.Tensor, + current_train_n_passages: int, + key_attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute local multi-vector MaxSim scores and labels without in-batch negatives.""" + assert key.shape[0] == query.shape[0] * current_train_n_passages, "{} != {} * {}".format( + key.shape[0], query.shape[0], current_train_n_passages + ) + assert key_attention_mask.shape == key.shape[:2], "{} != {}".format(key_attention_mask.shape, key.shape[:2]) + + key = key.reshape(query.shape[0], current_train_n_passages, key.shape[1], key.shape[2]) + key_attention_mask = key_attention_mask.reshape(query.shape[0], current_train_n_passages, key.shape[2]) + + token_scores = torch.einsum("bqd,bnpd->bnqp", query, key) + token_scores.masked_fill_(~key_attention_mask[:, :, None, :].bool(), torch.finfo(token_scores.dtype).min) + maxsim = token_scores.max(dim=3).values + scores = maxsim.sum(dim=2) + labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device) + return scores, labels + + +def distributed_maxsim_scores_and_labels( + query: torch.Tensor, + key: torch.Tensor, + current_train_n_passages: int, + key_attention_mask: torch.Tensor, + rank: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute local-query multi-vector MaxSim scores against globally gathered passages.""" + assert key.shape[0] % current_train_n_passages == 0, "{} % {} > 0".format( + key.shape[0], current_train_n_passages + ) + assert key_attention_mask.shape == key.shape[:2], "{} != {}".format(key_attention_mask.shape, key.shape[:2]) + + global_batch_size = key.shape[0] // current_train_n_passages + key = key.reshape(global_batch_size, current_train_n_passages, key.shape[1], key.shape[2]) + key_attention_mask = key_attention_mask.reshape(global_batch_size, current_train_n_passages, key.shape[2]) + + scores_by_passage = [] + for passage_idx in range(current_train_n_passages): + token_scores = torch.einsum("bqd,gpd->bgqp", query, key[:, passage_idx]) + token_scores.masked_fill_( + ~key_attention_mask[None, :, passage_idx, None, :].bool(), + torch.finfo(token_scores.dtype).min, + ) + scores_by_passage.append(token_scores.max(dim=3).values.sum(dim=2)) + + scores = torch.stack(scores_by_passage, dim=2).reshape(query.shape[0], key.shape[0] * current_train_n_passages) + labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device) + rank * query.shape[0] + labels = labels * current_train_n_passages + return scores, labels + + def _unpack_qp(inputs: dict[str, torch.Tensor]) -> tuple: """Unpack query and passage inputs from batch dictionary. @@ -347,11 +407,11 @@ def _forward_backward_step(self, idx, batch, *, loss_buffer, num_batches, is_tra p_reps = model(passage) n_passages = self.train_n_passages + use_multi_vector_scoring = _uses_multi_vector_scoring(model) if is_train and getattr(model, "do_distributed_inbatch_negative", False): - if getattr(model, "pooling", None) == "colbert": - raise NotImplementedError("Distributed in-batch negatives are not implemented for ColBERT pooling.") from nemo_automodel.components.models.common.inbatch_neg_utils import ( dist_gather_tensor, + dist_gather_tensor_with_dim1_padding, mask_gathered_passages_same_doc_as_positive, ) @@ -359,11 +419,30 @@ def _forward_backward_step(self, idx, batch, *, loss_buffer, num_batches, is_tra dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() rank = torch.distributed.get_rank() if dist_initialized else 0 world_size = torch.distributed.get_world_size() if dist_initialized else 1 - all_p = dist_gather_tensor(p_reps) - expected_p = world_size * local_bs * n_passages - assert all_p.shape[0] == expected_p, f"Gathered passage count {all_p.shape[0]} != expected {expected_p}" - scores = torch.mm(q_reps, all_p.t()) - labels = (torch.arange(local_bs, device=q_reps.device) + rank * local_bs) * n_passages + preserve_gather_grad = not getattr(model, "detach_distributed_inbatch_negatives", True) + + if use_multi_vector_scoring: + all_p = dist_gather_tensor_with_dim1_padding(p_reps, preserve_grad=preserve_gather_grad) + all_p_mask = dist_gather_tensor_with_dim1_padding(passage["attention_mask"], padding_value=False) + expected_p = world_size * local_bs * n_passages + assert ( + all_p.shape[0] == expected_p + ), f"Gathered passage count {all_p.shape[0]} != expected {expected_p}" + scores, labels = distributed_maxsim_scores_and_labels( + q_reps, + all_p, + n_passages, + all_p_mask, + rank, + ) + else: + all_p = dist_gather_tensor(p_reps, preserve_grad=preserve_gather_grad) + expected_p = world_size * local_bs * n_passages + assert ( + all_p.shape[0] == expected_p + ), f"Gathered passage count {all_p.shape[0]} != expected {expected_p}" + scores = torch.mm(q_reps, all_p.t()) + labels = (torch.arange(local_bs, device=q_reps.device) + rank * local_bs) * n_passages if model.l2_normalize: scores = scores / self.temperature passage_doc_ids = batch.get("passage_doc_ids") @@ -377,7 +456,15 @@ def _forward_backward_step(self, idx, batch, *, loss_buffer, num_batches, is_tra local_batch_size=local_bs, ) else: - scores, labels = contrastive_scores_and_labels(q_reps, p_reps, n_passages) + if use_multi_vector_scoring: + scores, labels = maxsim_scores_and_labels( + q_reps, + p_reps, + n_passages, + passage["attention_mask"], + ) + else: + scores, labels = contrastive_scores_and_labels(q_reps, p_reps, n_passages) if model.l2_normalize: scores = scores / self.temperature loss = F.cross_entropy(scores, labels) @@ -465,7 +552,15 @@ def _run_validation_epoch(self, val_dataloader): q_reps = model(query) p_reps = model(passage) - scores, labels = contrastive_scores_and_labels(q_reps, p_reps, self.val_n_passages) + if _uses_multi_vector_scoring(model): + scores, labels = maxsim_scores_and_labels( + q_reps, + p_reps, + self.val_n_passages, + passage["attention_mask"], + ) + else: + scores, labels = contrastive_scores_and_labels(q_reps, p_reps, self.val_n_passages) if model.l2_normalize: scores = scores / self.temperature loss = F.cross_entropy(scores, labels) diff --git a/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py b/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py index 29f6c0bf6e..b4c3aaff58 100644 --- a/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py +++ b/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py @@ -12,9 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import SimpleNamespace + +import pytest +import torch import nemo_automodel._transformers.auto_model as am +import nemo_automodel.recipes.retrieval.train_bi_encoder as tbe from nemo_automodel._transformers.retrieval import BiEncoderModel, CrossEncoderModel +from nemo_automodel.recipes.retrieval.train_bi_encoder import ( + TrainBiEncoderRecipe, + distributed_maxsim_scores_and_labels, + maxsim_scores_and_labels, +) class DummyModel: @@ -27,6 +37,19 @@ class DummyMesh: pass +class _ToyMultiVectorBiEncoder(torch.nn.Module): + do_distributed_inbatch_negative = False + l2_normalize = False + pooling = "multi_vector" + + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.tensor(1.0)) + + def forward(self, batch): + return batch["input_ids"].float() * self.scale + + def _apply_common_mocks(monkeypatch): """Mock CUDA-dependent infrastructure so tests run without a GPU.""" monkeypatch.setattr(am, "instantiate_infrastructure", lambda **kwargs: (None, None, None, None)) @@ -67,6 +90,8 @@ def fake_apply_infrastructure(model, **kwargs): pretrained_model_name_or_path="some/path", pooling="avg", l2_normalize=True, + do_distributed_inbatch_negative=True, + detach_distributed_inbatch_negatives=False, use_liger_kernel=True, use_sdpa_patching=True, sdpa_method=None, @@ -77,6 +102,8 @@ def fake_apply_infrastructure(model, **kwargs): assert "liger" in model.marker and "sdpa" in model.marker # Ensure HF kwargs injected + passthrough of parameters to build assert last_kwargs["attn_implementation"] == "flash_attention_2" + assert last_kwargs["do_distributed_inbatch_negative"] is True + assert last_kwargs["detach_distributed_inbatch_negatives"] is False assert last_kwargs["some_other_kwarg"] == "x" @@ -180,3 +207,239 @@ def test_cross_encoder_retries_without_liger(monkeypatch): def test_cross_encoder_retries_without_sdpa(monkeypatch): _assert_retries_without_sdpa(monkeypatch, CrossEncoderModel, am.NeMoAutoModelCrossEncoder) + + +def test_maxsim_scores_and_labels_masks_padding_before_maxsim(): + query = torch.tensor( + [ + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ) + key = torch.tensor( + [ + [[-0.4, -0.4, 0.0, 0.0], [-0.6, -0.6, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.8, 0.0, 0.0, 0.0], [0.0, 0.7, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.2, 0.0, 0.0, 0.0], [0.0, 0.1, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, -0.2, 0.0, 0.0], [0.0, -0.5, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, 0.6, 0.0, 0.0], [0.1, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + [[0.0, -0.9, 0.0, 0.0], [0.0, -0.4, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], + ] + ) + key_attention_mask = torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0], [1, 1, 0]]) + + scores, labels = maxsim_scores_and_labels( + query, + key, + current_train_n_passages=3, + key_attention_mask=key_attention_mask, + ) + + assert torch.allclose(scores, torch.tensor([[-0.8, 1.5, 0.3], [-0.2, 0.6, -0.4]])) + assert torch.equal(labels, torch.tensor([0, 0])) + + +def test_distributed_maxsim_scores_and_labels_matches_all_at_once_scoring(): + torch.manual_seed(0) + query = torch.randn(2, 3, 4, requires_grad=True) + key = torch.randn(8, 5, 4, requires_grad=True) + key_attention_mask = torch.tensor( + [ + [1, 1, 1, 0, 0], + [1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 1, 0], + [1, 0, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + dtype=torch.long, + ) + query_ref = query.detach().clone().requires_grad_() + key_ref = key.detach().clone().requires_grad_() + + scores, labels = distributed_maxsim_scores_and_labels( + query, + key, + current_train_n_passages=2, + key_attention_mask=key_attention_mask, + rank=1, + ) + + ref_token_scores = torch.einsum("bqd,kpd->bkqp", query_ref, key_ref) + ref_token_scores.masked_fill_( + ~key_attention_mask[None, :, None, :].bool(), + torch.finfo(ref_token_scores.dtype).min, + ) + ref_scores = ref_token_scores.max(dim=3).values.sum(dim=2) + ref_labels = torch.tensor([4, 6]) + + assert scores.shape == (2, 8) + assert torch.allclose(scores, ref_scores) + assert torch.equal(labels, ref_labels) + + scores.sum().backward() + ref_scores.sum().backward() + assert torch.allclose(query.grad, query_ref.grad) + assert torch.allclose(key.grad, key_ref.grad) + + +def test_forward_backward_step_supports_local_multi_vector_pooling(): + recipe = TrainBiEncoderRecipe.__new__(TrainBiEncoderRecipe) + recipe.dist_env = SimpleNamespace(device="cpu") + recipe.distributed_config = SimpleNamespace(defer_fsdp_grad_sync=True) + recipe.model_parts = [_ToyMultiVectorBiEncoder()] + recipe.temperature = 1.0 + recipe.train_n_passages = 2 + + batch = { + "q_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 0.0]], + ] + ), + "q_attention_mask": torch.tensor([[1, 1], [1, 0]]), + "d_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 0.0]], + [[1.0, 0.0], [0.0, 0.0]], + [[0.0, 1.0], [1.0, 1.0]], + ] + ), + "d_attention_mask": torch.tensor([[1, 1], [1, 0], [1, 0], [1, 1]]), + } + loss_buffer = [] + + recipe._forward_backward_step(0, batch, loss_buffer=loss_buffer, num_batches=1, is_train=True) + + assert len(loss_buffer) == 1 + assert torch.isfinite(loss_buffer[0]) + assert recipe.model_parts[0].scale.grad is not None + + +def test_validation_epoch_supports_multi_vector_pooling(): + recipe = TrainBiEncoderRecipe.__new__(TrainBiEncoderRecipe) + recipe.dist_env = SimpleNamespace(device="cpu") + recipe.model_parts = [_ToyMultiVectorBiEncoder()] + recipe.temperature = 1.0 + recipe.val_n_passages = 2 + recipe.step_scheduler = SimpleNamespace(step=3, epoch=1) + + val_dataloader = [ + { + "q_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[1.0, 1.0], [0.0, 0.0]], + ] + ), + "q_attention_mask": torch.tensor([[1, 1], [1, 0]]), + "d_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 0.0]], + [[1.0, 0.0], [0.0, 0.0]], + [[0.0, 1.0], [1.0, 1.0]], + ] + ), + "d_attention_mask": torch.tensor([[1, 1], [1, 0], [1, 0], [1, 1]]), + } + ] + + metrics = recipe._run_validation_epoch(val_dataloader) + + assert metrics.step == 3 + assert metrics.epoch == 1 + assert torch.isfinite(torch.tensor(metrics.metrics["val_loss"])) + assert 0.0 <= metrics.metrics["val_acc1"] <= 1.0 + assert 0.0 <= metrics.metrics["val_mrr"] <= 1.0 + assert recipe.model_parts[0].scale.grad is None + + +@pytest.mark.parametrize("detach_distributed_inbatch_negatives", [True, False]) +def test_forward_backward_step_supports_distributed_multi_vector_inbatch_negatives( + monkeypatch, + detach_distributed_inbatch_negatives, +): + """Exercise the trainer branch that gathers token embeddings across ranks.""" + import nemo_automodel.components.models.common.inbatch_neg_utils as inbatch_neg_utils + + recipe = TrainBiEncoderRecipe.__new__(TrainBiEncoderRecipe) + recipe.dist_env = SimpleNamespace(device="cpu") + recipe.distributed_config = SimpleNamespace(defer_fsdp_grad_sync=True) + model = _ToyMultiVectorBiEncoder() + model.do_distributed_inbatch_negative = True + model.detach_distributed_inbatch_negatives = detach_distributed_inbatch_negatives + recipe.model_parts = [model] + recipe.temperature = 1.0 + recipe.train_n_passages = 2 + + monkeypatch.setattr(torch.distributed, "is_available", lambda: True) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + monkeypatch.setattr(torch.distributed, "get_rank", lambda: 1) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 2) + + gather_with_padding_calls = [] + + def fake_gather_with_dim1_padding(tensor, padding_value=0, preserve_grad=False): + gather_with_padding_calls.append((tuple(tensor.shape), padding_value, preserve_grad)) + return torch.cat([tensor.detach().clone(), tensor], dim=0) + + gather_tensor_calls = [] + + def fake_gather_tensor(tensor, preserve_grad=False): + gather_tensor_calls.append((tuple(tensor.shape), preserve_grad)) + remote_doc_ids = torch.tensor([500, 999, 600, 998], dtype=tensor.dtype, device=tensor.device) + return torch.cat([remote_doc_ids, tensor], dim=0) + + captured = {} + + def fake_cross_entropy(scores, labels): + captured["scores"] = scores.detach().clone() + captured["labels"] = labels.detach().clone() + return -scores.gather(1, labels.unsqueeze(1)).mean() + + monkeypatch.setattr(inbatch_neg_utils, "dist_gather_tensor_with_dim1_padding", fake_gather_with_dim1_padding) + monkeypatch.setattr(inbatch_neg_utils, "dist_gather_tensor", fake_gather_tensor) + monkeypatch.setattr(tbe.F, "cross_entropy", fake_cross_entropy) + + batch = { + "q_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.0, 1.0], [1.0, 0.0]], + ] + ), + "q_attention_mask": torch.tensor([[1, 1], [1, 1]]), + "d_input_ids": torch.tensor( + [ + [[1.0, 0.0], [0.0, 1.0]], + [[0.0, 1.0], [0.0, 0.0]], + [[0.0, 1.0], [1.0, 0.0]], + [[1.0, 0.0], [0.0, 0.0]], + ] + ), + "d_attention_mask": torch.tensor([[1, 1], [1, 0], [1, 1], [1, 0]]), + "passage_doc_ids": torch.tensor([500, 501, 600, 601], dtype=torch.long), + } + loss_buffer = [] + + recipe._forward_backward_step(0, batch, loss_buffer=loss_buffer, num_batches=1, is_train=True) + + assert gather_with_padding_calls == [ + ((4, 2, 2), 0, not detach_distributed_inbatch_negatives), + ((4, 2), False, False), + ] + assert gather_tensor_calls == [((4,), False)] + assert torch.equal(captured["labels"], torch.tensor([4, 6])) + assert captured["scores"].shape == (2, 8) + assert captured["scores"][0, 0].item() == torch.finfo(captured["scores"].dtype).min + assert captured["scores"][1, 2].item() == torch.finfo(captured["scores"].dtype).min + assert captured["scores"][0, 4].item() > torch.finfo(captured["scores"].dtype).min + assert captured["scores"][1, 6].item() > torch.finfo(captured["scores"].dtype).min + assert len(loss_buffer) == 1 + assert torch.isfinite(loss_buffer[0]) + assert model.scale.grad is not None diff --git a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py index 365856408e..e1f4e5a5f4 100644 --- a/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py +++ b/tests/unit_tests/models/bi_encoder/test_llama_bidirectional_model.py @@ -43,7 +43,7 @@ def test_contrastive_scores_and_labels_shapes_and_labels(): assert torch.all(labels == 0) and labels.shape == (2,) -@pytest.mark.parametrize("pool_type", ["avg", "weighted_avg", "cls", "colbert"]) +@pytest.mark.parametrize("pool_type", ["avg", "weighted_avg", "cls", "colbert", "multi_vector"]) def test_pool_basic_modes(pool_type): last_hidden = torch.tensor( [ @@ -61,7 +61,7 @@ def test_pool_basic_modes(pool_type): assert torch.allclose(out[0], torch.tensor([1.0 + 3.0, 2.0 + 4.0])) elif pool_type == "cls": assert torch.allclose(out[:, :], last_hidden[:, 0]) - elif pool_type == "colbert": + elif pool_type in {"colbert", "multi_vector"}: assert out.shape == last_hidden.shape diff --git a/tests/unit_tests/models/common/test_inbatch_neg_utils.py b/tests/unit_tests/models/common/test_inbatch_neg_utils.py index acf05681b5..4795767ef0 100644 --- a/tests/unit_tests/models/common/test_inbatch_neg_utils.py +++ b/tests/unit_tests/models/common/test_inbatch_neg_utils.py @@ -17,11 +17,14 @@ import pytest import torch +import nemo_automodel.components.models.common.inbatch_neg_utils as inbatch_neg_utils from nemo_automodel.components.models.common.inbatch_neg_utils import ( dist_gather_tensor, + dist_gather_tensor_with_dim1_padding, mask_gathered_passages_same_doc_as_positive, ) + def _is_masked(x: torch.Tensor) -> bool: """True when ``x`` is the dtype's ``-inf`` marker or its ``finfo.min``. @@ -47,6 +50,142 @@ def test_dist_gather_tensor_none_returns_none(): assert dist_gather_tensor(None) is None +def test_dist_gather_tensor_uses_autograd_gather_for_grad_tensors(monkeypatch): + monkeypatch.setattr(inbatch_neg_utils.dist, "is_available", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "is_initialized", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_world_size", lambda: 2) + + def fail_regular_all_gather(*args, **kwargs): + raise AssertionError("regular all_gather should not handle grad tensors") + + def fake_autograd_all_gather(tensor): + return (tensor.detach().clone(), tensor) + + monkeypatch.setattr(inbatch_neg_utils.dist, "all_gather", fail_regular_all_gather) + monkeypatch.setattr(inbatch_neg_utils.dist_nn_func, "all_gather", fake_autograd_all_gather) + + t = torch.tensor([[1.0], [2.0]], requires_grad=True) + gathered = dist_gather_tensor(t, preserve_grad=True) + + assert gathered.shape == (4, 1) + gathered.sum().backward() + assert torch.equal(t.grad, torch.ones_like(t)) + + +def test_dist_gather_tensor_detaches_remote_grad_tensors_by_default(monkeypatch): + monkeypatch.setattr(inbatch_neg_utils.dist, "is_available", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "is_initialized", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_world_size", lambda: 2) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_rank", lambda: 1) + + def fail_autograd_all_gather(*args, **kwargs): + raise AssertionError("autograd all_gather should not handle detached mode") + + def fake_regular_all_gather(gathered, tensor): + gathered[0].copy_(tensor.detach() + 10) + gathered[1].copy_(tensor.detach() + 20) + + monkeypatch.setattr(inbatch_neg_utils.dist, "all_gather", fake_regular_all_gather) + monkeypatch.setattr(inbatch_neg_utils.dist_nn_func, "all_gather", fail_autograd_all_gather) + + t = torch.tensor([[1.0], [2.0]], requires_grad=True) + gathered = dist_gather_tensor(t) + + expected = torch.tensor([[11.0], [12.0], [1.0], [2.0]]) + assert torch.equal(gathered, expected) + gathered.sum().backward() + assert torch.equal(t.grad, torch.ones_like(t)) + + +def test_dist_gather_tensor_uses_regular_gather_for_non_grad_tensors(monkeypatch): + monkeypatch.setattr(inbatch_neg_utils.dist, "is_available", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "is_initialized", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_world_size", lambda: 2) + + def fail_autograd_all_gather(*args, **kwargs): + raise AssertionError("autograd all_gather should not handle metadata tensors") + + def fake_regular_all_gather(gathered, tensor): + gathered[0].copy_(tensor + 10) + gathered[1].copy_(tensor + 20) + + monkeypatch.setattr(inbatch_neg_utils.dist, "all_gather", fake_regular_all_gather) + monkeypatch.setattr(inbatch_neg_utils.dist_nn_func, "all_gather", fail_autograd_all_gather) + + t = torch.tensor([[1], [2]], dtype=torch.long) + gathered = dist_gather_tensor(t) + + expected = torch.tensor([[11], [12], [21], [22]], dtype=torch.long) + assert torch.equal(gathered, expected) + + +def test_dist_gather_tensor_with_dim1_padding_single_rank_is_noop(): + t = torch.randn(4, 3, 8) + assert dist_gather_tensor_with_dim1_padding(t) is t + + +def test_dist_gather_tensor_with_dim1_padding_none_returns_none(): + assert dist_gather_tensor_with_dim1_padding(None) is None + + +def test_dist_gather_tensor_with_dim1_padding_preserves_grad_through_padding(monkeypatch): + monkeypatch.setattr(inbatch_neg_utils.dist, "is_available", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "is_initialized", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_world_size", lambda: 2) + + def fake_regular_all_gather(gathered, tensor): + if tensor.dtype != torch.long: + raise AssertionError("regular all_gather should only gather shapes here") + gathered[0].copy_(torch.tensor([2, 4, 3], device=tensor.device)) + gathered[1].copy_(torch.tensor([2, 2, 3], device=tensor.device)) + + def fake_autograd_all_gather(tensor): + assert tensor.shape == (2, 4, 3) + return (tensor.detach().clone(), tensor) + + monkeypatch.setattr(inbatch_neg_utils.dist, "all_gather", fake_regular_all_gather) + monkeypatch.setattr(inbatch_neg_utils.dist_nn_func, "all_gather", fake_autograd_all_gather) + + t = torch.randn(2, 2, 3, requires_grad=True) + gathered = dist_gather_tensor_with_dim1_padding(t, preserve_grad=True) + + assert gathered.shape == (4, 4, 3) + gathered.sum().backward() + assert torch.equal(t.grad, torch.ones_like(t)) + + +def test_dist_gather_tensor_with_dim1_padding_detaches_remote_grad_tensors_by_default(monkeypatch): + monkeypatch.setattr(inbatch_neg_utils.dist, "is_available", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "is_initialized", lambda: True) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_world_size", lambda: 2) + monkeypatch.setattr(inbatch_neg_utils.dist, "get_rank", lambda: 1) + + def fake_regular_all_gather(gathered, tensor): + if tensor.dtype == torch.long: + gathered[0].copy_(torch.tensor([2, 4, 3], device=tensor.device)) + gathered[1].copy_(torch.tensor([2, 2, 3], device=tensor.device)) + else: + assert tensor.shape == (2, 4, 3) + gathered[0].copy_(tensor.detach() + 10) + gathered[1].copy_(tensor.detach() + 20) + + def fail_autograd_all_gather(*args, **kwargs): + raise AssertionError("autograd all_gather should not handle detached mode") + + monkeypatch.setattr(inbatch_neg_utils.dist, "all_gather", fake_regular_all_gather) + monkeypatch.setattr(inbatch_neg_utils.dist_nn_func, "all_gather", fail_autograd_all_gather) + + t = torch.randn(2, 2, 3, requires_grad=True) + gathered = dist_gather_tensor_with_dim1_padding(t) + + assert gathered.shape == (4, 4, 3) + assert torch.allclose(gathered[:2], torch.cat([t.detach(), torch.zeros_like(t)], dim=1) + 10) + assert torch.allclose(gathered[2:, :2], t) + assert torch.equal(gathered[2:, 2:], torch.zeros_like(t)) + gathered.sum().backward() + assert torch.equal(t.grad, torch.ones_like(t)) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) def test_mask_same_doc_basic(dtype): """Duplicate of q0's positive doc id elsewhere in the batch must be masked