diff --git a/tests.py b/tests.py index 39a9acb..e7c117a 100644 --- a/tests.py +++ b/tests.py @@ -1,8 +1,18 @@ import doctest +import tempfile import unittest +from pathlib import Path import wsds from wsds import ws_dataset, ws_shard, ws_sink +from wsds.ws_sink import ( + KeyMismatchError, + SampleCountMismatchError, + WSBatchedSink, + WSSink, + _find_reference_shard, + _read_shard_keys, +) def load_tests(loader, tests, ignore): @@ -14,5 +24,201 @@ def load_tests(loader, tests, ignore): return tests +def _make_samples(keys: list[str]) -> list[dict]: + return [{"__key__": k, "value": i} for i, k in enumerate(keys)] + + +def _write_reference_shard(shard_path: Path, keys: list[str]) -> None: + """Write a small reference shard with the given __key__ values.""" + with WSSink(str(shard_path)) as sink: + for sample in _make_samples(keys): + sink.write(sample) + + +class TestHelpers(unittest.TestCase): + """Tests for _find_reference_shard and _read_shard_keys.""" + + def test_find_reference_shard_picks_smallest(self): + """Should pick the sibling shard with the smallest file size.""" + keys = ["a", "b", "c"] + with tempfile.TemporaryDirectory() as tmp: + dataset = Path(tmp) + # Write a small artifact + (dataset / "small_artifact").mkdir() + _write_reference_shard(dataset / "small_artifact" / "shard.wsds", keys) + # Write a larger artifact (more columns = bigger file) + (dataset / "large_artifact").mkdir() + with WSSink(str(dataset / "large_artifact" / "shard.wsds")) as sink: + for k in keys: + sink.write({"__key__": k, "v1": 0, "v2": "x" * 1000, "v3": 1.0}) + + target = dataset / "new_artifact" / "shard.wsds" + (dataset / "new_artifact").mkdir() + ref = _find_reference_shard(target) + self.assertIsNotNone(ref) + self.assertEqual(ref.parent.name, "small_artifact") + + def test_find_reference_shard_skips_current_dir(self): + keys = ["a", "b"] + with tempfile.TemporaryDirectory() as tmp: + dataset = Path(tmp) + (dataset / "artifact_a").mkdir() + _write_reference_shard(dataset / "artifact_a" / "shard.wsds", keys) + + # Target is in artifact_a itself — should not find itself + ref = _find_reference_shard(dataset / "artifact_a" / "shard.wsds") + self.assertIsNone(ref) + + def test_find_reference_shard_skips_link_and_computed(self): + keys = ["a"] + with tempfile.TemporaryDirectory() as tmp: + dataset = Path(tmp) + # Create a .wsds-link file and .wsds-computed dir + (dataset / "audio.wsds-link").touch() + (dataset / "audio.wsds-computed").mkdir() + (dataset / "audio.wsds-computed" / "shard.wsds").touch() + + target = dataset / "new_artifact" / "shard.wsds" + (dataset / "new_artifact").mkdir() + ref = _find_reference_shard(target) + self.assertIsNone(ref) + + def test_find_reference_shard_no_siblings(self): + with tempfile.TemporaryDirectory() as tmp: + dataset = Path(tmp) + (dataset / "lonely_artifact").mkdir() + ref = _find_reference_shard(dataset / "lonely_artifact" / "shard.wsds") + self.assertIsNone(ref) + + def test_read_shard_keys(self): + keys = ["x", "y", "z"] + with tempfile.TemporaryDirectory() as tmp: + shard_path = Path(tmp) / "shard.wsds" + _write_reference_shard(shard_path, keys) + result = _read_shard_keys(shard_path) + self.assertEqual(result, keys) + + +class TestValidateKeys(unittest.TestCase): + """Tests for validate_keys auto-discovery in WSSink.""" + + def _make_dataset(self, tmp: str, keys: list[str]) -> Path: + """Create a dataset directory with a reference artifact already written.""" + dataset = Path(tmp) + (dataset / "existing_artifact").mkdir() + _write_reference_shard(dataset / "existing_artifact" / "shard.wsds", keys) + (dataset / "new_artifact").mkdir() + return dataset + + def test_matching_keys_succeeds(self): + keys = ["a", "b", "c"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + with WSSink(fname, validate_keys=True) as sink: + for s in _make_samples(keys): + sink.write(s) + self.assertTrue(Path(fname).exists()) + + def test_key_mismatch_raises(self): + keys = ["a", "b", "c"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + with self.assertRaises(KeyMismatchError) as ctx: + with WSSink(fname, validate_keys=True) as sink: + for s in _make_samples(["a", "WRONG", "c"]): + sink.write(s) + self.assertEqual(ctx.exception.offset, 1) + self.assertEqual(ctx.exception.expected_key, "b") + self.assertEqual(ctx.exception.actual_key, "WRONG") + self.assertFalse(Path(fname).exists()) + + def test_missing_key_field_raises(self): + keys = ["a", "b"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + with self.assertRaises(KeyMismatchError) as ctx: + with WSSink(fname, validate_keys=True) as sink: + sink.write({"__key__": "a", "value": 0}) + sink.write({"value": 1}) # missing __key__ + self.assertEqual(ctx.exception.offset, 1) + self.assertIsNone(ctx.exception.actual_key) + + def test_too_many_samples_raises(self): + keys = ["a"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + with self.assertRaises(KeyMismatchError) as ctx: + with WSSink(fname, validate_keys=True) as sink: + sink.write({"__key__": "a", "value": 0}) + sink.write({"__key__": "extra", "value": 1}) + self.assertEqual(ctx.exception.offset, 1) + self.assertIsNone(ctx.exception.expected_key) + + def test_too_few_samples_raises(self): + keys = ["a", "b", "c"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + with self.assertRaises(SampleCountMismatchError) as ctx: + with WSSink(fname, validate_keys=True) as sink: + sink.write({"__key__": "a", "value": 0}) + sink.write({"__key__": "b", "value": 1}) + self.assertEqual(ctx.exception.expected_count, 3) + self.assertEqual(ctx.exception.actual_count, 2) + + def test_no_siblings_warns_and_skips(self): + """When no sibling artifacts exist, prints warning and skips validation.""" + with tempfile.TemporaryDirectory() as tmp: + dataset = Path(tmp) + (dataset / "only_artifact").mkdir() + fname = str(dataset / "only_artifact" / "shard.wsds") + # Should succeed without validation (no siblings to compare against) + with WSSink(fname, validate_keys=True) as sink: + sink.write({"__key__": "a", "value": 0}) + self.assertTrue(Path(fname).exists()) + + def test_validate_keys_false_no_validation(self): + """Default behavior: no validation even with siblings present.""" + keys = ["a", "b"] + with tempfile.TemporaryDirectory() as tmp: + dataset = self._make_dataset(tmp, keys) + fname = str(dataset / "new_artifact" / "shard.wsds") + # Write mismatched keys — should succeed because validate_keys=False + with WSSink(fname) as sink: + for s in _make_samples(["x", "y"]): + sink.write(s) + self.assertTrue(Path(fname).exists()) + + +class TestExceptions(unittest.TestCase): + """Tests for exception classes.""" + + def test_is_base_exception(self): + self.assertTrue(issubclass(KeyMismatchError, BaseException)) + self.assertFalse(issubclass(KeyMismatchError, Exception)) + self.assertTrue(issubclass(SampleCountMismatchError, BaseException)) + self.assertFalse(issubclass(SampleCountMismatchError, Exception)) + + def test_error_messages(self): + err = KeyMismatchError("shard.wsds", 5, "expected_k", "actual_k") + self.assertIn("offset 5", str(err)) + self.assertIn("expected_k", str(err)) + self.assertIn("actual_k", str(err)) + + err_missing = KeyMismatchError("shard.wsds", 3, "expected_k", None) + self.assertIn("missing", str(err_missing)) + + err_overflow = KeyMismatchError("shard.wsds", 10, None, "extra_k") + self.assertIn("Too many", str(err_overflow)) + + err_count = SampleCountMismatchError("shard.wsds", 5, 3) + self.assertIn("expected 5", str(err_count)) + self.assertIn("wrote 3", str(err_count)) + + if __name__ == "__main__": unittest.main() diff --git a/wsds/__init__.py b/wsds/__init__.py index b959670..692e950 100644 --- a/wsds/__init__.py +++ b/wsds/__init__.py @@ -8,13 +8,15 @@ from .ws_dataset import WSDataset from .ws_sample import WSSample from .ws_shard import WSSourceAudioShard -from .ws_sink import AtomicFile, SampleFormatChanged, WSSink +from .ws_sink import AtomicFile, KeyMismatchError, SampleCountMismatchError, SampleFormatChanged, WSSink __all__ = [ WSDataset, WSSample, WSSourceAudioShard, AtomicFile, + KeyMismatchError, + SampleCountMismatchError, SampleFormatChanged, WSSink, ] diff --git a/wsds/ws_sink.py b/wsds/ws_sink.py index c353c7a..24a7bd5 100644 --- a/wsds/ws_sink.py +++ b/wsds/ws_sink.py @@ -5,9 +5,13 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING import pyarrow +if TYPE_CHECKING: + from collections.abc import Sequence + from .ws_decode import encode_value @@ -29,6 +33,89 @@ def __str__(self): ) +@dataclass +class KeyMismatchError(BaseException): + fname: str + offset: int + expected_key: str | None + actual_key: str | None + + def __str__(self): + if self.expected_key is None: + return ( + f"Too many samples in {self.fname}: " + f"unexpected sample at offset {self.offset} with key '{self.actual_key}'" + ) + if self.actual_key is None: + return ( + f"Sample at offset {self.offset} in {self.fname} is missing '__key__' " + f"(expected '{self.expected_key}')" + ) + return ( + f"Key mismatch at offset {self.offset} in {self.fname}: " + f"expected '{self.expected_key}' but got '{self.actual_key}'" + ) + + +@dataclass +class SampleCountMismatchError(BaseException): + fname: str + expected_count: int + actual_count: int + + def __str__(self): + return ( + f"Sample count mismatch in {self.fname}: " + f"expected {self.expected_count} samples but wrote {self.actual_count}" + ) + + +def _find_reference_shard(shard_path: Path) -> Path | None: + """Find the smallest sibling shard to use as __key__ reference. + + Mirrors the read-side logic in list_all_columns() (utils.py) which sorts + __key__ sources by ascending shard file size to avoid loading heavy artifacts. + """ + column_dir = shard_path.parent + dataset_dir = column_dir.parent + shard_name = shard_path.name + + candidates: list[tuple[int, Path]] = [] + for sibling in dataset_dir.iterdir(): + if sibling == column_dir: + continue + if sibling.suffix in (".wsds-link", ".wsds-computed"): + continue + if not sibling.is_dir(): + continue + sibling_shard = sibling / shard_name + if sibling_shard.exists(): + candidates.append((sibling_shard.stat().st_size, sibling_shard)) + + if not candidates: + return None + candidates.sort() + return candidates[0][1] + + +def _read_shard_keys(shard_path: Path) -> list[str]: + """Read all __key__ values from a shard, in order.""" + reader = pyarrow.ipc.open_file(pyarrow.memory_map(str(shard_path))) + return reader.read_all().column("__key__").to_pylist() + + +def _resolve_reference_keys(shard_path: Path) -> list[str] | None: + """Find the smallest sibling artifact and read its __key__ column as reference. + + Returns None (with a printed warning) if no sibling artifacts exist. + """ + ref_shard = _find_reference_shard(shard_path) + if ref_shard is None: + print(f"Warning: no sibling artifacts found for {shard_path}, skipping key validation") + return None + return _read_shard_keys(ref_shard) + + class WSBatchedSink: """A helper for writing data to a PyArrow feather file. @@ -45,6 +132,7 @@ def __init__( compression: str | None = "zstd", throwaway=False, # discard the temp file, useful for testing and benchmarking schema: pyarrow.Schema | dict | None = None, # optional schema to enforce type coercion + reference_keys: Sequence[str] | None = None, # expected __key__ values, validated per-sample ): self.fname = fname self.batch_size = 1 @@ -52,14 +140,26 @@ def __init__( self.max_batch_size = 16384 self.compression = compression self.throwaway = throwaway + self.reference_keys = reference_keys self._buffer = [] self._sink = None + self._offset = 0 if isinstance(schema, dict): schema = pyarrow.schema(list(schema.items())) self._sink_schema = schema def write(self, x): + if self.reference_keys is not None: + actual_key = x.get("__key__") if isinstance(x, dict) else None + if self._offset >= len(self.reference_keys): + raise KeyMismatchError(self.fname, self._offset, None, actual_key) + expected_key = self.reference_keys[self._offset] + if actual_key is None: + raise KeyMismatchError(self.fname, self._offset, expected_key, None) + if actual_key != expected_key: + raise KeyMismatchError(self.fname, self._offset, expected_key, actual_key) + self._offset += 1 self._buffer.append({k: encode_value(k, v) for k, v in x.items()}) if len(self._buffer) >= self.batch_size: self.write_batch(self._buffer) @@ -95,6 +195,8 @@ def _truncate(v, limit=200): def close(self): if self._buffer: self.write_batch(self._buffer, flush=True) # flush the last batch + if self.reference_keys is not None and self._offset != len(self.reference_keys): + raise SampleCountMismatchError(self.fname, len(self.reference_keys), self._offset) assert self._sink is not None, "closing a WSSink that was never written to" self._sink.close() @@ -143,6 +245,7 @@ def WSSink( min_batch_size_bytes: int = 4 * 1024 * 1024, # auto-increase the batch size until it's at least this size in bytes ephemeral: bool = False, # discard the temp file, useful for testing and benchmarking schema: pyarrow.Schema | dict | None = None, # optional schema to enforce type coercion + validate_keys: bool = False, # validate __key__ values against the smallest sibling artifact ): """Context manager to atomically create a `.wsds` shard. @@ -152,8 +255,10 @@ def WSSink( sink.write({quality_metric: 5, transcript: "Hello, World!"}) ``` """ - with AtomicFile(fname, ephemeral) as fname: + reference_keys = _resolve_reference_keys(Path(fname)) if validate_keys else None + with AtomicFile(fname, ephemeral) as tmp_fname: with WSBatchedSink( - fname, min_batch_size_bytes=min_batch_size_bytes, compression=compression, schema=schema + tmp_fname, min_batch_size_bytes=min_batch_size_bytes, compression=compression, + schema=schema, reference_keys=reference_keys, ) as sink: yield sink