diff --git a/init2winit/dataset_lib/fastmri_dataset.py b/init2winit/dataset_lib/fastmri_dataset.py index bed2f6ef..c76bfa8f 100644 --- a/init2winit/dataset_lib/fastmri_dataset.py +++ b/init2winit/dataset_lib/fastmri_dataset.py @@ -19,6 +19,7 @@ import itertools import os +from absl import logging import h5py from init2winit.dataset_lib import data_utils import jax @@ -212,9 +213,11 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): # entirely to the end of it on the last host, because otherwise we will drop # the last `{train,valid}_size % split_size` elements. if jax.process_index() == jax.process_count() - 1: - if split == 'val': + if split in ['train', 'eval_train']: + end = hps.num_train_h5_files + elif split == 'val': end = hps.num_valid_h5_files - else: + else: # split == 'test' end = hps.num_test_h5_files + hps.num_valid_h5_files data_dir = hps.data_dir @@ -229,9 +232,58 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): else: # split == 'val' data_dir = os.path.join(data_dir, hps.val_dir) - h5_paths = [ - os.path.join(data_dir, path) for path in listdir(data_dir) - ][start:end] + try: + all_files = listdir(data_dir) + except tf.errors.NotFoundError as e: + raise FileNotFoundError( + f'FastMRI data directory not found: {data_dir}.' + ) from e + + h5_paths = [os.path.join(data_dir, path) for path in all_files][start:end] + + if not h5_paths: + raise ValueError( + f'No h5 files found for split={split} in {data_dir} ' + f'(start={start}, end={end}, total files={len(all_files)}).' + ) + logging.info( + 'FastMRI %s split: loaded %d h5 paths from %s (files %d-%d of %d).', + split, + len(h5_paths), + data_dir, + start, + end, + len(all_files), + ) + + # Probe-read the first h5 file to catch ACL / connectivity errors early, + # before they are silently swallowed by tf.data.from_generator. + probe_path = h5_paths[0] + try: + with gfile.GFile(probe_path, 'rb') as gf: + with h5py.File(gf, 'r') as hf: + if 'kspace' not in hf: + raise ValueError( + f'FastMRI h5 file {probe_path} is missing the "kspace" dataset. ' + 'The file may be corrupt or incomplete.' + ) + logging.info( + 'FastMRI probe read of %s succeeded: %d slices.', + probe_path, + hf['kspace'].shape[0], + ) + except PermissionError as e: + raise PermissionError( + f'Cannot read FastMRI h5 file {probe_path}: permission denied. ' + 'Check that the Borg job gfs_user has access to the FastMRI Placer ' + ) from e + except (OSError, IOError) as e: + raise OSError( + f'Cannot read FastMRI h5 file {probe_path}: {e}. ' + 'This may indicate a timeout, network issue, or that the Placer ' + 'fileset is not replicated to a nearby cell. See b/329522685.' + ) from e + ds = tf.data.Dataset.from_tensor_slices(h5_paths) ds = ds.interleave( _create_generator,