-
Notifications
You must be signed in to change notification settings - Fork 1
Feat/shard from audio #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
03dd37a
d90d924
c07c2ce
a9998de
11bc42b
7e290f2
818cc03
0917788
173c5da
e9dbbf6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,262 @@ | ||||||||||||||||||||||||||||||||||||||||||
| import struct | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| import pyarrow as pa | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| from wsds.ws_tools import shard_from_audio_dir | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def make_wav(path, num_samples=100, sample_rate=16000, num_channels=1): | ||||||||||||||||||||||||||||||||||||||||||
| """Write a minimal valid WAV file.""" | ||||||||||||||||||||||||||||||||||||||||||
| bits_per_sample = 16 | ||||||||||||||||||||||||||||||||||||||||||
| data_size = num_samples * num_channels * (bits_per_sample // 8) | ||||||||||||||||||||||||||||||||||||||||||
| header = struct.pack( | ||||||||||||||||||||||||||||||||||||||||||
| "<4sI4s4sIHHIIHH4sI", | ||||||||||||||||||||||||||||||||||||||||||
| b"RIFF", | ||||||||||||||||||||||||||||||||||||||||||
| 36 + data_size, | ||||||||||||||||||||||||||||||||||||||||||
| b"WAVE", | ||||||||||||||||||||||||||||||||||||||||||
| b"fmt ", | ||||||||||||||||||||||||||||||||||||||||||
| 16, | ||||||||||||||||||||||||||||||||||||||||||
| 1, # PCM | ||||||||||||||||||||||||||||||||||||||||||
| num_channels, | ||||||||||||||||||||||||||||||||||||||||||
| sample_rate, | ||||||||||||||||||||||||||||||||||||||||||
| sample_rate * num_channels * bits_per_sample // 8, | ||||||||||||||||||||||||||||||||||||||||||
| num_channels * bits_per_sample // 8, | ||||||||||||||||||||||||||||||||||||||||||
| bits_per_sample, | ||||||||||||||||||||||||||||||||||||||||||
| b"data", | ||||||||||||||||||||||||||||||||||||||||||
| data_size, | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| pcm = b"\x00\x01" * num_samples * num_channels | ||||||||||||||||||||||||||||||||||||||||||
| path.write_bytes(header + pcm) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def _collect_shards(output_dir): | ||||||||||||||||||||||||||||||||||||||||||
| """Read all .wsds shards in output_dir, return list of (keys, audio_bytes, audio_types) per shard.""" | ||||||||||||||||||||||||||||||||||||||||||
| shards = [] | ||||||||||||||||||||||||||||||||||||||||||
| for shard_path in sorted(output_dir.glob("*.wsds")): | ||||||||||||||||||||||||||||||||||||||||||
| reader = pa.ipc.open_file(str(shard_path)) | ||||||||||||||||||||||||||||||||||||||||||
| table = reader.read_all() | ||||||||||||||||||||||||||||||||||||||||||
| keys = table.column("__key__").to_pylist() | ||||||||||||||||||||||||||||||||||||||||||
| audio = [v.as_py() for v in table.column("audio")] | ||||||||||||||||||||||||||||||||||||||||||
| audio_types = table.column("audio_type").to_pylist() | ||||||||||||||||||||||||||||||||||||||||||
| shards.append((keys, audio, audio_types)) | ||||||||||||||||||||||||||||||||||||||||||
| return shards | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| class TestShardFromAudioDir: | ||||||||||||||||||||||||||||||||||||||||||
| def test_basic_sharding(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||||||
| """Files are split into correct number of shards and content matches.""" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir = tmp_path / "audio_in" | ||||||||||||||||||||||||||||||||||||||||||
| output_dir = tmp_path / "audio_out" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir.mkdir() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| stems = [f"clip_{i:03d}" for i in range(5)] | ||||||||||||||||||||||||||||||||||||||||||
| original_bytes = {} | ||||||||||||||||||||||||||||||||||||||||||
| for stem in stems: | ||||||||||||||||||||||||||||||||||||||||||
| p = input_dir / f"{stem}.wav" | ||||||||||||||||||||||||||||||||||||||||||
| make_wav(p, num_samples=50 + len(stem)) | ||||||||||||||||||||||||||||||||||||||||||
| original_bytes[stem] = p.read_bytes() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shard_from_audio_dir(str(input_dir), str(output_dir), max_files_per_shard=2) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shards = _collect_shards(output_dir) | ||||||||||||||||||||||||||||||||||||||||||
| # 5 files / 2 per shard = 3 shards | ||||||||||||||||||||||||||||||||||||||||||
| assert len(shards) == 3 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| all_keys = [] | ||||||||||||||||||||||||||||||||||||||||||
| all_audio = {} | ||||||||||||||||||||||||||||||||||||||||||
| all_types = [] | ||||||||||||||||||||||||||||||||||||||||||
| for keys, audio, audio_types in shards: | ||||||||||||||||||||||||||||||||||||||||||
| all_keys.extend(keys) | ||||||||||||||||||||||||||||||||||||||||||
| for k, a in zip(keys, audio): | ||||||||||||||||||||||||||||||||||||||||||
| all_audio[k] = a | ||||||||||||||||||||||||||||||||||||||||||
| all_types.extend(audio_types) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| assert sorted(all_keys) == sorted(stems) | ||||||||||||||||||||||||||||||||||||||||||
| for stem in stems: | ||||||||||||||||||||||||||||||||||||||||||
| assert all_audio[stem] == original_bytes[stem] | ||||||||||||||||||||||||||||||||||||||||||
| assert all(t == "wav" for t in all_types) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def test_key_prefix(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||||||
| """key_prefix is prepended to each key.""" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir = tmp_path / "in" | ||||||||||||||||||||||||||||||||||||||||||
| output_dir = tmp_path / "out" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir.mkdir() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| make_wav(input_dir / "hello.wav") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shard_from_audio_dir(str(input_dir), str(output_dir), key_prefix="dataset1") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shards = _collect_shards(output_dir) | ||||||||||||||||||||||||||||||||||||||||||
| keys = shards[0][0] | ||||||||||||||||||||||||||||||||||||||||||
| assert keys == ["dataset1/hello"] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def test_key_fn(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||||||
| """key_fn transforms the key.""" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir = tmp_path / "in" | ||||||||||||||||||||||||||||||||||||||||||
| output_dir = tmp_path / "out" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir.mkdir() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| make_wav(input_dir / "original.wav") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shard_from_audio_dir( | ||||||||||||||||||||||||||||||||||||||||||
| str(input_dir), str(output_dir), key_fn=lambda s: s.upper() | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shards = _collect_shards(output_dir) | ||||||||||||||||||||||||||||||||||||||||||
| keys = shards[0][0] | ||||||||||||||||||||||||||||||||||||||||||
| assert keys == ["ORIGINAL"] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def test_key_fn_with_prefix(self, tmp_path): | ||||||||||||||||||||||||||||||||||||||||||
| """key_fn receives the prefixed stem.""" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir = tmp_path / "in" | ||||||||||||||||||||||||||||||||||||||||||
| output_dir = tmp_path / "out" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir.mkdir() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| make_wav(input_dir / "file.wav") | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shard_from_audio_dir( | ||||||||||||||||||||||||||||||||||||||||||
| str(input_dir), | ||||||||||||||||||||||||||||||||||||||||||
| str(output_dir), | ||||||||||||||||||||||||||||||||||||||||||
| key_prefix="pfx", | ||||||||||||||||||||||||||||||||||||||||||
| key_fn=lambda s: s.replace("/", "__"), | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| shards = _collect_shards(output_dir) | ||||||||||||||||||||||||||||||||||||||||||
| keys = shards[0][0] | ||||||||||||||||||||||||||||||||||||||||||
| assert keys == ["pfx__file"] | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| def test_oversized_files_skipped(self, tmp_path, monkeypatch): | ||||||||||||||||||||||||||||||||||||||||||
| """Files exceeding the Arrow byte limit are skipped.""" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir = tmp_path / "in" | ||||||||||||||||||||||||||||||||||||||||||
| output_dir = tmp_path / "out" | ||||||||||||||||||||||||||||||||||||||||||
| input_dir.mkdir() | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| make_wav(input_dir / "small.wav", num_samples=10) | ||||||||||||||||||||||||||||||||||||||||||
| make_wav(input_dir / "big.wav", num_samples=100) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| small_size = (input_dir / "small.wav").stat().st_size | ||||||||||||||||||||||||||||||||||||||||||
| large_size = (input_dir / "big.wav").stat().st_size | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Pick a fake limit between the two file sizes so big.wav gets skipped | ||||||||||||||||||||||||||||||||||||||||||
| fake_limit = (small_size + large_size) // 2 | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Patch the read_bytes to attach a fake size, then patch len check via | ||||||||||||||||||||||||||||||||||||||||||
| # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. | ||||||||||||||||||||||||||||||||||||||||||
| # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function | ||||||||||||||||||||||||||||||||||||||||||
| # by replacing it with one that sets a lower limit. | ||||||||||||||||||||||||||||||||||||||||||
| import wsds.ws_tools as mod | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| orig_code = mod.shard_from_audio_dir.__code__ | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Replace the constant in the code object's co_consts | ||||||||||||||||||||||||||||||||||||||||||
| new_consts = tuple( | ||||||||||||||||||||||||||||||||||||||||||
| fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| new_code = orig_code.replace(co_consts=new_consts) | ||||||||||||||||||||||||||||||||||||||||||
| monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+143
to
+157
|
||||||||||||||||||||||||||||||||||||||||||
| # Patch the read_bytes to attach a fake size, then patch len check via | |
| # a wrapper around shard_from_audio_dir that lowers MAX_ARROW_BYTES. | |
| # Since MAX_ARROW_BYTES is a local, we instead wrap the whole function | |
| # by replacing it with one that sets a lower limit. | |
| import wsds.ws_tools as mod | |
| orig_code = mod.shard_from_audio_dir.__code__ | |
| # Replace the constant in the code object's co_consts | |
| new_consts = tuple( | |
| fake_limit if c == 2_140_000_000 else c for c in orig_code.co_consts | |
| ) | |
| new_code = orig_code.replace(co_consts=new_consts) | |
| monkeypatch.setattr(mod.shard_from_audio_dir, "__code__", new_code) | |
| # Patch the Arrow byte-limit via a module-level constant so that | |
| # files larger than fake_limit are skipped. | |
| import wsds.ws_tools as mod | |
| # Override the max-bytes limit used by shard_from_audio_dir. | |
| monkeypatch.setattr(mod, "MAX_ARROW_BYTES", fake_limit) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,7 +127,7 @@ def get_reader(self, sample_rate=None): | |
| if self.reader is None or sample_rate_switch: | ||
| try: | ||
| from torchcodec.decoders import AudioDecoder | ||
| except ImportError: | ||
| except Exception: | ||
| AudioDecoder = CompatAudioDecoder | ||
|
|
||
|
Comment on lines
127
to
132
|
||
| reader = AudioDecoder(to_filelike(self.src), sample_rate=sample_rate) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module 'wsds.ws_tools' is imported with both 'import' and 'import from'.