Skip to content

Commit 9fbbbeb

Browse files
scavallariandompesta
authored andcommitted
fix cpu usecase
1 parent c513c47 commit 9fbbbeb

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,10 @@ def postprocess(module, name):
10331033
original_device = weight.device
10341034
original_dtype = weight.dtype
10351035
weight_f64 = weight.to(dtype=torch.float64, device=original_device)
1036-
u, s, vt = torch.linalg.svd(weight_f64, driver="gesvd", full_matrices=False)
1036+
if original_device.type == "cuda":
1037+
u, s, vt = torch.linalg.svd(weight_f64, driver="gesvd", full_matrices=False)
1038+
else:
1039+
u, s, vt = torch.linalg.svd(weight_f64, full_matrices=False)
10371040
if u.shape[1] < lowrank or vt.shape[0] < lowrank:
10381041
warnings.warn(
10391042
"The low-rank dimensions do not match the layer dimensions. "

0 commit comments

Comments
 (0)