Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion scripts/imagine_win.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,24 @@ def _cpu_input_hook(module, args, kwargs):
log(f'STATUS:Loading Flux transformer ({flux_key})...')
transformer = FluxTransformer2DModel.from_single_file(
FLUX_FILES[flux_key], torch_dtype=dtype
).to('cuda')
)
# The flux1-*-fp8 checkpoints are fp8 on disk, but from_single_file upcasts
# them to the bf16 compute dtype (~23GB). On a 24GB card that overflows once
# the desktop/browser baseline (~5GB) plus activations are accounted for, so
# the driver spills to shared system RAM and each diffusion step takes
# minutes — long enough to trip the media-job idle watchdog (300s). Layerwise
# casting keeps the weights stored as fp8 (~15GB resident — fp8 weight bytes
# plus the norm/embedding modules that stay bf16) and upcasts each layer to
# bf16 only during its forward pass, so the model fits in VRAM while compute
# stays bf16. (Ampere/3090 has no fp8 tensor cores, so we must NOT compute in
# fp8 — storage-only casting is the correct strategy here.) Set
# IMAGINE_WIN_FP8=0 (or false/never) to fall back to the old full-bf16 load.
if os.environ.get('IMAGINE_WIN_FP8', '1').strip().lower() not in ('0', 'false', 'never'):
transformer.enable_layerwise_casting(
storage_dtype=torch.float8_e4m3fn, compute_dtype=dtype
)
log('STATUS:Transformer weights stored as fp8 (bf16 compute)...')
transformer = transformer.to('cuda')

log('STATUS:Assembling pipeline...')
pipe = FluxPipeline(
Expand Down
Loading