diff --git a/README.md b/README.md index 1895568..a1febc5 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ shape: (5_271_939, 3) ```pycon >>> x['audio'] -WSAudio(audio_reader=AudioReader(src=, sample_rate=None), tstart=614.46246, tend=627.3976) +WSAudioSegment(episode=WSAudioEpisode(src=, sample_rate=None), tstart=614.46246, tend=627.3976) ``` diff --git a/requirements.txt b/requirements.txt index 4fe9e43..64cde6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ numpy polars>=1.36.1 pyarrow>=20 torch -torchaudio # torchcodec – optional, it causes serious performance regressions diff --git a/scripts/test_audio_backends.sh b/scripts/test_audio_backends.sh new file mode 100755 index 0000000..48bf02f --- /dev/null +++ b/scripts/test_audio_backends.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# Run the full test suite in isolated environments, one per decoder backend. +# +# Usage: +# ./scripts/test_audio_backends.sh +# +# Each environment installs only one decoder backend so we verify that the +# interfaces work correctly regardless of which backend is present. + +TEST_COMMAND=${1:-"python -m tests"} + +parallel --tag --lb \ + "uv run --isolated --with {} $TEST_COMMAND" \ + ::: humecodec "torchaudio<2.9" torchcodec diff --git a/tests.py b/tests.py index 39a9acb..41ffc57 100644 --- a/tests.py +++ b/tests.py @@ -2,7 +2,7 @@ import unittest import wsds -from wsds import ws_dataset, ws_shard, ws_sink +from wsds import ws_dataset, ws_shard, ws_sink, ws_audio, audio_codec def load_tests(loader, tests, ignore): @@ -10,6 +10,8 @@ def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(ws_dataset)) tests.addTests(doctest.DocTestSuite(ws_shard)) tests.addTests(doctest.DocTestSuite(ws_sink)) + tests.addTests(doctest.DocTestSuite(ws_audio)) + tests.addTests(doctest.DocTestSuite(audio_codec)) tests.addTests(doctest.DocFileSuite("README.md")) return tests diff --git a/wsds/audio_codec.py b/wsds/audio_codec.py index 2c64c24..d546e49 100644 --- a/wsds/audio_codec.py +++ b/wsds/audio_codec.py @@ -2,236 +2,299 @@ This module contains all audio encoding/decoding logic, separated from the data model layer in ws_audio.py. It provides: -- Decoder backends (TorchFFmpegAudioDecoder, CompatAudioDecoder) -- A factory for creating decoders with automatic backend selection -- MP3 encoding with multi-backend fallback +- AudioDecoder: unified decoder with automatic backend selection (humecodec or torchaudio) +- encode_audio(): multi-backend encoder (humecodec -> torchcodec -> torchaudio) - HTML audio rendering utility """ from __future__ import annotations import io +import traceback import typing import pyarrow as pa -def to_filelike(src: typing.Any) -> typing.BinaryIO: - """Coerces files, byte-strings and PyArrow binary buffers into file-like objects.""" - if hasattr(src, "read"): # an open file - return src - # if not an open file then we assume some kind of binary data in memory - if hasattr(src, "as_buffer"): # PyArrow binary data - return pa.BufferReader(src.as_buffer()) - return io.BytesIO(src) - - -class TorchFFmpegAudioDecoder: - def __init__(self, src, sample_rate): - from torchffmpeg import MediaDecoder - - if hasattr(src, "_optimal_read_size"): - buffer_size = src._optimal_read_size - else: - buffer_size = 128 * 1024 - self.src = src - self.reader = MediaDecoder(to_filelike(self.src), buffer_size=buffer_size) - self.metadata = self.reader.get_src_stream_info(self.reader.default_audio_stream) - - if sample_rate is None: - sample_rate = int(self.metadata.sample_rate) +class AudioDecoder: + """Unified audio decoder that works with humecodec or torchaudio backends.""" + def __init__(self, reader, metadata, sample_rate, codec_delay=0): + self.reader = reader + self.metadata = metadata self.sample_rate = sample_rate + self.debug = False + self.codec_delay = codec_delay + self.init_skip_samples = getattr(metadata, 'start_skip_samples', 0) or 0 + # Codecs where flush produces unreliable output (wrong skip_samples, + # wrong frame sizes). For these, always read from the start and trim. + codec_name = getattr(metadata, 'codec', '') or '' + self._seek_unreliable = codec_name in ('wmav2', 'wmapro', 'vorbis') + # Raw MPEG audio formats: timestamp seek does sequential scan, + # byte-offset seek with our own index is much faster. + self._use_byte_index = codec_name in ('mp3', 'mp2', 'mp1') + self._packet_index = None + + def _build_index(self): + """Build a sparse packet index for byte-offset seeking.""" + if self._packet_index is not None: + return + try: + idx = self.reader.build_packet_index( + self.reader.default_audio_stream, 128 * 1024) + if idx and len(idx) > 1: + self._packet_index = idx + except Exception: + self._packet_index = [] + + def _indexed_seek(self, target_time): + """Seek via byte offset using the packet index. Returns the index entry's PTS or None.""" + self._build_index() + if not self._packet_index: + return None + # Find last entry with pts <= target_time + best = self._packet_index[0] + for entry in self._packet_index: + if entry.pts_seconds <= target_time: + best = entry + else: + break + self.reader.seek_to_byte_offset(best.pos) + return best.pts_seconds + + def get_samples_played_in_range(self, tstart=0, tend=None, margin=.25): + import torch - self.reader.add_basic_audio_stream( - frames_per_chunk=int(32 * sample_rate), - sample_rate=sample_rate, - decoder_option={"threads": "4", "thread_type": "frame"}, - ) + chunk = True + while chunk is not None: + (chunk,) = self.reader.pop_chunks() - def get_samples_played_in_range(self, tstart=0, tend=None): - import torch + # For short seeks and unreliable codecs, read from the start. + # This avoids seek accuracy issues for tstart < 5s (tiny cost) and + # codec flush bugs for wmav2/wmapro/vorbis. + read_from_start = self._seek_unreliable or tstart < 5.0 + + # Only adjust for start_skip_samples when actually seeking — when + # reading from start, the decoder applies skip_samples automatically. + seek_adj = 0.0 + index_pts = None + if not read_from_start: + # For raw MPEG formats, use indexed byte seek (fast, avoids sequential scan). + # No seek_adj needed: the index PTS and decoded audio are both in + # the raw timeline (skip_samples is not applied after byte seek). + if self._use_byte_index: + index_pts = self._indexed_seek(tstart - margin) + else: + # Timestamp seek: the demuxer applies start_skip_samples at + # pts=0 but not after seeking, so adjust tstart to compensate. + seek_adj = self.init_skip_samples / self.metadata.sample_rate + tstart += seek_adj + if tend is not None: + tend += seek_adj + + if index_pts is None: + # Fall back to timestamp seek (or read from start) + seek_target = 0.0 if read_from_start else max(0, tstart - margin) + self.reader.seek(seek_target, "key") + + chunks = [] + more_data = True + while more_data: + if self.reader.fill_buffer() == 1: + more_data = False + (chunk,) = self.reader.pop_chunks() + chunks.append(chunk) + if tend is not None: + chunk_end_pts = chunk.pts + chunk.shape[0] / self.sample_rate + if index_pts is not None: + # PTS not updated by demuxer after byte seek — estimate from index + elapsed = sum(c.shape[0] for c in chunks) / self.sample_rate + chunk_end_pts = index_pts + elapsed + if chunk_end_pts > tend + margin: + break + + # Determine the reference PTS for trimming + if read_from_start: + chunk0_pts = 0.0 + elif index_pts is not None: + # Byte seek: demuxer PTS is stale, use our index entry + chunk0_pts = index_pts + else: + chunk0_pts = chunks[0].pts + prefix = round(tstart * self.sample_rate) - round(chunk0_pts * self.sample_rate) + + if self.debug: + import torch as _t + total_samples = sum(c.shape[0] for c in chunks) + print(f" [decode] codec={self.metadata.codec} sr={self.sample_rate} " + f"tstart_orig={tstart - seek_adj:.4f} tstart_adj={tstart:.4f} " + f"seek_adj={seek_adj:.6f} (init_skip={self.init_skip_samples} codec_delay={self.codec_delay}) " + f"chunk0.pts={chunks[0].pts:.6f} chunk0_pts_used={chunk0_pts:.6f} " + f"n_chunks={len(chunks)} total_samples={total_samples} prefix={prefix}", flush=True) - self.reader.seek(max(0, tstart - 1), "key") - - if tend is None: - chunks = [] - more_data = True - while more_data: - if self.reader.fill_buffer() == 1: - more_data = False - (chunk,) = self.reader.pop_chunks() - if chunk is not None: - chunks.append(chunk) - prefix = int((tstart - chunks[0].pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - return torch.cat(chunks)[prefix:].mT - - self.reader.fill_buffer() - (chunk,) = self.reader.pop_chunks() - prefix = int((tstart - chunk.pts) * self.sample_rate) if prefix < 0: + if self.debug: + print(f" [trim] negative prefix {prefix}, clamping to 0", flush=True) prefix = 0 - if tend: - samples = chunk[prefix : prefix + int((tend - tstart) * self.sample_rate)].mT + samples = torch.cat(chunks) + if tend is not None: + return samples[prefix : prefix + round(tend * self.sample_rate) - round(tstart * self.sample_rate)].mT else: - samples = chunk[prefix:].mT - while chunk is not None: - (chunk,) = self.reader.pop_chunks() - return samples + return samples[prefix:].mT +def _create_reader_humecodec(src, buffer_size): + from humecodec import MediaDecoder -class CompatAudioDecoder: - def __init__(self, src, sample_rate): - import torchaudio + reader = MediaDecoder(src=src, buffer_size=buffer_size) + metadata = reader.get_src_stream_info(reader.default_audio_stream) + return reader, metadata - if not hasattr(torchaudio, "io"): - raise ImportError("You need either torchaudio<2.9 or torchcodec installed") - self.src = src - if hasattr(src, "_optimal_read_size"): - buffer_size = src._optimal_read_size - else: - buffer_size = 128 * 1024 - self.reader = torchaudio.io.StreamReader(src=to_filelike(self.src), buffer_size=buffer_size) - self.metadata = self.reader.get_src_stream_info(0) - if sample_rate is None: - sample_rate = self.metadata.sample_rate +def _create_reader_torchaudio(src, buffer_size): + from torchaudio.io import StreamReader - self.sample_rate = sample_rate + reader = StreamReader(src=src, buffer_size=buffer_size) + metadata = reader.get_src_stream_info(reader.default_audio_stream) + return reader, metadata - # fetch 32 seconds because we likely need 30s at maximum but the seeking may be imprecise (and we seek 1s early) - # FIXME: check if we can get away with some better settings here (-1, maybe 10s + concatenate the chunks in a loop) - self.reader.add_basic_audio_stream( - frames_per_chunk=int(32 * sample_rate), - sample_rate=sample_rate, - decoder_option={"threads": "4", "thread_type": "frame"}, - ) - - def get_samples_played_in_range(self, tstart=0, tend=None): - # rought seek - self.reader.seek(max(0, tstart - 1), "key") - - if tend is None: - import torch - - chunks = [] - more_data = True - while more_data: - if self.reader.fill_buffer() == 1: - more_data = False - (chunk,) = self.reader.pop_chunks() - chunks.append(chunk) - prefix = int((tstart - chunks[0].pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - return torch.cat(chunks)[prefix:].mT - - self.reader.fill_buffer() - (chunk,) = self.reader.pop_chunks() - # tight crop (seems accurate down to 1 sample in my tests) - prefix = int((tstart - chunk.pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - if tend: - samples = chunk[prefix : prefix + int((tend - tstart) * self.sample_rate)].mT - else: - samples = chunk[prefix:].mT - # clear out any remaining data - while chunk is not None: - (chunk,) = self.reader.pop_chunks() - return samples +def _create_decoder_torchcodec(src, sample_rate): + """Create a torchcodec-backed decoder that matches the AudioDecoder interface.""" + from types import SimpleNamespace -def create_decoder(src, sample_rate=None): - """Factory: tries torchffmpeg -> torchcodec -> torchaudio, returns a decoder instance. + from torchcodec.decoders import AudioDecoder as TorchcodecDecoder - Args: - src: A file-like object or bytes-like source for audio data. - sample_rate: Optional target sample rate for resampling. + # torchcodec accepts bytes but not BytesIO + decoder = TorchcodecDecoder(src, sample_rate=sample_rate) + metadata = decoder.metadata - Returns: - A decoder instance with .metadata, .sample_rate, and .get_samples_played_in_range() interface. - """ - try: - from torchffmpeg import MediaDecoder as _ # noqa: F401 + class TorchcodecAdapter: + def __init__(self): + self.metadata = metadata + self.sample_rate = sample_rate if sample_rate is not None else int(metadata.sample_rate) - AudioDecoder = TorchFFmpegAudioDecoder - except ImportError: - try: - from torchcodec.decoders import AudioDecoder - except ImportError: - AudioDecoder = CompatAudioDecoder + def get_samples_played_in_range(self, tstart=0, tend=None): + return decoder.get_samples_played_in_range(tstart, tend) + + return TorchcodecAdapter() - return AudioDecoder(src, sample_rate=sample_rate) +_STREAMING_BACKENDS = [ + (_create_reader_humecodec, "humecodec"), + (_create_reader_torchaudio, "torchaudio.io"), +] -def decode_segment(src, start=0, end=None, sample_rate=None): - """One-shot decode: creates decoder, reads segment, returns tensor with .sample_rate attr. +_chosen_backend = None - Handles MP3 skip_samples compensation automatically. + +def create_decoder(src, sample_rate=None): + """Factory: tries humecodec -> torchaudio -> torchcodec, returns a decoder instance. Args: - src: Audio source (file-like, bytes, or PyArrow buffer). - start: Start time in seconds. - end: End time in seconds (None for rest of file). - sample_rate: Optional target sample rate. + src: A file-like object for audio data. + sample_rate: Optional target sample rate for resampling. Returns: - A torch.Tensor with a .sample_rate attribute. + A decoder with .metadata, .sample_rate, and .get_samples_played_in_range(). """ - filelike = to_filelike(src) - decoder = create_decoder(filelike, sample_rate) - - skip_samples = 0 - if decoder.metadata.codec == "mp3": - skip_samples = 1105 + global _chosen_backend + + buffer_size = getattr(src, "_optimal_read_size", 128 * 1024) + + if _chosen_backend is not None: + if _chosen_backend == "torchcodec": + return _create_decoder_torchcodec(src, sample_rate) + reader, metadata = _chosen_backend(src, buffer_size) + else: + for factory, module in _STREAMING_BACKENDS: + try: + reader, metadata = factory(src, buffer_size) + _chosen_backend = factory + break + except ImportError: + continue + else: + # Fall back to torchcodec (different API, no streaming reader) + try: + decoder = _create_decoder_torchcodec(src, sample_rate) + _chosen_backend = "torchcodec" + return decoder + except ImportError: + raise ImportError("Neither humecodec, torchaudio, nor torchcodec is installed.") if sample_rate is None: - sample_rate = decoder.metadata.sample_rate + sample_rate = int(metadata.sample_rate) - seek_adjustment = skip_samples / sample_rate if start > 0 else 0 - samples = decoder.get_samples_played_in_range( - start + seek_adjustment, end + seek_adjustment if end is not None else None + reader.add_basic_audio_stream( + frames_per_chunk=int(1 * sample_rate), + sample_rate=sample_rate, + decoder_option={"threads": "4", "thread_type": "frame"}, ) - if hasattr(samples, "data"): - samples = samples.data - samples.sample_rate = sample_rate - return samples + + # Get codec_delay from the decoder (available after add_audio_stream opens the codec) + codec_delay = 0 + try: + out_info = reader.get_out_stream_info(0) + codec_delay = getattr(out_info, 'codec_delay', 0) or 0 + except Exception: + pass + + return AudioDecoder(reader, metadata, sample_rate, codec_delay=codec_delay) + -def encode_mp3(samples) -> bytes: - """Encode a torch tensor to MP3 bytes. +def encode_audio(samples, format="mp3", sample_rate=None, bitrate=None) -> bytes: + """Encode a torch tensor to audio bytes. - Tries torchffmpeg -> torchcodec -> torchaudio as encoder backends. + Tries humecodec -> torchcodec -> torchaudio as encoder backends. + + >>> from wsds import WSDataset + >>> audio = WSDataset("librilight/source")[0].get_audio() + >>> samples = audio.read_segment(start=0, end=2.0, sample_rate=16000) + >>> mp3 = encode_audio(samples, format="mp3") + >>> mp3[:3] == b"ID3" or mp3[:2] in (b"\\xff\\xfb", b"\\xff\\xf3") + True + >>> ogg = encode_audio(samples, format="ogg") # doctest: +SKIP + >>> ogg[:4] == b"OggS" # doctest: +SKIP + True Args: samples: A torch.Tensor with a .sample_rate attribute. Shape: (channels, frames). + format: Output format, e.g. "mp3", "ogg" (Opus). Default: "mp3". + sample_rate: Target sample rate (defaults to samples.sample_rate). + bitrate: Bitrate in bps. Only used for formats that support it (e.g. Opus). Returns: - MP3-encoded bytes. + Encoded audio bytes. """ + if sample_rate is None: + sample_rate = int(samples.sample_rate) + out = io.BytesIO() try: - from torchffmpeg import MediaEncoder + from humecodec import MediaEncoder - sample_rate = int(samples.sample_rate) - # samples is (channels, frames), write_audio_chunk expects (frames, channels) waveform = samples.mT.float().contiguous() - enc = MediaEncoder(out, "mp3") - enc.add_audio_stream(sample_rate=sample_rate, num_channels=waveform.size(1), format="flt") + enc = MediaEncoder(out, format) + stream_kwargs = dict(sample_rate=sample_rate, num_channels=waveform.size(1), format="flt") + if format == "ogg": + from humecodec import CodecConfig + + stream_kwargs.update(encoder="libopus", encoder_format="flt") + if bitrate: + stream_kwargs["codec_config"] = CodecConfig(bit_rate=bitrate) + enc.add_audio_stream(**stream_kwargs) with enc.open(): enc.write_audio_chunk(0, waveform) except ImportError: try: from torchcodec.encoders import AudioEncoder - AudioEncoder(samples, sample_rate=int(samples.sample_rate)).to_file_like(out, "mp3") + AudioEncoder(samples, sample_rate=sample_rate).to_file_like(out, format) except ImportError: import torchaudio - torchaudio.save(out, samples, int(samples.sample_rate), format="mp3") + torchaudio.save(out, samples, sample_rate, format=format) return out.getvalue() @@ -247,5 +310,5 @@ def audio_to_html(samples) -> str: """ import base64 - mp3_data = base64.b64encode(encode_mp3(samples)).decode("ascii") + mp3_data = base64.b64encode(encode_audio(samples, format="mp3")).decode("ascii") return f'' diff --git a/wsds/ws_audio.py b/wsds/ws_audio.py index 1dd7351..2d59078 100644 --- a/wsds/ws_audio.py +++ b/wsds/ws_audio.py @@ -3,31 +3,31 @@ import typing from dataclasses import dataclass -from .audio_codec import audio_to_html, create_decoder, encode_mp3, to_filelike +from .audio_codec import audio_to_html, create_decoder, encode_audio +from .pupyarrow import pupyarrow -def load_segment(src, start, end, sample_rate=None): - """Efficiently loads an audio segment from `src` (see below) `tstart` to `tend` seconds while - optionally resampling it to `sample_rate`. - - `src` can be one of: - - a file-like object - - a byte string - - a PyArrow binary buffer in memory""" - return AudioReader(src).read_segment(start, end, sample_rate=sample_rate) - @dataclass() -class AudioReader: - """A lazy seeking-capable audio reader for random-access to recordings stored in wsds shards.""" +class WSAudioEpisode: + """A lazy seeking-capable audio reader for random-access to recordings stored in wsds shards. + + >>> from wsds import WSDataset + >>> ds = WSDataset("librilight/source") + >>> audio = ds[0].get_audio() + >>> audio.load().shape + torch.Size([1, 17884909]) + >>> audio.read_segment(start=2, end=5).shape + torch.Size([1, 48000]) + >>> audio.read_segment(start=2, end=5, sample_rate=8000).shape + torch.Size([1, 24000]) + """ src: typing.Any _decoder: typing.Any = None _sample_rate: int | None = None - skip_samples: int = 0 - def __repr__(self): - return f"AudioReader(src={type(self.src)}, sample_rate={self._sample_rate})" + return f"WSAudioEpisode(src={type(self.src)}, sample_rate={self._sample_rate})" def unwrap(self): """Return the raw audio bytes""" @@ -35,27 +35,20 @@ def unwrap(self): return self.src.as_buffer().to_pybytes() elif isinstance(self.src, (bytes, bytearray)): return self.src + elif isinstance(self.src, pupyarrow.LazyBuffer): + return self.src.read() else: - raise TypeError(f"Unsupported AudioReader src type: {type(self.src)}") + raise TypeError(f"Unsupported src type: {type(self.src)}") + + to_bytes = unwrap def get_decoder(self, sample_rate=None): """Lazily creates/caches decoder via audio_codec.create_decoder().""" - sample_rate_switch = False - if self._sample_rate is not None: - sample_rate_switch = self._sample_rate != sample_rate - - if self._decoder is None or sample_rate_switch: - decoder = create_decoder(to_filelike(self.src), sample_rate=sample_rate) - # mp3 has encoder delays that are not handled well when seeking - if decoder.metadata.codec == "mp3": - self.skip_samples = 1105 - - if sample_rate is None: - sample_rate = decoder.metadata.sample_rate - - self._decoder = decoder - self._sample_rate = sample_rate - + requested_sr = sample_rate or (self._decoder and self._decoder.metadata.sample_rate) + if self._decoder is None or requested_sr != self._sample_rate: + self.src.seek(0) + self._decoder = create_decoder(self.src, sample_rate=sample_rate) + self._sample_rate = sample_rate or self._decoder.metadata.sample_rate return self._decoder, self._sample_rate @property @@ -70,14 +63,9 @@ def sample_rate(self): def read_segment(self, start=0, end=None, sample_rate=None): decoder, sample_rate = self.get_decoder(sample_rate) - seek_adjustment = self.skip_samples / sample_rate if start > 0 else 0 - _samples = decoder.get_samples_played_in_range( - start + seek_adjustment, end + seek_adjustment if end is not None else None - ) - if hasattr(_samples, "data"): - samples = _samples.data - else: - samples = _samples + samples = decoder.get_samples_played_in_range(start, end) + if hasattr(samples, "data"): + samples = samples.data samples.sample_rate = sample_rate return samples @@ -91,56 +79,60 @@ def _repr_html_(self): def _display_(self): import marimo - return marimo.audio(encode_mp3(self.read_segment())) + return marimo.audio(encode_audio(self.read_segment())) @dataclass(frozen=True) -class WSAudio: - """A lazy reference to a single sample from a segmented audio file.""" +class WSAudioSegment: + """A lazy reference to a single sample from a segmented audio file. + """ - audio_reader: AudioReader + episode: WSAudioEpisode tstart: float tend: float + def __repr__(self) -> str: + return f"WSAudioSegment(episode={self.episode}, tstart={self.tstart!s}, tend={self.tend!s})" + @property def duration(self) -> float: """Duration of the audio segment in seconds.""" return self.tend - self.tstart - def with_context(self, before: float = 0, after: float = 0) -> "WSAudio": - """Return a new WSAudio with expanded timestamps to include surrounding context. + def with_context(self, before: float = 0, after: float = 0) -> "WSAudioSegment": + """Return a new WSAudioSegment with expanded timestamps to include surrounding context. Args: before: Seconds of context to add before the segment start (will not go below 0) after: Seconds of context to add after the segment end Returns: - A new WSAudio instance with adjusted timestamps + A new WSAudioSegment instance with adjusted timestamps """ - return WSAudio( - audio_reader=self.audio_reader, + return WSAudioSegment( + episode=self.episode, tstart=max(0, self.tstart - before), tend=self.tend + after, ) - def with_timestamps(self, tstart: float | None = None, tend: float | None = None) -> "WSAudio": - """Return a new WSAudio with modified timestamps. + def with_timestamps(self, tstart: float | None = None, tend: float | None = None) -> "WSAudioSegment": + """Return a new WSAudioSegment with modified timestamps. Args: tstart: New start time in seconds (None to keep current) tend: New end time in seconds (None to keep current) Returns: - A new WSAudio instance with the specified timestamps + A new WSAudioSegment instance with the specified timestamps """ - return WSAudio( - audio_reader=self.audio_reader, + return WSAudioSegment( + episode=self.episode, tstart=tstart if tstart is not None else self.tstart, tend=tend if tend is not None else self.tend, ) def load(self, sample_rate=None, pad_to_seconds=None): - samples = self.audio_reader.read_segment(self.tstart, self.tend, sample_rate) + samples = self.episode.read_segment(self.tstart, self.tend, sample_rate) sample_rate = samples.sample_rate if pad_to_seconds is not None: import torch @@ -152,7 +144,7 @@ def load(self, sample_rate=None, pad_to_seconds=None): @property def metadata(self): - return self.audio_reader.metadata + return self.episode.metadata def _repr_html_(self): return audio_to_html(self.load()) @@ -160,4 +152,4 @@ def _repr_html_(self): def _display_(self): import marimo - return marimo.audio(encode_mp3(self.load())) + return marimo.audio(encode_audio(self.load())) diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index c8fd72e..cf96590 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -39,8 +39,8 @@ class WSDataset: >>> sample = dataset["large/5304/the_tinted_venus_1408_librivox_64kb_mp3/tintedvenus_05_anstey_64kb_090"] >>> print(repr(sample["transcription_wslang_raw.txt"])) ' I will accompany you," she said.' - >>> sample['audio'] - WSAudio(audio_reader=AudioReader(src=, sample_rate=None), tstart=1040.2133, tend=1042.8413) + >>> sample['audio'].load().shape + torch.Size([1, 42049]) """ dataset_root: Path diff --git a/wsds/ws_decode.py b/wsds/ws_decode.py index 7d89979..4dc4aa3 100644 --- a/wsds/ws_decode.py +++ b/wsds/ws_decode.py @@ -5,7 +5,7 @@ import numpy as np import pyarrow as pa -from .ws_audio import AudioReader +from .ws_audio import WSAudioEpisode AUDIO_FILE_KEYS = frozenset( [ @@ -50,7 +50,7 @@ def decode_sample(column: str, data): import json return json.load(fd) elif ext in AUDIO_FILE_KEYS: - return AudioReader(fd) + return WSAudioEpisode(fd) else: return fd.read() @@ -90,7 +90,7 @@ def get_audio(sample, audio_columns=None): audio_columns: Optional list of column names to try. Defaults to AUDIO_FILE_KEYS. Returns: - The audio value (typically an AudioReader or WSAudio). + The audio value (typically a WSAudioEpisode or WSAudioSegment). Raises: KeyError: If no audio column is found in the sample. diff --git a/wsds/ws_shard.py b/wsds/ws_shard.py index 39639cc..a465b0d 100644 --- a/wsds/ws_shard.py +++ b/wsds/ws_shard.py @@ -7,7 +7,7 @@ import pyarrow as pa from .utils import WSShardMissingError -from .ws_audio import AudioReader, WSAudio +from .ws_audio import WSAudioEpisode, WSAudioSegment from .ws_decode import decode_sample from .ws_sample import WSSample @@ -125,7 +125,7 @@ class WSSourceAudioShard(WSShardInterface): # cache _source_file_name: str = None _source_sample: WSSample = None - _source_reader: AudioReader = None + _source_reader: WSAudioEpisode = None @classmethod def from_link(cls, link, dataset, shard_ref): @@ -149,7 +149,7 @@ def get_sample(self, _column, offset): self._source_file_name = file_name tstart, tend = self.get_timestamps(segment_offset) - return WSAudio(self._source_reader, tstart, tend) + return WSAudioSegment(self._source_reader, tstart, tend) class WSYoutubeVideoShard(WSSourceAudioShard):