From 0a2697506e70004ceb92228a7155555ae4ab8a47 Mon Sep 17 00:00:00 2001 From: AbdulmalikDS Date: Sat, 25 Oct 2025 13:20:25 +0300 Subject: [PATCH 1/9] support pre-tokenized parquet datasets --- src/llamafactory/data/collator_tokenized.py | 88 +++++++++++++++++++++ src/llamafactory/data/loader.py | 46 ++++++++--- src/llamafactory/data/tokenized_parquet.py | 79 ++++++++++++++++++ src/llamafactory/hparams/data_args.py | 36 ++++++++- 4 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 src/llamafactory/data/collator_tokenized.py create mode 100644 src/llamafactory/data/tokenized_parquet.py diff --git a/src/llamafactory/data/collator_tokenized.py b/src/llamafactory/data/collator_tokenized.py new file mode 100644 index 0000000000..61d98286a4 --- /dev/null +++ b/src/llamafactory/data/collator_tokenized.py @@ -0,0 +1,88 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import torch + +from ..extras.constants import IGNORE_INDEX +from .collator import MultiModalDataCollatorForSeq2Seq + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +def _resolve_pad_token_id(tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel") -> int: + r"""Resolve the padding token ID from tokenizer or model config.""" + pad_id = getattr(getattr(model, "config", None), "pad_token_id", None) + if pad_id is None and tokenizer is not None: + pad_id = getattr(tokenizer, "pad_token_id", None) + if pad_id is None: + pad_id = getattr(getattr(model, "config", None), "eos_token_id", None) + return 0 if pad_id is None else int(pad_id) + + +@dataclass +class TokenizedIdsCollator(MultiModalDataCollatorForSeq2Seq): + r"""Collator for pre-tokenized LM data. + + Expects features containing `input_ids` and optionally `attention_mask`. + Pads to batch max length with `pad_token_id`, generates labels and masks missing fields when needed. + """ + + strict: bool = True + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + pad_id = _resolve_pad_token_id(self.tokenizer, self.model) + + # Validate and compute max length + max_len = 0 + for f in features: + if "input_ids" not in f or not isinstance(f["input_ids"], list): + if self.strict: + raise ValueError("Each feature must contain list[int] `input_ids`.") + else: + f["input_ids"] = f.get("input_ids", []) or [] + max_len = max(max_len, len(f["input_ids"])) + + input_ids = [] + attention_mask = [] + labels = [] + for f in features: + ids = f["input_ids"] + pad_amt = max_len - len(ids) + row_ids = ids + [pad_id] * pad_amt + input_ids.append(row_ids) + + if "attention_mask" in f and isinstance(f["attention_mask"], list): + if self.strict and len(f["attention_mask"]) != len(ids): + raise ValueError("attention_mask length must match input_ids length.") + mask = f["attention_mask"] + [0] * pad_amt + else: + mask = [1] * len(ids) + [0] * pad_amt + attention_mask.append(mask) + + row_labels = row_ids.copy() + for i in range(len(ids), max_len): + row_labels[i] = IGNORE_INDEX + labels.append(row_labels) + + batch = { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long), + } + return batch diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index b5adc139e9..10aca16380 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -21,6 +21,7 @@ from ..extras import logging from ..extras.constants import FILEEXT2TYPE from ..extras.misc import check_version, has_tokenized_data +from .collator_tokenized import TokenizedIdsCollator from .converter import align_dataset from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset from .parser import get_dataset_list @@ -32,6 +33,7 @@ SupervisedDatasetProcessor, UnsupervisedDatasetProcessor, ) +from .tokenized_parquet import load_tokenized_parquet_dataset if TYPE_CHECKING: @@ -241,6 +243,10 @@ def _get_preprocessed_dataset( if dataset is None: return None + # Bypass tokenizer for pre-tokenized pathway + if data_args.dataset_format == "tokenized_ids": + return dataset + dataset_processor = _get_dataset_processor( data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval) ) @@ -301,15 +307,30 @@ def get_dataset( # Load and preprocess dataset with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)): - dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) - eval_dataset = _get_merged_dataset( - data_args.eval_dataset, - model_args, - data_args, - training_args, - stage, - return_dict=data_args.eval_on_each_dataset, - ) + if data_args.dataset_format == "tokenized_ids": + # Load pre-tokenized parquet files + cols = data_args.dataset_columns or {} + ids_key = cols.get("ids", "input_ids") + mask_key = cols.get("mask", "attention_mask") + files = data_args.data_files + if isinstance(files, dict): + files = files.get("train", []) + if not isinstance(files, list) or len(files) == 0: + raise ValueError( + "For dataset_format=tokenized_ids, provide non-empty data_files list (parquet paths)." + ) + dataset = load_tokenized_parquet_dataset(files, ids_key=ids_key, mask_key=mask_key) + eval_dataset = None + else: + dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage) + eval_dataset = _get_merged_dataset( + data_args.eval_dataset, + model_args, + data_args, + training_args, + stage, + return_dict=data_args.eval_on_each_dataset, + ) with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)): dataset = _get_preprocessed_dataset( @@ -332,4 +353,9 @@ def get_dataset( logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.") logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.") - return get_dataset_module(dataset_dict) + module = get_dataset_module(dataset_dict) + # Replace collator for tokenized_ids + if data_args.dataset_format == "tokenized_ids": + collator = TokenizedIdsCollator(tokenizer=tokenizer, model=None) # model attached later by trainer + module["data_collator"] = collator + return module diff --git a/src/llamafactory/data/tokenized_parquet.py b/src/llamafactory/data/tokenized_parquet.py new file mode 100644 index 0000000000..c80cf0cd83 --- /dev/null +++ b/src/llamafactory/data/tokenized_parquet.py @@ -0,0 +1,79 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional + +import pyarrow as pa +import pyarrow.parquet as pq + +from ..extras import logging + + +if TYPE_CHECKING: + from datasets import IterableDataset + + +logger = logging.get_logger(__name__) + + +def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]: + r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples.""" + for path in paths: + with open(path, "rb") as f: + pf = pq.ParquetFile(f) + for i in range(pf.num_row_groups): + table: pa.Table = pf.read_row_group(i) + ids_col = table[ids_key] + mask_col = table[mask_key] if mask_key and mask_key in table.column_names else None + ids_py = ids_col.to_pylist() + mask_py = mask_col.to_pylist() if mask_col is not None else itertools.repeat(None) + for ids, mask in zip(ids_py, mask_py): + yield { + "input_ids": list(ids) if isinstance(ids, (list, tuple)) else ids, + **( + {"attention_mask": (list(mask) if isinstance(mask, (list, tuple)) else mask)} + if mask is not None + else {} + ), + } + + +def load_tokenized_parquet_dataset( + data_files: list[str], + ids_key: str = "input_ids", + mask_key: Optional[str] = "attention_mask", +) -> "IterableDataset": + r"""Create a streaming HF IterableDataset over pre-tokenized Parquet samples. + + Args: + data_files: List of local Parquet file paths. + ids_key: Column name for input token IDs. + mask_key: Column name for attention mask (optional). + + Returns: + IterableDataset yielding dictionaries with `input_ids` and optionally `attention_mask`. + + Note: + Always streams row groups to avoid materializing large corpora in memory. + """ + from datasets import IterableDataset + + if not data_files: + raise ValueError("data_files must be a non-empty list of Parquet paths") + + logger.info_rank0(f"Building streaming dataset from {len(data_files)} parquet file(s)") + gen = lambda: _iter_parquet_rows(data_files, ids_key, mask_key) + return IterableDataset.from_generator(gen) # type: ignore diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e6844733e5..ebc6dca5d3 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -16,7 +16,7 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union @dataclass @@ -137,6 +137,34 @@ class DataArguments: default=False, metadata={"help": "Whether or not to use a shared file system for the datasets."}, ) + dataset_format: Optional[Literal["default", "tokenized_ids"]] = field( + default="default", + metadata={ + "help": ( + "Format of the input dataset. Use 'tokenized_ids' for pre-tokenized parquet files " + "containing token IDs. This bypasses the tokenization step during training." + ) + }, + ) + data_files: Optional[Union[str, list[str]]] = field( + default=None, + metadata={ + "help": ( + "Path(s) to data files for tokenized_ids format. " + "Can be a single path, comma-separated paths, or a list of paths." + ) + }, + ) + dataset_columns: Optional[dict[str, str]] = field( + default=None, + metadata={ + "help": ( + "Column name mapping for tokenized datasets. " + "Example: {'ids': 'token_ids', 'mask': 'attn_mask'}. " + "Defaults to {'ids': 'input_ids', 'mask': 'attention_mask'}." + ) + }, + ) def __post_init__(self): def split_arg(arg): @@ -147,6 +175,12 @@ def split_arg(arg): self.dataset = split_arg(self.dataset) self.eval_dataset = split_arg(self.eval_dataset) + # Handle data_files for tokenized_ids format + if self.dataset_format == "tokenized_ids": + if self.data_files is None: + raise ValueError("data_files must be specified when using dataset_format='tokenized_ids'.") + self.data_files = split_arg(self.data_files) + if self.media_dir is None: self.media_dir = self.dataset_dir From fc02353732b569c18453364109b8a770b714f59a Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Sat, 25 Oct 2025 13:52:30 +0300 Subject: [PATCH 2/9] Update src/llamafactory/data/collator_tokenized.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/collator_tokenized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/data/collator_tokenized.py b/src/llamafactory/data/collator_tokenized.py index 61d98286a4..39ce9b67c8 100644 --- a/src/llamafactory/data/collator_tokenized.py +++ b/src/llamafactory/data/collator_tokenized.py @@ -36,7 +36,7 @@ def _resolve_pad_token_id(tokenizer: "PreTrainedTokenizer", model: "PreTrainedMo @dataclass -class TokenizedIdsCollator(MultiModalDataCollatorForSeq2Seq): +class TokenizedIdsCollator(DataCollatorForSeq2Seq): r"""Collator for pre-tokenized LM data. Expects features containing `input_ids` and optionally `attention_mask`. From bcc701659cfacbc904ad33d909aa9fb04db1618d Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Sun, 26 Oct 2025 16:17:08 +0300 Subject: [PATCH 3/9] Update src/llamafactory/data/collator_tokenized.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/collator_tokenized.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/data/collator_tokenized.py b/src/llamafactory/data/collator_tokenized.py index 39ce9b67c8..cf0908a7e6 100644 --- a/src/llamafactory/data/collator_tokenized.py +++ b/src/llamafactory/data/collator_tokenized.py @@ -18,7 +18,7 @@ import torch from ..extras.constants import IGNORE_INDEX -from .collator import MultiModalDataCollatorForSeq2Seq +from transformers import DataCollatorForSeq2Seq if TYPE_CHECKING: From 112647003af4c74889ad803fda40542c1d9c921e Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Sun, 26 Oct 2025 16:18:30 +0300 Subject: [PATCH 4/9] Update src/llamafactory/data/tokenized_parquet.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/tokenized_parquet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llamafactory/data/tokenized_parquet.py b/src/llamafactory/data/tokenized_parquet.py index c80cf0cd83..92fa3a4478 100644 --- a/src/llamafactory/data/tokenized_parquet.py +++ b/src/llamafactory/data/tokenized_parquet.py @@ -32,8 +32,7 @@ def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]: r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples.""" for path in paths: - with open(path, "rb") as f: - pf = pq.ParquetFile(f) + with pq.ParquetFile(path) as pf: for i in range(pf.num_row_groups): table: pa.Table = pf.read_row_group(i) ids_col = table[ids_key] From 93b667b6f5458a01666f98afbe7beeb3dd6e748b Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Sun, 26 Oct 2025 16:19:20 +0300 Subject: [PATCH 5/9] Update src/llamafactory/data/tokenized_parquet.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/tokenized_parquet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llamafactory/data/tokenized_parquet.py b/src/llamafactory/data/tokenized_parquet.py index 92fa3a4478..cb079eaf48 100644 --- a/src/llamafactory/data/tokenized_parquet.py +++ b/src/llamafactory/data/tokenized_parquet.py @@ -74,5 +74,4 @@ def load_tokenized_parquet_dataset( raise ValueError("data_files must be a non-empty list of Parquet paths") logger.info_rank0(f"Building streaming dataset from {len(data_files)} parquet file(s)") - gen = lambda: _iter_parquet_rows(data_files, ids_key, mask_key) - return IterableDataset.from_generator(gen) # type: ignore + return IterableDataset.from_generator(_iter_parquet_rows, gen_kwargs={"paths": data_files, "ids_key": ids_key, "mask_key": mask_key}) # type: ignore From 5bdcea1700e50c119c9ff673845471e09385d0fa Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Sun, 26 Oct 2025 16:20:08 +0300 Subject: [PATCH 6/9] Update src/llamafactory/hparams/data_args.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/hparams/data_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index ebc6dca5d3..bfa76744d2 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -146,7 +146,7 @@ class DataArguments: ) }, ) - data_files: Optional[Union[str, list[str]]] = field( + data_files: Optional[Union[str, list[str], dict]] = field( default=None, metadata={ "help": ( From eabd8554499f4224c0ca12c153542f4b706b059e Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Wed, 29 Oct 2025 13:45:36 +0300 Subject: [PATCH 7/9] Update src/llamafactory/data/collator_tokenized.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/collator_tokenized.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/llamafactory/data/collator_tokenized.py b/src/llamafactory/data/collator_tokenized.py index cf0908a7e6..f1db1a25cf 100644 --- a/src/llamafactory/data/collator_tokenized.py +++ b/src/llamafactory/data/collator_tokenized.py @@ -75,10 +75,7 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: mask = [1] * len(ids) + [0] * pad_amt attention_mask.append(mask) - row_labels = row_ids.copy() - for i in range(len(ids), max_len): - row_labels[i] = IGNORE_INDEX - labels.append(row_labels) + labels.append(ids + [IGNORE_INDEX] * pad_amt) batch = { "input_ids": torch.tensor(input_ids, dtype=torch.long), From f03fa2adaadbc6fa5acc3b909faaae813dc03a80 Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Wed, 29 Oct 2025 13:46:07 +0300 Subject: [PATCH 8/9] Update src/llamafactory/data/tokenized_parquet.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/data/tokenized_parquet.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/llamafactory/data/tokenized_parquet.py b/src/llamafactory/data/tokenized_parquet.py index cb079eaf48..b20c5d9cc4 100644 --- a/src/llamafactory/data/tokenized_parquet.py +++ b/src/llamafactory/data/tokenized_parquet.py @@ -32,6 +32,13 @@ def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]: r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples.""" for path in paths: + try: + pf = pq.ParquetFile(path) + except FileNotFoundError: + logger.warning(f"Parquet file not found, skipping: {path}") + continue + + with pf: with pq.ParquetFile(path) as pf: for i in range(pf.num_row_groups): table: pa.Table = pf.read_row_group(i) From dc763236ac715b92837c197ff9600d319aefea8e Mon Sep 17 00:00:00 2001 From: Abdulmalik Alquwayfili Date: Wed, 29 Oct 2025 13:47:22 +0300 Subject: [PATCH 9/9] Update src/llamafactory/hparams/data_args.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/hparams/data_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index bfa76744d2..6f8ea5ce80 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -146,7 +146,7 @@ class DataArguments: ) }, ) - data_files: Optional[Union[str, list[str], dict]] = field( + data_files: Optional[Any] = field( default=None, metadata={ "help": (