Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -861,11 +861,6 @@ def postprocess_qa_predictions(
return eval_metric


# Model Optimizer: Define a teacher factory for initializing the distillation model
def teacher_factory(model_name_or_path):
return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)


# Model Optimizer: Define a custom distillation loss function that uses start and end logits
class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
def forward(self, outputs_s, outputs_t):
Expand Down Expand Up @@ -1199,7 +1194,9 @@ def forward_loop(model):
logger.info(f"Using distillation with teacher {args.model_name_or_path}")

kd_config = {
"teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
"teacher_model": AutoModelForQuestionAnswering.from_pretrained(
args.model_name_or_path,
),
"criterion": StartEndLogitsDistillationLoss(args.temperature),
}
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
Expand Down
1 change: 0 additions & 1 deletion examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def forward(input):
+ "teacher_model": teacher_model,
+ "criterion": distill_config["criterion"],
+ "loss_balancer": distill_config["loss_balancer"],
+ "expose_minimal_state_dict": False,
+ }
+ transformer = mtd.convert(transformer, mode=[("kd_loss", kd_config)])

Expand Down
15 changes: 7 additions & 8 deletions examples/llm_distill/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from transformers import AutoTokenizer
from trl import SFTTrainer

import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
from modelopt.torch.distill.plugins.huggingface import KDTrainer, LMLogitsLoss

Expand Down Expand Up @@ -64,7 +63,7 @@ def _format_smoltalk_chat_template(sample, tokenizer):
return tokenizer.apply_chat_template(messages, tokenize=False)


class KDSFTTrainer(SFTTrainer, KDTrainer):
class KDSFTTrainer(KDTrainer, SFTTrainer):
pass


Expand Down Expand Up @@ -105,23 +104,22 @@ def train():
tokenizer.padding_side = "right"
logger.info("Tokenizer loaded.")

# Model
# Model(s)
logger.info("Loading student model...")
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)
logger.info("Student loaded.")
# Load checkpoint
logger.info("Loading teacher model and converting to Distillation model...")
logger.info("Loading teacher model...")
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
)

# Distillation configuration
kd_config = {
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
}
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
logger.info("Models converted.")

# Fix problematic settings that logger.info excessive warnings
model.generation_config.temperature = None
Expand All @@ -131,6 +129,7 @@ def train():
trainer = KDSFTTrainer(
model,
training_args,
distill_config=kd_config,
train_dataset=dset_train,
eval_dataset=dset_eval,
formatting_func=lambda sample: _format_smoltalk_chat_template(sample, tokenizer),
Expand All @@ -153,7 +152,7 @@ def train():
# Save checkpoint
logger.info("Saving checkpoint...")
trainer.save_state()
trainer.save_model(trainer.args.output_dir, export_student=True)
trainer.save_model(trainer.args.output_dir)
logger.info("Checkpoint saved.")


Expand Down
16 changes: 3 additions & 13 deletions examples/llm_qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,8 @@ from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer
# [Not shown] load model, tokenizer, data loaders etc
# Create the distillation config
distill_config = {
"teacher_model": (
_teacher_factory,
(
model_args.teacher_model,
training_args.cache_dir,
),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False,
}

trainer = QADTrainer(
Expand All @@ -147,7 +139,7 @@ trainer = QADTrainer(
trainer.train() # Train the quantized model using distillation (i.e, QAD)

# Save the final student model weights; An example usage
trainer.save_model(export_student=True)
trainer.save_model()
```

### NeMo QAT/QAD Simplified Flow Example
Expand Down Expand Up @@ -245,8 +237,6 @@ You could also add your own customized quantization format to `CUSTOM_QUANT_CFG`

> **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap <your_layer_class>` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument.

> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--use_fsdp2 True" to the `launch.sh` script. Note that FSDP2 is less stable than FSDP1 currently. Use it with caution.

### Results

Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--train_bs` and `--accum_steps` as per the available GPU memory).
Expand Down Expand Up @@ -279,7 +269,7 @@ To perform QAD with logits loss, run:
--distill True
```

> **_NOTE:_** QAD currently requires quantization to be applied before the FSDP wrapper. Training is not supported for models that exceed single GPU memory capacity.
> **_NOTE:_** QAD doesn't support FSDP1 (<https://docs.pytorch.org/docs/stable/fsdp.html>) backend - only FSDP2.

## Testing QAT model with LLM benchmarks for accuracy evaluation

Expand Down
20 changes: 8 additions & 12 deletions examples/llm_qat/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ while [ $# -gt 0 ]; do
--fsdp_transformer_layer_cls_to_wrap*) FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--max_seq_length*) MAX_SEQ_LENGTH=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--backend*) BACKEND=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
--use_fsdp2*) USE_FSDP2=$(parse_value "$@"); [[ "$1" != *=* ]] && shift ;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -83,7 +82,7 @@ COMPRESS=${COMPRESS:-"False"}
DISTILL=${DISTILL:-"False"}
TEACHER_MODEL=${TEACHER_MODEL:-$MODEL}
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
BACKEND=${BACKEND:-"fsdp1"}
BACKEND=${BACKEND:-"fsdp2"}

if [ -z $QUANT_CFG ]; then
QUANT_ARGS=""
Expand All @@ -96,12 +95,6 @@ if [ ! -z $MAX_STEPS ]; then
OPTIONAL_ARGS="$OPTIONAL_ARGS --max_steps $MAX_STEPS"
fi

# Set backend based on --backend parameter, with backward compatibility for --use_fsdp2
if [[ "${USE_FSDP2,,}" == "true" ]]; then
echo "Warning: --use_fsdp2 is deprecated. Use --backend=fsdp2 instead."
BACKEND="fsdp2"
fi

# if compress is true, set backend to ddp
if [[ "${COMPRESS,,}" == "true" ]]; then
BACKEND="ddp"
Expand All @@ -115,7 +108,7 @@ case "${BACKEND,,}" in
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
"fsdp2")
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
echo "Using FSDP2 instead of FSDP1."
CONFIG_FILE="fsdp2.yaml"
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
Expand All @@ -139,8 +132,11 @@ esac
DISTILLATION_ARGS=""
if [[ "${DISTILL,,}" == "true" ]]; then
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
# Distillation does not work with memory efficient loading for FSDP
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
if [[ "${BACKEND,,}" == "fsdp1" ]]; then
echo "Error: Distillation does not support FSDP1. Use FSDP2 instead."
exit 1
elif [[ "${BACKEND,,}" == "fsdp2" ]]; then
# Distillation does not work with memory efficient loading for FSDP
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
fi
fi
Expand Down Expand Up @@ -180,4 +176,4 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \

start_time=$(date +%s)
sh -c "$CMD"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
4 changes: 4 additions & 0 deletions examples/llm_qat/llama_factory/launch_llamafactory.sh
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ else

# Add teacher model specific FSDP args if needed
if [[ "${HAS_TEACHER_MODEL,,}" == "true" ]]; then
if [[ "${USE_FSDP2,,}" != "true" ]]; then
echo "Error: Quantization aware distillation is only supported with FSDP2."
exit 1
fi
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
fi

Expand Down
21 changes: 5 additions & 16 deletions examples/llm_qat/llama_factory/llama_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ def _get_init_kwargs(model_args: ModelArguments) -> dict[str, Any]:
mto.enable_huggingface_checkpointing()


def _teacher_factory(model_name_or_path):
"""Function to create a teacher model."""
return transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
)


def parse_args():
"""Parse configuration file and extract ModelOpt quantization/distillation arguments.
Expand Down Expand Up @@ -221,14 +214,12 @@ def __init__(self, *args, **kwargs):
# Initialize parent classes
modelopt_trainer_args = {"quant_args": quant_args}
if distill_args and distill_args.distill:
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
distill_args.teacher_model,
)
distill_config = {
"teacher_model": (
_teacher_factory,
(distill_args.teacher_model,),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False, # FSDP requires this to be False
}
modelopt_trainer_args["distill_config"] = distill_config
super().__init__(*args, **modelopt_trainer_args, **kwargs)
Expand All @@ -249,11 +240,9 @@ def create_modelcard_and_push(
) -> None:
original_fn(trainer, *args, **kwargs)

# export the student model for quantization aware distillation
kwargs = {"export_student": True} if hasattr(trainer, "distill_config") else {}
# save the model in the output directory
trainer.save_state()
trainer.save_model(output_dir=trainer.args.output_dir, **kwargs)
trainer.save_model(output_dir=trainer.args.output_dir)

module.create_modelcard_and_push = create_modelcard_and_push

Expand Down
30 changes: 9 additions & 21 deletions examples/llm_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,6 @@ class QuantizationArguments:
)


def _teacher_factory(model_name_or_path, cache_dir=None):
"""Function to create a teacher model."""
return transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16,
)


def train():
parser = transformers.HfArgumentParser(
(ModelArguments, TrainingArguments, DataArguments, QuantizationArguments)
Expand All @@ -186,7 +177,7 @@ def train():
tokenizer.pad_token_id = tokenizer.eos_token_id

# We set model.config.use_cache to False for training when gradient_checkpointing=False.
# Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.åå
# Currently useful for FSDP2 to allow for setting activation_checkpointing=True in the config file.
model.config.use_cache = False

print_rank_0("Loading dataset...")
Expand Down Expand Up @@ -228,17 +219,15 @@ def train():
distill_kwargs = {}
if training_args.distill:
assert model_args.teacher_model is not None, "Teacher model is required for distillation."

teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_model,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
distill_config = {
"teacher_model": (
_teacher_factory,
(
model_args.teacher_model,
training_args.cache_dir,
),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False, # FSDP forces us to disable this
}
distill_kwargs["distill_config"] = distill_config
trainer_cls = QADTrainer if training_args.distill else QATTrainer
Expand Down Expand Up @@ -270,8 +259,7 @@ def train():
if training_args.do_train or quant_args.quant_cfg is not None:
print_rank_0("Saving the model...")
trainer.save_state()
kwargs = {"export_student": True} if training_args.distill else {}
trainer.save_model(training_args.output_dir, **kwargs)
trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
Expand Down
5 changes: 0 additions & 5 deletions modelopt/torch/distill/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return KDLossConfig

@property
def next_modes(self) -> set[str] | None:
"""Modes that must immediately follow this mode."""
return {"export_student"}

@property
def export_mode(self) -> str | None:
"""The mode that corresponds to the export mode of this mode."""
Expand Down
Loading