Skip to content
85 changes: 85 additions & 0 deletions src/llamafactory/data/collator_tokenized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import torch

from ..extras.constants import IGNORE_INDEX
from transformers import DataCollatorForSeq2Seq


if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer


def _resolve_pad_token_id(tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel") -> int:
r"""Resolve the padding token ID from tokenizer or model config."""
pad_id = getattr(getattr(model, "config", None), "pad_token_id", None)
if pad_id is None and tokenizer is not None:
pad_id = getattr(tokenizer, "pad_token_id", None)
if pad_id is None:
pad_id = getattr(getattr(model, "config", None), "eos_token_id", None)
return 0 if pad_id is None else int(pad_id)


@dataclass
class TokenizedIdsCollator(DataCollatorForSeq2Seq):
r"""Collator for pre-tokenized LM data.

Expects features containing `input_ids` and optionally `attention_mask`.
Pads to batch max length with `pad_token_id`, generates labels and masks missing fields when needed.
"""

strict: bool = True

def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
pad_id = _resolve_pad_token_id(self.tokenizer, self.model)

# Validate and compute max length
max_len = 0
for f in features:
if "input_ids" not in f or not isinstance(f["input_ids"], list):
if self.strict:
raise ValueError("Each feature must contain list[int] `input_ids`.")
else:
f["input_ids"] = f.get("input_ids", []) or []
max_len = max(max_len, len(f["input_ids"]))

input_ids = []
attention_mask = []
labels = []
for f in features:
ids = f["input_ids"]
pad_amt = max_len - len(ids)
row_ids = ids + [pad_id] * pad_amt
input_ids.append(row_ids)

if "attention_mask" in f and isinstance(f["attention_mask"], list):
if self.strict and len(f["attention_mask"]) != len(ids):
raise ValueError("attention_mask length must match input_ids length.")
mask = f["attention_mask"] + [0] * pad_amt
else:
mask = [1] * len(ids) + [0] * pad_amt
attention_mask.append(mask)

labels.append(ids + [IGNORE_INDEX] * pad_amt)

batch = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
return batch
46 changes: 36 additions & 10 deletions src/llamafactory/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..extras import logging
from ..extras.constants import FILEEXT2TYPE
from ..extras.misc import check_version, has_tokenized_data
from .collator_tokenized import TokenizedIdsCollator
from .converter import align_dataset
from .data_utils import get_dataset_module, merge_dataset, read_cloud_json, split_dataset
from .parser import get_dataset_list
Expand All @@ -32,6 +33,7 @@
SupervisedDatasetProcessor,
UnsupervisedDatasetProcessor,
)
from .tokenized_parquet import load_tokenized_parquet_dataset


if TYPE_CHECKING:
Expand Down Expand Up @@ -241,6 +243,10 @@ def _get_preprocessed_dataset(
if dataset is None:
return None

# Bypass tokenizer for pre-tokenized pathway
if data_args.dataset_format == "tokenized_ids":
return dataset

dataset_processor = _get_dataset_processor(
data_args, stage, template, tokenizer, processor, do_generate=(training_args.predict_with_generate and is_eval)
)
Expand Down Expand Up @@ -301,15 +307,30 @@ def get_dataset(

# Load and preprocess dataset
with training_args.main_process_first(desc="load dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)
if data_args.dataset_format == "tokenized_ids":
# Load pre-tokenized parquet files
cols = data_args.dataset_columns or {}
ids_key = cols.get("ids", "input_ids")
mask_key = cols.get("mask", "attention_mask")
files = data_args.data_files
if isinstance(files, dict):
files = files.get("train", [])
if not isinstance(files, list) or len(files) == 0:
raise ValueError(
"For dataset_format=tokenized_ids, provide non-empty data_files list (parquet paths)."
)
dataset = load_tokenized_parquet_dataset(files, ids_key=ids_key, mask_key=mask_key)
eval_dataset = None
else:
dataset = _get_merged_dataset(data_args.dataset, model_args, data_args, training_args, stage)
eval_dataset = _get_merged_dataset(
data_args.eval_dataset,
model_args,
data_args,
training_args,
stage,
return_dict=data_args.eval_on_each_dataset,
)

with training_args.main_process_first(desc="pre-process dataset", local=(not data_args.data_shared_file_system)):
dataset = _get_preprocessed_dataset(
Expand All @@ -332,4 +353,9 @@ def get_dataset(
logger.info_rank0(f"Tokenized dataset is saved at {data_args.tokenized_path}.")
logger.info_rank0(f"Please launch the training with `tokenized_path: {data_args.tokenized_path}`.")

return get_dataset_module(dataset_dict)
module = get_dataset_module(dataset_dict)
# Replace collator for tokenized_ids
if data_args.dataset_format == "tokenized_ids":
collator = TokenizedIdsCollator(tokenizer=tokenizer, model=None) # model attached later by trainer
module["data_collator"] = collator
return module
84 changes: 84 additions & 0 deletions src/llamafactory/data/tokenized_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Optional

import pyarrow as pa
import pyarrow.parquet as pq

from ..extras import logging


if TYPE_CHECKING:
from datasets import IterableDataset


logger = logging.get_logger(__name__)


def _iter_parquet_rows(paths: list[str], ids_key: str, mask_key: Optional[str]) -> Iterable[dict[str, Any]]:
r"""Iterate over rows from multiple Parquet files, yielding pre-tokenized samples."""
for path in paths:
try:
pf = pq.ParquetFile(path)
except FileNotFoundError:
logger.warning(f"Parquet file not found, skipping: {path}")
continue

with pf:
with pq.ParquetFile(path) as pf:
for i in range(pf.num_row_groups):
table: pa.Table = pf.read_row_group(i)
ids_col = table[ids_key]
mask_col = table[mask_key] if mask_key and mask_key in table.column_names else None
ids_py = ids_col.to_pylist()
mask_py = mask_col.to_pylist() if mask_col is not None else itertools.repeat(None)
for ids, mask in zip(ids_py, mask_py):
yield {
"input_ids": list(ids) if isinstance(ids, (list, tuple)) else ids,
**(
{"attention_mask": (list(mask) if isinstance(mask, (list, tuple)) else mask)}
if mask is not None
else {}
),
}


def load_tokenized_parquet_dataset(
data_files: list[str],
ids_key: str = "input_ids",
mask_key: Optional[str] = "attention_mask",
) -> "IterableDataset":
r"""Create a streaming HF IterableDataset over pre-tokenized Parquet samples.

Args:
data_files: List of local Parquet file paths.
ids_key: Column name for input token IDs.
mask_key: Column name for attention mask (optional).

Returns:
IterableDataset yielding dictionaries with `input_ids` and optionally `attention_mask`.

Note:
Always streams row groups to avoid materializing large corpora in memory.
"""
from datasets import IterableDataset

if not data_files:
raise ValueError("data_files must be a non-empty list of Parquet paths")

logger.info_rank0(f"Building streaming dataset from {len(data_files)} parquet file(s)")
return IterableDataset.from_generator(_iter_parquet_rows, gen_kwargs={"paths": data_files, "ids_key": ids_key, "mask_key": mask_key}) # type: ignore
36 changes: 35 additions & 1 deletion src/llamafactory/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.

from dataclasses import asdict, dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Literal, Optional, Union


@dataclass
Expand Down Expand Up @@ -137,6 +137,34 @@ class DataArguments:
default=False,
metadata={"help": "Whether or not to use a shared file system for the datasets."},
)
dataset_format: Optional[Literal["default", "tokenized_ids"]] = field(
default="default",
metadata={
"help": (
"Format of the input dataset. Use 'tokenized_ids' for pre-tokenized parquet files "
"containing token IDs. This bypasses the tokenization step during training."
)
},
)
data_files: Optional[Any] = field(
default=None,
metadata={
"help": (
"Path(s) to data files for tokenized_ids format. "
"Can be a single path, comma-separated paths, or a list of paths."
)
},
)
dataset_columns: Optional[dict[str, str]] = field(
default=None,
metadata={
"help": (
"Column name mapping for tokenized datasets. "
"Example: {'ids': 'token_ids', 'mask': 'attn_mask'}. "
"Defaults to {'ids': 'input_ids', 'mask': 'attention_mask'}."
)
},
)

def __post_init__(self):
def split_arg(arg):
Expand All @@ -147,6 +175,12 @@ def split_arg(arg):
self.dataset = split_arg(self.dataset)
self.eval_dataset = split_arg(self.eval_dataset)

# Handle data_files for tokenized_ids format
if self.dataset_format == "tokenized_ids":
if self.data_files is None:
raise ValueError("data_files must be specified when using dataset_format='tokenized_ids'.")
self.data_files = split_arg(self.data_files)

if self.media_dir is None:
self.media_dir = self.dataset_dir

Expand Down