Skip to content

Commit 1896f34

Browse files
authored
Merge pull request #120 from OpenBioSim/feature_primary_context
Update GPU memory footprint estimation for shared primary context
2 parents bc1c2ab + b597404 commit 1896f34

File tree

3 files changed

+71
-38
lines changed

3 files changed

+71
-38
lines changed

src/somd2/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@
3737
# Store the sire version.
3838
from sire import __version__ as _sire_version
3939
from sire import __revisionid__ as _sire_revisionid
40+
41+
# Store the ghostly version.
42+
from ghostly import __version__ as _ghostly_version
43+
44+
# Store the loch version.
45+
from loch import __version__ as _loch_version

src/somd2/runner/_base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,20 @@ def __init__(self, system, config):
117117
self._perturbed_box = None
118118

119119
# Log the versions of somd2 and sire.
120-
from somd2 import __version__, _sire_version, _sire_revisionid
120+
from somd2 import (
121+
__version__,
122+
_sire_version,
123+
_sire_revisionid,
124+
_ghostly_version,
125+
_loch_version,
126+
)
121127

122128
_logger.info(f"somd2 version: {__version__}")
123129
_logger.info(f"sire version: {_sire_version}+{_sire_revisionid}")
130+
if self._config.ghost_modifications:
131+
_logger.info(f"ghostly version: {_ghostly_version}")
132+
if self._config.gcmc:
133+
_logger.info(f"loch version: {_loch_version}")
124134

125135
# Flag whether frames are being saved.
126136
if (

src/somd2/runner/_repex.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -215,30 +215,31 @@ def _create_dynamics(
215215
# Initialise the dynamics object list.
216216
self._dynamics = []
217217

218-
# A set of visited device indices.
219-
devices = set()
218+
# Per-device memory tracking for estimation.
219+
device_mem = {}
220220

221-
# Determine whether there is a remainder in the number of replicas.
221+
# Work out how many replicas are assigned to each device.
222+
# Replicas are assigned round-robin, so the first (num_replicas % num_gpus)
223+
# devices get one extra replica.
224+
base = floor(num_replicas / num_gpus)
222225
remainder = num_replicas % num_gpus
223-
224-
# Store the number of contexts for each device. The last device will
225-
# have remainder contexts, while all others have
226-
contexts_per_device = num_replicas * [floor(num_replicas / num_gpus)]
227-
228-
# Set the last device to have the remainder contexts.
229-
contexts_per_device[-1] = remainder
226+
contexts_per_device = [
227+
base + (1 if i < remainder else 0) for i in range(num_gpus)
228+
]
230229

231230
# Create the dynamics objects in serial.
232231
for i, (lam, scale) in enumerate(zip(lambdas, rest2_scale_factors)):
233232
# Work out the device index.
234233
device = i % num_gpus
235234

236-
# If we've not seen this device before then get the memory statistics
237-
# prior to creating the dynamics object and GCMC sampler.
238-
if device not in devices:
239-
used_mem_before, free_mem_before, total_mem = self._check_device_memory(
240-
device
241-
)
235+
# Record baseline memory before the first replica on this device.
236+
if device not in device_mem:
237+
used_before, _, total_mem = self._check_device_memory(device)
238+
device_mem[device] = {
239+
"before": used_before,
240+
"total": total_mem,
241+
"count": 0,
242+
}
242243

243244
# This is a restart, get the system for this replica.
244245
if isinstance(system, list):
@@ -321,19 +322,43 @@ def _create_dynamics(
321322
# Append the dynamics object.
322323
self._dynamics.append(dynamics)
323324

324-
# Check the memory footprint for this device.
325-
if not device in devices:
326-
# Add the device to the set of visited devices.
327-
devices.add(device)
325+
# Track memory footprint for this device.
326+
info = device_mem[device]
327+
info["count"] += 1
328+
num_contexts = contexts_per_device[device]
328329

329-
# Get the current memory usage.
330-
used_mem, free_mem, total_mem = self._check_device_memory(device)
330+
# Estimate memory after the first or second replica.
331+
if info["count"] == 1:
332+
used_mem, _, _ = self._check_device_memory(device)
333+
info["after_first"] = used_mem
331334

332-
# Work out the memory used by this dynamics object and GCMC sampler.
333-
mem_used = used_mem - used_mem_before
335+
if num_contexts == 1:
336+
# Only one replica on this device, use actual measurement.
337+
est_total = used_mem
338+
else:
339+
# Wait for the second replica to get the marginal cost.
340+
est_total = None
341+
342+
elif info["count"] == 2:
343+
used_mem, _, _ = self._check_device_memory(device)
344+
# The first replica includes one-time context overhead.
345+
# The marginal cost of subsequent replicas is the difference
346+
# between the second and first.
347+
first_cost = info["after_first"] - info["before"]
348+
marginal_cost = used_mem - info["after_first"]
349+
est_total = (
350+
info["before"] + first_cost + marginal_cost * (num_contexts - 1)
351+
)
352+
_logger.info(
353+
f"Memory per replica on device {device}: "
354+
f"first = {first_cost / (1024**2):.0f} MiB, "
355+
f"marginal = {marginal_cost / (1024**2):.0f} MiB"
356+
)
357+
else:
358+
est_total = None
334359

335-
# Work out the estimated total after all replicas have been created.
336-
est_total = mem_used * contexts_per_device[device] + used_mem_before
360+
if est_total is not None:
361+
total_mem = info["total"]
337362

338363
# If this exceeds the total memory, raise an error.
339364
if est_total > total_mem:
@@ -562,18 +587,10 @@ def _check_device_memory(device_index=0):
562587

563588
pynvml.nvmlInit()
564589

565-
# Find matching device by name
566-
device_count = pynvml.nvmlDeviceGetCount()
567-
for i in range(device_count):
568-
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
569-
name = pynvml.nvmlDeviceGetName(handle)
570-
571-
if name in device.name or device.name in name:
572-
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
573-
pynvml.nvmlShutdown()
574-
return (memory.used, memory.free, memory.total)
575-
590+
handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
591+
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
576592
pynvml.nvmlShutdown()
593+
return (memory.used, memory.free, memory.total)
577594
except Exception as e:
578595
msg = f"Could not get NVIDIA GPU memory info for device {device_index}: {e}"
579596
_logger.error(msg)

0 commit comments

Comments
 (0)