From 48d881fb38939a1802ec286d7182ed37810da5d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gonzalo=20Pe=C3=B1a-Castellanos?= Date: Mon, 4 May 2026 18:30:33 -0500 Subject: [PATCH] Replace 'aws s3 ls' shell-out in dataset.py with boto3 --- .gitignore | 5 +- README.md | 15 ++ pyproject.toml | 1 + stable_audio_tools/data/dataset.py | 117 +-------- stable_audio_tools/data/s3_utils.py | 237 ++++++++++++++++++ tests/test_dataset_s3.py | 368 ++++++++++++++++++++++++++++ uv.lock | 52 ++++ 7 files changed, 688 insertions(+), 107 deletions(-) create mode 100644 stable_audio_tools/data/s3_utils.py create mode 100644 tests/test_dataset_s3.py diff --git a/.gitignore b/.gitignore index fd9f3fe6..4264d5ac 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,7 @@ cython_debug/ *.wav wandb/* *.out -test_* \ No newline at end of file +test_* +!tests/test_*.py +# macOS +.DS_Store diff --git a/README.md b/README.md index 32576225..a16f743f 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,21 @@ The following properties are defined in the top level of the model configuration ## Dataset config `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md) +## S3-compatible storage +The S3 dataset loader uses `boto3`, which ships in the `train` extra. If you installed without that extra, add it with `pip install boto3` (or `pip install "stable-audio-tools[train]"`). + +The loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at S3-compatible storage providers (for example Backblaze B2, MinIO, Cloudflare R2, or other compatible endpoints) without changing the dataset config. + +Example: +```bash +export AWS_ENDPOINT_URL= +export AWS_DEFAULT_REGION= +export AWS_ACCESS_KEY_ID= +export AWS_SECRET_ACCESS_KEY= +``` + +When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3, so existing setups are unaffected. + # Todo - [ ] Add troubleshooting section - [ ] Add contribution guidelines diff --git a/pyproject.toml b/pyproject.toml index ec0e4821..0d8721a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ [project.optional-dependencies] train = [ "auraloss==0.4.0", + "boto3>=1.26,<2", "descript-audio-codec==1.0.0", "encodec==0.1.1", "inf-cl", diff --git a/stable_audio_tools/data/dataset.py b/stable_audio_tools/data/dataset.py index 0baf1244..2f057d66 100644 --- a/stable_audio_tools/data/dataset.py +++ b/stable_audio_tools/data/dataset.py @@ -1,24 +1,21 @@ import importlib -import numpy as np import io import json import os -import dill -import posixpath import random -import re -import subprocess import time +from os import path +from typing import Callable, List, Optional + +import dill +import numpy as np import torch import torchaudio import webdataset as wds - -from os import path from torch import nn from torchaudio import transforms as T -from typing import Optional, Callable, List -from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, strip_trailing_silence +from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo, VolumeNorm, strip_trailing_silence AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") @@ -481,105 +478,13 @@ def __getitem__(self, idx): print(f'Couldn\'t load file {latent_filename}: {e}') return self[random.randrange(len(self))] -# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py - -def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): - """ - Returns a list of full S3 paths to files in a given S3 bucket and directory path. - """ - # Ensure dataset_path ends with a trailing slash - if dataset_path != '' and not dataset_path.endswith('/'): - dataset_path += '/' - # Use posixpath to construct the S3 URL path - bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) - # Construct the `aws s3 ls` command - cmd = ['aws', 's3', 'ls', bucket_path] - - if profile is not None: - cmd.extend(['--profile', profile]) - - if recursive: - # Add the --recursive flag if requested - cmd.append('--recursive') - - # Run the `aws s3 ls` command and capture the output - run_ls = subprocess.run(cmd, capture_output=True, check=True) - # Split the output into lines and strip whitespace from each line - contents = run_ls.stdout.decode('utf-8').split('\n') - contents = [x.strip() for x in contents if x] - # Remove the timestamp from lines that begin with a timestamp - contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) - if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] - # Construct a full S3 path for each file in the contents list - contents = [posixpath.join(s3_url_prefix or '', x) - for x in contents if not x.endswith('/')] - # Apply the filter, if specified - if filter: - contents = [x for x in contents if filter in x] - # Remove redundant directory names in the S3 URL - if recursive: - # Get the main directory name from the S3 URL - main_dir = "/".join(bucket_path.split('/')[3:]) - # Remove the redundant directory names from each file path - contents = [x.replace(f'{main_dir}', '').replace( - '//', '/') for x in contents] - # Print debugging information, if requested - if debug: - print("contents = \n", contents) - # Return the list of S3 paths to files - return contents - - -def get_all_s3_urls( - names=[], # list of all valid [LAION AudioDataset] dataset names - # list of subsets you want from those datasets, e.g. ['train','valid'] - subsets=[''], - s3_url_prefix=None, # prefix for those dataset names - recursive=True, # recursively list all tar files in all subdirs - filter_str='tar', # only grab files with this substring - # print debugging info -- note: info displayed likely to change at dev's whims - debug=False, - profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} -): - "get urls of shards (tar files) for multiple datasets in one s3 bucket" - urls = [] - for name in names: - # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list - if s3_url_prefix is None: - contents_str = name - else: - # Construct the S3 path using the s3_url_prefix and the current name value - contents_str = posixpath.join(s3_url_prefix, name) - if debug: - print(f"get_all_s3_urls: {contents_str}:") - for subset in subsets: - subset_str = posixpath.join(contents_str, subset) - if debug: - print(f"subset_str = {subset_str}") - # Get the list of tar files in the current subset directory - profile = profiles.get(name, None) - tar_list = get_s3_contents( - subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) - for tar in tar_list: - # Escape spaces and parentheses in the tar filename for use in the shell command - tar = tar.replace(" ", "\ ").replace( - "(", "\(").replace(")", "\)") - # Construct the S3 path to the current tar file - s3_path = posixpath.join(name, subset, tar) + " -" - # Construct the AWS CLI command to download the current tar file - if s3_url_prefix is None: - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" - else: - request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" - if profiles.get(name): - request_str += f" --profile {profiles.get(name)}" - if debug: - print("request_str = ", request_str) - # Add the constructed URL to the list of URLs - urls.append(request_str) - return urls +# S3 helpers live in the import-light s3_utils module (no torch) so it can also +# run as a `pipe:` subprocess that streams individual shards. The previously +# public functions are re-exported here for backwards compatibility. +from .s3_utils import get_all_s3_urls, get_s3_contents # noqa: E402,F401 +# WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py def log_and_continue(exn): """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" print(f"Handling webdataset error ({repr(exn)}). Ignoring.") diff --git a/stable_audio_tools/data/s3_utils.py b/stable_audio_tools/data/s3_utils.py new file mode 100644 index 00000000..b78d7874 --- /dev/null +++ b/stable_audio_tools/data/s3_utils.py @@ -0,0 +1,237 @@ +"""S3 / S3-compatible (e.g. Backblaze B2) helpers for the WebDataset loaders. + +Deliberately free of heavy imports (torch, webdataset). This module doubles as a +``pipe:`` subprocess entry point (``python -m stable_audio_tools.data.s3_utils +``) that streams a single shard via boto3, so spawning it per shard open +stays cheap and shard downloads: + +* authenticate on every open (no presigned-URL expiry mid-run), and +* never place credentials on the command line (only the ``s3://`` path). + +S3 code originally based on the implementation by Scott Hawley in +https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py +""" +import os +import posixpath +import shlex +import sys +from functools import lru_cache +from importlib.metadata import PackageNotFoundError, version + + +@lru_cache(maxsize=1) +def _user_agent(): + "Base ``stable-audio-tools/`` token, looked up once on first use." + try: + ver = version("stable-audio-tools") + except PackageNotFoundError: # source/editable checkout without dist metadata + ver = "dev" + return f"stable-audio-tools/{ver}" + + +def _build_user_agent_extra(user_agent_extra=None): + """``stable-audio-tools/`` with any caller- or env-provided + (``STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA``) value appended, not replacing it. + Pass ``user_agent_extra=""`` to suppress the env value and use the base only.""" + base = _user_agent() + if user_agent_extra is None: + user_agent_extra = os.environ.get("STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA") + return f"{base} {user_agent_extra}" if user_agent_extra else base + + +@lru_cache(maxsize=32) +def _build_s3_client(profile, endpoint_url, user_agent_extra): + try: + import boto3 # local import so boto3 is only required when S3 is used + from botocore.config import Config + except ModuleNotFoundError as e: + raise ImportError( + "S3 dataset access requires boto3. Install it with " + "'pip install boto3' or 'pip install stable-audio-tools[train]'." + ) from e + + session = boto3.Session(profile_name=profile) if profile else boto3.Session() + return session.client( + "s3", + endpoint_url=endpoint_url, + # Backblaze B2's S3 API only accepts SigV4; force it. retries restore the + # transient-failure resilience the previous `aws s3` CLI path provided. + config=Config( + signature_version="s3v4", + user_agent_extra=user_agent_extra, + retries={"max_attempts": 5, "mode": "standard"}, + ), + ) + + +def _get_s3_client(profile=None, user_agent_extra=None): + """ + Build (and reuse) a boto3 S3 client. Honors AWS_ENDPOINT_URL when set so the + same code path works against any S3-compatible endpoint (AWS S3 by default; + set AWS_ENDPOINT_URL to a Backblaze B2 endpoint to point it at B2). When the + env var is unset, behavior matches the default AWS client. + + Clients are cached per (profile, endpoint, user-agent) so listing and + streaming share one client instead of building a new one on each call. + """ + endpoint_url = os.environ.get("AWS_ENDPOINT_URL") or None + return _build_s3_client( + profile, endpoint_url, _build_user_agent_extra(user_agent_extra) + ) + + +def _parse_s3_url(url): + "Split an ``s3://bucket/key`` URL into (bucket, key). Raises ValueError otherwise." + if not url.startswith("s3://"): + raise ValueError(f"expected an s3:// URL, got: {url!r}") + bucket, _, key = url[len("s3://"):].partition("/") + if not bucket.strip(): + raise ValueError(f"s3:// URL is missing a bucket name: {url!r}") + return bucket, key + + +def get_s3_contents( + dataset_path, + s3_url_prefix=None, + filter_str='', # only keep keys containing this substring + recursive=True, + debug=False, + profile=None, +): + """ + Returns a list of S3 object keys relative to ``dataset_path``, matching the + output shape of the previous ``aws s3 ls`` based implementation. Uses boto3 + directly so it works against any S3-compatible endpoint when + ``AWS_ENDPOINT_URL`` is set. + """ + # Ensure dataset_path ends with a trailing slash + if dataset_path != '' and not dataset_path.endswith('/'): + dataset_path += '/' + # Use posixpath to construct the S3 URL path (e.g. "s3://bucket/prefix/") + bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + + bucket, prefix = _parse_s3_url(bucket_path) + + s3 = _get_s3_client(profile=profile) + paginator = s3.get_paginator("list_objects_v2") + list_kwargs = {"Bucket": bucket, "Prefix": prefix} + if not recursive: + list_kwargs["Delimiter"] = "/" + + keys = [] + for page in paginator.paginate(**list_kwargs): + for obj in page.get("Contents", []) or []: + key = obj.get("Key", "") + if not key or key.endswith("/"): + continue + keys.append(key) + + # Apply the filter, if specified + if filter_str: + keys = [k for k in keys if filter_str in k] + + # Strip the listed prefix so keys are relative to dataset_path (matches the + # previous implementation's output shape). + if prefix: + keys = [k[len(prefix):] if k.startswith(prefix) else k for k in keys] + keys = [k.lstrip('/') for k in keys] + + if debug: + print("contents = \n", keys) + + return keys + + +def shard_pipe_command(url, profile=None): + """Build the WebDataset ``pipe:`` command that streams ``url`` via boto3. + + Only the ``s3://`` path (never credentials) lands on the command line, and the + object is fetched with fresh credentials on each open, so shards never expire. + """ + _parse_s3_url(url) # validate early + cmd = f"{shlex.quote(sys.executable)} -m stable_audio_tools.data.s3_utils {shlex.quote(url)}" + if profile: + cmd += f" --profile {shlex.quote(profile)}" + return f"pipe:{cmd}" + + +def get_all_s3_urls( + names=None, # list of [LAION AudioDataset] dataset names; None -> [] + subsets=None, # list of subsets, e.g. ['train','valid']; None -> [''] + s3_url_prefix=None, # prefix for those dataset names + recursive=True, # recursively list all tar files in all subdirs + filter_str='tar', # only grab files with this substring + debug=False, # print debugging info + profiles=None, # dict of profiles per name, e.g. {'dataset1': 'profile1'}; None -> {} +): + """Get WebDataset ``pipe:`` commands that stream shards (tar files) for + multiple datasets in one S3 bucket. + + Each command streams its shard with boto3 at open time (see + ``shard_pipe_command`` / ``stream_object``), so shards re-authenticate on + every open, never expire mid-run, and never expose credentials on the + command line. + """ + names = [] if names is None else names + subsets = [''] if subsets is None else subsets + profiles = profiles or {} + urls = [] + for name in names: + # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list + if s3_url_prefix is None: + contents_str = name + else: + contents_str = posixpath.join(s3_url_prefix, name) + if debug: + print(f"get_all_s3_urls: {contents_str}:") + for subset in subsets: + subset_str = posixpath.join(contents_str, subset) + if debug: + print(f"subset_str = {subset_str}") + profile = profiles.get(name, None) + tar_list = get_s3_contents( + subset_str, s3_url_prefix=None, recursive=recursive, + filter_str=filter_str, debug=debug, profile=profile) + for tar in tar_list: + if s3_url_prefix is None: + full_s3_url = posixpath.join(name, subset, tar) + else: + full_s3_url = posixpath.join(s3_url_prefix, name, subset, tar) + request_str = shard_pipe_command(full_s3_url, profile=profile) + if debug: + print("request_str =", request_str) + urls.append(request_str) + return urls + + +def stream_object(url, profile=None, out=None): + """Stream the bytes of an S3 object to ``out`` (default: stdout's binary buffer).""" + bucket, key = _parse_s3_url(url) + client = _get_s3_client(profile=profile) + body = client.get_object(Bucket=bucket, Key=key)["Body"] + # Binary destination required; sys.stdout.buffer is the binary handle of a + # real process stdout (the pipe: use case). Tests pass an explicit `out`. + out = out if out is not None else sys.stdout.buffer + try: + for chunk in body.iter_chunks(chunk_size=1024 * 1024): + out.write(chunk) + finally: + body.close() + + +def main(argv=None): + import argparse + + parser = argparse.ArgumentParser( + prog="stable_audio_tools.data.s3_utils", + description="Stream an S3 object to stdout (used as a WebDataset pipe: source).", + ) + parser.add_argument("url", help="s3://bucket/key to stream") + parser.add_argument("--profile", default=None, help="optional AWS/B2 profile name") + args = parser.parse_args(argv) + stream_object(args.url, profile=args.profile) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_dataset_s3.py b/tests/test_dataset_s3.py new file mode 100644 index 00000000..576d9d3f --- /dev/null +++ b/tests/test_dataset_s3.py @@ -0,0 +1,368 @@ +import importlib.abc +import io +import os +import shlex +import sys +import types +from unittest import mock + +import pytest + +from stable_audio_tools.data import s3_utils as s3 + + +@pytest.fixture(autouse=True) +def _clear_s3_client_cache(): + "_build_s3_client caches clients via lru_cache; clear it so tests don't share state." + s3._build_s3_client.cache_clear() + yield + s3._build_s3_client.cache_clear() + + +class _FakeConfig: + "Minimal stand-in for botocore.config.Config that records its kwargs." + + def __init__(self, **kwargs): + self.user_agent_extra = kwargs.get("user_agent_extra") + self.kwargs = kwargs + + +def _patch_boto3(fake_boto3): + "Patch boto3 + botocore.config so _get_s3_client runs without them installed." + botocore = types.ModuleType("botocore") + botocore.__path__ = [] # mark as a package so submodule import resolves + botocore_config = types.ModuleType("botocore.config") + botocore_config.Config = _FakeConfig + botocore.config = botocore_config + return mock.patch.dict( + "sys.modules", + { + "boto3": fake_boto3, + "botocore": botocore, + "botocore.config": botocore_config, + }, + ) + + +def _fake_session(): + fake_boto3 = mock.MagicMock() + fake_session = mock.MagicMock() + fake_boto3.Session.return_value = fake_session + return fake_boto3, fake_session + + +def _fake_paginator(pages): + "Paginator-like mock; records last paginate(**kwargs) on .last_kwargs." + pag = mock.MagicMock() + pag.last_kwargs = {} + + def paginate(**kwargs): + pag.last_kwargs = kwargs + return iter(pages) + + pag.paginate.side_effect = paginate + return pag + + +def _fake_client(pages=None): + client = mock.MagicMock() + client.get_paginator.return_value = _fake_paginator(pages or []) + return client + + +# ---- client construction ------------------------------------------------- + +def test_get_s3_client_uses_aws_endpoint_url_env(): + fake_boto3, fake_session = _fake_session() + + with mock.patch.dict(os.environ, {"AWS_ENDPOINT_URL": "https://s3.us-west-004.backblazeb2.com"}, clear=False): + with _patch_boto3(fake_boto3): + s3._get_s3_client() + + fake_session.client.assert_called_once() + args, kwargs = fake_session.client.call_args + assert args == ("s3",) + assert kwargs["endpoint_url"] == "https://s3.us-west-004.backblazeb2.com" + assert kwargs["config"].user_agent_extra.startswith("stable-audio-tools/") + + +def test_get_s3_client_default_when_env_unset(): + fake_boto3, fake_session = _fake_session() + + env = {k: v for k, v in os.environ.items() if k != "AWS_ENDPOINT_URL"} + with mock.patch.dict(os.environ, env, clear=True): + with _patch_boto3(fake_boto3): + s3._get_s3_client() + + fake_session.client.assert_called_once() + args, kwargs = fake_session.client.call_args + assert args == ("s3",) + assert kwargs["endpoint_url"] is None + assert kwargs["config"].user_agent_extra.startswith("stable-audio-tools/") + + +def test_get_s3_client_uses_profile_when_given(): + fake_boto3, fake_session = _fake_session() + + with _patch_boto3(fake_boto3): + s3._get_s3_client(profile="myprofile") + + fake_boto3.Session.assert_called_once_with(profile_name="myprofile") + + +def test_s3_client_configured_for_sigv4(): + fake_boto3, fake_session = _fake_session() + + with _patch_boto3(fake_boto3): + s3._get_s3_client() + + _, kwargs = fake_session.client.call_args + # B2 only supports SigV4; the client must be configured for it. + assert kwargs["config"].kwargs.get("signature_version") == "s3v4" + + +def test_get_s3_client_presigns_with_sigv4(): + # Real boto3: the client our code builds must sign with SigV4 (B2 rejects v2). + pytest.importorskip("boto3") + with mock.patch.dict(os.environ, { + "AWS_ENDPOINT_URL": "https://s3.us-west-004.backblazeb2.com", + "AWS_ACCESS_KEY_ID": "test-key-id", + "AWS_SECRET_ACCESS_KEY": "test-secret", + "AWS_DEFAULT_REGION": "us-west-004", + }, clear=False): + client = s3._get_s3_client() + url = client.generate_presigned_url( + "get_object", Params={"Bucket": "b", "Key": "k"}, ExpiresIn=3600) + + assert "X-Amz-Algorithm=AWS4-HMAC-SHA256" in url + assert "AWSAccessKeyId=" not in url # SigV2 marker must be absent + + +def test_get_s3_client_missing_boto3_raises_actionable_error(monkeypatch): + class _Blocker(importlib.abc.MetaPathFinder): + def find_spec(self, name, path, target=None): + if name == "boto3" or name.startswith("botocore"): + raise ModuleNotFoundError(f"No module named {name!r}") + return None + + for mod in ("boto3", "botocore", "botocore.config"): + monkeypatch.delitem(sys.modules, mod, raising=False) + monkeypatch.setattr(sys, "meta_path", [_Blocker(), *sys.meta_path]) + + with pytest.raises(ImportError, match="boto3"): + s3._get_s3_client() + + +# ---- user agent ---------------------------------------------------------- + +def test_user_agent_is_versioned_product_form(): + fake_boto3, fake_session = _fake_session() + + with mock.patch.dict(os.environ, {}, clear=True): + with _patch_boto3(fake_boto3): + s3._get_s3_client() + + _, kwargs = fake_session.client.call_args + ua = kwargs["config"].user_agent_extra + assert ua == s3._user_agent() + assert ua.startswith("stable-audio-tools/") + + +def test_user_agent_appends_extra_argument(): + fake_boto3, fake_session = _fake_session() + + with mock.patch.dict(os.environ, {}, clear=True): + with _patch_boto3(fake_boto3): + s3._get_s3_client(user_agent_extra="myapp/1.0") + + _, kwargs = fake_session.client.call_args + assert kwargs["config"].user_agent_extra == f"{s3._user_agent()} myapp/1.0" + + +def test_user_agent_empty_extra_suppresses_env(): + fake_boto3, fake_session = _fake_session() + + with mock.patch.dict(os.environ, {"STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA": "fromenv/2"}, clear=True): + with _patch_boto3(fake_boto3): + s3._get_s3_client(user_agent_extra="") + + _, kwargs = fake_session.client.call_args + assert kwargs["config"].user_agent_extra == s3._user_agent() + + +def test_user_agent_appends_extra_from_env(): + fake_boto3, fake_session = _fake_session() + + with mock.patch.dict(os.environ, {"STABLE_AUDIO_TOOLS_USER_AGENT_EXTRA": "fromenv/2"}, clear=True): + with _patch_boto3(fake_boto3): + s3._get_s3_client() + + _, kwargs = fake_session.client.call_args + assert kwargs["config"].user_agent_extra == f"{s3._user_agent()} fromenv/2" + + +# ---- _parse_s3_url ------------------------------------------------------- + +def test_parse_s3_url(): + assert s3._parse_s3_url("s3://bucket/a/b.tar") == ("bucket", "a/b.tar") + assert s3._parse_s3_url("s3://bucket") == ("bucket", "") + with pytest.raises(ValueError): + s3._parse_s3_url("https://not-s3/x") + with pytest.raises(ValueError): + s3._parse_s3_url("s3://") # empty bucket + + +# ---- get_s3_contents ----------------------------------------------------- + +def test_get_s3_contents_returns_keys_relative_to_dataset_path(): + # Matches the legacy aws s3 ls output shape: keys relative to dataset_path. + pages = [ + {"Contents": [ + {"Key": "prefix/a.tar"}, + {"Key": "prefix/sub/b.tar"}, + {"Key": "prefix/"}, # directory marker -> skipped + ]}, + ] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + keys = s3.get_s3_contents("s3://bucket/prefix/", recursive=True) + + client.get_paginator.assert_called_once_with("list_objects_v2") + pag = client.get_paginator.return_value + assert pag.last_kwargs == {"Bucket": "bucket", "Prefix": "prefix/"} + assert keys == ["a.tar", "sub/b.tar"] + + +def test_get_s3_contents_non_recursive_adds_delimiter(): + client = _fake_client(pages=[{"Contents": []}]) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + s3.get_s3_contents("s3://bucket/prefix/", recursive=False) + + pag = client.get_paginator.return_value + assert pag.last_kwargs == {"Bucket": "bucket", "Prefix": "prefix/", "Delimiter": "/"} + + +def test_get_s3_contents_non_recursive_strips_prefix_from_keys(): + pages = [{"Contents": [{"Key": "prefix/a.tar"}, {"Key": "prefix/b.tar"}]}] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + keys = s3.get_s3_contents("s3://bucket/prefix/", recursive=False) + + assert keys == ["a.tar", "b.tar"] + + +def test_get_s3_contents_applies_filter(): + pages = [{"Contents": [ + {"Key": "prefix/a.tar"}, + {"Key": "prefix/b.txt"}, + {"Key": "prefix/c.tar"}, + ]}] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + keys = s3.get_s3_contents("s3://bucket/prefix/", filter_str="tar", recursive=True) + + assert keys == ["a.tar", "c.tar"] + + +def test_get_s3_contents_rejects_non_s3_url(): + with pytest.raises(ValueError): + s3.get_s3_contents("not-an-s3-url/") + + +# ---- shard_pipe_command + streaming ------------------------------------- + +def test_shard_pipe_command_builds_streaming_pipe(): + cmd = s3.shard_pipe_command("s3://bucket/key.tar") + assert cmd.startswith(f"pipe:{shlex.quote(sys.executable)} -m stable_audio_tools.data.s3_utils ") + assert "s3://bucket/key.tar" in cmd + assert "--profile" not in cmd + # No credentials or signatures are ever placed on the command line. + assert "X-Amz-" not in cmd and "Signature" not in cmd + + +def test_shard_pipe_command_includes_profile(): + cmd = s3.shard_pipe_command("s3://bucket/key.tar", profile="myprof") + assert f"--profile {shlex.quote('myprof')}" in cmd + + +def test_shard_pipe_command_rejects_non_s3(): + with pytest.raises(ValueError): + s3.shard_pipe_command("https://nope/x") + + +def test_get_all_s3_urls_emits_streaming_pipe_command(): + pages = [{"Contents": [{"Key": "name/train/shard-000.tar"}]}] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + urls = s3.get_all_s3_urls( + names=["name"], + subsets=["train"], + s3_url_prefix="s3://bucket", + recursive=True, + filter_str="tar", + ) + + assert urls == [s3.shard_pipe_command("s3://bucket/name/train/shard-000.tar")] + assert urls[0].startswith("pipe:") + assert "s3://bucket/name/train/shard-000.tar" in urls[0] + + +def test_get_all_s3_urls_with_full_url_names_and_no_prefix(): + # s3_url_prefix=None: each name must already be a full s3:// URL. + pages = [{"Contents": [{"Key": "prefix/name/train/shard-000.tar"}]}] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + urls = s3.get_all_s3_urls( + names=["s3://bucket/prefix/name"], + subsets=["train"], + s3_url_prefix=None, + recursive=True, + filter_str="tar", + ) + + assert urls == [s3.shard_pipe_command("s3://bucket/prefix/name/train/shard-000.tar")] + + +def test_get_all_s3_urls_passes_profile_into_command(): + pages = [{"Contents": [{"Key": "name/train/shard.tar"}]}] + client = _fake_client(pages=pages) + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + urls = s3.get_all_s3_urls( + names=["name"], subsets=["train"], s3_url_prefix="s3://bucket", + profiles={"name": "myprof"}, + ) + + assert f"--profile {shlex.quote('myprof')}" in urls[0] + + +def test_get_all_s3_urls_defaults_no_names_returns_empty(): + assert s3.get_all_s3_urls() == [] + + +def test_stream_object_writes_object_bytes(): + body = mock.MagicMock() + body.iter_chunks.return_value = [b"abc", b"def"] + client = mock.MagicMock() + client.get_object.return_value = {"Body": body} + out = io.BytesIO() + + with mock.patch.object(s3, "_get_s3_client", return_value=client): + s3.stream_object("s3://bucket/key", out=out) + + assert out.getvalue() == b"abcdef" + client.get_object.assert_called_once_with(Bucket="bucket", Key="key") + + +def test_main_invokes_stream_object_with_profile(): + with mock.patch.object(s3, "stream_object") as m: + rc = s3.main(["s3://bucket/key", "--profile", "myprof"]) + + assert rc == 0 + m.assert_called_once_with("s3://bucket/key", profile="myprof") diff --git a/uv.lock b/uv.lock index 5769707c..aa6c77ec 100644 --- a/uv.lock +++ b/uv.lock @@ -200,6 +200,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/ab/8df927d3f0951cf67ca5973d89b35bcbda1777a4c78bf90a853d02d91285/auraloss-0.4.0-py3-none-any.whl", hash = "sha256:7ca1cfff0d04db9c1269038a1c527fc38bc4756dd33bfff115889a3461d97d37", size = 16743, upload-time = "2023-04-21T09:21:44.905Z" }, ] +[[package]] +name = "boto3" +version = "1.43.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/da/229987ebb70daf5928f959aa8f4dd77dfcf425e6b0e7ff03aaef61ccc333/boto3-1.43.19.tar.gz", hash = "sha256:8b84704719dd3960ac12a8f37d9ff5adb853715baa9742f84fdbe2de0305c4cb", size = 113225, upload-time = "2026-06-01T19:33:06.514Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/cc/77097be39d83068f864767b710cee0d8f9cd61331de816dd2675a596c328/boto3-1.43.19-py3-none-any.whl", hash = "sha256:ec6825193b75fbb6bfbf12181e4960d00ad2f404343586765394ce620e63783c", size = 140535, upload-time = "2026-06-01T19:33:03.758Z" }, +] + +[[package]] +name = "botocore" +version = "1.43.19" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/a7/298986789785b74a954e2347114993be7e6b070417159125a6865f2687b6/botocore-1.43.19.tar.gz", hash = "sha256:18ac2fdd76c89b940707eb10493ff58678adad337d03215caec2d408ccd43cc0", size = 15435441, upload-time = "2026-06-01T19:32:53.126Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/75/fe4d45bdd08afd66f3d5273db58f3d8a29365e52ce3a0851f7f5e5900943/botocore-1.43.19-py3-none-any.whl", hash = "sha256:99dbdccbf748974750601e805cecc9362a85d11fee89d6d58cd3f4ff302e6ff9", size = 15117709, upload-time = "2026-06-01T19:32:47.871Z" }, +] + [[package]] name = "braceexpand" version = "0.1.7" @@ -952,6 +980,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + [[package]] name = "joblib" version = "1.5.3" @@ -1906,6 +1943,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, ] +[[package]] +name = "s3transfer" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/1f/12417f7f493fc45e1f9fd5d4a9b6c125cf8d2cf3f8ddbdfab3e76406e9d6/s3transfer-0.18.0.tar.gz", hash = "sha256:3760b8b7ec1315da54048b2d626276732bee4300d054d492d4e1d43e20d4ecbd", size = 160560, upload-time = "2026-05-28T19:39:09.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/58/a58fc997655386daa2e25784e30c288aa3e3819e401f77029ee4899fb55a/s3transfer-0.18.0-py3-none-any.whl", hash = "sha256:239c13b09e65ad0346e1be7348b8a202dcad44ac7ea7c6eb858fc881dce739b6", size = 88572, upload-time = "2026-05-28T19:39:07.999Z" }, +] + [[package]] name = "safehttpx" version = "0.1.7" @@ -2150,6 +2199,7 @@ dependencies = [ [package.optional-dependencies] all = [ { name = "auraloss" }, + { name = "boto3" }, { name = "descript-audio-codec" }, { name = "encodec" }, { name = "gradio" }, @@ -2164,6 +2214,7 @@ all = [ ] train = [ { name = "auraloss" }, + { name = "boto3" }, { name = "descript-audio-codec" }, { name = "encodec" }, { name = "inf-cl" }, @@ -2183,6 +2234,7 @@ ui = [ requires-dist = [ { name = "alias-free-torch", specifier = "==0.0.6" }, { name = "auraloss", marker = "extra == 'train'", specifier = "==0.4.0" }, + { name = "boto3", marker = "extra == 'train'", specifier = ">=1.26,<2" }, { name = "descript-audio-codec", marker = "extra == 'train'", specifier = "==1.0.0" }, { name = "dill" }, { name = "einops" },