Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ def train_model(params: Params,
create_serialization_dir(params, serialization_dir, recover)
prepare_global_logging(serialization_dir, file_friendly_logging)

check_for_gpu(params.params.get('trainer').get('cuda_device', -1))
cuda_device = params.params.get('trainer').get('cuda_device', -1)
if isinstance(cuda_device, list):
for device in cuda_device:
check_for_gpu(device)
else:
check_for_gpu(cuda_device)

serialization_params = deepcopy(params).as_dict(quiet=True)
with open(os.path.join(serialization_dir, CONFIG_NAME), "w") as param_file:
Expand Down
17 changes: 14 additions & 3 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,20 @@ def _data_parallel(self, batch):
of torch.nn.parallel.data_parallel to support the allennlp model
interface.
"""
metadata_batch_size = len(batch['metadata']) if 'metadata' in batch and isinstance(batch['metadata'],list) else None

inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0)

if metadata_batch_size is not None:
# Metadata batches also have to be chunked as PyTorch is unaware of them.
# Follows chunking implementation by ATen.native.TensorShape functions.
chunk_size = 1 + (metadata_batch_size - 1)//len(self._cuda_devices)
chunk_offset = 0
for instance in module_kwargs:
if 'metadata' in instance:
instance['metadata'] = instance['metadata'][chunk_offset:chunk_size+chunk_offset]
chunk_offset += chunk_size

used_device_ids = self._cuda_devices[:len(inputs)]
replicas = replicate(self._model, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
Expand Down Expand Up @@ -926,13 +939,11 @@ def from_params(cls,
patience = params.pop_int("patience", None)
validation_metric = params.pop("validation_metric", "-loss")
num_epochs = params.pop_int("num_epochs", 20)
cuda_device = params.pop_int("cuda_device", -1)
cuda_device = params.pop( "cuda_device")
grad_norm = params.pop_float("grad_norm", None)
grad_clipping = params.pop_float("grad_clipping", None)
lr_scheduler_params = params.pop("learning_rate_scheduler", None)

if cuda_device >= 0:
Copy link
Owner

@murphp15 murphp15 Jul 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model is always none now then maybe it should be removed as a parameter from the Trainer constructor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That wasn't in my changes though, right? Pretty sure upstream modified the trainer.py quite a bit so you may want to diff against their master first.

model = model.cuda(cuda_device)
parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))

Expand Down