diff --git a/README.md b/README.md index 1895568..a1febc5 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ shape: (5_271_939, 3) ```pycon >>> x['audio'] -WSAudio(audio_reader=AudioReader(src=, sample_rate=None), tstart=614.46246, tend=627.3976) +WSAudioSegment(episode=WSAudioEpisode(src=, sample_rate=None), tstart=614.46246, tend=627.3976) ``` diff --git a/docs/internal/RecordBatchFileWriter-fd-lifecycle.md b/docs/internal/RecordBatchFileWriter-fd-lifecycle.md new file mode 100644 index 0000000..2498e46 --- /dev/null +++ b/docs/internal/RecordBatchFileWriter-fd-lifecycle.md @@ -0,0 +1,152 @@ +# RecordBatchFileWriter.close() and File Descriptor Lifecycle + +This document traces what happens when `RecordBatchFileWriter.close()` is called +and when the underlying file descriptor is actually closed. + +## Python Layer + +`RecordBatchFileWriter` (in `python/pyarrow/ipc.py`) inherits from +`_RecordBatchFileWriter` (Cython, `python/pyarrow/ipc.pxi:1106`), which inherits +from `_RecordBatchStreamWriter`, which inherits from `_CRecordBatchWriter`. + +The `close()` method lives on `_CRecordBatchWriter` (`ipc.pxi:619`): + +```python +def close(self): + with nogil: + check_status(self.writer.get().Close()) +``` + +This calls straight into the C++ `RecordBatchWriter::Close()`. + +## C++ Layer + +### IpcFormatWriter::Close() (`cpp/src/arrow/ipc/writer.cc:1246`) + +`MakeFileWriter()` constructs an `IpcFormatWriter` wrapping a `PayloadFileWriter`. +`IpcFormatWriter::Close()` delegates to the payload writer: + +```cpp +Status Close() override { + RETURN_NOT_OK(CheckStarted()); + RETURN_NOT_OK(payload_writer_->Close()); + closed_ = true; + return Status::OK(); +} +``` + +### PayloadFileWriter::Close() (`cpp/src/arrow/ipc/writer.cc:1502`) + +This finalizes the IPC file format on the stream but **does not close the +underlying OutputStream**: + +```cpp +Status Close() override { + // Write 0 EOS message for compatibility with sequential readers + RETURN_NOT_OK(WriteEOS()); + + // Write file footer + RETURN_NOT_OK(UpdatePosition()); + int64_t initial_position = position_; + RETURN_NOT_OK( + WriteFileFooter(*schema_, dictionaries_, record_batches_, metadata_, sink_)); + + // Write footer length (4 bytes, little-endian) + RETURN_NOT_OK(UpdatePosition()); + int32_t footer_length = static_cast(position_ - initial_position); + if (footer_length <= 0) { + return Status::Invalid("Invalid file footer"); + } + footer_length = bit_util::ToLittleEndian(footer_length); + RETURN_NOT_OK(Write(&footer_length, sizeof(int32_t))); + + // Write magic bytes to end file + return Write(kArrowMagicBytes, strlen(kArrowMagicBytes)); +} +``` + +The `sink_` pointer comes from `StreamBookKeeper` (`writer.cc:1367`), which +stores both a raw pointer (`sink_`) and optionally an owning shared pointer +(`owned_sink_`). Neither `PayloadFileWriter::Close()` nor +`StreamBookKeeper` ever call `sink_->Close()`. + +The header `cpp/src/arrow/ipc/writer.h:136` is explicit about this contract: + +> "User is responsible for closing the actual OutputStream." + +## When Does the File Descriptor Actually Close? + +### Case 1: Sink created from a file path string + +When a path string is passed to `RecordBatchFileWriter(sink, schema)`, the +Cython `get_writer()` function (`python/pyarrow/io.pxi:2195`) creates a +temporary `OSFile` wrapping a C++ `FileOutputStream`: + +```python +cdef get_writer(object source, shared_ptr[COutputStream]* writer): + # ... + source = OSFile(source_path, mode='w') + # ... + nf = source + writer[0] = nf.get_output_stream() +``` + +The `OSFile` Python object is ephemeral -- it goes out of scope immediately. +However, the `shared_ptr` that was extracted from it is kept alive +inside `StreamBookKeeper::owned_sink_` for the lifetime of the writer. + +The fd is closed when: + +1. All Python references to the writer are dropped. +2. The `IpcFormatWriter` and its `PayloadFileWriter` are destroyed. +3. The `shared_ptr` ref count drops to zero. +4. `FileOutputStream::~FileOutputStream()` calls `internal::CloseFromDestructor(this)` + (`cpp/src/arrow/io/file.cc:357`). +5. That calls `OSFile::Close()` -> `FileDescriptor::Close()`, which closes the fd. + +**There can be a window between `writer.close()` and the actual fd close**, +depending on when garbage collection runs. + +### Case 2: Sink is a user-owned NativeFile / OSFile + +If you pass an already-opened `NativeFile` to the writer, the writer holds a +`shared_ptr` to the same underlying `COutputStream`. The fd remains open as long +as the Python `NativeFile` object is alive. You must close it yourself (or use it +as a context manager). + +### Case 3: C++ API with raw pointer overload + +`MakeFileWriter(io::OutputStream* sink, ...)` stores only the raw pointer (no +`owned_sink_`). The caller owns the stream entirely and must close it after +calling `RecordBatchWriter::Close()`. + +## Summary + +| Event | What happens | +|--------------------------------------|-----------------------------------------------------------------| +| `writer.close()` | Writes EOS, footer, footer length, magic bytes. **fd stays open.** | +| Writer object is garbage collected | `shared_ptr` refcount -> 0, `FileOutputStream` destructor closes fd. | +| User closes the sink explicitly | fd closed immediately. | + +## Recommended Pattern + +Use context managers for both the writer and any explicitly opened sink to ensure +deterministic cleanup: + +```python +import pyarrow as pa +from pyarrow import ipc + +# Option A: pass a path (sink lifecycle is tied to the writer) +with ipc.RecordBatchFileWriter(sink="output.arrow", schema=schema) as writer: + writer.write_batch(batch) +# At exit: writer.close() is called (footer written). +# fd closes when the writer object is destroyed (usually immediately). + +# Option B: explicit sink control +with pa.OSFile("output.arrow", mode="w") as sink: + with ipc.RecordBatchFileWriter(sink, schema=schema) as writer: + writer.write_batch(batch) + # writer.close() writes footer +# sink.close() closes the fd +``` diff --git a/requirements.txt b/requirements.txt index 4fe9e43..64cde6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,5 +4,4 @@ numpy polars>=1.36.1 pyarrow>=20 torch -torchaudio # torchcodec – optional, it causes serious performance regressions diff --git a/scripts/test_audio_backends.sh b/scripts/test_audio_backends.sh new file mode 100755 index 0000000..48bf02f --- /dev/null +++ b/scripts/test_audio_backends.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +# Run the full test suite in isolated environments, one per decoder backend. +# +# Usage: +# ./scripts/test_audio_backends.sh +# +# Each environment installs only one decoder backend so we verify that the +# interfaces work correctly regardless of which backend is present. + +TEST_COMMAND=${1:-"python -m tests"} + +parallel --tag --lb \ + "uv run --isolated --with {} $TEST_COMMAND" \ + ::: humecodec "torchaudio<2.9" torchcodec diff --git a/tests.py b/tests.py index 39a9acb..41ffc57 100644 --- a/tests.py +++ b/tests.py @@ -2,7 +2,7 @@ import unittest import wsds -from wsds import ws_dataset, ws_shard, ws_sink +from wsds import ws_dataset, ws_shard, ws_sink, ws_audio, audio_codec def load_tests(loader, tests, ignore): @@ -10,6 +10,8 @@ def load_tests(loader, tests, ignore): tests.addTests(doctest.DocTestSuite(ws_dataset)) tests.addTests(doctest.DocTestSuite(ws_shard)) tests.addTests(doctest.DocTestSuite(ws_sink)) + tests.addTests(doctest.DocTestSuite(ws_audio)) + tests.addTests(doctest.DocTestSuite(audio_codec)) tests.addTests(doctest.DocFileSuite("README.md")) return tests diff --git a/wsds/audio_codec.py b/wsds/audio_codec.py index 2c64c24..d546e49 100644 --- a/wsds/audio_codec.py +++ b/wsds/audio_codec.py @@ -2,236 +2,299 @@ This module contains all audio encoding/decoding logic, separated from the data model layer in ws_audio.py. It provides: -- Decoder backends (TorchFFmpegAudioDecoder, CompatAudioDecoder) -- A factory for creating decoders with automatic backend selection -- MP3 encoding with multi-backend fallback +- AudioDecoder: unified decoder with automatic backend selection (humecodec or torchaudio) +- encode_audio(): multi-backend encoder (humecodec -> torchcodec -> torchaudio) - HTML audio rendering utility """ from __future__ import annotations import io +import traceback import typing import pyarrow as pa -def to_filelike(src: typing.Any) -> typing.BinaryIO: - """Coerces files, byte-strings and PyArrow binary buffers into file-like objects.""" - if hasattr(src, "read"): # an open file - return src - # if not an open file then we assume some kind of binary data in memory - if hasattr(src, "as_buffer"): # PyArrow binary data - return pa.BufferReader(src.as_buffer()) - return io.BytesIO(src) - - -class TorchFFmpegAudioDecoder: - def __init__(self, src, sample_rate): - from torchffmpeg import MediaDecoder - - if hasattr(src, "_optimal_read_size"): - buffer_size = src._optimal_read_size - else: - buffer_size = 128 * 1024 - self.src = src - self.reader = MediaDecoder(to_filelike(self.src), buffer_size=buffer_size) - self.metadata = self.reader.get_src_stream_info(self.reader.default_audio_stream) - - if sample_rate is None: - sample_rate = int(self.metadata.sample_rate) +class AudioDecoder: + """Unified audio decoder that works with humecodec or torchaudio backends.""" + def __init__(self, reader, metadata, sample_rate, codec_delay=0): + self.reader = reader + self.metadata = metadata self.sample_rate = sample_rate + self.debug = False + self.codec_delay = codec_delay + self.init_skip_samples = getattr(metadata, 'start_skip_samples', 0) or 0 + # Codecs where flush produces unreliable output (wrong skip_samples, + # wrong frame sizes). For these, always read from the start and trim. + codec_name = getattr(metadata, 'codec', '') or '' + self._seek_unreliable = codec_name in ('wmav2', 'wmapro', 'vorbis') + # Raw MPEG audio formats: timestamp seek does sequential scan, + # byte-offset seek with our own index is much faster. + self._use_byte_index = codec_name in ('mp3', 'mp2', 'mp1') + self._packet_index = None + + def _build_index(self): + """Build a sparse packet index for byte-offset seeking.""" + if self._packet_index is not None: + return + try: + idx = self.reader.build_packet_index( + self.reader.default_audio_stream, 128 * 1024) + if idx and len(idx) > 1: + self._packet_index = idx + except Exception: + self._packet_index = [] + + def _indexed_seek(self, target_time): + """Seek via byte offset using the packet index. Returns the index entry's PTS or None.""" + self._build_index() + if not self._packet_index: + return None + # Find last entry with pts <= target_time + best = self._packet_index[0] + for entry in self._packet_index: + if entry.pts_seconds <= target_time: + best = entry + else: + break + self.reader.seek_to_byte_offset(best.pos) + return best.pts_seconds + + def get_samples_played_in_range(self, tstart=0, tend=None, margin=.25): + import torch - self.reader.add_basic_audio_stream( - frames_per_chunk=int(32 * sample_rate), - sample_rate=sample_rate, - decoder_option={"threads": "4", "thread_type": "frame"}, - ) + chunk = True + while chunk is not None: + (chunk,) = self.reader.pop_chunks() - def get_samples_played_in_range(self, tstart=0, tend=None): - import torch + # For short seeks and unreliable codecs, read from the start. + # This avoids seek accuracy issues for tstart < 5s (tiny cost) and + # codec flush bugs for wmav2/wmapro/vorbis. + read_from_start = self._seek_unreliable or tstart < 5.0 + + # Only adjust for start_skip_samples when actually seeking — when + # reading from start, the decoder applies skip_samples automatically. + seek_adj = 0.0 + index_pts = None + if not read_from_start: + # For raw MPEG formats, use indexed byte seek (fast, avoids sequential scan). + # No seek_adj needed: the index PTS and decoded audio are both in + # the raw timeline (skip_samples is not applied after byte seek). + if self._use_byte_index: + index_pts = self._indexed_seek(tstart - margin) + else: + # Timestamp seek: the demuxer applies start_skip_samples at + # pts=0 but not after seeking, so adjust tstart to compensate. + seek_adj = self.init_skip_samples / self.metadata.sample_rate + tstart += seek_adj + if tend is not None: + tend += seek_adj + + if index_pts is None: + # Fall back to timestamp seek (or read from start) + seek_target = 0.0 if read_from_start else max(0, tstart - margin) + self.reader.seek(seek_target, "key") + + chunks = [] + more_data = True + while more_data: + if self.reader.fill_buffer() == 1: + more_data = False + (chunk,) = self.reader.pop_chunks() + chunks.append(chunk) + if tend is not None: + chunk_end_pts = chunk.pts + chunk.shape[0] / self.sample_rate + if index_pts is not None: + # PTS not updated by demuxer after byte seek — estimate from index + elapsed = sum(c.shape[0] for c in chunks) / self.sample_rate + chunk_end_pts = index_pts + elapsed + if chunk_end_pts > tend + margin: + break + + # Determine the reference PTS for trimming + if read_from_start: + chunk0_pts = 0.0 + elif index_pts is not None: + # Byte seek: demuxer PTS is stale, use our index entry + chunk0_pts = index_pts + else: + chunk0_pts = chunks[0].pts + prefix = round(tstart * self.sample_rate) - round(chunk0_pts * self.sample_rate) + + if self.debug: + import torch as _t + total_samples = sum(c.shape[0] for c in chunks) + print(f" [decode] codec={self.metadata.codec} sr={self.sample_rate} " + f"tstart_orig={tstart - seek_adj:.4f} tstart_adj={tstart:.4f} " + f"seek_adj={seek_adj:.6f} (init_skip={self.init_skip_samples} codec_delay={self.codec_delay}) " + f"chunk0.pts={chunks[0].pts:.6f} chunk0_pts_used={chunk0_pts:.6f} " + f"n_chunks={len(chunks)} total_samples={total_samples} prefix={prefix}", flush=True) - self.reader.seek(max(0, tstart - 1), "key") - - if tend is None: - chunks = [] - more_data = True - while more_data: - if self.reader.fill_buffer() == 1: - more_data = False - (chunk,) = self.reader.pop_chunks() - if chunk is not None: - chunks.append(chunk) - prefix = int((tstart - chunks[0].pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - return torch.cat(chunks)[prefix:].mT - - self.reader.fill_buffer() - (chunk,) = self.reader.pop_chunks() - prefix = int((tstart - chunk.pts) * self.sample_rate) if prefix < 0: + if self.debug: + print(f" [trim] negative prefix {prefix}, clamping to 0", flush=True) prefix = 0 - if tend: - samples = chunk[prefix : prefix + int((tend - tstart) * self.sample_rate)].mT + samples = torch.cat(chunks) + if tend is not None: + return samples[prefix : prefix + round(tend * self.sample_rate) - round(tstart * self.sample_rate)].mT else: - samples = chunk[prefix:].mT - while chunk is not None: - (chunk,) = self.reader.pop_chunks() - return samples + return samples[prefix:].mT +def _create_reader_humecodec(src, buffer_size): + from humecodec import MediaDecoder -class CompatAudioDecoder: - def __init__(self, src, sample_rate): - import torchaudio + reader = MediaDecoder(src=src, buffer_size=buffer_size) + metadata = reader.get_src_stream_info(reader.default_audio_stream) + return reader, metadata - if not hasattr(torchaudio, "io"): - raise ImportError("You need either torchaudio<2.9 or torchcodec installed") - self.src = src - if hasattr(src, "_optimal_read_size"): - buffer_size = src._optimal_read_size - else: - buffer_size = 128 * 1024 - self.reader = torchaudio.io.StreamReader(src=to_filelike(self.src), buffer_size=buffer_size) - self.metadata = self.reader.get_src_stream_info(0) - if sample_rate is None: - sample_rate = self.metadata.sample_rate +def _create_reader_torchaudio(src, buffer_size): + from torchaudio.io import StreamReader - self.sample_rate = sample_rate + reader = StreamReader(src=src, buffer_size=buffer_size) + metadata = reader.get_src_stream_info(reader.default_audio_stream) + return reader, metadata - # fetch 32 seconds because we likely need 30s at maximum but the seeking may be imprecise (and we seek 1s early) - # FIXME: check if we can get away with some better settings here (-1, maybe 10s + concatenate the chunks in a loop) - self.reader.add_basic_audio_stream( - frames_per_chunk=int(32 * sample_rate), - sample_rate=sample_rate, - decoder_option={"threads": "4", "thread_type": "frame"}, - ) - - def get_samples_played_in_range(self, tstart=0, tend=None): - # rought seek - self.reader.seek(max(0, tstart - 1), "key") - - if tend is None: - import torch - - chunks = [] - more_data = True - while more_data: - if self.reader.fill_buffer() == 1: - more_data = False - (chunk,) = self.reader.pop_chunks() - chunks.append(chunk) - prefix = int((tstart - chunks[0].pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - return torch.cat(chunks)[prefix:].mT - - self.reader.fill_buffer() - (chunk,) = self.reader.pop_chunks() - # tight crop (seems accurate down to 1 sample in my tests) - prefix = int((tstart - chunk.pts) * self.sample_rate) - if prefix < 0: - prefix = 0 - if tend: - samples = chunk[prefix : prefix + int((tend - tstart) * self.sample_rate)].mT - else: - samples = chunk[prefix:].mT - # clear out any remaining data - while chunk is not None: - (chunk,) = self.reader.pop_chunks() - return samples +def _create_decoder_torchcodec(src, sample_rate): + """Create a torchcodec-backed decoder that matches the AudioDecoder interface.""" + from types import SimpleNamespace -def create_decoder(src, sample_rate=None): - """Factory: tries torchffmpeg -> torchcodec -> torchaudio, returns a decoder instance. + from torchcodec.decoders import AudioDecoder as TorchcodecDecoder - Args: - src: A file-like object or bytes-like source for audio data. - sample_rate: Optional target sample rate for resampling. + # torchcodec accepts bytes but not BytesIO + decoder = TorchcodecDecoder(src, sample_rate=sample_rate) + metadata = decoder.metadata - Returns: - A decoder instance with .metadata, .sample_rate, and .get_samples_played_in_range() interface. - """ - try: - from torchffmpeg import MediaDecoder as _ # noqa: F401 + class TorchcodecAdapter: + def __init__(self): + self.metadata = metadata + self.sample_rate = sample_rate if sample_rate is not None else int(metadata.sample_rate) - AudioDecoder = TorchFFmpegAudioDecoder - except ImportError: - try: - from torchcodec.decoders import AudioDecoder - except ImportError: - AudioDecoder = CompatAudioDecoder + def get_samples_played_in_range(self, tstart=0, tend=None): + return decoder.get_samples_played_in_range(tstart, tend) + + return TorchcodecAdapter() - return AudioDecoder(src, sample_rate=sample_rate) +_STREAMING_BACKENDS = [ + (_create_reader_humecodec, "humecodec"), + (_create_reader_torchaudio, "torchaudio.io"), +] -def decode_segment(src, start=0, end=None, sample_rate=None): - """One-shot decode: creates decoder, reads segment, returns tensor with .sample_rate attr. +_chosen_backend = None - Handles MP3 skip_samples compensation automatically. + +def create_decoder(src, sample_rate=None): + """Factory: tries humecodec -> torchaudio -> torchcodec, returns a decoder instance. Args: - src: Audio source (file-like, bytes, or PyArrow buffer). - start: Start time in seconds. - end: End time in seconds (None for rest of file). - sample_rate: Optional target sample rate. + src: A file-like object for audio data. + sample_rate: Optional target sample rate for resampling. Returns: - A torch.Tensor with a .sample_rate attribute. + A decoder with .metadata, .sample_rate, and .get_samples_played_in_range(). """ - filelike = to_filelike(src) - decoder = create_decoder(filelike, sample_rate) - - skip_samples = 0 - if decoder.metadata.codec == "mp3": - skip_samples = 1105 + global _chosen_backend + + buffer_size = getattr(src, "_optimal_read_size", 128 * 1024) + + if _chosen_backend is not None: + if _chosen_backend == "torchcodec": + return _create_decoder_torchcodec(src, sample_rate) + reader, metadata = _chosen_backend(src, buffer_size) + else: + for factory, module in _STREAMING_BACKENDS: + try: + reader, metadata = factory(src, buffer_size) + _chosen_backend = factory + break + except ImportError: + continue + else: + # Fall back to torchcodec (different API, no streaming reader) + try: + decoder = _create_decoder_torchcodec(src, sample_rate) + _chosen_backend = "torchcodec" + return decoder + except ImportError: + raise ImportError("Neither humecodec, torchaudio, nor torchcodec is installed.") if sample_rate is None: - sample_rate = decoder.metadata.sample_rate + sample_rate = int(metadata.sample_rate) - seek_adjustment = skip_samples / sample_rate if start > 0 else 0 - samples = decoder.get_samples_played_in_range( - start + seek_adjustment, end + seek_adjustment if end is not None else None + reader.add_basic_audio_stream( + frames_per_chunk=int(1 * sample_rate), + sample_rate=sample_rate, + decoder_option={"threads": "4", "thread_type": "frame"}, ) - if hasattr(samples, "data"): - samples = samples.data - samples.sample_rate = sample_rate - return samples + + # Get codec_delay from the decoder (available after add_audio_stream opens the codec) + codec_delay = 0 + try: + out_info = reader.get_out_stream_info(0) + codec_delay = getattr(out_info, 'codec_delay', 0) or 0 + except Exception: + pass + + return AudioDecoder(reader, metadata, sample_rate, codec_delay=codec_delay) + -def encode_mp3(samples) -> bytes: - """Encode a torch tensor to MP3 bytes. +def encode_audio(samples, format="mp3", sample_rate=None, bitrate=None) -> bytes: + """Encode a torch tensor to audio bytes. - Tries torchffmpeg -> torchcodec -> torchaudio as encoder backends. + Tries humecodec -> torchcodec -> torchaudio as encoder backends. + + >>> from wsds import WSDataset + >>> audio = WSDataset("librilight/source")[0].get_audio() + >>> samples = audio.read_segment(start=0, end=2.0, sample_rate=16000) + >>> mp3 = encode_audio(samples, format="mp3") + >>> mp3[:3] == b"ID3" or mp3[:2] in (b"\\xff\\xfb", b"\\xff\\xf3") + True + >>> ogg = encode_audio(samples, format="ogg") # doctest: +SKIP + >>> ogg[:4] == b"OggS" # doctest: +SKIP + True Args: samples: A torch.Tensor with a .sample_rate attribute. Shape: (channels, frames). + format: Output format, e.g. "mp3", "ogg" (Opus). Default: "mp3". + sample_rate: Target sample rate (defaults to samples.sample_rate). + bitrate: Bitrate in bps. Only used for formats that support it (e.g. Opus). Returns: - MP3-encoded bytes. + Encoded audio bytes. """ + if sample_rate is None: + sample_rate = int(samples.sample_rate) + out = io.BytesIO() try: - from torchffmpeg import MediaEncoder + from humecodec import MediaEncoder - sample_rate = int(samples.sample_rate) - # samples is (channels, frames), write_audio_chunk expects (frames, channels) waveform = samples.mT.float().contiguous() - enc = MediaEncoder(out, "mp3") - enc.add_audio_stream(sample_rate=sample_rate, num_channels=waveform.size(1), format="flt") + enc = MediaEncoder(out, format) + stream_kwargs = dict(sample_rate=sample_rate, num_channels=waveform.size(1), format="flt") + if format == "ogg": + from humecodec import CodecConfig + + stream_kwargs.update(encoder="libopus", encoder_format="flt") + if bitrate: + stream_kwargs["codec_config"] = CodecConfig(bit_rate=bitrate) + enc.add_audio_stream(**stream_kwargs) with enc.open(): enc.write_audio_chunk(0, waveform) except ImportError: try: from torchcodec.encoders import AudioEncoder - AudioEncoder(samples, sample_rate=int(samples.sample_rate)).to_file_like(out, "mp3") + AudioEncoder(samples, sample_rate=sample_rate).to_file_like(out, format) except ImportError: import torchaudio - torchaudio.save(out, samples, int(samples.sample_rate), format="mp3") + torchaudio.save(out, samples, sample_rate, format=format) return out.getvalue() @@ -247,5 +310,5 @@ def audio_to_html(samples) -> str: """ import base64 - mp3_data = base64.b64encode(encode_mp3(samples)).decode("ascii") + mp3_data = base64.b64encode(encode_audio(samples, format="mp3")).decode("ascii") return f'' diff --git a/wsds/convplayer.py b/wsds/convplayer.py index 5a1ae7a..7ddc831 100644 --- a/wsds/convplayer.py +++ b/wsds/convplayer.py @@ -25,7 +25,7 @@ background: #98E; padding: 4px 7px; box-sizing: border-box; - width: 560px; + max-width: 560px; margin: 5px auto; border-radius: 7px; } @@ -47,7 +47,6 @@ } .middle-box { margin: auto; - width: 600px; display: flex; } .col-separator { @@ -370,11 +369,12 @@ def __str__(self): class ConvPlayer: - def __init__(self, path, snd, sr, rmdir=False, pixels_per_second=50): + def __init__(self, path, snd, sr, rmdir=False, pixels_per_second=50, mel_width=80): if isinstance(path, str): path = Path(path) self.path = path self.pixels_per_second = pixels_per_second + self.mel_width = mel_width self.left = ColumnList(self, "left") self.right = ColumnList(self, "right") @@ -397,8 +397,9 @@ def _add_audio(self, snd, sr): self.right.append_img("ticks", ticks[::-1], scaley=2, repeat_y=True) for i, snd in enumerate(torch.split(snd, 5 * 60 * sr)): mels = mel_img(snd, sr) - for c, i in zip([self.left, self.right], mels): - c.append_img("mel", i, scaley=100 / self.pixels_per_second) + mel_scalex = mels.shape[1] / self.mel_width + for c, m in zip([self.left, self.right], mels): + c.append_img("mel", m, scalex=mel_scalex, scaley=100 / self.pixels_per_second) def close(self, zip=False, show=False): self.html.write('
\n') diff --git a/wsds/pupyarrow/file_reader.py b/wsds/pupyarrow/file_reader.py index c9cf562..cc9eb9f 100644 --- a/wsds/pupyarrow/file_reader.py +++ b/wsds/pupyarrow/file_reader.py @@ -1,38 +1,89 @@ from __future__ import annotations +import asyncio +import contextvars import os import threading import time +from dataclasses import dataclass from pathlib import Path -from typing import BinaryIO -BLOCK_SIZE = 16384 # 16kB minimum read size +BLOCK_SIZE = 8192 # 8kB minimum sync read size +MIN_ASYNC_READ = 4096 # 4kB minimum async read size +VERBOSE = False + + +@dataclass +class _Region: + """A coalesced read region covering one or more buffer descriptors.""" + + offset: int + length: int + members: list[tuple[int, int, int]] # list of (abs_offset, start_in_region, end_in_region) + + +def _coalesce_regions(items: list[tuple[int, int]], gap_threshold: int = 64 * 1024) -> list[_Region]: + """Merge nearby reads into larger contiguous fetches. + + items: list of (absolute_offset, length) pairs. + Returns _Region objects with members referencing back to the original offsets. + """ + if not items: + return [] + + sorted_items = sorted(items, key=lambda x: x[0]) + + regions: list[_Region] = [] + cur_offset = sorted_items[0][0] + cur_end = cur_offset + sorted_items[0][1] + cur_members: list[tuple[int, int, int]] = [(sorted_items[0][0], 0, sorted_items[0][1])] + + for abs_offset, length in sorted_items[1:]: + item_end = abs_offset + length + if abs_offset <= cur_end + gap_threshold: + member_start = abs_offset - cur_offset + cur_members.append((abs_offset, member_start, member_start + length)) + cur_end = max(cur_end, item_end) + else: + regions.append(_Region(offset=cur_offset, length=cur_end - cur_offset, members=cur_members)) + cur_offset = abs_offset + cur_end = item_end + cur_members = [(abs_offset, 0, length)] + + regions.append(_Region(offset=cur_offset, length=cur_end - cur_offset, members=cur_members)) + return regions class FileReader: - """Base class for reading bytes from a file with two cache slots. + """Base class for reading bytes from a file. - Forward cache: caches reads from absolute offsets (for sequential access). - Tail cache: caches reads from the end of the file (for footer parsing). + Sync reads use a two-slot cache (forward + tail). + Async reads use a range cache with plan mode for coalesced IO. - Subclasses implement _raw_read and _raw_read_end only. + Subclasses implement _raw_read, _raw_read_end, and optionally _async_read_impl. IO stats (io_time, io_count, io_bytes, cache_hits) are always tracked. - Pass verbose=True to print per-request details to stderr. """ - def __init__(self, *, verbose: bool = False): + def __init__(self): + # Sync caches self._fwd_start: int = 0 self._fwd_data: bytes = b"" self._tail_data: bytes = b"" - self._verbose = verbose + self._verbose = VERBOSE + self._async_first = False # subclasses that track IO in _async_read_impl set this self.io_time: float = 0.0 self.io_count: int = 0 self.io_bytes: int = 0 self.cache_hits: int = 0 + # Async plan mode (per-task via contextvar) + self._planned: contextvars.ContextVar[bool] = contextvars.ContextVar("_planned", default=False) + self._pending: list[tuple[int, int, int, asyncio.Future]] = [] # (offset, actual, length, future) + self._cache: list[tuple[int, bytes]] = [] # (offset, data) ranges + def read(self, offset: int, length: int) -> bytes: - """Read length bytes at absolute offset, using the forward cache.""" + """Read length bytes at absolute offset, using forward cache.""" fwd_end = self._fwd_start + len(self._fwd_data) if self._fwd_data and offset >= self._fwd_start and offset + length <= fwd_end: self.cache_hits += 1 @@ -42,11 +93,14 @@ def read(self, offset: int, length: int) -> bytes: t0 = time.monotonic() self._fwd_data = self._raw_read(offset, actual_length) dt = time.monotonic() - t0 - self.io_time += dt - self.io_count += 1 - self.io_bytes += len(self._fwd_data) + if not self._async_first: + self.io_time += dt + self.io_count += 1 + self.io_bytes += len(self._fwd_data) if self._verbose: - print(f"[IO] read offset={offset} len={actual_length} got={len(self._fwd_data)} {dt * 1000:.1f}ms") + print( + f"[IO] read offset={offset} reqn={length} len={actual_length} got={len(self._fwd_data)} {dt * 1000:.1f}ms" + ) self._fwd_start = offset return self._fwd_data[:length] @@ -62,14 +116,14 @@ def read_end(self, offset: int, length: int) -> bytes: return self._tail_data[start : start + length] actual_n = max(needed, BLOCK_SIZE) t0 = time.monotonic() - print(self, self._raw_read_end) self._tail_data = self._raw_read_end(actual_n) dt = time.monotonic() - t0 - self.io_time += dt - self.io_count += 1 - self.io_bytes += len(self._tail_data) + if not self._async_first: + self.io_time += dt + self.io_count += 1 + self.io_bytes += len(self._tail_data) if self._verbose: - print(f"[IO] read_end n={actual_n} got={len(self._tail_data)} {dt * 1000:.1f}ms") + print(f"[IO] read_end reqn={length} n={actual_n} got={len(self._tail_data)} {dt * 1000:.1f}ms") start = len(self._tail_data) + offset return self._tail_data[start : start + length] @@ -81,73 +135,154 @@ def _raw_read_end(self, n: int) -> bytes: """Read the last n bytes of the file. May return fewer if file is smaller.""" raise NotImplementedError + # -- Async IO with range cache and plan mode -------------------------------- + + async def async_read(self, offset: int, length: int) -> bytes: + """Async read with range cache and optional plan mode. + + Every read fetches at least MIN_ASYNC_READ bytes and caches the result. + In plan mode, reads are deferred and coalesced on flush(). + """ + # Check range cache + for c_off, c_data in self._cache: + if offset >= c_off and offset + length <= c_off + len(c_data): + self.cache_hits += 1 + s = offset - c_off + return c_data[s : s + length] + + actual = max(length, MIN_ASYNC_READ) + + if not self._planned.get(): + # Eager mode: read directly + data = await self._async_read_impl(offset, actual) + self._cache.append((offset, data)) + return data[:length] + + # Plan mode: submit and await future + fut = asyncio.get_running_loop().create_future() + self._pending.append((offset, actual, length, fut)) + return await fut + + async def flush(self): + """Coalesce pending reads, execute, resolve futures.""" + if not self._pending: + return + items = [(off, actual) for off, actual, _, _ in self._pending] + regions = _coalesce_regions(items) + fetched = await asyncio.gather(*[self._async_read_impl(r.offset, r.length) for r in regions]) + + data_map: dict[int, bytes] = {} + for region, data in zip(regions, fetched): + self._cache.append((region.offset, data)) + for abs_offset, start, end in region.members: + data_map[abs_offset] = data[start:end] + + for offset, _actual, length, fut in self._pending: + fut.set_result(data_map[offset][:length]) + self._pending.clear() + + def clear_cache(self): + """Clear the async range cache.""" + self._cache.clear() + + @property + def has_pending(self) -> bool: + return len(self._pending) > 0 + + async def _async_read_impl(self, offset: int, length: int) -> bytes: + """Actual async IO. Default runs _raw_read in a thread executor. + + Subclasses with native async IO (S3, Modal) override this. + """ + return await asyncio.get_event_loop().run_in_executor(None, self._raw_read, offset, length) + def close(self): pass class LocalFileReader(FileReader): - """FileReader backed by a local file.""" + """FileReader backed by a local file via os.pread.""" - def __init__(self, path_or_file: str | Path | BinaryIO, *, verbose: bool = False): - super().__init__(verbose=verbose) - if isinstance(path_or_file, (str, Path)): - self._file: BinaryIO = open(path_or_file, "rb") - self._owns_file = True - else: - self._file = path_or_file - self._owns_file = False + def __init__(self, path: str | Path): + super().__init__() + self._fd = os.open(str(path), os.O_RDONLY) def _raw_read(self, offset: int, length: int) -> bytes: - self._file.seek(offset, os.SEEK_SET) - return self._file.read(length) + return os.pread(self._fd, length, offset) def _raw_read_end(self, n: int) -> bytes: - self._file.seek(-n, os.SEEK_END) - return self._file.read(n) + size = os.fstat(self._fd).st_size + return os.pread(self._fd, n, max(size - n, 0)) def close(self): - if self._owns_file: - self._file.close() + os.close(self._fd) class S3FileReader(FileReader): - """FileReader backed by S3 range requests via boto3.""" + """FileReader backed by S3 range requests via aiobotocore. + + Async-first: _async_read_impl is the canonical implementation. + Sync _raw_read calls into _async_read_impl via the background event loop. + + Takes a pre-created aiobotocore S3 client (not a session) so that + SSL context and connection pool setup is amortized across readers. + """ - def __init__(self, s3_client, bucket: str, key: str, *, verbose: bool = False): - super().__init__(verbose=verbose) - self._client = s3_client + def __init__(self, client, bucket: str, key: str): + super().__init__() + self._async_first = True + self._client = client # aiobotocore S3 client (already entered) self._bucket = bucket self._key = key - def _raw_read(self, offset: int, length: int) -> bytes: + async def _async_read_impl(self, offset: int, length: int) -> bytes: range_header = f"bytes={offset}-{offset + length - 1}" - resp = self._client.get_object(Bucket=self._bucket, Key=self._key, Range=range_header) - return resp["Body"].read() + t0 = time.monotonic() + resp = await self._client.get_object(Bucket=self._bucket, Key=self._key, Range=range_header) + async with resp["Body"] as stream: + data = await stream.read() + dt = time.monotonic() - t0 + self.io_time += dt + self.io_count += 1 + self.io_bytes += len(data) + if self._verbose: + print(f"[S3] async_read offset={offset} req={length} got={len(data)} {dt * 1000:.1f}ms") + return data + + async def _async_read_end(self, n: int) -> bytes: + t0 = time.monotonic() + resp = await self._client.get_object(Bucket=self._bucket, Key=self._key, Range=f"bytes=-{n}") + async with resp["Body"] as stream: + data = await stream.read() + dt = time.monotonic() - t0 + self.io_time += dt + self.io_count += 1 + self.io_bytes += len(data) + if self._verbose: + print(f"[S3] async_read_end req={n} got={len(data)} {dt * 1000:.1f}ms") + return data + + def _raw_read(self, offset: int, length: int) -> bytes: + return _get_io_loop().run(self._async_read_impl(offset, length)) def _raw_read_end(self, n: int) -> bytes: - range_header = f"bytes=-{n}" - resp = self._client.get_object(Bucket=self._bucket, Key=self._key, Range=range_header) - return resp["Body"].read() + return _get_io_loop().run(self._async_read_end(n)) -class _ModalEventLoop: +class _IOLoop: """A persistent event loop running on a dedicated daemon thread. - All Modal gRPC work is dispatched here so the client's channel stays - bound to a single loop, and callers on the main thread (or Jupyter, - or another loop) are never blocked by "loop already running" errors.""" + Used for async-first readers (S3, Modal). Callers on the main thread + (or Jupyter, or another loop) are never blocked by "loop already running" + errors.""" def __init__(self): - import asyncio - self._loop = asyncio.new_event_loop() self._thread = threading.Thread(target=self._loop.run_forever, daemon=True) self._thread.start() def run(self, coro): """Submit *coro* to the background loop and block until it completes.""" - import asyncio - future = asyncio.run_coroutine_threadsafe(coro, self._loop) return future.result() @@ -157,43 +292,45 @@ def close(self): # Module-level singleton — created on first use. -_modal_loop: _ModalEventLoop | None = None -_modal_loop_lock = threading.Lock() +_io_loop: _IOLoop | None = None +_io_loop_lock = threading.Lock() -def _get_modal_loop() -> _ModalEventLoop: - global _modal_loop - if _modal_loop is None: - with _modal_loop_lock: - if _modal_loop is None: - _modal_loop = _ModalEventLoop() - return _modal_loop +def _get_io_loop() -> _IOLoop: + global _io_loop + if _io_loop is None: + with _io_loop_lock: + if _io_loop is None: + _io_loop = _IOLoop() + return _io_loop class ModalFileReader(FileReader): - """FileReader backed by Modal Volume range requests via gRPC. + """FileReader backed by Modal Volume range requests via gRPC + aiohttp. - Uses the undocumented ``start``/``len`` fields on ``VolumeGetFile2Request`` - to fetch only the needed byte ranges. Presigned block URLs returned by the - gRPC call are downloaded with ``urllib``. + Async-first: _async_read_impl is the canonical implementation using native + async gRPC for metadata and aiohttp for presigned URL downloads. + Sync _raw_read calls into _async_read_impl via the background event loop. - All async gRPC work runs on a shared daemon-thread event loop (see - ``_ModalEventLoop``) so it works regardless of whether the caller already + All async work runs on a shared daemon-thread event loop (see + ``_IOLoop``) so it works regardless of whether the caller already has a running loop (Jupyter, Modal synchronizer, etc.).""" - def __init__(self, vol, path: str, *, verbose: bool = False): - super().__init__(verbose=verbose) + def __init__(self, vol, path: str): + super().__init__() + self._async_first = True self._vol = vol self._path = path self._size: int | None = None - self._loop = _get_modal_loop() + self._loop = _get_io_loop() + self._aiohttp_session = None @classmethod - def from_name(cls, volume_name: str, path: str, *, verbose: bool = False) -> "ModalFileReader": + def from_name(cls, volume_name: str, path: str) -> "ModalFileReader": """Create a reader for *path* inside the named Modal Volume.""" - loop = _get_modal_loop() + loop = _get_io_loop() vol = loop.run(cls._hydrate(volume_name)) - reader = cls(vol, path, verbose=verbose) + reader = cls(vol, path) return reader @staticmethod @@ -215,17 +352,26 @@ async def _get_range(self, start: int, length: int): ) return await self._vol._client.stub.VolumeGetFile2(req) - def _fetch_urls(self, resp) -> bytes: - """Download presigned block URLs and concatenate the bytes.""" - import requests + async def _get_aiohttp_session(self): + if self._aiohttp_session is None: + import aiohttp - chunks = [] - for url in resp.get_urls: - r = requests.get(url) - r.raise_for_status() - chunks.append(r.content) + self._aiohttp_session = aiohttp.ClientSession() + return self._aiohttp_session + + async def _async_fetch_urls(self, resp) -> bytes: + """Download presigned block URLs concurrently via aiohttp.""" + session = await self._get_aiohttp_session() + tasks = [self._fetch_one(session, url) for url in resp.get_urls] + chunks = await asyncio.gather(*tasks) return b"".join(chunks) + @staticmethod + async def _fetch_one(session, url: str) -> bytes: + async with session.get(url) as r: + r.raise_for_status() + return await r.read() + def _ensure_size(self) -> int: """Fetch the total file size (cached after first call).""" if self._size is None: @@ -233,13 +379,26 @@ def _ensure_size(self) -> int: self._size = resp.size return self._size - def _raw_read(self, offset: int, length: int) -> bytes: - resp = self._loop.run(self._get_range(offset, length)) + async def _async_read_impl(self, offset: int, length: int) -> bytes: + """Native async: gRPC for range metadata, aiohttp for URL downloads.""" + resp = await self._get_range(offset, length) if self._size is None: self._size = resp.size - return self._fetch_urls(resp) + return await self._async_fetch_urls(resp) + + def _raw_read(self, offset: int, length: int) -> bytes: + return self._loop.run(self._async_read_impl(offset, length)) def _raw_read_end(self, n: int) -> bytes: size = self._ensure_size() offset = max(size - n, 0) return self._raw_read(offset, size - offset) + + async def _async_close(self): + if self._aiohttp_session is not None: + await self._aiohttp_session.close() + self._aiohttp_session = None + + def close(self): + if self._aiohttp_session is not None: + self._loop.run(self._async_close()) diff --git a/wsds/pupyarrow/pupyarrow.py b/wsds/pupyarrow/pupyarrow.py index df1c3b2..f04a405 100644 --- a/wsds/pupyarrow/pupyarrow.py +++ b/wsds/pupyarrow/pupyarrow.py @@ -12,11 +12,12 @@ from __future__ import annotations +import asyncio import struct from dataclasses import dataclass from enum import IntEnum from pathlib import Path -from typing import Any, BinaryIO, Iterator +from typing import Any, Iterator import numpy as np @@ -368,6 +369,83 @@ def _decompress_buffer(raw_data: bytes, compression: str | None) -> bytes: raise ValueError(f"Unknown compression codec: {compression}") +class BlockCache: + """Sorted interval cache for byte ranges. + + Stores non-overlapping ``(start, end, data)`` intervals, merging on insert. + Designed for caching S3 range-read results so that ffmpeg's AVIO reads + hit local memory instead of issuing new HTTP requests. + + >>> cache = BlockCache() + >>> cache.put(100, b'hello') + >>> cache.put(105, b'world') + >>> cache.get(100, 10) + b'helloworld' + >>> cache.get(103, 4) + b'lowo' + >>> cache.get(100, 11) is None # extends past cached range + True + + Adjacent/overlapping ranges are merged: + + >>> cache2 = BlockCache() + >>> cache2.put(0, b'AAAA') + >>> cache2.put(10, b'BBBB') + >>> cache2.put(4, b'CCCCCC') + >>> cache2.get(0, 14) + b'AAAACCCCCCBBBB' + """ + + __slots__ = ("_ranges",) + + def __init__(self): + self._ranges: list[tuple[int, int, bytes]] = [] + + def get(self, offset: int, length: int) -> bytes | None: + """Return data if ``[offset, offset+length)`` is fully cached, else None.""" + end = offset + length + for start, rend, data in self._ranges: + if start <= offset and rend >= end: + return data[offset - start : offset - start + length] + return None + + def put(self, offset: int, data: bytes) -> None: + """Insert a range, merging with any overlapping/adjacent intervals.""" + if not data: + return + new_start = offset + new_end = offset + len(data) + new_data = bytearray(data) + + merged = [] + for start, end, rdata in self._ranges: + if end < new_start or start > new_end: + # No overlap — keep as-is + merged.append((start, end, rdata)) + else: + # Overlap or adjacent — merge into new range + if start < new_start: + prefix = rdata[: new_start - start] + new_data = bytearray(prefix) + new_data + new_start = start + if end > new_end: + suffix = rdata[new_end - start :] + new_data = new_data + bytearray(suffix) + new_end = end + + merged.append((new_start, new_end, bytes(new_data))) + merged.sort(key=lambda r: r[0]) + self._ranges = merged + + @property + def total_bytes(self) -> int: + return sum(end - start for start, end, _ in self._ranges) + + def __repr__(self) -> str: + parts = [f"[{s}:{e}]" for s, e, _ in self._ranges] + return f"BlockCache({', '.join(parts)}, total={self.total_bytes})" + + class LazyBuffer: """ A lazy buffer that reads data on demand and implements the file-like interface. @@ -383,9 +461,14 @@ class LazyBuffer: same reader with adjusted offset/length. For uncompressed buffers the slice goes directly to the reader; for compressed buffers the parent data is decompressed once and the slice is pre-populated. + + For audio seeking, call ``enable_cache()`` to activate a block cache with + readahead. Then ``prepopulate(ranges)`` pre-fetches byte ranges in + parallel. ffmpeg/humecodec reads hit the cache and only fetch from S3 on + miss. """ - # __slots__ = ("_reader", "_offset", "_length", "_data", "_compression", "_pos") + READAHEAD = 256 * 1024 # readahead on cache miss def __init__(self, reader: FileReader, offset: int, length: int, compression: str | None = None): self._reader = reader @@ -394,6 +477,28 @@ def __init__(self, reader: FileReader, offset: int, length: int, compression: st self._data: bytes | None = None self._compression = compression self._pos = 0 + self._cache: BlockCache | None = None + + def enable_cache(self, readahead: int = 256 * 1024) -> "LazyBuffer": + """Activate block cache with readahead for seeking workloads.""" + self._cache = BlockCache() + self.READAHEAD = readahead + return self + + def prepopulate(self, ranges: list[tuple[int, int]]) -> None: + """Pre-fetch byte ranges (relative to buffer start) into the cache. + + Each range is ``(offset, length)``. Ranges are fetched via the + underlying reader (which may coalesce nearby reads). + """ + if self._cache is None: + self.enable_cache() + for rel_offset, length in ranges: + length = min(length, self._length - rel_offset) + if length <= 0: + continue + data = self._reader.read(self._offset + rel_offset, length) + self._cache.put(rel_offset, data) @property def offset(self) -> int: @@ -417,7 +522,6 @@ def read(self, size: int = -1) -> bytes: Advances the position by the number of bytes returned. """ remaining = self._length - self._pos - print("read:", self._pos, size, remaining) if size < 0: size = remaining else: @@ -452,14 +556,26 @@ def read_range(self, start: int, end: int) -> bytes: """Read a byte range from the (decompressed) buffer. If the buffer is already cached, slices it. If uncompressed, - reads directly from the range without reading the whole buffer. - If compressed, falls back to a full read and slices. + reads directly from the range (via block cache if enabled, or + the reader directly). If compressed, falls back to a full read. """ if self._data is not None: return self._data[start:end] - if self._compression is None: - return self._reader.read(self._offset + start, end - start) - return self._read_all()[start:end] + if self._compression is not None: + return self._read_all()[start:end] + # Uncompressed path — use block cache if enabled + length = end - start + if self._cache is not None: + cached = self._cache.get(start, length) + if cached is not None: + return cached + # Cache miss — fetch with readahead + fetch_len = max(length, self.READAHEAD) + fetch_len = min(fetch_len, self._length - start) + data = self._reader.read(self._offset + start, fetch_len) + self._cache.put(start, data) + return data[:length] + return self._reader.read(self._offset + start, length) def slice(self, start: int, end: int) -> "LazyBuffer": """Create a sub-buffer over the byte range [start, end). @@ -481,12 +597,51 @@ def as_numpy(self, dtype: np.dtype) -> np.ndarray: data = self._read_all() return np.frombuffer(data, dtype=dtype) + # -- Async API (IO plan aware) ------------------------------------------ + + async def async_read_all(self) -> bytes: + """Async read via reader (uses cache/plan mode if active).""" + if self._data is not None: + return self._data + raw = await self._reader.async_read(self._offset, self._length) + self._data = _decompress_buffer(raw, self._compression) + return self._data + + async def async_as_numpy(self, dtype: np.dtype) -> np.ndarray: + """Async version of as_numpy.""" + data = await self.async_read_all() + return np.frombuffer(data, dtype=dtype) + + async def async_prepopulate(self, ranges: list[tuple[int, int]]) -> None: + """Async pre-fetch of byte ranges into the cache.""" + if self._cache is None: + self.enable_cache() + coros = [] + for rel_offset, length in ranges: + length = min(length, self._length - rel_offset) + if length <= 0: + continue + coros.append(self._async_fetch_range(rel_offset, length)) + if coros: + await asyncio.gather(*coros) + + async def _async_fetch_range(self, rel_offset: int, length: int) -> None: + data = await self._reader.async_read(self._offset + rel_offset, length) + self._cache.put(rel_offset, data) + + def seekable(self) -> bool: + return True + + def readable(self) -> bool: + return True + def __len__(self) -> int: return self._length def __repr__(self) -> str: + cached = f", cache={self._cache}" if self._cache else "" loaded = "loaded" if self._data is not None else "not loaded" - return f"LazyBuffer(offset={self._offset}, length={self._length}, {loaded})" + return f"LazyBuffer(offset={self._offset}, length={self._length}, {loaded}{cached})" class LazyArray: @@ -539,6 +694,32 @@ def validity_mask(self) -> np.ndarray | None: return self._validity + async def async_resolve(self) -> None: + """Prefetch all buffers for this array (coalesced if IO plan active).""" + futs = [buf.async_read_all() for buf in self._buffers if buf._length > 0] + if futs: + await asyncio.gather(*futs) + + def to_numpy(self) -> np.ndarray: + raise NotImplementedError(f"{type(self).__name__} does not support to_numpy()") + + def to_masked_array(self) -> np.ma.MaskedArray: + """Return values as a masked array with nulls masked.""" + values = self.to_numpy() + mask = self.validity_mask() + if mask is None: + return np.ma.array(values, mask=False) + return np.ma.array(values, mask=~mask) + + def __getitem__(self, idx: int | slice) -> Any: + return self.to_numpy()[idx] + + async def async_to_numpy(self) -> np.ndarray: + await self.async_resolve() + return self.to_numpy() + + async_to_py = async_to_numpy + class LazyIntArray(LazyArray): """Lazy integer array with numpy access.""" @@ -571,17 +752,6 @@ def to_numpy(self) -> np.ndarray: self._values = self._buffers[1].as_numpy(self._dtype)[: self.length] return self._values - def to_masked_array(self) -> np.ma.MaskedArray: - """Return values as a masked array with nulls masked.""" - values = self.to_numpy() - mask = self.validity_mask() - if mask is None: - return np.ma.array(values, mask=False) - return np.ma.array(values, mask=~mask) - - def __getitem__(self, idx: int | slice) -> Any: - return self.to_numpy()[idx] - def __repr__(self) -> str: return f"LazyIntArray(dtype={self._dtype.__name__}, length={self.length}, nulls={self.null_count})" @@ -607,17 +777,6 @@ def to_numpy(self) -> np.ndarray: self._values = self._buffers[1].as_numpy(self._dtype)[: self.length] return self._values - def to_masked_array(self) -> np.ma.MaskedArray: - """Return values as a masked array with nulls masked.""" - values = self.to_numpy() - mask = self.validity_mask() - if mask is None: - return np.ma.array(values, mask=False) - return np.ma.array(values, mask=~mask) - - def __getitem__(self, idx: int | slice) -> Any: - return self.to_numpy()[idx] - def __repr__(self) -> str: return f"LazyFloatArray(dtype={self._dtype.__name__}, length={self.length}, nulls={self.null_count})" @@ -641,120 +800,15 @@ def to_numpy(self) -> np.ndarray: self._values = np.unpackbits(packed, bitorder="little")[: self.length].astype(bool) return self._values - def to_masked_array(self) -> np.ma.MaskedArray: - """Return values as a masked array with nulls masked.""" - values = self.to_numpy() - mask = self.validity_mask() - if mask is None: - return np.ma.array(values, mask=False) - return np.ma.array(values, mask=~mask) - - def __getitem__(self, idx: int | slice) -> Any: - return self.to_numpy()[idx] - def __repr__(self) -> str: return f"LazyBoolArray(length={self.length}, nulls={self.null_count})" -class LazyStringArray(LazyArray): - """ - Lazy string (Utf8/LargeUtf8) array. - - Offsets are loaded eagerly for efficient slicing. - String data is loaded lazily. - """ - - def __init__( - self, - field: Field, - node: FieldNode, - buffers: list[LazyBuffer], - large: bool = False, - ): - super().__init__(field, node, buffers) - self._large = large - self._offsets: np.ndarray | None = None - self._data_buffer = buffers[2] - self._data: bytes | None = None - # Eagerly load offsets (metadata) - self._load_offsets() - - def _load_offsets(self) -> None: - """Eagerly load offset array.""" - offset_dtype = np.int64 if self._large else np.int32 - self._offsets = self._buffers[1].as_numpy(offset_dtype)[: self.length + 1] - - @property - def offsets(self) -> np.ndarray: - """Return the offset array (eagerly loaded).""" - return self._offsets # type: ignore - - def _ensure_data(self) -> bytes: - """Lazily load string data buffer.""" - if self._data is None: - self._data = self._data_buffer._read_all() - return self._data - - def __getitem__(self, idx: int | slice) -> str | None | list[str | None]: - """Get string(s) by index or slice.""" - if isinstance(idx, slice): - indices = range(*idx.indices(self.length)) - return [self._get_single(i) for i in indices] - - if idx < 0: - idx += self.length - if idx < 0 or idx >= self.length: - raise IndexError(f"Index {idx} out of range for array of length {self.length}") - - return self._get_single(idx) - - def _get_single(self, idx: int) -> str | None: - """Get a single string by index.""" - # Check validity - mask = self.validity_mask() - if mask is not None and not mask[idx]: - return None - - start = int(self._offsets[idx]) - end = int(self._offsets[idx + 1]) - return self._data_buffer.read_range(start, end).decode("utf-8") - - def to_list(self) -> list[str | None]: - """Convert to a Python list of strings.""" - data = self._ensure_data() - mask = self.validity_mask() - result: list[str | None] = [] - - for i in range(self.length): - if mask is not None and not mask[i]: - result.append(None) - else: - start = int(self._offsets[i]) - end = int(self._offsets[i + 1]) - result.append(data[start:end].decode("utf-8")) - - return result - - def to_numpy(self) -> np.ndarray: - """Return as numpy object array of strings.""" - return np.array(self.to_list(), dtype=object) - - def byte_sizes(self) -> np.ndarray: - """Return array of byte sizes for each string (without loading data).""" - return np.diff(self._offsets) - - def __repr__(self) -> str: - type_name = "LargeUtf8" if self._large else "Utf8" - data_loaded = "loaded" if self._data is not None else "not loaded" - return f"LazyStringArray({type_name}, length={self.length}, nulls={self.null_count}, data={data_loaded})" - - class LazyBinaryArray(LazyArray): """ Lazy binary (Binary/LargeBinary) array. - Offsets are loaded eagerly for efficient slicing. - Binary data is loaded lazily. + Offsets and binary data are both loaded lazily on first access. """ def __init__( @@ -769,18 +823,18 @@ def __init__( self._offsets: np.ndarray | None = None self._data_buffer = buffers[2] self._data: bytes | None = None - # Eagerly load offsets (metadata) - self._load_offsets() - def _load_offsets(self) -> None: - """Eagerly load offset array.""" - offset_dtype = np.int64 if self._large else np.int32 - self._offsets = self._buffers[1].as_numpy(offset_dtype)[: self.length + 1] + def _ensure_offsets(self) -> np.ndarray: + """Load offset array on first access.""" + if self._offsets is None: + offset_dtype = np.int64 if self._large else np.int32 + self._offsets = self._buffers[1].as_numpy(offset_dtype)[: self.length + 1] + return self._offsets @property def offsets(self) -> np.ndarray: - """Return the offset array (eagerly loaded).""" - return self._offsets # type: ignore + """Return the offset array (loaded on first access).""" + return self._ensure_offsets() def _ensure_data(self) -> bytes: """Lazily load binary data buffer.""" @@ -788,8 +842,8 @@ def _ensure_data(self) -> bytes: self._data = self._data_buffer._read_all() return self._data - def __getitem__(self, idx: int | slice) -> LazyBuffer | None | list[LazyBuffer | None]: - """Get binary data as a file-like LazyBuffer by index or slice.""" + def __getitem__(self, idx: int | slice): + """Get element(s) by index or slice.""" if isinstance(idx, slice): indices = range(*idx.indices(self.length)) return [self._get_single(i) for i in indices] @@ -807,8 +861,9 @@ def _get_single(self, idx: int) -> LazyBuffer | None: if mask is not None and not mask[idx]: return None - start = int(self._offsets[idx]) - end = int(self._offsets[idx + 1]) + offsets = self._ensure_offsets() + start = int(offsets[idx]) + end = int(offsets[idx + 1]) return self._data_buffer.slice(start, end) def read_range(self, idx: int, start: int, end: int) -> bytes: @@ -822,8 +877,9 @@ def read_range(self, idx: int, start: int, end: int) -> bytes: if idx < 0 or idx >= self.length: raise IndexError(f"Index {idx} out of range") - elem_start = int(self._offsets[idx]) - elem_end = int(self._offsets[idx + 1]) + offsets = self._ensure_offsets() + elem_start = int(offsets[idx]) + elem_end = int(offsets[idx + 1]) # Clamp range to element bounds read_start = elem_start + max(0, start) @@ -836,6 +892,7 @@ def read_range(self, idx: int, start: int, end: int) -> bytes: def to_list(self) -> list[LazyBuffer | None]: """Convert to a Python list of file-like LazyBuffers.""" + offsets = self._ensure_offsets() mask = self.validity_mask() result: list[LazyBuffer | None] = [] @@ -843,15 +900,27 @@ def to_list(self) -> list[LazyBuffer | None]: if mask is not None and not mask[i]: result.append(None) else: - start = int(self._offsets[i]) - end = int(self._offsets[i + 1]) + start = int(offsets[i]) + end = int(offsets[i + 1]) result.append(self._data_buffer.slice(start, end)) return result def byte_sizes(self) -> np.ndarray: """Return array of byte sizes for each element (without loading data).""" - return np.diff(self._offsets) + return np.diff(self._ensure_offsets()) + + async def async_to_bytes_list(self) -> list[bytes | None]: + await self.async_resolve() + data = self._ensure_data() + offsets = self.offsets + mask = self.validity_mask() + return [ + None if (mask is not None and not mask[i]) else data[int(offsets[i]) : int(offsets[i + 1])] + for i in range(self.length) + ] + + async_to_py = async_to_bytes_list def __repr__(self) -> str: type_name = "LargeBinary" if self._large else "Binary" @@ -859,6 +928,53 @@ def __repr__(self) -> str: return f"LazyBinaryArray({type_name}, length={self.length}, nulls={self.null_count}, data={data_loaded})" +class LazyStringArray(LazyBinaryArray): + """ + Lazy string (Utf8/LargeUtf8) array. + + Subclass of LazyBinaryArray that decodes values as UTF-8 strings. + """ + + def _get_single(self, idx: int) -> str | None: + """Get a single string by index, decoding the LazyBuffer from super().""" + buf = super()._get_single(idx) + if buf is None: + return None + return buf._read_all().decode("utf-8") + + def to_list(self) -> list[str | None]: + """Convert to a Python list of strings.""" + offsets = self._ensure_offsets() + data = self._ensure_data() + mask = self.validity_mask() + result: list[str | None] = [] + + for i in range(self.length): + if mask is not None and not mask[i]: + result.append(None) + else: + start = int(offsets[i]) + end = int(offsets[i + 1]) + result.append(data[start:end].decode("utf-8")) + + return result + + def to_numpy(self) -> np.ndarray: + """Return as numpy object array of strings.""" + return np.array(self.to_list(), dtype=object) + + async def async_to_list(self) -> list[str | None]: + await self.async_resolve() + return self.to_list() + + async_to_py = async_to_list + + def __repr__(self) -> str: + type_name = "LargeUtf8" if self._large else "Utf8" + data_loaded = "loaded" if self._data is not None else "not loaded" + return f"LazyStringArray({type_name}, length={self.length}, nulls={self.null_count}, data={data_loaded})" + + class LazyFixedSizeBinaryArray(LazyArray): """Lazy fixed-size binary array with numpy access.""" @@ -878,9 +994,6 @@ def to_numpy(self) -> np.ndarray: self._values = np.frombuffer(data, dtype=np.uint8).reshape(-1, self._byte_width)[: self.length] return self._values - def __getitem__(self, idx: int | slice) -> np.ndarray: - return self.to_numpy()[idx] - def __repr__(self) -> str: return f"LazyFixedSizeBinaryArray(byte_width={self._byte_width}, length={self.length}, nulls={self.null_count})" @@ -936,6 +1049,8 @@ def _compute_buffer_indices(self) -> list[tuple[int, int]]: @staticmethod def _get_num_buffers_for_type(type_id: ArrowType) -> int: """Return the number of buffers used by a type.""" + if type_id == ArrowType.Null: + return 0 if type_id in (ArrowType.Utf8, ArrowType.Binary): return 3 if type_id in (ArrowType.LargeUtf8, ArrowType.LargeBinary): @@ -1026,6 +1141,48 @@ def __repr__(self) -> str: return f"RecordBatch(rows={self.num_rows}, columns={self.num_columns})" +def _parse_record_batch_message(message_bytes: bytes) -> RecordBatchInfo: + """Parse a record batch message into RecordBatchInfo (no IO).""" + continuation = struct.unpack(" RecordBatchInfo: return self._record_batch_infos[index] # type: ignore block = self._record_batch_blocks[index] - message_bytes = self._reader.read(block.offset, block.metadata_length) - - continuation = struct.unpack(" Iterator[RecordBatch]: for i in range(self.num_record_batches): yield self.record_batch(i) + async def async_record_batch(self, index: int) -> RecordBatch: + """Async version of record_batch — reads metadata via async_read if not cached.""" + if self._record_batch_infos[index] is None: + block = self._record_batch_blocks[index] + raw = await self._reader.async_read(block.offset, block.metadata_length) + self._record_batch_infos[index] = _parse_record_batch_message(raw) + return self.record_batch(index) + + def __getitem__(self, key: str | tuple[str, ...] | list[str]) -> np.ndarray | list | dict[str, np.ndarray | list]: + """ + Read column data across all batches with concurrent IO. + + Single column returns the resolved array/list directly. + Multiple columns return a dict mapping column name to resolved data. + + Numeric/bool/fixed-size-binary columns return np.ndarray. + String columns return list[str | None]. + Binary columns return list[bytes | None]. + + Uses the reader's plan mode to coalesce reads across all batches. + Per-batch async tasks submit reads; the execution loop flushes + them in coalesced rounds. + + Examples: + values = f["score"] # np.ndarray + texts = f["text"] # list[str] + cols = f["score", "text"] # dict + cols = f[["score", "text"]] # dict (also works) + """ + if isinstance(key, str): + names = [key] + elif isinstance(key, (tuple, list)): + names = list(key) + else: + raise TypeError(f"Key must be str, tuple[str, ...], or list[str], got {type(key).__name__}") + + for name in names: + if self._schema.field(name) is None: + raise KeyError(f"Column {name!r} not in schema") + + from wsds.pupyarrow.file_reader import _get_io_loop + + batch_results = _get_io_loop().run(self._async_getitem(names)) + + # Concatenate across batches + output: dict[str, np.ndarray | list] = {} + for name in names: + field = self._schema.field(name) + is_numpy = field.type_id in ( + ArrowType.Int, + ArrowType.FloatingPoint, + ArrowType.Bool, + ArrowType.FixedSizeBinary, + ) + chunks = [br[name] for br in batch_results] + if is_numpy: + output[name] = np.concatenate(chunks) if len(chunks) > 1 else chunks[0] if chunks else np.array([]) + else: + output[name] = [item for chunk in chunks for item in chunk] + + return output[names[0]] if isinstance(key, str) else output + + async def _async_getitem(self, names: list[str]) -> list[dict[str, np.ndarray | list]]: + """Async entry point: resolve all batches with coalesced IO.""" + + async def resolve_batch(batch_idx: int) -> dict[str, np.ndarray | list]: + batch = await self.async_record_batch(batch_idx) + return {name: await batch.column(name).async_to_py() for name in names} + + self._reader._planned.set(True) + try: + tasks = [asyncio.ensure_future(resolve_batch(i)) for i in range(self.num_record_batches)] + while not all(t.done() for t in tasks): + await asyncio.sleep(0) + await self._reader.flush() + return [t.result() for t in tasks] + finally: + self._reader._planned.set(False) + self._reader.clear_cache() + + # -- Context manager & lifecycle ------------------------------------------- + def __enter__(self) -> FeatherFile: return self diff --git a/wsds/ws_audio.py b/wsds/ws_audio.py index 1dd7351..2d59078 100644 --- a/wsds/ws_audio.py +++ b/wsds/ws_audio.py @@ -3,31 +3,31 @@ import typing from dataclasses import dataclass -from .audio_codec import audio_to_html, create_decoder, encode_mp3, to_filelike +from .audio_codec import audio_to_html, create_decoder, encode_audio +from .pupyarrow import pupyarrow -def load_segment(src, start, end, sample_rate=None): - """Efficiently loads an audio segment from `src` (see below) `tstart` to `tend` seconds while - optionally resampling it to `sample_rate`. - - `src` can be one of: - - a file-like object - - a byte string - - a PyArrow binary buffer in memory""" - return AudioReader(src).read_segment(start, end, sample_rate=sample_rate) - @dataclass() -class AudioReader: - """A lazy seeking-capable audio reader for random-access to recordings stored in wsds shards.""" +class WSAudioEpisode: + """A lazy seeking-capable audio reader for random-access to recordings stored in wsds shards. + + >>> from wsds import WSDataset + >>> ds = WSDataset("librilight/source") + >>> audio = ds[0].get_audio() + >>> audio.load().shape + torch.Size([1, 17884909]) + >>> audio.read_segment(start=2, end=5).shape + torch.Size([1, 48000]) + >>> audio.read_segment(start=2, end=5, sample_rate=8000).shape + torch.Size([1, 24000]) + """ src: typing.Any _decoder: typing.Any = None _sample_rate: int | None = None - skip_samples: int = 0 - def __repr__(self): - return f"AudioReader(src={type(self.src)}, sample_rate={self._sample_rate})" + return f"WSAudioEpisode(src={type(self.src)}, sample_rate={self._sample_rate})" def unwrap(self): """Return the raw audio bytes""" @@ -35,27 +35,20 @@ def unwrap(self): return self.src.as_buffer().to_pybytes() elif isinstance(self.src, (bytes, bytearray)): return self.src + elif isinstance(self.src, pupyarrow.LazyBuffer): + return self.src.read() else: - raise TypeError(f"Unsupported AudioReader src type: {type(self.src)}") + raise TypeError(f"Unsupported src type: {type(self.src)}") + + to_bytes = unwrap def get_decoder(self, sample_rate=None): """Lazily creates/caches decoder via audio_codec.create_decoder().""" - sample_rate_switch = False - if self._sample_rate is not None: - sample_rate_switch = self._sample_rate != sample_rate - - if self._decoder is None or sample_rate_switch: - decoder = create_decoder(to_filelike(self.src), sample_rate=sample_rate) - # mp3 has encoder delays that are not handled well when seeking - if decoder.metadata.codec == "mp3": - self.skip_samples = 1105 - - if sample_rate is None: - sample_rate = decoder.metadata.sample_rate - - self._decoder = decoder - self._sample_rate = sample_rate - + requested_sr = sample_rate or (self._decoder and self._decoder.metadata.sample_rate) + if self._decoder is None or requested_sr != self._sample_rate: + self.src.seek(0) + self._decoder = create_decoder(self.src, sample_rate=sample_rate) + self._sample_rate = sample_rate or self._decoder.metadata.sample_rate return self._decoder, self._sample_rate @property @@ -70,14 +63,9 @@ def sample_rate(self): def read_segment(self, start=0, end=None, sample_rate=None): decoder, sample_rate = self.get_decoder(sample_rate) - seek_adjustment = self.skip_samples / sample_rate if start > 0 else 0 - _samples = decoder.get_samples_played_in_range( - start + seek_adjustment, end + seek_adjustment if end is not None else None - ) - if hasattr(_samples, "data"): - samples = _samples.data - else: - samples = _samples + samples = decoder.get_samples_played_in_range(start, end) + if hasattr(samples, "data"): + samples = samples.data samples.sample_rate = sample_rate return samples @@ -91,56 +79,60 @@ def _repr_html_(self): def _display_(self): import marimo - return marimo.audio(encode_mp3(self.read_segment())) + return marimo.audio(encode_audio(self.read_segment())) @dataclass(frozen=True) -class WSAudio: - """A lazy reference to a single sample from a segmented audio file.""" +class WSAudioSegment: + """A lazy reference to a single sample from a segmented audio file. + """ - audio_reader: AudioReader + episode: WSAudioEpisode tstart: float tend: float + def __repr__(self) -> str: + return f"WSAudioSegment(episode={self.episode}, tstart={self.tstart!s}, tend={self.tend!s})" + @property def duration(self) -> float: """Duration of the audio segment in seconds.""" return self.tend - self.tstart - def with_context(self, before: float = 0, after: float = 0) -> "WSAudio": - """Return a new WSAudio with expanded timestamps to include surrounding context. + def with_context(self, before: float = 0, after: float = 0) -> "WSAudioSegment": + """Return a new WSAudioSegment with expanded timestamps to include surrounding context. Args: before: Seconds of context to add before the segment start (will not go below 0) after: Seconds of context to add after the segment end Returns: - A new WSAudio instance with adjusted timestamps + A new WSAudioSegment instance with adjusted timestamps """ - return WSAudio( - audio_reader=self.audio_reader, + return WSAudioSegment( + episode=self.episode, tstart=max(0, self.tstart - before), tend=self.tend + after, ) - def with_timestamps(self, tstart: float | None = None, tend: float | None = None) -> "WSAudio": - """Return a new WSAudio with modified timestamps. + def with_timestamps(self, tstart: float | None = None, tend: float | None = None) -> "WSAudioSegment": + """Return a new WSAudioSegment with modified timestamps. Args: tstart: New start time in seconds (None to keep current) tend: New end time in seconds (None to keep current) Returns: - A new WSAudio instance with the specified timestamps + A new WSAudioSegment instance with the specified timestamps """ - return WSAudio( - audio_reader=self.audio_reader, + return WSAudioSegment( + episode=self.episode, tstart=tstart if tstart is not None else self.tstart, tend=tend if tend is not None else self.tend, ) def load(self, sample_rate=None, pad_to_seconds=None): - samples = self.audio_reader.read_segment(self.tstart, self.tend, sample_rate) + samples = self.episode.read_segment(self.tstart, self.tend, sample_rate) sample_rate = samples.sample_rate if pad_to_seconds is not None: import torch @@ -152,7 +144,7 @@ def load(self, sample_rate=None, pad_to_seconds=None): @property def metadata(self): - return self.audio_reader.metadata + return self.episode.metadata def _repr_html_(self): return audio_to_html(self.load()) @@ -160,4 +152,4 @@ def _repr_html_(self): def _display_(self): import marimo - return marimo.audio(encode_mp3(self.load())) + return marimo.audio(encode_audio(self.load())) diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index 580fb57..6785efd 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -39,8 +39,8 @@ class WSDataset: >>> sample = dataset["large/5304/the_tinted_venus_1408_librivox_64kb_mp3/tintedvenus_05_anstey_64kb_090"] >>> print(repr(sample["transcription_wslang_raw.txt"])) ' I will accompany you," she said.' - >>> sample['audio'] - WSAudio(audio_reader=AudioReader(src=, sample_rate=None), tstart=1040.2133, tend=1042.8413) + >>> sample['audio'].load().shape + torch.Size([1, 42049]) """ dataset_root: Path diff --git a/wsds/ws_decode.py b/wsds/ws_decode.py index 7d89979..4dc4aa3 100644 --- a/wsds/ws_decode.py +++ b/wsds/ws_decode.py @@ -5,7 +5,7 @@ import numpy as np import pyarrow as pa -from .ws_audio import AudioReader +from .ws_audio import WSAudioEpisode AUDIO_FILE_KEYS = frozenset( [ @@ -50,7 +50,7 @@ def decode_sample(column: str, data): import json return json.load(fd) elif ext in AUDIO_FILE_KEYS: - return AudioReader(fd) + return WSAudioEpisode(fd) else: return fd.read() @@ -90,7 +90,7 @@ def get_audio(sample, audio_columns=None): audio_columns: Optional list of column names to try. Defaults to AUDIO_FILE_KEYS. Returns: - The audio value (typically an AudioReader or WSAudio). + The audio value (typically a WSAudioEpisode or WSAudioSegment). Raises: KeyError: If no audio column is found in the sample. diff --git a/wsds/ws_s3_shard.py b/wsds/ws_s3_shard.py index 6e30dab..75b85e2 100644 --- a/wsds/ws_s3_shard.py +++ b/wsds/ws_s3_shard.py @@ -13,10 +13,31 @@ from .ws_dataset import WSDataset +def create_s3_client(endpoint_url: str | None = None): + """Create a shared aiobotocore S3 client. + + Returns the entered client and its context manager (for cleanup). + The client should be shared across all S3FileReader instances. + """ + from aiobotocore.session import AioSession + from botocore.config import Config + + from .pupyarrow.file_reader import _get_io_loop + + endpoint_url = endpoint_url or os.environ.get("WSDS_S3_ENDPOINT_URL") + session = AioSession() + kwargs = {"config": Config(max_pool_connections=50)} + if endpoint_url: + kwargs["endpoint_url"] = endpoint_url + ctx = session.create_client("s3", **kwargs) + client = _get_io_loop().run(ctx.__aenter__()) + return client, ctx + + class WSS3Shard(WSShardInterface): - """A shard reader that loads data from S3 via boto3 range requests. + """A shard reader that loads data from S3 via aiobotocore range requests. - Uses pupyarrow's FeatherFile with an S3File wrapper so that only the + Uses pupyarrow's FeatherFile with an S3FileReader so that only the IPC footer and the specific batch(es) needed are fetched, rather than downloading the entire shard file.""" @@ -27,15 +48,13 @@ def __init__(self, dataset: "WSDataset", bucket: str, key: str, shard_ref: Optio self.key = key if s3_client is None: - import boto3 - - s3_client = boto3.client("s3") + s3_client, _ = create_s3_client() self._reader = S3FileReader(s3_client, bucket, key) try: self._feather = FeatherFile(self._reader) - except s3_client.exceptions.ClientError as err: - raise WSShardMissingError.from_s3(s3_client, bucket, key, err) + except Exception as err: + raise WSShardMissingError(f"Failed to open s3://{bucket}/{key}: {err}") self.batch_size = int(self._feather.schema.custom_metadata["batch_size"]) # cache @@ -67,32 +86,33 @@ def from_link(cls, link, dataset, shard_ref): partition, shard = shard_ref prefix = link["prefix"] key = f"{prefix}/{partition}/{shard}.wsds" if partition else f"{prefix}/{shard}.wsds" - s3_client = cls._make_s3_client(link.get("endpoint_url")) + s3_client, _ = create_s3_client(link.get("endpoint_url")) return cls(dataset, link["bucket"], os.path.normpath(key), shard_ref=shard_ref, s3_client=s3_client) - @classmethod - def _make_s3_client(cls, endpoint_url=None): - import boto3 - - endpoint_url = endpoint_url or os.environ.get("WSDS_S3_ENDPOINT_URL") - kwargs = {} - if endpoint_url: - kwargs["endpoint_url"] = endpoint_url - return boto3.client("s3", **kwargs) - @classmethod def _discover_columns_from_s3(cls, link): """Read one shard's footer from S3 to discover column names.""" - s3_client = cls._make_s3_client(link.get("endpoint_url")) + from .pupyarrow.file_reader import _get_io_loop + + endpoint_url = link.get("endpoint_url") or os.environ.get("WSDS_S3_ENDPOINT_URL") bucket = link["bucket"] prefix = link["prefix"] - response = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=10) - for obj in response.get("Contents", []): - if obj["Key"].endswith(".wsds"): - reader = S3FileReader(s3_client, bucket, obj["Key"]) - feather = FeatherFile(reader) - return feather.schema.names - raise ValueError(f"No .wsds files found in s3://{bucket}/{prefix}") + + async def _discover(): + from aiobotocore.session import AioSession + + session = AioSession() + kwargs = {"endpoint_url": endpoint_url} if endpoint_url else {} + async with session.create_client("s3", **kwargs) as client: + response = await client.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=10) + for obj in response.get("Contents", []): + if obj["Key"].endswith(".wsds"): + reader = S3FileReader(client, bucket, obj["Key"]) + feather = FeatherFile(reader) + return feather.schema.names + raise ValueError(f"No .wsds files found in s3://{bucket}/{prefix}") + + return _get_io_loop().run(_discover()) def _s3_path(self) -> str: return f"s3://{self.bucket}/{self.key}" diff --git a/wsds/ws_shard.py b/wsds/ws_shard.py index 39639cc..8470d18 100644 --- a/wsds/ws_shard.py +++ b/wsds/ws_shard.py @@ -6,8 +6,9 @@ import pyarrow as pa +from .pupyarrow.file_reader import FileReader, LocalFileReader from .utils import WSShardMissingError -from .ws_audio import AudioReader, WSAudio +from .ws_audio import WSAudioEpisode, WSAudioSegment from .ws_decode import decode_sample from .ws_sample import WSSample @@ -31,6 +32,10 @@ def get_columns(cls, link: dict, dataset: "WSDataset") -> dict[str, str] | None: def get_sample(self, column: str, offset: int) -> typing.Any: raise NotImplementedError + def get_reader(self) -> FileReader: + """Return a pupyarrow FileReader for the underlying shard file.""" + raise NotImplementedError + class WSShard(WSShardInterface): """Represents a single open data shard (`.wsds` file). @@ -104,6 +109,9 @@ def close(self): pass self._source_file = None + def get_reader(self): + return LocalFileReader(self.fname) + def __repr__(self): r = f"WSShard({repr(self.fname)})" if self._data: @@ -125,7 +133,7 @@ class WSSourceAudioShard(WSShardInterface): # cache _source_file_name: str = None _source_sample: WSSample = None - _source_reader: AudioReader = None + _source_reader: WSAudioEpisode = None @classmethod def from_link(cls, link, dataset, shard_ref): @@ -149,7 +157,7 @@ def get_sample(self, _column, offset): self._source_file_name = file_name tstart, tend = self.get_timestamps(segment_offset) - return WSAudio(self._source_reader, tstart, tend) + return WSAudioSegment(self._source_reader, tstart, tend) class WSYoutubeVideoShard(WSSourceAudioShard): diff --git a/wsds/ws_sink.py b/wsds/ws_sink.py index cfad44e..e6c7d56 100644 --- a/wsds/ws_sink.py +++ b/wsds/ws_sink.py @@ -39,13 +39,40 @@ def __str__(self): ) +def _cast_batch(batch, target_schema): + """Cast a RecordBatch to target_schema, adding null columns for missing fields.""" + arrays = [] + for field in target_schema: + if batch.schema.get_field_index(field.name) >= 0: + arrays.append(batch.column(field.name).cast(field.type)) + else: + arrays.append(pyarrow.nulls(batch.num_rows, type=field.type)) + return pyarrow.RecordBatch.from_arrays(arrays, schema=target_schema) + + class WSBatchedSink: """A helper for writing data to a PyArrow feather file. Automatically batches data and infers the schema from the first batch. + If the schema changes (new columns, type promotions, null -> concrete type), + the file is transparently rewritten with a unified schema. Example: >>> with WSBatchedSink('output.feather', throwaway=True) as sink: sink.write({'a': 1, 'b': 'x'}) + + Schema evolution -- int to float, null to string, new column: + >>> import tempfile, pyarrow as pa + >>> f = tempfile.NamedTemporaryFile(suffix='.wsds') + >>> sink = WSBatchedSink(f.name, min_batch_size_bytes=0); sink.__enter__() # doctest: +ELLIPSIS + <...WSBatchedSink...> + >>> sink.write({'x': 1, 'y': None}) + >>> sink.write({'x': 2.5, 'y': 'hello', 'z': True}) + >>> sink.close() + >>> r = pa.ipc.open_file(f.name) + >>> r.get_batch(0).to_pydict() + {'x': [1.0], 'y': [None], 'z': [None]} + >>> r.get_batch(1).to_pydict() + {'x': [2.5], 'y': ['hello'], 'z': [True]} """ def __init__( @@ -71,6 +98,7 @@ def __init__( self._sink_schema = schema self._key_iter = key_iter self._last_key = None + self._fixed_schema = schema is not None def write(self, x): if self._key_iter is not None: @@ -87,12 +115,61 @@ def write(self, x): if len(self._buffer) >= self.batch_size: self.write_batch(self._buffer) + def _rewrite_with_new_schema(self, new_record): + """Rewrite the file with a unified schema when a schema conflict is detected. + + Checks schema compatibility first, then streams old batches through an unlinked + file handle to avoid loading everything into memory at once. + Raises SampleFormatChanged if unification fails. + """ + # Check compatibility before doing any I/O + old_schema_no_meta = self._sink_schema.remove_metadata() + new_schema_no_meta = new_record.schema.remove_metadata() + try: + unified_no_meta = pyarrow.unify_schemas([old_schema_no_meta, new_schema_no_meta], promote_options="permissive") + except pyarrow.ArrowInvalid: + raise SampleFormatChanged(self._sink_schema, new_record.schema) + unified = unified_no_meta.with_metadata(self._sink_schema.metadata) + + # Close writer, open reader on the old file, then unlink it. + # The open file handle keeps the data accessible while we overwrite the path. + self._sink.close() + old_file = pyarrow.OSFile(self.fname) + reader = pyarrow.RecordBatchFileReader(old_file) + os.unlink(self.fname) + + self._native_file = pyarrow.output_stream(self.fname) + self._sink = pyarrow.RecordBatchFileWriter( + self._native_file, unified, options=pyarrow.ipc.IpcWriteOptions(compression=self.compression) + ) + self._sink_schema = unified + + for i in range(reader.num_record_batches): + self._sink.write(_cast_batch(reader.get_batch(i), unified)) + self._sink.write(_cast_batch(new_record, unified)) + # TODO: test writing batches of data straight from a PyTorch batched processing loop def write_batch(self, b, flush=False): import pyarrow try: + if self._sink is not None and not self._fixed_schema: + # Subsequent batch: use natural inference to detect schema evolution + # (from_pylist with an explicit schema silently drops new columns and coerces types) + record = pyarrow.RecordBatch.from_pylist(b) + if record.schema != self._sink_schema: + self._rewrite_with_new_schema(record) + self._buffer.clear() + return + self._sink.write(record) + self._buffer.clear() + return record = pyarrow.RecordBatch.from_pylist(b, self._sink_schema) + except (pyarrow.ArrowInvalid, pyarrow.ArrowTypeError): + if self._fixed_schema: + actual = pyarrow.RecordBatch.from_pylist(b).schema + raise SampleFormatChanged(self._sink_schema.remove_metadata(), actual) from None + raise except Exception: def _truncate(v, limit=200): r = repr(v) @@ -106,8 +183,9 @@ def _truncate(v, limit=200): self.batch_size *= 2 return schema = record.schema.with_metadata({"batch_size": str(len(b))}) + self._native_file = pyarrow.output_stream(self.fname) self._sink = pyarrow.RecordBatchFileWriter( - self.fname, schema, options=pyarrow.ipc.IpcWriteOptions(compression=self.compression) + self._native_file, schema, options=pyarrow.ipc.IpcWriteOptions(compression=self.compression) ) self._sink_schema = schema if record.schema != self._sink_schema: @@ -127,10 +205,8 @@ def close(self): ) assert self._sink is not None, "closing a WSSink that was never written to" self._sink.close() - # pyarrow RecordBatchFileWriter.close() does NOT release the fd — only GC does. - # Drop the reference and force collection so volume.reload() won't be blocked. + self._native_file.close() self._sink = None - import gc; gc.collect() def __enter__(self): assert self._sink is None, "WSSink is not re-entrant"