A from-scratch C++ training framework for large-scale models with multi-dimensional distributed parallelism.
- Recommended: NVIDIA Ampere-class GPUs (A100/A800) or newer
- CUDA / NCCL: Latest stable versions
- gcc / g++: Version 13+
- CMake: Version 3.13+
mkdir build
cd build
cmake .. -DUSE_CUDA=ON -DUSE_NCCL=ON
make -jBuild Options:
-
USE_CUDA=ONEnable CUDA backend support.
-
USE_NCCL=ONEnable NCCL-based distributed communication.
Both options are optional and can be disabled for CPU-only builds.
| Category | Feature | Description | Status |
|---|---|---|---|
| Model Support | GPT-2 | Decoder-only Transformer language model | โ Supported |
| LLaMA 3 | Modern LLaMA-family Transformer architecture | โ Supported | |
| Qwen3-8B | Qwen3 8B language model | ๐ Planned | |
| DeepSeek-V3 | Large-scale MoE-based language model | ๐ Planned | |
| Precision | Multiple Data Type | FP32, BF16 | โ Supported |
| Mixed Precision | Autocast-based BF16 compute with FP32 accumulation | โ Supported | |
| Distributed Training | Data Parallel (DP) | Parameter-server-style data parallelism | โ Supported |
| Distributed Data Parallel (DDP) | Collective-based data parallelism | โ Supported | |
| Tensor Parallelism (TP) | Intra-layer tensor sharding | โ Supported | |
| Sequence Parallelism (SP) | Sequence dimension sharding | โ Supported | |
| Pipeline Parallelism (PP) | GPipe, 1F1B scheduling, Virtual Pipeline (vPP) | โ Supported | |
| Hybrid Parallelism | Arbitrary combination of DDP + TP + SP + PP | โ Supported | |
| Core Components | Multi-backend | CPU and CUDA execution backends | โ Supported |
| Multi-node Distributed Training | Distributed execution across multiple nodes | โ Supported | |
| Transformer Abstraction | Generic Transformer structure abstraction | โ Supported | |
| Backend Registries | Device / CCL / dtype abstraction and registration | โ Supported | |
| Kernel Dispatcher | Kernel registration and dynamic dispatch mechanism | โ Supported | |
| Autograd | Automatic differentiation engine | โ Supported | |
| Autocast | Automatic mixed precision runtime | โ Supported | |
| Checkpointing | Training checkpoint save and restore | ๐ Planned | |
| Fine-tuning | LoRA | Memory-efficient fine-tuning with merge / unmerge | โ Supported |
| Memory Optimizations | ZeRO Stage-1 | Sharded optimizer states for DDP | โ Supported |
| ZeRO Stage-2 | Sharded gradients across DDP ranks | โ Supported | |
| Activation Recomputation | Recompute activations to reduce memory usage | ๐ Planned | |
| Performance Optimizations | ComputeโComm Overlap | Explicit scheduling to hide communication latency | โ Supported |
| DDP Gradient Bucketing | Deferred and bucketed gradient synchronization | โ Supported | |
| Execution Mode | Training Mode | Full forwardโbackward training with autograd | โ Supported |
no_grad Inference |
Forward-only execution without gradient tracking | โ Supported | |
| Debugging & Tooling | Built-in Profiler | Kernel-level performance profiling | โ Supported |
| Precision Alignment Checker | Function / Module precision checks and E2E loss diff | โ Supported | |
| CTest + GTest Infrastructure | Automated unit tests with CTest integration | โ Supported | |
| Automated Benchmarking | One-click execution, log analysis and Feishu export | โ Supported |
Each model in the example/ directory is compiled into an independent executable.
For example, the llama3 example produces a binary named llama3.
To view available runtime options:
./llama3 --helpThe following examples demonstrate LLaMA 3 supervised fine-tuning (SFT) using InfiniTrain.
./llama3 \
--device cuda \
--input_bin [training_data_path] \
--llmc_filepath [model_path] \
--num_iteration 10
./infini_run \
--nnodes=2 \
--nproc_per_node=1 \
--node_rank=[rank_id] \
-- ./llama3 \
--device cuda \
--input_bin [training_data_path] \
--llmc_filepath [model_path] \
--num_iteration 10 \
--nthread_per_process 8 \
--batch_size 40 \
--total_batch_size 10240 \
--tensor_parallel 2 \
--pipeline_parallel 2 \
--sequence_parallel--nthread_per_process 8 # ddp_size = nthread_per_process / (tensor_parallel ร pipeline_parallel)--tensor_parallel 4 # 4-way tensor parallelism
--sequence_parallel # Enable sequence parallelism (requires TP > 1)--pipeline_parallel 8 # 8 pipeline stages
--virtual_pipeline_parallel 4 # Virtual pipeline for better load balancingMultiple parallelism strategies (DDP, TP, SP, PP) can be freely combined to scale training across devices and nodes.
-
2025/03/10 โ InfiniTrain v0.1.0
Initial framework prototype with MNIST CPU training.
-
2025/04/30 โ InfiniTrain v0.3.0
Added Autograd support and GPT-2 training on CPU/CUDA.
-
2025/07/09 โ InfiniTrain v0.4.0
Introduced kernel registration, LLaMA training on CPU/CUDA, BF16 precision, and Data Parallelism.
-
2025/12/31 โ InfiniTrain v0.5.0
Added Autocast, multi-dimensional distributed parallelism (DDP, TP, SP, PP with GPipe / 1F1B / vPP), multi-node training,
no_gradmode, and communicationโcomputation overlap with bucketed gradient synchronization. -
2026/06/08 โ InfiniTrain v0.6.0
Added loss alignment tooling for Function / Module level precision checks and end-to-end loss comparison, with a unified hook mechanism.
Added memory optimizations for DDP training and Autograd execution. ZeRO Stage-1 shards optimizer states across DDP ranks, while ZeRO Stage-2 further shards gradients. Autograd Tensor release timing was also optimized to reduce peak memory usage.
Introduced LoRA fine-tuning with
merge/unmergesupport for efficient training and inference-time weight merging.Refactored core backend abstractions around device, communication, and low-precision dtype registration. The framework layer now uses
DeviceGuard,CclGroupGuard, and backend-registered FP16 / BF16 native types to avoid hardware-specialized framework code.Introduced a generic Transformer structure abstraction backed by
TransformerConfig, providing a common foundation for GPT-2 and LLaMA 3 style model construction.Improved BF16 training performance through autocast and elementwise kernel optimizations.
Integrated a CTest + GTest based testing infrastructure to strengthen the framework's automated test workflow.