Skip to content

FSDP + DPO fails with "Expected all tensors to be on the same device" #9560

@UsernameFull

Description

@UsernameFull

Reminder

  • I have read the above rules and searched the existing issues.

System Info

  • llamafactory: 0.9.4
  • python: 3.11.14
  • torch: 2.7.1
  • tansformers: 4.57.1
  • accelerate: 1.11
  • trl: 0.96
  • cuda: 12.6

Reproduction

Description

When running DPO training with FSDP enabled (FULL_SHARD), a RuntimeError occurs at the embedding layer because the reference model resides on the CPU while inputs are on CUDA.

Error message:

return  torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:x and cpu!

fsdp.yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16  # or fp16
num_machines: 1  # the number of nodes
num_processes: 8  # the number of GPUs in all nodes
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Solution

Manually move the ref_model to the current CUDA device in src/llamafactory/train/dpo/workflow.py will work

# File: src/llamafactory/train/dpo/workflow.py

    # ... inside run_dpo function ...
    
    if finetuning_args.use_ref_model:
        if finetuning_args.ref_model is None and (not training_args.do_train):
            ref_model = model
        else:
            ref_model = create_ref_model(model_args, finetuning_args)
            # ++++++ INSERTED CODE ++++++
            import torch
            ref_model = ref_model.to(torch.cuda.current_device())
            # +++++++++++++++++++++++++++
    else:
        ref_model = None

Others

similar issues:

#7641
#4608

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingpendingThis problem is yet to be addressed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions