Skip to content

Commit b90f9a6

Browse files
committed
(Fix) data_loader: Change Value Error to Warning for BatchSampler in sampler argument, and Fix Typos
1 parent ac777e9 commit b90f9a6

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

src/accelerate/data_loader.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -631,10 +631,12 @@ def get_sampler(self):
631631
return get_sampler(self)
632632

633633
def set_sampler(self, sampler):
634-
if isinstance(sampler, BatchSampler):
635-
self.sampler.batch_sampler = sampler
636-
elif isinstance(sampler, Sampler):
634+
if isinstance(self.sampler, BatchSampler):
637635
self.sampler.sampler = sampler
636+
else:
637+
self.batch_sampler.sampler = sampler
638+
if hasattr(self.batch_sampler, "batch_sampler"):
639+
self.batch_sampler.batch_sampler.sampler = sampler
638640

639641

640642
if is_torch_xla_available():
@@ -955,12 +957,12 @@ def get_sampler(self):
955957
return get_sampler(self)
956958

957959
def set_sampler(self, sampler):
958-
if isinstance(sampler, BatchSampler):
959-
self.sampler.batch_sampler = sampler
960-
elif isinstance(sampler, Sampler):
960+
if isinstance(self.sampler, BatchSampler):
961961
self.sampler.sampler = sampler
962962
else:
963-
raise ValueError(f"{sampler} must be of type torch.utills.data.Sampler or torch.utils.data.BatchSampler")
963+
self.batch_sampler.sampler = sampler
964+
if hasattr(self.batch_sampler, "batch_sampler"):
965+
self.batch_sampler.batch_sampler.sampler = sampler
964966

965967

966968
def get_sampler(dataloader):
@@ -1150,8 +1152,10 @@ def prepare_data_loader(
11501152
new_dataset = dataloader.dataset
11511153
# Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
11521154
if isinstance(dataloader.sampler, BatchSampler):
1153-
raise ValueError(
1154-
"Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1155+
logger.warning(
1156+
"BatchSampler was passed to sampler argument."
1157+
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1158+
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
11551159
)
11561160

11571161
new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None
@@ -1347,8 +1351,10 @@ def skip_first_batches(dataloader, num_batches=0):
13471351

13481352
dataset = dataloader.dataset
13491353
if isinstance(dataloader.sampler, BatchSampler):
1350-
raise ValueError(
1351-
"Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1354+
logger.warning(
1355+
"BatchSampler was passed to sampler argument."
1356+
"If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1357+
"For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
13521358
)
13531359

13541360
if isinstance(dataset, IterableDataset):

0 commit comments

Comments
 (0)