Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ allow-direct-references = true
files = ["requirements.txt"]

[tool.hatch.build.targets.wheel]
packages = ["."]
packages = ["wsds"]

[tool.ruff]
line-length = 120
Expand All @@ -24,6 +24,12 @@ indent-width = 4
ignore = ["E203", "E501", "E731"]
extend-select = ["I"]

[project.optional-dependencies]
test = ["pytest", "tqdm"]

[tool.pytest.ini_options]
testpaths = ["tests"]

# --- build-data --- #
[build-system]
requires = ["hatchling", "hatch-requirements-txt"]
Expand Down
Empty file added tests/__init__.py
Empty file.
262 changes: 262 additions & 0 deletions tests/test_shard_from_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
import struct

import pyarrow as pa

from wsds.ws_tools import shard_from_audio_dir


def make_wav(path, num_samples=100, sample_rate=16000, num_channels=1):
"""Write a minimal valid WAV file."""
bits_per_sample = 16
data_size = num_samples * num_channels * (bits_per_sample // 8)
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF",
36 + data_size,
b"WAVE",
b"fmt ",
16,
1, # PCM
num_channels,
sample_rate,
sample_rate * num_channels * bits_per_sample // 8,
num_channels * bits_per_sample // 8,
bits_per_sample,
b"data",
data_size,
)
pcm = b"\x00\x01" * num_samples * num_channels
path.write_bytes(header + pcm)


def _collect_shards(output_dir):
"""Read all .wsds shards in output_dir, return list of (keys, audio_bytes, audio_types) per shard."""
shards = []
for shard_path in sorted(output_dir.glob("*.wsds")):
reader = pa.ipc.open_file(str(shard_path))
table = reader.read_all()
keys = table.column("__key__").to_pylist()
audio = [v.as_py() for v in table.column("audio")]
audio_types = table.column("audio_type").to_pylist()
shards.append((keys, audio, audio_types))
return shards


class TestShardFromAudioDir:
def test_basic_sharding(self, tmp_path):
"""Files are split into correct number of shards and content matches."""
input_dir = tmp_path / "audio_in"
output_dir = tmp_path / "audio_out"
input_dir.mkdir()

stems = [f"clip_{i:03d}" for i in range(5)]
original_bytes = {}
for stem in stems:
p = input_dir / f"{stem}.wav"
make_wav(p, num_samples=50 + len(stem))
original_bytes[stem] = p.read_bytes()

shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2)

shards = _collect_shards(output_dir)
# 5 files / 2 per shard = 3 shards
assert len(shards) == 3

all_keys = []
all_audio = {}
all_types = []
for keys, audio, audio_types in shards:
all_keys.extend(keys)
for k, a in zip(keys, audio):
all_audio[k] = a
all_types.extend(audio_types)

assert sorted(all_keys) == sorted(stems)
for stem in stems:
assert all_audio[stem] == original_bytes[stem]
assert all(t == "wav" for t in all_types)

def test_key_prefix(self, tmp_path):
"""key_prefix is prepended to each key."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

make_wav(input_dir / "hello.wav")

shard_from_audio_dir(str(input_dir), str(output_dir), key_prefix="dataset1")

shards = _collect_shards(output_dir)
keys = shards[0][0]
assert keys == ["dataset1/hello"]

def test_key_fn(self, tmp_path):
"""key_fn transforms the key."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

make_wav(input_dir / "original.wav")

shard_from_audio_dir(
str(input_dir), str(output_dir), key_fn=lambda s: s.upper()
)

shards = _collect_shards(output_dir)
keys = shards[0][0]
assert keys == ["ORIGINAL"]

def test_key_fn_with_prefix(self, tmp_path):
"""key_fn receives the prefixed stem."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

make_wav(input_dir / "file.wav")

shard_from_audio_dir(
str(input_dir),
str(output_dir),
key_prefix="pfx",
key_fn=lambda s: s.replace("/", "__"),
)

shards = _collect_shards(output_dir)
keys = shards[0][0]
assert keys == ["pfx__file"]

def test_oversized_files_skipped(self, tmp_path, monkeypatch):
"""Files exceeding the Arrow byte limit are skipped."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

make_wav(input_dir / "small.wav", num_samples=10)
make_wav(input_dir / "big.wav", num_samples=100)

small_size = (input_dir / "small.wav").stat().st_size
large_size = (input_dir / "big.wav").stat().st_size

# Pick a fake limit between the two file sizes so big.wav gets skipped
fake_limit = (small_size + large_size) // 2

# Patch the read_bytes to attach a fake size, then patch len check via
# a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES.
# Since MAX_ARROW_BYTES is a local, we instead wrap the whole function
# by replacing it with one that sets a lower limit.
import wsds.ws_tools as mod

Copilot AI Feb 10, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Module 'wsds.ws_tools' is imported with both 'import' and 'import from'.

Copilot uses AI. Check for mistakes.

orig_code = mod.shard_from_audio_dir.__code__

# Replace the constant in the code object's co_consts
new_consts = tuple(
fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts
)
new_code = orig_code.replace(co_consts=new_consts)
monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code)

Comment on lines +143 to +157

Copilot AI Feb 10, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test mutates shard_from_audio_dir.__code__.co_consts to change a local constant. This is brittle (depends on CPython implementation details and the constant appearing exactly once) and can break with small refactors. Prefer making the max-arrow-bytes limit injectable (e.g., a parameter or a module-level constant) so the test can monkeypatch it safely.

Suggested change
# Patch the read_bytes to attach a fake size, then patch len check via
# a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES.
# Since MAX_ARROW_BYTES is a local, we instead wrap the whole function
# by replacing it with one that sets a lower limit.
import wsds.ws_tools as mod
orig_code = mod.shard_from_audio_dir.__code__
# Replace the constant in the code object's co_consts
new_consts = tuple(
fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts
)
new_code = orig_code.replace(co_consts=new_consts)
monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code)
# Patch the Arrow byte-limit via a module-level constant so that
# files larger than fake_limit are skipped.
import wsds.ws_tools as mod
# Override the max-bytes limit used by shard_from_audio_dir.
monkeypatch.setattr(mod, "MAX_ARROW_BYTES", fake_limit)

Copilot uses AI. Check for mistakes.
shard_from_audio_dir(str(input_dir), str(output_dir))

shards = _collect_shards(output_dir)
all_keys = [k for keys, _, _ in shards for k in keys]
assert "small" in all_keys
assert "big" not in all_keys

def test_empty_input_dir(self, tmp_path):
"""Empty input directory produces no shards."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

shard_from_audio_dir(str(input_dir), str(output_dir))

assert list(output_dir.glob("*.wsds")) == []

def test_subdirectory_files(self, tmp_path):
"""Audio files in subdirectories use relative path as key."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
sub = input_dir / "speaker1"
sub.mkdir(parents=True)

make_wav(sub / "utt.wav")

shard_from_audio_dir(str(input_dir), str(output_dir))

shards = _collect_shards(output_dir)
keys = shards[0][0]
assert keys == ["speaker1/utt"]

def test_shard_naming(self, tmp_path):
"""Shard files are named audio-NNNNN.wsds."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "out"
input_dir.mkdir()

for i in range(4):
make_wav(input_dir / f"f{i}.wav")

shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2)

shard_names = sorted(p.name for p in output_dir.glob("*.wsds"))
assert shard_names == ["audio-00000.wsds", "audio-00001.wsds"]

def test_init_index_creates_audio_subdir(self, tmp_path):
"""When init_index=True, shards are written to audio/ subdirectory and index is created."""
input_dir = tmp_path / "in"
output_dir = tmp_path / "dataset"
input_dir.mkdir()

# Create some test audio files
for i in range(3):
make_wav(input_dir / f"file{i}.wav")

shard_from_audio_dir(
str(input_dir),
str(output_dir),
max_files_per_shard=2,
init_index=True,
require_audio_duration=False, # Skip audio duration requirement for test
)

# Shards should be in output_dir/audio/
audio_dir = output_dir / "audio"
assert audio_dir.exists()
assert audio_dir.is_dir()

# Check shards are in the audio subdirectory
shard_files = sorted(audio_dir.glob("*.wsds"))
assert len(shard_files) == 2 # 3 files / 2 per shard = 2 shards

# Check index was created at dataset root
index_file = output_dir / "index.sqlite3"
assert index_file.exists()

def test_init_index_with_audio_named_output(self, tmp_path):
"""When init_index=True and output_dir is already named 'audio', don't create nested audio/audio/."""
input_dir = tmp_path / "in"
dataset_root = tmp_path / "dataset"
output_dir = dataset_root / "audio"
input_dir.mkdir()

make_wav(input_dir / "test.wav")

shard_from_audio_dir(
str(input_dir),
str(output_dir),
init_index=True,
require_audio_duration=False,
)

# Shards should be in output_dir (which is already named 'audio')
shard_files = sorted(output_dir.glob("*.wsds"))
assert len(shard_files) == 1

# Check we didn't create audio/audio/
nested_audio = output_dir / "audio"
assert not nested_audio.exists()

# Index should be at dataset_root (parent of audio/)
index_file = dataset_root / "index.sqlite3"
assert index_file.exists()

2 changes: 1 addition & 1 deletion wsds/ws_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_reader(self, sample_rate=None):
if self.reader is None or sample_rate_switch:
try:
from torchcodec.decoders import AudioDecoder
except ImportError:
except Exception:
AudioDecoder = CompatAudioDecoder

Comment on lines 127 to 132

Copilot AI Feb 10, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Catching Exception around the torchcodec import will also swallow unrelated issues (e.g., internal bugs in torchcodec or environment problems) and silently fall back to the torchaudio-based decoder, making failures harder to diagnose. Consider catching a narrower set of exceptions (e.g., ImportError, OSError, RuntimeError) and/or logging the exception when falling back.

Copilot uses AI. Check for mistakes.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

reader = AudioDecoder(to_filelike(self.src), sample_rate=sample_rate)
Expand Down
Loading