Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,7 @@ cython_debug/
*.wav
wandb/*
*.out
test_*
test_*
!tests/test_*.py
# macOS
.DS_Store
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ The following properties are defined in the top level of the model configuration
## Dataset config
`stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)

## S3-compatible storage
The S3 dataset loader uses `boto3`, which ships in the `train` extra. If you installed without that extra, add it with `pip install boto3` (or `pip install "stable-audio-tools[train]"`).

The loader honors the `AWS_ENDPOINT_URL` environment variable, so you can point it at S3-compatible storage providers (for example Backblaze B2, MinIO, Cloudflare R2, or other compatible endpoints) without changing the dataset config.

Example:
```bash
export AWS_ENDPOINT_URL=<s3-compatible-endpoint>
export AWS_DEFAULT_REGION=<endpoint-region>
export AWS_ACCESS_KEY_ID=<access-key-id>
export AWS_SECRET_ACCESS_KEY=<secret-access-key>
```

When `AWS_ENDPOINT_URL` is unset, the loader uses default AWS S3, so existing setups are unaffected.

# Todo
- [ ] Add troubleshooting section
- [ ] Add contribution guidelines
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ dependencies = [
[project.optional-dependencies]
train = [
"auraloss==0.4.0",
"boto3>=1.26,<2",
"descript-audio-codec==1.0.0",
"encodec==0.1.1",
"inf-cl",
Expand Down
117 changes: 11 additions & 106 deletions stable_audio_tools/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import importlib
import numpy as np
import io
import json
import os
import dill
import posixpath
import random
import re
import subprocess
import time
from os import path
from typing import Callable, List, Optional

import dill
import numpy as np
import torch
import torchaudio
import webdataset as wds

from os import path
from torch import nn
from torchaudio import transforms as T
from typing import Optional, Callable, List

from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, VolumeNorm, strip_trailing_silence
from .utils import Mono, PadCrop_Normalized_T, PhaseFlipper, Stereo, VolumeNorm, strip_trailing_silence

AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")

Expand Down Expand Up @@ -481,105 +478,13 @@ def __getitem__(self, idx):
print(f'Couldn\'t load file {latent_filename}: {e}')
return self[random.randrange(len(self))]

# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py

def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
"""
Returns a list of full S3 paths to files in a given S3 bucket and directory path.
"""
# Ensure dataset_path ends with a trailing slash
if dataset_path != '' and not dataset_path.endswith('/'):
dataset_path += '/'
# Use posixpath to construct the S3 URL path
bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
# Construct the `aws s3 ls` command
cmd = ['aws', 's3', 'ls', bucket_path]

if profile is not None:
cmd.extend(['--profile', profile])

if recursive:
# Add the --recursive flag if requested
cmd.append('--recursive')

# Run the `aws s3 ls` command and capture the output
run_ls = subprocess.run(cmd, capture_output=True, check=True)
# Split the output into lines and strip whitespace from each line
contents = run_ls.stdout.decode('utf-8').split('\n')
contents = [x.strip() for x in contents if x]
# Remove the timestamp from lines that begin with a timestamp
contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
# Construct a full S3 path for each file in the contents list
contents = [posixpath.join(s3_url_prefix or '', x)
for x in contents if not x.endswith('/')]
# Apply the filter, if specified
if filter:
contents = [x for x in contents if filter in x]
# Remove redundant directory names in the S3 URL
if recursive:
# Get the main directory name from the S3 URL
main_dir = "/".join(bucket_path.split('/')[3:])
# Remove the redundant directory names from each file path
contents = [x.replace(f'{main_dir}', '').replace(
'//', '/') for x in contents]
# Print debugging information, if requested
if debug:
print("contents = \n", contents)
# Return the list of S3 paths to files
return contents


def get_all_s3_urls(
names=[], # list of all valid [LAION AudioDataset] dataset names
# list of subsets you want from those datasets, e.g. ['train','valid']
subsets=[''],
s3_url_prefix=None, # prefix for those dataset names
recursive=True, # recursively list all tar files in all subdirs
filter_str='tar', # only grab files with this substring
# print debugging info -- note: info displayed likely to change at dev's whims
debug=False,
profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
):
"get urls of shards (tar files) for multiple datasets in one s3 bucket"
urls = []
for name in names:
# If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
if s3_url_prefix is None:
contents_str = name
else:
# Construct the S3 path using the s3_url_prefix and the current name value
contents_str = posixpath.join(s3_url_prefix, name)
if debug:
print(f"get_all_s3_urls: {contents_str}:")
for subset in subsets:
subset_str = posixpath.join(contents_str, subset)
if debug:
print(f"subset_str = {subset_str}")
# Get the list of tar files in the current subset directory
profile = profiles.get(name, None)
tar_list = get_s3_contents(
subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
for tar in tar_list:
# Escape spaces and parentheses in the tar filename for use in the shell command
tar = tar.replace(" ", "\ ").replace(
"(", "\(").replace(")", "\)")
# Construct the S3 path to the current tar file
s3_path = posixpath.join(name, subset, tar) + " -"
# Construct the AWS CLI command to download the current tar file
if s3_url_prefix is None:
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
else:
request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
if profiles.get(name):
request_str += f" --profile {profiles.get(name)}"
if debug:
print("request_str = ", request_str)
# Add the constructed URL to the list of URLs
urls.append(request_str)
return urls
# S3 helpers live in the import-light s3_utils module (no torch) so it can also
# run as a `pipe:` subprocess that streams individual shards. The previously
# public functions are re-exported here for backwards compatibility.
from .s3_utils import get_all_s3_urls, get_s3_contents # noqa: E402,F401


# WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
Expand Down
Loading