-
Notifications
You must be signed in to change notification settings - Fork 7.7k
support pre-tokenized parquet datasets #9351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AbdulmalikDS
wants to merge
9
commits into
hiyouga:main
Choose a base branch
from
AbdulmalikDS:feature/tokenized-parquet-v2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+240
−11
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
0a26975
support pre-tokenized parquet datasets
AbdulmalikDS fc02353
Update src/llamafactory/data/collator_tokenized.py
AbdulmalikDS bcc7016
Update src/llamafactory/data/collator_tokenized.py
AbdulmalikDS 1126470
Update src/llamafactory/data/tokenized_parquet.py
AbdulmalikDS 93b667b
Update src/llamafactory/data/tokenized_parquet.py
AbdulmalikDS 5bdcea1
Update src/llamafactory/hparams/data_args.py
AbdulmalikDS eabd855
Update src/llamafactory/data/collator_tokenized.py
AbdulmalikDS f03fa2a
Update src/llamafactory/data/tokenized_parquet.py
AbdulmalikDS dc76323
Update src/llamafactory/hparams/data_args.py
AbdulmalikDS File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # Copyright 2025 the LlamaFactory team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| import torch | ||
|
|
||
| from ..extras.constants import IGNORE_INDEX | ||
| from transformers import DataCollatorForSeq2Seq | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import PreTrainedModel, PreTrainedTokenizer | ||
|
|
||
|
|
||
| def _resolve_pad_token_id(tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel") -> int: | ||
| r"""Resolve the padding token ID from tokenizer or model config.""" | ||
| pad_id = getattr(getattr(model, "config", None), "pad_token_id", None) | ||
| if pad_id is None and tokenizer is not None: | ||
| pad_id = getattr(tokenizer, "pad_token_id", None) | ||
| if pad_id is None: | ||
| pad_id = getattr(getattr(model, "config", None), "eos_token_id", None) | ||
| return 0 if pad_id is None else int(pad_id) | ||
|
|
||
|
|
||
| @dataclass | ||
| class TokenizedIdsCollator(DataCollatorForSeq2Seq): | ||
| r"""Collator for pre-tokenized LM data. | ||
|
|
||
| Expects features containing `input_ids` and optionally `attention_mask`. | ||
| Pads to batch max length with `pad_token_id`, generates labels and masks missing fields when needed. | ||
| """ | ||
|
|
||
| strict: bool = True | ||
|
|
||
| def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: | ||
| pad_id = _resolve_pad_token_id(self.tokenizer, self.model) | ||
|
|
||
| # Validate and compute max length | ||
| max_len = 0 | ||
| for f in features: | ||
| if "input_ids" not in f or not isinstance(f["input_ids"], list): | ||
| if self.strict: | ||
| raise ValueError("Each feature must contain list[int] `input_ids`.") | ||
| else: | ||
| f["input_ids"] = f.get("input_ids", []) or [] | ||
| max_len = max(max_len, len(f["input_ids"])) | ||
|
|
||
| input_ids = [] | ||
| attention_mask = [] | ||
| labels = [] | ||
| for f in features: | ||
| ids = f["input_ids"] | ||
| pad_amt = max_len - len(ids) | ||
| row_ids = ids + [pad_id] * pad_amt | ||
| input_ids.append(row_ids) | ||
|
|
||
| if "attention_mask" in f and isinstance(f["attention_mask"], list): | ||
| if self.strict and len(f["attention_mask"]) != len(ids): | ||
| raise ValueError("attention_mask length must match input_ids length.") | ||
| mask = f["attention_mask"] + [0] * pad_amt | ||
| else: | ||
| mask = [1] * len(ids) + [0] * pad_amt | ||
| attention_mask.append(mask) | ||
|
|
||
| labels.append(ids + [IGNORE_INDEX] * pad_amt) | ||
|
|
||
| batch = { | ||
| "input_ids": torch.tensor(input_ids, dtype=torch.long), | ||
| "attention_mask": torch.tensor(attention_mask, dtype=torch.long), | ||
| "labels": torch.tensor(labels, dtype=torch.long), | ||
| } | ||
| return batch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Copyright 2025 the LlamaFactory team. | ||
AbdulmalikDS marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import itertools | ||
| from collections.abc import Iterable | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| import pyarrow as pa | ||
| import pyarrow.parquet as pq | ||
|
|
||
| from ..extras import logging | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from datasets import IterableDataset | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]: | ||
| r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples.""" | ||
| for path in paths: | ||
AbdulmalikDS marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try: | ||
| pf = pq.ParquetFile(path) | ||
| except FileNotFoundError: | ||
| logger.warning(f"Parquet file not found, skipping: {path}") | ||
| continue | ||
|
|
||
| with pf: | ||
| with pq.ParquetFile(path) as pf: | ||
| for i in range(pf.num_row_groups): | ||
| table: pa.Table = pf.read_row_group(i) | ||
| ids_col = table[ids_key] | ||
| mask_col = table[mask_key] if mask_key and mask_key in table.column_names else None | ||
| ids_py = ids_col.to_pylist() | ||
| mask_py = mask_col.to_pylist() if mask_col is not None else itertools.repeat(None) | ||
| for ids, mask in zip(ids_py, mask_py): | ||
| yield { | ||
| "input_ids": list(ids) if isinstance(ids, (list, tuple)) else ids, | ||
| **( | ||
| {"attention_mask": (list(mask) if isinstance(mask, (list, tuple)) else mask)} | ||
| if mask is not None | ||
| else {} | ||
| ), | ||
| } | ||
AbdulmalikDS marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def load_tokenized_parquet_dataset( | ||
| data_files: list[str], | ||
| ids_key: str = "input_ids", | ||
| mask_key: Optional[str] = "attention_mask", | ||
| ) -> "IterableDataset": | ||
| r"""Create a streaming HF IterableDataset over pre-tokenized Parquet samples. | ||
|
|
||
| Args: | ||
| data_files: List of local Parquet file paths. | ||
| ids_key: Column name for input token IDs. | ||
| mask_key: Column name for attention mask (optional). | ||
|
|
||
| Returns: | ||
| IterableDataset yielding dictionaries with `input_ids` and optionally `attention_mask`. | ||
|
|
||
| Note: | ||
| Always streams row groups to avoid materializing large corpora in memory. | ||
| """ | ||
| from datasets import IterableDataset | ||
|
|
||
| if not data_files: | ||
| raise ValueError("data_files must be a non-empty list of Parquet paths") | ||
|
|
||
| logger.info_rank0(f"Building streaming dataset from {len(data_files)} parquet file(s)") | ||
| return IterableDataset.from_generator(_iter_parquet_rows, gen_kwargs={"paths": data_files, "ids_key": ids_key, "mask_key": mask_key}) # type: ignore | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.