@@ -631,10 +631,12 @@ def get_sampler(self):
631631 return get_sampler (self )
632632
633633 def set_sampler (self , sampler ):
634- if isinstance (sampler , BatchSampler ):
635- self .sampler .batch_sampler = sampler
636- elif isinstance (sampler , Sampler ):
634+ if isinstance (self .sampler , BatchSampler ):
637635 self .sampler .sampler = sampler
636+ else :
637+ self .batch_sampler .sampler = sampler
638+ if hasattr (self .batch_sampler , "batch_sampler" ):
639+ self .batch_sampler .batch_sampler .sampler = sampler
638640
639641
640642if is_torch_xla_available ():
@@ -955,12 +957,12 @@ def get_sampler(self):
955957 return get_sampler (self )
956958
957959 def set_sampler (self , sampler ):
958- if isinstance (sampler , BatchSampler ):
959- self .sampler .batch_sampler = sampler
960- elif isinstance (sampler , Sampler ):
960+ if isinstance (self .sampler , BatchSampler ):
961961 self .sampler .sampler = sampler
962962 else :
963- raise ValueError (f"{ sampler } must be of type torch.utills.data.Sampler or torch.utils.data.BatchSampler" )
963+ self .batch_sampler .sampler = sampler
964+ if hasattr (self .batch_sampler , "batch_sampler" ):
965+ self .batch_sampler .batch_sampler .sampler = sampler
964966
965967
966968def get_sampler (dataloader ):
@@ -1150,8 +1152,10 @@ def prepare_data_loader(
11501152 new_dataset = dataloader .dataset
11511153 # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it
11521154 if isinstance (dataloader .sampler , BatchSampler ):
1153- raise ValueError (
1154- "Should not pass a BatchSampler do dataloader sampler argument. As per pytorch>2.1.0 documentation, please pass this to sampler instead"
1155+ logger .warning (
1156+ "BatchSampler was passed to sampler argument."
1157+ "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1158+ "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
11551159 )
11561160
11571161 new_batch_sampler = dataloader .batch_sampler if not isinstance (new_dataset , IterableDataset ) else None
@@ -1347,8 +1351,10 @@ def skip_first_batches(dataloader, num_batches=0):
13471351
13481352 dataset = dataloader .dataset
13491353 if isinstance (dataloader .sampler , BatchSampler ):
1350- raise ValueError (
1351- "Should not pass a BatchSampler do dataloader sampler argument. As per the latest pytorch documentation, please pass this to sampler instead"
1354+ logger .warning (
1355+ "BatchSampler was passed to sampler argument."
1356+ "If you have a custom Sampler that yields a list of batch indices at a time, please pass it as the batch_sampler argument instead."
1357+ "For more information, see https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"
13521358 )
13531359
13541360 if isinstance (dataset , IterableDataset ):
0 commit comments