Skip to content

Off by one error in multi-document masking? #10

Description

@andreasgrv

Hi,

Thanks for the masking code for training with packing!
I was playing around with the code earlier and noticed that the returned token type mask isn't what I would expect, here is some code to reproduce:

Code Snippet

import torch
import torch.nn.functional as F

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('EvaByte/EvaByte', trust_remote_code=True)

EOS_TOKEN_TYPE_ID = tokenizer.added_tokens_encoder['<file_sep>']
PAD_TOKEN_TYPE_ID = tokenizer.pad_token_type_id

input_ids = tokenizer('Hello!', return_tensors='pt')['input_ids']
# Append <file_sep> token (replace the !)
input_ids[:, -1] = EOS_TOKEN_TYPE_ID

# hello<file_sep>hello
target_token_type_ids = torch.concat([input_ids, input_ids], dim=1)
target_token_type_ids = F.pad(target_token_type_ids, (0, 5), mode='constant', value=PAD_TOKEN_TYPE_ID)
print('Token ids:', target_token_type_ids.tolist())

batch_size, num_tokens = target_token_type_ids.shape

# Taken from https://github.com/OpenEvaByte/evabyte/blob/432331ad1feb017fd3b8a51a54fee637d6389900/training_utils.py#L129
##### step 1: mark each document with a unique id
end_token_ids = {EOS_TOKEN_TYPE_ID, PAD_TOKEN_TYPE_ID}
token_types = torch.zeros(batch_size, num_tokens)
for sequence_idx, sequence in enumerate(target_token_type_ids):
    num_articles = 0
    start_index = 0
    # for each sample in the batch, the collapsed attention mask looks like:
    # [1, 1, .... 1, 0, 2, 2, ... 2, 0, ... n, n ..... n], assuming there are n articles in the sequence.
    # Each of the n articles are separated by 0.
    for token_idx, token_type_id in enumerate(sequence):
        if start_index is not None and token_type_id.item() in end_token_ids:
            num_articles += 1
            end_index = token_idx if token_type_id == PAD_TOKEN_TYPE_ID else token_idx + 1
            token_types[sequence_idx][start_index:end_index] = num_articles
            start_index = None
        elif start_index is None and token_type_id not in end_token_ids:
            start_index = token_idx + 1

print('token  type')
print('-----------')
for token, ttype in zip(target_token_type_ids.squeeze().tolist(), token_types.squeeze().tolist()):
    print(f'{token:>6} {int(ttype):>4}')

Output

Token ids: [[1, 136, 165, 172, 172, 175, 6, 1, 136, 165, 172, 172, 175, 6, 0, 0, 0, 0, 0]]
token  type
-----------
     1    1
   136    1
   165    1
   172    1
   172    1
   175    1
     6    1               <--------         should be here?
     1    0               <-------- This
   136    2
   165    2
   172    2
   172    2
   175    2
     6    2
     0    0
     0    0
     0    0
     0    0
     0    0

Shouldn't the token type be 0 for the sep id which is 6? Is it off by one or am I missing something?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions