diff --git a/allennlp/commands/train.py b/allennlp/commands/train.py index dd67f907404..13a18af9a06 100644 --- a/allennlp/commands/train.py +++ b/allennlp/commands/train.py @@ -253,7 +253,12 @@ def train_model(params: Params, create_serialization_dir(params, serialization_dir, recover) prepare_global_logging(serialization_dir, file_friendly_logging) - check_for_gpu(params.params.get('trainer').get('cuda_device', -1)) + cuda_device = params.params.get('trainer').get('cuda_device', -1) + if isinstance(cuda_device, list): + for device in cuda_device: + check_for_gpu(device) + else: + check_for_gpu(cuda_device) serialization_params = deepcopy(params).as_dict(quiet=True) with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file: diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 8dcbd7e3d91..7d5ff1194e0 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -378,7 +378,20 @@ def _data_parallel(self, batch): of torch.nn.parallel.data_parallel to support the allennlp model interface. """ + metadata_batch_size = len(batch['metadata']) if 'metadata' in batch and isinstance(batch['metadata'],list) else None + inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0) + + if metadata_batch_size is not None: + # Metadata batches also have to be chunked as PyTorch is unaware of them. + # Follows chunking implementation by ATen.native.TensorShape functions. + chunk_size = 1 + (metadata_batch_size - 1)//len(self._cuda_devices) + chunk_offset = 0 + for instance in module_kwargs: + if 'metadata' in instance: + instance['metadata'] = instance['metadata'][chunk_offset:chunk_size+chunk_offset] + chunk_offset += chunk_size + used_device_ids = self._cuda_devices[:len(inputs)] replicas = replicate(self._model, used_device_ids) outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) @@ -926,13 +939,11 @@ def from_params(cls, patience = params.pop_int("patience", None) validation_metric = params.pop("validation_metric", "-loss") num_epochs = params.pop_int("num_epochs", 20) - cuda_device = params.pop_int("cuda_device", -1) + cuda_device = params.pop( "cuda_device") grad_norm = params.pop_float("grad_norm", None) grad_clipping = params.pop_float("grad_clipping", None) lr_scheduler_params = params.pop("learning_rate_scheduler", None) - if cuda_device >= 0: - model = model.cuda(cuda_device) parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))