@@ -215,30 +215,31 @@ def _create_dynamics(
215215 # Initialise the dynamics object list.
216216 self ._dynamics = []
217217
218- # A set of visited device indices .
219- devices = set ()
218+ # Per-device memory tracking for estimation .
219+ device_mem = {}
220220
221- # Determine whether there is a remainder in the number of replicas.
221+ # Work out how many replicas are assigned to each device.
222+ # Replicas are assigned round-robin, so the first (num_replicas % num_gpus)
223+ # devices get one extra replica.
224+ base = floor (num_replicas / num_gpus )
222225 remainder = num_replicas % num_gpus
223-
224- # Store the number of contexts for each device. The last device will
225- # have remainder contexts, while all others have
226- contexts_per_device = num_replicas * [floor (num_replicas / num_gpus )]
227-
228- # Set the last device to have the remainder contexts.
229- contexts_per_device [- 1 ] = remainder
226+ contexts_per_device = [
227+ base + (1 if i < remainder else 0 ) for i in range (num_gpus )
228+ ]
230229
231230 # Create the dynamics objects in serial.
232231 for i , (lam , scale ) in enumerate (zip (lambdas , rest2_scale_factors )):
233232 # Work out the device index.
234233 device = i % num_gpus
235234
236- # If we've not seen this device before then get the memory statistics
237- # prior to creating the dynamics object and GCMC sampler.
238- if device not in devices :
239- used_mem_before , free_mem_before , total_mem = self ._check_device_memory (
240- device
241- )
235+ # Record baseline memory before the first replica on this device.
236+ if device not in device_mem :
237+ used_before , _ , total_mem = self ._check_device_memory (device )
238+ device_mem [device ] = {
239+ "before" : used_before ,
240+ "total" : total_mem ,
241+ "count" : 0 ,
242+ }
242243
243244 # This is a restart, get the system for this replica.
244245 if isinstance (system , list ):
@@ -321,19 +322,43 @@ def _create_dynamics(
321322 # Append the dynamics object.
322323 self ._dynamics .append (dynamics )
323324
324- # Check the memory footprint for this device.
325- if not device in devices :
326- # Add the device to the set of visited devices.
327- devices . add ( device )
325+ # Track memory footprint for this device.
326+ info = device_mem [ device ]
327+ info [ "count" ] += 1
328+ num_contexts = contexts_per_device [ device ]
328329
329- # Get the current memory usage.
330- used_mem , free_mem , total_mem = self ._check_device_memory (device )
330+ # Estimate memory after the first or second replica.
331+ if info ["count" ] == 1 :
332+ used_mem , _ , _ = self ._check_device_memory (device )
333+ info ["after_first" ] = used_mem
331334
332- # Work out the memory used by this dynamics object and GCMC sampler.
333- mem_used = used_mem - used_mem_before
335+ if num_contexts == 1 :
336+ # Only one replica on this device, use actual measurement.
337+ est_total = used_mem
338+ else :
339+ # Wait for the second replica to get the marginal cost.
340+ est_total = None
341+
342+ elif info ["count" ] == 2 :
343+ used_mem , _ , _ = self ._check_device_memory (device )
344+ # The first replica includes one-time context overhead.
345+ # The marginal cost of subsequent replicas is the difference
346+ # between the second and first.
347+ first_cost = info ["after_first" ] - info ["before" ]
348+ marginal_cost = used_mem - info ["after_first" ]
349+ est_total = (
350+ info ["before" ] + first_cost + marginal_cost * (num_contexts - 1 )
351+ )
352+ _logger .info (
353+ f"Memory per replica on device { device } : "
354+ f"first = { first_cost / (1024 ** 2 ):.0f} MiB, "
355+ f"marginal = { marginal_cost / (1024 ** 2 ):.0f} MiB"
356+ )
357+ else :
358+ est_total = None
334359
335- # Work out the estimated total after all replicas have been created.
336- est_total = mem_used * contexts_per_device [ device ] + used_mem_before
360+ if est_total is not None :
361+ total_mem = info [ "total" ]
337362
338363 # If this exceeds the total memory, raise an error.
339364 if est_total > total_mem :
@@ -562,18 +587,10 @@ def _check_device_memory(device_index=0):
562587
563588 pynvml .nvmlInit ()
564589
565- # Find matching device by name
566- device_count = pynvml .nvmlDeviceGetCount ()
567- for i in range (device_count ):
568- handle = pynvml .nvmlDeviceGetHandleByIndex (i )
569- name = pynvml .nvmlDeviceGetName (handle )
570-
571- if name in device .name or device .name in name :
572- memory = pynvml .nvmlDeviceGetMemoryInfo (handle )
573- pynvml .nvmlShutdown ()
574- return (memory .used , memory .free , memory .total )
575-
590+ handle = pynvml .nvmlDeviceGetHandleByIndex (device_index )
591+ memory = pynvml .nvmlDeviceGetMemoryInfo (handle )
576592 pynvml .nvmlShutdown ()
593+ return (memory .used , memory .free , memory .total )
577594 except Exception as e :
578595 msg = f"Could not get NVIDIA GPU memory info for device { device_index } : { e } "
579596 _logger .error (msg )
0 commit comments