Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/retrieval/bi_encoder/llama3_2_1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ model:
use_liger_kernel: true
use_sdpa_patching: true
torch_dtype: bfloat16
do_distributed_inbatch_negative: false

tokenizer:
_target_: nemo_automodel.NeMoAutoTokenizer.from_pretrained
Expand Down
8 changes: 8 additions & 0 deletions nemo_automodel/_transformers/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,8 @@ def from_pretrained(
pretrained_model_name_or_path: str,
pooling: str = "avg",
l2_normalize: bool = True,
do_distributed_inbatch_negative: bool = False,
detach_distributed_inbatch_negatives: bool = True,
**kwargs,
) -> PreTrainedModel:
"""Load a bi-encoder model with infrastructure.
Expand All @@ -1085,6 +1087,10 @@ def from_pretrained(
pretrained_model_name_or_path: Path to pretrained model or model identifier.
pooling: Pooling strategy (``'avg'``, ``'cls'``, ``'last'``, etc.).
l2_normalize: Whether to L2-normalize embeddings.
do_distributed_inbatch_negative: Whether to gather passages across ranks for distributed in-batch
negatives during training.
detach_distributed_inbatch_negatives: Whether to detach remote passage embeddings in distributed
in-batch-negative losses. Set to false for full cross-rank gradient flow.
**kwargs: Forwarded to ``_NeMoAutoModelForRetrievalBase.from_pretrained``.

Returns:
Expand All @@ -1094,6 +1100,8 @@ def from_pretrained(
pretrained_model_name_or_path,
pooling=pooling,
l2_normalize=l2_normalize,
do_distributed_inbatch_negative=do_distributed_inbatch_negative,
detach_distributed_inbatch_negatives=detach_distributed_inbatch_negatives,
**kwargs,
)

Expand Down
6 changes: 5 additions & 1 deletion nemo_automodel/_transformers/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_ty
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
elif pool_type == "colbert":
elif pool_type in {"colbert", "multi_vector"}:
emb = last_hidden
else:
raise ValueError(f"pool_type {pool_type} not supported")
Expand Down Expand Up @@ -359,12 +359,14 @@ def __init__(
pooling: str = "avg",
l2_normalize: bool = True,
do_distributed_inbatch_negative: bool = False,
detach_distributed_inbatch_negatives: bool = True,
):
super().__init__()
_init_encoder_common(self, model)
self.pooling = pooling
self.l2_normalize = l2_normalize
self.do_distributed_inbatch_negative = do_distributed_inbatch_negative
self.detach_distributed_inbatch_negatives = detach_distributed_inbatch_negatives

@classmethod
def build(
Expand All @@ -374,6 +376,7 @@ def build(
pooling: str = "avg",
l2_normalize: bool = True,
do_distributed_inbatch_negative: bool = False,
detach_distributed_inbatch_negatives: bool = True,
trust_remote_code: bool = False,
**hf_kwargs,
):
Expand All @@ -393,6 +396,7 @@ def build(
pooling=pooling,
l2_normalize=l2_normalize,
do_distributed_inbatch_negative=do_distributed_inbatch_negative,
detach_distributed_inbatch_negatives=detach_distributed_inbatch_negatives,
)

def save_pretrained(self, save_directory: str, **kwargs):
Expand Down
54 changes: 45 additions & 9 deletions nemo_automodel/components/models/common/inbatch_neg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,61 @@

import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_nn_func


def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
def _all_gather_tensor(t: torch.Tensor, preserve_grad: bool = False) -> torch.Tensor:
"""All-gather ``t`` along dim 0, preserving autograd only when needed."""
if preserve_grad and t.requires_grad:
return torch.cat(dist_nn_func.all_gather(t), dim=0)

gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, t)
if t.requires_grad:
gathered[dist.get_rank()] = t
return torch.cat(gathered, dim=0)


def dist_gather_tensor(t: Optional[torch.Tensor], preserve_grad: bool = False) -> Optional[torch.Tensor]:
"""All-gather ``t`` along dim 0 across the default process group.

The local-rank slice is replaced with the original ``t`` so that gradients
flow back only to the local portion of the gathered tensor (other ranks'
slices are detached). Returns ``t`` unchanged when distributed is not
available, not initialized, or world size is 1.
When ``preserve_grad`` is true, tensors that require gradients use an
autograd-aware gather so distributed in-batch-negative losses can send
passage gradients back to the owning rank. Otherwise, remote slices are
detached and only the local slice keeps gradient flow. Non-gradient tensors,
such as masks or IDs, always use a regular detached gather.
Returns ``t`` unchanged when distributed is not available, not initialized,
or world size is 1.
"""
if t is None:
return None
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() <= 1:
return t
t = t.contiguous()
gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(gathered, t)
gathered[dist.get_rank()] = t
return torch.cat(gathered, dim=0)
return _all_gather_tensor(t, preserve_grad=preserve_grad)


def dist_gather_tensor_with_dim1_padding(
Comment thread
rnyak marked this conversation as resolved.
t: Optional[torch.Tensor],
padding_value: int | float | bool = 0,
preserve_grad: bool = False,
) -> Optional[torch.Tensor]:
"""All-gather ``t`` after padding dim 1 to the maximum length across ranks."""
if t is None:
return None
if not (dist.is_available() and dist.is_initialized()) or dist.get_world_size() <= 1:
return t
local_shape = torch.tensor(t.shape, device=t.device, dtype=torch.long)
shapes = [torch.empty_like(local_shape) for _ in range(dist.get_world_size())]
dist.all_gather(shapes, local_shape)
max_dim1 = max(int(shape[1].item()) for shape in shapes)
if t.shape[1] < max_dim1:
pad_shape = list(t.shape)
pad_shape[1] = max_dim1 - t.shape[1]
padding = t.new_full(pad_shape, padding_value)
t = torch.cat([t, padding], dim=1)
t = t.contiguous()
return _all_gather_tensor(t, preserve_grad=preserve_grad)


def mask_gathered_passages_same_doc_as_positive(
Expand Down
113 changes: 104 additions & 9 deletions nemo_automodel/recipes/retrieval/train_bi_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@
logger = logging.getLogger(__name__)


def _uses_multi_vector_scoring(model) -> bool:
"""Return whether the model emits token-level embeddings for MaxSim scoring."""
return getattr(model, "pooling", None) in {"colbert", "multi_vector"}


def contrastive_scores_and_labels(
query: torch.Tensor, key: torch.Tensor, current_train_n_passages: int
) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -72,6 +77,61 @@ def contrastive_scores_and_labels(
return qk, labels


def maxsim_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
current_train_n_passages: int,
key_attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute local multi-vector MaxSim scores and labels without in-batch negatives."""
assert key.shape[0] == query.shape[0] * current_train_n_passages, "{} != {} * {}".format(
key.shape[0], query.shape[0], current_train_n_passages
)
assert key_attention_mask.shape == key.shape[:2], "{} != {}".format(key_attention_mask.shape, key.shape[:2])

key = key.reshape(query.shape[0], current_train_n_passages, key.shape[1], key.shape[2])
key_attention_mask = key_attention_mask.reshape(query.shape[0], current_train_n_passages, key.shape[2])

token_scores = torch.einsum("bqd,bnpd->bnqp", query, key)
token_scores.masked_fill_(~key_attention_mask[:, :, None, :].bool(), torch.finfo(token_scores.dtype).min)
maxsim = token_scores.max(dim=3).values
scores = maxsim.sum(dim=2)
labels = torch.zeros(query.shape[0], dtype=torch.long, device=query.device)
return scores, labels


def distributed_maxsim_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
current_train_n_passages: int,
key_attention_mask: torch.Tensor,
rank: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute local-query multi-vector MaxSim scores against globally gathered passages."""
assert key.shape[0] % current_train_n_passages == 0, "{} % {} > 0".format(
key.shape[0], current_train_n_passages
)
assert key_attention_mask.shape == key.shape[:2], "{} != {}".format(key_attention_mask.shape, key.shape[:2])

global_batch_size = key.shape[0] // current_train_n_passages
key = key.reshape(global_batch_size, current_train_n_passages, key.shape[1], key.shape[2])
key_attention_mask = key_attention_mask.reshape(global_batch_size, current_train_n_passages, key.shape[2])

scores_by_passage = []
for passage_idx in range(current_train_n_passages):
token_scores = torch.einsum("bqd,gpd->bgqp", query, key[:, passage_idx])
token_scores.masked_fill_(
~key_attention_mask[None, :, passage_idx, None, :].bool(),
torch.finfo(token_scores.dtype).min,
)
scores_by_passage.append(token_scores.max(dim=3).values.sum(dim=2))

scores = torch.stack(scores_by_passage, dim=2).reshape(query.shape[0], key.shape[0] * current_train_n_passages)
labels = torch.arange(query.shape[0], dtype=torch.long, device=query.device) + rank * query.shape[0]
labels = labels * current_train_n_passages
return scores, labels


def _unpack_qp(inputs: dict[str, torch.Tensor]) -> tuple:
"""Unpack query and passage inputs from batch dictionary.

Expand Down Expand Up @@ -347,23 +407,42 @@ def _forward_backward_step(self, idx, batch, *, loss_buffer, num_batches, is_tra
p_reps = model(passage)

n_passages = self.train_n_passages
use_multi_vector_scoring = _uses_multi_vector_scoring(model)
if is_train and getattr(model, "do_distributed_inbatch_negative", False):
if getattr(model, "pooling", None) == "colbert":
raise NotImplementedError("Distributed in-batch negatives are not implemented for ColBERT pooling.")
from nemo_automodel.components.models.common.inbatch_neg_utils import (
dist_gather_tensor,
dist_gather_tensor_with_dim1_padding,
mask_gathered_passages_same_doc_as_positive,
)

local_bs = q_reps.shape[0]
dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
rank = torch.distributed.get_rank() if dist_initialized else 0
world_size = torch.distributed.get_world_size() if dist_initialized else 1
all_p = dist_gather_tensor(p_reps)
expected_p = world_size * local_bs * n_passages
assert all_p.shape[0] == expected_p, f"Gathered passage count {all_p.shape[0]} != expected {expected_p}"
scores = torch.mm(q_reps, all_p.t())
labels = (torch.arange(local_bs, device=q_reps.device) + rank * local_bs) * n_passages
preserve_gather_grad = not getattr(model, "detach_distributed_inbatch_negatives", True)

if use_multi_vector_scoring:
all_p = dist_gather_tensor_with_dim1_padding(p_reps, preserve_grad=preserve_gather_grad)
all_p_mask = dist_gather_tensor_with_dim1_padding(passage["attention_mask"], padding_value=False)
expected_p = world_size * local_bs * n_passages
assert (
all_p.shape[0] == expected_p
), f"Gathered passage count {all_p.shape[0]} != expected {expected_p}"
scores, labels = distributed_maxsim_scores_and_labels(
q_reps,
all_p,
n_passages,
all_p_mask,
rank,
)
else:
all_p = dist_gather_tensor(p_reps, preserve_grad=preserve_gather_grad)
expected_p = world_size * local_bs * n_passages
assert (
all_p.shape[0] == expected_p
), f"Gathered passage count {all_p.shape[0]} != expected {expected_p}"
scores = torch.mm(q_reps, all_p.t())
labels = (torch.arange(local_bs, device=q_reps.device) + rank * local_bs) * n_passages
if model.l2_normalize:
scores = scores / self.temperature
passage_doc_ids = batch.get("passage_doc_ids")
Expand All @@ -377,7 +456,15 @@ def _forward_backward_step(self, idx, batch, *, loss_buffer, num_batches, is_tra
local_batch_size=local_bs,
)
else:
scores, labels = contrastive_scores_and_labels(q_reps, p_reps, n_passages)
if use_multi_vector_scoring:
scores, labels = maxsim_scores_and_labels(
q_reps,
p_reps,
n_passages,
passage["attention_mask"],
)
else:
scores, labels = contrastive_scores_and_labels(q_reps, p_reps, n_passages)
if model.l2_normalize:
scores = scores / self.temperature
loss = F.cross_entropy(scores, labels)
Expand Down Expand Up @@ -465,7 +552,15 @@ def _run_validation_epoch(self, val_dataloader):
q_reps = model(query)
p_reps = model(passage)

scores, labels = contrastive_scores_and_labels(q_reps, p_reps, self.val_n_passages)
if _uses_multi_vector_scoring(model):
scores, labels = maxsim_scores_and_labels(
q_reps,
p_reps,
self.val_n_passages,
passage["attention_mask"],
)
else:
scores, labels = contrastive_scores_and_labels(q_reps, p_reps, self.val_n_passages)
if model.l2_normalize:
scores = scores / self.temperature
loss = F.cross_entropy(scores, labels)
Expand Down
Loading
Loading