Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions init2winit/dataset_lib/fastmri_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
import os

from absl import logging
import h5py
from init2winit.dataset_lib import data_utils
import jax
Expand Down Expand Up @@ -212,9 +213,11 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None):
# entirely to the end of it on the last host, because otherwise we will drop
# the last `{train,valid}_size % split_size` elements.
if jax.process_index() == jax.process_count() - 1:
if split == 'val':
if split in ['train', 'eval_train']:
end = hps.num_train_h5_files
elif split == 'val':
end = hps.num_valid_h5_files
else:
else: # split == 'test'
end = hps.num_test_h5_files + hps.num_valid_h5_files

data_dir = hps.data_dir
Expand All @@ -229,9 +232,58 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None):
else: # split == 'val'
data_dir = os.path.join(data_dir, hps.val_dir)

h5_paths = [
os.path.join(data_dir, path) for path in listdir(data_dir)
][start:end]
try:
all_files = listdir(data_dir)
except tf.errors.NotFoundError as e:
raise FileNotFoundError(
f'FastMRI data directory not found: {data_dir}.'
) from e

h5_paths = [os.path.join(data_dir, path) for path in all_files][start:end]

if not h5_paths:
raise ValueError(
f'No h5 files found for split={split} in {data_dir} '
f'(start={start}, end={end}, total files={len(all_files)}).'
)
logging.info(
'FastMRI %s split: loaded %d h5 paths from %s (files %d-%d of %d).',
split,
len(h5_paths),
data_dir,
start,
end,
len(all_files),
)

# Probe-read the first h5 file to catch ACL / connectivity errors early,
# before they are silently swallowed by tf.data.from_generator.
probe_path = h5_paths[0]
try:
with gfile.GFile(probe_path, 'rb') as gf:
with h5py.File(gf, 'r') as hf:
if 'kspace' not in hf:
raise ValueError(
f'FastMRI h5 file {probe_path} is missing the "kspace" dataset. '
'The file may be corrupt or incomplete.'
)
logging.info(
'FastMRI probe read of %s succeeded: %d slices.',
probe_path,
hf['kspace'].shape[0],
)
except PermissionError as e:
raise PermissionError(
f'Cannot read FastMRI h5 file {probe_path}: permission denied. '
'Check that the Borg job gfs_user has access to the FastMRI Placer '
) from e
except (OSError, IOError) as e:
raise OSError(
f'Cannot read FastMRI h5 file {probe_path}: {e}. '
'This may indicate a timeout, network issue, or that the Placer '
'fileset is not replicated to a nearby cell. See b/329522685.'
) from e

ds = tf.data.Dataset.from_tensor_slices(h5_paths)
ds = ds.interleave(
_create_generator,
Expand Down
Loading