diff --git a/src/phlower/app.py b/src/phlower/app.py index 4b64bf6..8d99407 100644 --- a/src/phlower/app.py +++ b/src/phlower/app.py @@ -14,7 +14,6 @@ from .config import Config from .events import CeleryEventConsumer -from .sqlite_store import SQLiteStore from .sse import SSEBroadcaster from .store import Store @@ -165,61 +164,55 @@ async def _background_recovery(store: Store, sqlite_store, config: Config) -> No logger.exception("Background recovery failed") -async def _purge_in_batches(loop, batch_fn, cutoff_ts: float) -> int: - """Repeat a batch purge until empty, yielding between batches. - - Each batch acquires the SQLite lock independently, so concurrent flushes - can interleave. Without this, a multi-million-row purge holds the lock - for tens of minutes and starves the flush loop until liveness probes fail. - """ - total = 0 - while True: - deleted = await loop.run_in_executor(None, batch_fn, cutoff_ts) - total += deleted - if deleted < SQLiteStore.PURGE_BATCH_SIZE: - break - await asyncio.sleep(0.05) - return total - - async def _sqlite_purge_loop( store: Store, sqlite_store, config: Config, consumer=None ) -> None: - """Purge detail rows after SQLITE_DETAIL_HOURS, core rows after SQLITE_INVOCATION_RETENTION_HOURS.""" + """Drop expired daily partitions; ensure tomorrow's partition exists. + + With per-day partitioned tables, retention enforcement is a series of + DROP TABLE statements — metadata operations that take milliseconds. + The previous DELETE-based purge held the SQLite write lock long enough + to starve the flush loop and OOM the pod under load spikes. + """ while True: await asyncio.sleep(3600) loop = asyncio.get_running_loop() now = time.time() - # Disk pressure: if usage exceeds cap, halve the retention window - # repeatedly until it fits or hits a 1-hour floor. + # Make sure today's and tomorrow's partitions exist proactively, so + # midnight-UTC rollover doesn't race with the first flush of the day. + from .sqlite_store import _suffix_for_ts + for ts in (now, now + 86400): + await loop.run_in_executor( + None, sqlite_store.ensure_partition, _suffix_for_ts(ts) + ) + retention_hours = config.sqlite_invocation_retention_hours - detail_hours = config.sqlite_detail_hours + + # Disk pressure: shrink retention until it fits, halving each pass. disk_pct = await loop.run_in_executor(None, sqlite_store.disk_usage_pct) if disk_pct > config.sqlite_disk_usage_pct_cap: while retention_hours > 1 and disk_pct > config.sqlite_disk_usage_pct_cap: retention_hours = max(1, retention_hours // 2) - detail_hours = max(1, detail_hours // 2) logger.warning( - "Disk %.0f%% > %d%% cap — emergency purge with %dh retention, %dh details", - disk_pct, config.sqlite_disk_usage_pct_cap, retention_hours, detail_hours, + "Disk %.0f%% > %d%% cap — emergency purge with %dh retention", + disk_pct, config.sqlite_disk_usage_pct_cap, retention_hours, + ) + dropped = await loop.run_in_executor( + None, sqlite_store.purge_old_partitions, retention_hours ) - purge_cutoff = now - retention_hours * 3600 - await _purge_in_batches(loop, sqlite_store.purge_expired_batch, purge_cutoff) - detail_cutoff = now - detail_hours * 3600 - await _purge_in_batches(loop, sqlite_store.purge_details_batch, detail_cutoff) + if dropped: + logger.info("Emergency purge: dropped %d partitions", dropped) disk_pct = await loop.run_in_executor(None, sqlite_store.disk_usage_pct) else: - # Normal purge: details first (short retention), then core rows - detail_cutoff = now - detail_hours * 3600 - purged_details = await _purge_in_batches(loop, sqlite_store.purge_details_batch, detail_cutoff) - if purged_details: - logger.info("SQLite purge: deleted %d detail rows (>%dh)", purged_details, detail_hours) - - purge_cutoff = now - retention_hours * 3600 - deleted = await _purge_in_batches(loop, sqlite_store.purge_expired_batch, purge_cutoff) - if deleted: - logger.info("SQLite purge: deleted %d expired rows (>%dh)", deleted, retention_hours) + dropped = await loop.run_in_executor( + None, sqlite_store.purge_old_partitions, retention_hours + ) + if dropped: + logger.info( + "SQLite purge: dropped %d partitions (>%dh)", + dropped, retention_hours, + ) # Remove snapshots for tasks no longer tracked active = set(store.tasks.keys()) @@ -238,10 +231,13 @@ async def _sqlite_purge_loop( await loop.run_in_executor(None, sqlite_store.refresh_cached_stats) size_mb = await loop.run_in_executor(None, sqlite_store.db_size_mb) wal_mb = await loop.run_in_executor(None, sqlite_store.wal_size_mb) + dropped_invocations = store.snapshot_dropped_invocations() logger.info( - "SQLite: %.1f MB, %d rows (%d detail), WAL: %.1f MB, disk: %.0f%%", + "SQLite: %.1f MB, %d rows (%d detail), WAL: %.1f MB, disk: %.0f%%, " + "dropped-invocations: %d", size_mb, sqlite_store._cached_row_count, sqlite_store._cached_detail_row_count, wal_mb, disk_pct, + dropped_invocations, ) diff --git a/src/phlower/config.py b/src/phlower/config.py index 22af5ec..277f979 100644 --- a/src/phlower/config.py +++ b/src/phlower/config.py @@ -73,3 +73,8 @@ class Config: detail_rate_threshold: int = field( default_factory=lambda: int(os.environ.get("DETAIL_RATE_THRESHOLD", "500")) ) + sqlite_pending_buffer_cap: int = field( + default_factory=lambda: int( + os.environ.get("SQLITE_PENDING_BUFFER_CAP", "200000") + ) + ) diff --git a/src/phlower/sqlite_store.py b/src/phlower/sqlite_store.py index 684c8db..0a03071 100644 --- a/src/phlower/sqlite_store.py +++ b/src/phlower/sqlite_store.py @@ -1,13 +1,26 @@ -"""SQLite write-behind warm index for historical task ID lookups.""" +"""SQLite write-behind store with daily-partitioned invocation tables. + +Partitions are named ``invocations_YYYYMMDD`` and ``invocation_details_YYYYMMDD`` +(UTC date suffix). Purge becomes ``DROP TABLE`` — a metadata operation that +takes milliseconds, so the hourly purge loop never starves the flush loop the +way row-by-row ``DELETE`` did on multi-million-row tables. + +The first startup against a pre-partition database renames the existing +single tables to ``invocations_legacy`` / ``invocation_details_legacy``; +they get unioned into reads until their data ages past retention, then +they're dropped. +""" from __future__ import annotations import functools import logging import os +import re import sqlite3 import threading import time +from datetime import datetime, timezone from pathlib import Path from typing import Iterator @@ -22,31 +35,11 @@ def wrapper(self, *args, **kwargs): return method(self, *args, **kwargs) return wrapper -logger = logging.getLogger(__name__) -SCHEMA = """ -CREATE TABLE IF NOT EXISTS invocations ( - task_id TEXT PRIMARY KEY, - task_name TEXT NOT NULL, - state TEXT NOT NULL, - received_at REAL, - started_at REAL, - finished_at REAL, - runtime_ms REAL, - worker TEXT, - queue TEXT, - exception_type TEXT -); -CREATE INDEX IF NOT EXISTS idx_inv_finished ON invocations (finished_at); -CREATE INDEX IF NOT EXISTS idx_inv_task_name ON invocations (task_name, finished_at); - -CREATE TABLE IF NOT EXISTS invocation_details ( - task_id TEXT PRIMARY KEY, - args_preview TEXT, - kwargs_preview TEXT, - traceback_snippet TEXT -); +logger = logging.getLogger(__name__) +# Singleton tables — never partitioned, never grow with invocation volume. +SINGLETON_SCHEMA = """ CREATE TABLE IF NOT EXISTS metadata ( key TEXT NOT NULL, value TEXT NOT NULL, @@ -60,18 +53,57 @@ def wrapper(self, *args, **kwargs): ); """ -UPSERT_SQL = """ -INSERT OR REPLACE INTO invocations - (task_id, task_name, state, received_at, started_at, finished_at, - runtime_ms, worker, queue, exception_type) -VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) -""" - -UPSERT_DETAILS_SQL = """ -INSERT OR REPLACE INTO invocation_details - (task_id, args_preview, kwargs_preview, traceback_snippet) -VALUES (?, ?, ?, ?) -""" +INV_COLUMNS = ( + "task_id TEXT PRIMARY KEY, " + "task_name TEXT NOT NULL, " + "state TEXT NOT NULL, " + "received_at REAL, " + "started_at REAL, " + "finished_at REAL, " + "runtime_ms REAL, " + "worker TEXT, " + "queue TEXT, " + "exception_type TEXT" +) + +DETAILS_COLUMNS = ( + "task_id TEXT PRIMARY KEY, " + "args_preview TEXT, " + "kwargs_preview TEXT, " + "traceback_snippet TEXT" +) + +UPSERT_INV_SQL = ( + "INSERT OR REPLACE INTO {tbl} " + "(task_id, task_name, state, received_at, started_at, finished_at, " + " runtime_ms, worker, queue, exception_type) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" +) + +UPSERT_DETAILS_SQL = ( + "INSERT OR REPLACE INTO {tbl} " + "(task_id, args_preview, kwargs_preview, traceback_snippet) " + "VALUES (?, ?, ?, ?)" +) + +LEGACY_INV = "invocations_legacy" +LEGACY_DETAILS = "invocation_details_legacy" + +# Validates partition suffixes against SQL injection — table names can't be +# parameterized in SQLite, so any name we splice into SQL must match this. +_PARTITION_SUFFIX_RE = re.compile(r"^\d{8}$") +_INV_TABLE_RE = re.compile(r"^invocations_(\d{8})$") +_DETAILS_TABLE_RE = re.compile(r"^invocation_details_(\d{8})$") + + +def _suffix_for_ts(ts: float) -> str: + """Return UTC date suffix YYYYMMDD for a unix timestamp.""" + return datetime.fromtimestamp(ts, tz=timezone.utc).strftime("%Y%m%d") + + +def _ts_for_suffix(suffix: str) -> float: + """Return UTC midnight unix timestamp for a YYYYMMDD suffix.""" + return datetime.strptime(suffix, "%Y%m%d").replace(tzinfo=timezone.utc).timestamp() class SQLiteStore: @@ -81,6 +113,11 @@ def __init__(self, db_path: str) -> None: self._cached_detail_row_count: int = 0 self._cached_oldest_at: float | None = None self._write_lock = threading.Lock() + # Suffixes we've verified exist this process — avoids issuing + # "CREATE TABLE IF NOT EXISTS" on every flush. + self._ensured_partitions: set[str] = set() + self._has_legacy_inv: bool = False + self._has_legacy_details: bool = False Path(db_path).parent.mkdir(parents=True, exist_ok=True) self._conn = self._connect(db_path) @@ -95,23 +132,66 @@ def _connect(self, path: str) -> sqlite3.Connection: return conn def init_schema(self) -> None: - self._conn.executescript(SCHEMA) - self._migrate() - # Checkpoint any WAL inherited from a crash — this must happen before + self._conn.executescript(SINGLETON_SCHEMA) + self._migrate_to_partitions() + self._refresh_legacy_flags() + # Always make sure today's partition exists at startup, so the first + # flush after boot doesn't race with creation. + self.ensure_partition(_suffix_for_ts(time.time())) + # Checkpoint any WAL inherited from a crash — must happen before # recovery opens its read connection, otherwise the stale WAL blocks # checkpointing for the entire recovery duration. self.checkpoint(truncate=True) - def _migrate(self) -> None: - """Migrate from single-table to split-table schema if needed.""" - cols = { - row[1] - for row in self._conn.execute("PRAGMA table_info(invocations)").fetchall() - } - if "args_preview" not in cols: + # -- migration -------------------------------------------------------- + + def _table_exists(self, name: str) -> bool: + row = self._conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?", (name,) + ).fetchone() + return row is not None + + def _migrate_to_partitions(self) -> None: + """Rename pre-partition single tables to *_legacy on first boot. + + Rename is a SQLite metadata operation — instant even on multi-GB + tables. The legacy tables get unioned into reads until their data + ages past retention, then dropped. + """ + if not self._table_exists("invocations") and not self._table_exists( + "invocation_details" + ): return - logger.info("Migrating to split-table schema (invocations + invocation_details)") + # Run the existing column-split migration (args/kwargs out of + # invocations) before renaming, so legacy data lands in the + # canonical layout — _split_legacy_columns() creates the + # invocation_details table itself, so re-check existence after. + if self._table_exists("invocations"): + cols = { + row[1] + for row in self._conn.execute("PRAGMA table_info(invocations)").fetchall() + } + if "args_preview" in cols: + self._split_legacy_columns() + + if self._table_exists("invocations") and not self._table_exists(LEGACY_INV): + logger.info("Migrating: ALTER TABLE invocations RENAME TO %s", LEGACY_INV) + self._conn.execute(f"ALTER TABLE invocations RENAME TO {LEGACY_INV}") + if self._table_exists("invocation_details") and not self._table_exists( + LEGACY_DETAILS + ): + logger.info( + "Migrating: ALTER TABLE invocation_details RENAME TO %s", LEGACY_DETAILS + ) + self._conn.execute( + f"ALTER TABLE invocation_details RENAME TO {LEGACY_DETAILS}" + ) + self._conn.commit() + + def _split_legacy_columns(self) -> None: + """Move args_preview/kwargs_preview/traceback_snippet out of invocations.""" + logger.info("Splitting legacy invocations table (extracting detail columns)") self._conn.execute( "CREATE TABLE IF NOT EXISTS invocation_details (" " task_id TEXT PRIMARY KEY," @@ -124,10 +204,12 @@ def _migrate(self) -> None: total = 0 while True: cur = self._conn.execute( - "INSERT OR IGNORE INTO invocation_details (task_id, args_preview, kwargs_preview, traceback_snippet) " + "INSERT OR IGNORE INTO invocation_details " + "(task_id, args_preview, kwargs_preview, traceback_snippet) " "SELECT task_id, args_preview, kwargs_preview, traceback_snippet " "FROM invocations " - "WHERE (args_preview IS NOT NULL OR kwargs_preview IS NOT NULL OR traceback_snippet IS NOT NULL) " + "WHERE (args_preview IS NOT NULL OR kwargs_preview IS NOT NULL " + " OR traceback_snippet IS NOT NULL) " "AND task_id NOT IN (SELECT task_id FROM invocation_details) " "LIMIT 50000" ) @@ -135,7 +217,7 @@ def _migrate(self) -> None: batch = cur.rowcount total += batch if batch > 0: - logger.info("Migration progress: %d rows copied", total) + logger.info("Split-table migration progress: %d rows copied", total) if batch < 50000: break for col in ("args_preview", "kwargs_preview", "traceback_snippet"): @@ -143,89 +225,225 @@ def _migrate(self) -> None: self._conn.commit() logger.info("Split-table migration complete — %d detail rows", total) - # -- writes ----------------------------------------------------------- + def _refresh_legacy_flags(self) -> None: + self._has_legacy_inv = self._table_exists(LEGACY_INV) + self._has_legacy_details = self._table_exists(LEGACY_DETAILS) + + # -- partition discovery ---------------------------------------------- + + def list_partition_suffixes(self) -> list[str]: + """Suffixes (YYYYMMDD) for all existing invocation partitions, newest first.""" + rows = self._conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'invocations_%'" + ).fetchall() + suffixes: list[str] = [] + for (name,) in rows: + m = _INV_TABLE_RE.match(name) + if m: + suffixes.append(m.group(1)) + suffixes.sort(reverse=True) + return suffixes @_serialized + def ensure_partition(self, suffix: str) -> None: + """Idempotently create the invocations + details partition for a date.""" + if suffix in self._ensured_partitions: + return + if not _PARTITION_SUFFIX_RE.match(suffix): + raise ValueError(f"invalid partition suffix: {suffix!r}") + inv_tbl = f"invocations_{suffix}" + det_tbl = f"invocation_details_{suffix}" + self._conn.execute(f"CREATE TABLE IF NOT EXISTS {inv_tbl} ({INV_COLUMNS})") + self._conn.execute( + f"CREATE INDEX IF NOT EXISTS idx_{inv_tbl}_finished " + f"ON {inv_tbl}(finished_at)" + ) + self._conn.execute( + f"CREATE INDEX IF NOT EXISTS idx_{inv_tbl}_task_name " + f"ON {inv_tbl}(task_name, finished_at)" + ) + self._conn.execute( + f"CREATE TABLE IF NOT EXISTS {det_tbl} ({DETAILS_COLUMNS})" + ) + self._conn.commit() + self._ensured_partitions.add(suffix) + + # -- writes ----------------------------------------------------------- + def flush_batch(self, records: list) -> int: + """Persist completed records, grouping by their UTC date. + + Note: ``INSERT OR REPLACE`` only deduplicates within one partition, + so a task whose lifecycle straddles midnight UTC (e.g. RETRY at + 23:59, SUCCESS at 00:01) can have two rows in two partitions. Read + paths dedupe by task_id at query time. Aggregate recovery may + slightly inflate counts for these tasks — bounded by the fraction + of tasks crossing midnight, typically <1%. + """ if not records: return 0 - self._conn.executemany( - UPSERT_SQL, - [ - ( - r.task_id, r.task_name, r.state, - r.received_at, r.started_at, r.finished_at, - r.runtime_ms, r.worker, r.queue, r.exception_type, - ) - for r in records - ], - ) - details = [] - detail_deletes = [] + # Group by partition suffix (UTC date of finished_at, falling back + # to received_at/started_at for terminal-without-finished edge case). + by_suffix: dict[str, list] = {} for r in records: - if r.args_preview or r.kwargs_preview or r.traceback_snippet: - details.append((r.task_id, r.args_preview, r.kwargs_preview, r.traceback_snippet)) - else: - detail_deletes.append((r.task_id,)) - if details: - self._conn.executemany(UPSERT_DETAILS_SQL, details) - if detail_deletes: - self._conn.executemany("DELETE FROM invocation_details WHERE task_id = ?", detail_deletes) - self._conn.commit() + ts = r.finished_at or r.started_at or r.received_at or time.time() + suffix = _suffix_for_ts(ts) + by_suffix.setdefault(suffix, []).append(r) + + # Make sure each partition exists. Cheap: cached after first call. + for suffix in by_suffix: + self.ensure_partition(suffix) + + with self._write_lock: + for suffix, group in by_suffix.items(): + inv_tbl = f"invocations_{suffix}" + det_tbl = f"invocation_details_{suffix}" + self._conn.executemany( + UPSERT_INV_SQL.format(tbl=inv_tbl), + [ + ( + r.task_id, r.task_name, r.state, + r.received_at, r.started_at, r.finished_at, + r.runtime_ms, r.worker, r.queue, r.exception_type, + ) + for r in group + ], + ) + details = [] + detail_deletes = [] + for r in group: + if r.args_preview or r.kwargs_preview or r.traceback_snippet: + details.append( + (r.task_id, r.args_preview, r.kwargs_preview, r.traceback_snippet) + ) + else: + detail_deletes.append((r.task_id,)) + if details: + self._conn.executemany(UPSERT_DETAILS_SQL.format(tbl=det_tbl), details) + if detail_deletes: + self._conn.executemany( + f"DELETE FROM {det_tbl} WHERE task_id = ?", detail_deletes + ) + self._conn.commit() return len(records) - PURGE_BATCH_SIZE = 50000 + # -- purge ------------------------------------------------------------ @_serialized - def purge_details_batch(self, cutoff_ts: float) -> int: - """Delete one batch of detail rows. Returns rows deleted in this batch.""" - cur = self._conn.execute( - "DELETE FROM invocation_details WHERE task_id IN (" - " SELECT i.task_id FROM invocations i" - " JOIN invocation_details d ON d.task_id = i.task_id" - " WHERE i.finished_at < ? LIMIT ?" - ")", - (cutoff_ts, self.PURGE_BATCH_SIZE), - ) - self._conn.commit() - affected = cur.rowcount - self._cached_detail_row_count = max(0, self._cached_detail_row_count - affected) - return affected + def purge_old_partitions(self, retention_hours: int) -> int: + """Drop partitions older than ``retention_hours``. Returns count dropped. - @_serialized - def purge_expired_batch(self, cutoff_ts: float) -> int: - """Delete one batch of expired invocations + their details. - Returns invocation rows deleted in this batch. + Each DROP TABLE is a metadata operation (fast, predictable); the + whole purge replaces the multi-minute row-by-row DELETE that + previously starved the flush loop. """ - self._conn.execute( - "DELETE FROM invocation_details WHERE task_id IN (" - " SELECT task_id FROM invocations WHERE finished_at < ? LIMIT ?" - ")", - (cutoff_ts, self.PURGE_BATCH_SIZE), - ) - cur = self._conn.execute( - "DELETE FROM invocations WHERE rowid IN (" - " SELECT rowid FROM invocations WHERE finished_at < ? LIMIT ?" - ")", - (cutoff_ts, self.PURGE_BATCH_SIZE), - ) + cutoff_ts = time.time() - retention_hours * 3600 + cutoff_suffix = _suffix_for_ts(cutoff_ts) + dropped = 0 + for suffix in self.list_partition_suffixes(): + if suffix >= cutoff_suffix: + continue + inv_tbl = f"invocations_{suffix}" + det_tbl = f"invocation_details_{suffix}" + self._conn.execute(f"DROP TABLE IF EXISTS {inv_tbl}") + self._conn.execute(f"DROP TABLE IF EXISTS {det_tbl}") + self._ensured_partitions.discard(suffix) + dropped += 1 + logger.info("Dropped expired partition %s", suffix) + self._conn.commit() + # Legacy tables: drop wholesale once their newest row is past + # retention. Cheap to check — single MAX() per table. + self._maybe_drop_legacy(cutoff_ts) + return dropped + + def _maybe_drop_legacy(self, cutoff_ts: float) -> None: + if self._has_legacy_inv: + row = self._conn.execute( + f"SELECT MAX(finished_at) FROM {LEGACY_INV}" + ).fetchone() + newest = row[0] if row else None + if newest is None or newest < cutoff_ts: + logger.info("Dropping legacy table %s (newest=%s)", LEGACY_INV, newest) + self._conn.execute(f"DROP TABLE {LEGACY_INV}") + self._has_legacy_inv = False + if self._has_legacy_details and not self._has_legacy_inv: + # Details on its own carries no finished_at — drop it whenever + # the corresponding invocations table is gone. + logger.info("Dropping legacy table %s", LEGACY_DETAILS) + self._conn.execute(f"DROP TABLE {LEGACY_DETAILS}") + self._has_legacy_details = False self._conn.commit() - affected = cur.rowcount - self._cached_row_count = max(0, self._cached_row_count - affected) - return affected + + # -- read helpers ----------------------------------------------------- + + def _read_tables(self) -> list[tuple[str, str]]: + """Return [(invocations_table, details_table_or_None), ...] newest first. + + Includes legacy as the oldest source. + """ + out: list[tuple[str, str]] = [] + for suffix in self.list_partition_suffixes(): + out.append((f"invocations_{suffix}", f"invocation_details_{suffix}")) + if self._has_legacy_inv: + out.append((LEGACY_INV, LEGACY_DETAILS if self._has_legacy_details else "")) + return out + + def _union_subqueries( + self, where_sql: str, params: list[object], *, with_details: bool = True + ) -> tuple[str, list[object]]: + """Build a UNION ALL across all read tables for a given WHERE clause. + + ``where_sql`` and ``params`` are duplicated for each branch. + """ + tables = self._read_tables() + if not tables: + return "SELECT NULL WHERE 0", [] + all_params: list[object] = [] + branches: list[str] = [] + for inv_tbl, det_tbl in tables: + if with_details and det_tbl: + join = ( + f"FROM {inv_tbl} i LEFT JOIN {det_tbl} d ON i.task_id = d.task_id" + ) + cols = ( + "i.task_id, i.task_name, i.state, i.received_at, i.started_at, " + "i.finished_at, i.runtime_ms, i.worker, i.queue, i.exception_type, " + "d.args_preview, d.kwargs_preview, d.traceback_snippet" + ) + else: + join = f"FROM {inv_tbl} i" + cols = ( + "i.task_id, i.task_name, i.state, i.received_at, i.started_at, " + "i.finished_at, i.runtime_ms, i.worker, i.queue, i.exception_type, " + "NULL, NULL, NULL" + ) + branches.append(f"SELECT {cols} {join} WHERE {where_sql}") + all_params.extend(params) + return " UNION ALL ".join(branches), all_params # -- reads ------------------------------------------------------------ def lookup_task_id(self, task_id: str) -> InvocationRecord | None: - row = self._conn.execute( - "SELECT i.*, d.args_preview, d.kwargs_preview, d.traceback_snippet " - "FROM invocations i LEFT JOIN invocation_details d ON i.task_id = d.task_id " - "WHERE i.task_id = ?", - (task_id,), - ).fetchone() - if row is None: - return None - return self._row_to_record(row) + for inv_tbl, det_tbl in self._read_tables(): + if det_tbl: + sql = ( + "SELECT i.task_id, i.task_name, i.state, i.received_at, i.started_at, " + " i.finished_at, i.runtime_ms, i.worker, i.queue, i.exception_type, " + " d.args_preview, d.kwargs_preview, d.traceback_snippet " + f"FROM {inv_tbl} i LEFT JOIN {det_tbl} d ON i.task_id = d.task_id " + "WHERE i.task_id = ?" + ) + else: + sql = ( + "SELECT i.task_id, i.task_name, i.state, i.received_at, i.started_at, " + " i.finished_at, i.runtime_ms, i.worker, i.queue, i.exception_type, " + " NULL, NULL, NULL " + f"FROM {inv_tbl} i WHERE i.task_id = ?" + ) + row = self._conn.execute(sql, (task_id,)).fetchone() + if row: + return self._row_to_record(row) + return None def list_by_task( self, @@ -236,7 +454,7 @@ def list_by_task( after_ts: float | None = None, exclude_ids: set[str] | None = None, ) -> list[InvocationRecord]: - """List invocations for a task, newest first. Uses idx_inv_task_name.""" + """List invocations for a task, newest first.""" clauses = ["i.task_name = ?"] params: list[object] = [task_name] if before_ts is not None: @@ -246,17 +464,26 @@ def list_by_task( clauses.append("i.finished_at > ?") params.append(after_ts) where = " AND ".join(clauses) + union_sql, union_params = self._union_subqueries(where, params) fetch_limit = limit + (len(exclude_ids) if exclude_ids else 0) sql = ( - "SELECT i.*, d.args_preview, d.kwargs_preview, d.traceback_snippet " - "FROM invocations i LEFT JOIN invocation_details d ON i.task_id = d.task_id " - f"WHERE {where} ORDER BY i.finished_at DESC LIMIT ?" + f"SELECT * FROM ({union_sql}) " + "ORDER BY finished_at DESC LIMIT ?" ) - params.append(fetch_limit) - rows = self._conn.execute(sql, params).fetchall() + # Over-fetch to leave room for dedup. Cross-partition duplicates + # are bounded by the fraction of tasks whose lifecycle straddles + # midnight UTC — rare in practice — but a fetch_limit of just + # ``limit`` could underfill the result if duplicates show up. + union_params.append(fetch_limit * 2) + rows = self._conn.execute(sql, union_params).fetchall() + seen_ids: set[str] = set() results: list[InvocationRecord] = [] for row in rows: - if exclude_ids and row[0] in exclude_ids: + tid = row[0] + if tid in seen_ids: + continue # cross-partition dedup — rows arrive newest-first + seen_ids.add(tid) + if exclude_ids and tid in exclude_ids: continue results.append(self._row_to_record(row)) if len(results) >= limit: @@ -307,105 +534,160 @@ def search( params.extend([pattern] * 7) where = " AND ".join(clauses) if clauses else "1=1" + union_sql, union_params = self._union_subqueries(where, params) fetch_limit = limit + (len(exclude_ids) if exclude_ids else 0) sql = ( - "SELECT i.*, d.args_preview, d.kwargs_preview, d.traceback_snippet " - "FROM invocations i LEFT JOIN invocation_details d ON i.task_id = d.task_id " - f"WHERE {where} ORDER BY i.finished_at DESC LIMIT ? OFFSET ?" + f"SELECT * FROM ({union_sql}) " + "ORDER BY finished_at DESC LIMIT ? OFFSET ?" ) - params.extend([fetch_limit, offset]) - rows = self._conn.execute(sql, params).fetchall() + union_params.extend([fetch_limit * 2, offset]) + rows = self._conn.execute(sql, union_params).fetchall() + seen_ids: set[str] = set() results: list[InvocationRecord] = [] for row in rows: - if exclude_ids and row[0] in exclude_ids: + tid = row[0] + if tid in seen_ids: + continue # cross-partition dedup — rows arrive newest-first + seen_ids.add(tid) + if exclude_ids and tid in exclude_ids: continue results.append(self._row_to_record(row)) if len(results) >= limit: break return results + # -- recovery loaders ------------------------------------------------- + # + # Read from a separate connection (so the write lock isn't held) and + # iterate per-partition. Each partition's data is naturally bounded by + # one UTC day, which keeps the per-statement memory footprint stable. + def open_recovery_conn(self) -> sqlite3.Connection: - """Open a separate connection for recovery. Caller must close it.""" return self._connect(self.db_path) - def load_recovery_counts(self, conn: sqlite3.Connection, since_ts: float) -> Iterator[sqlite3.Row]: - """Aggregated counts per task/state/minute for fast recovery. - - Processes in 4-hour chunks so the read lock is released between - chunks, allowing WAL checkpointing to proceed. - """ - now = time.time() - chunk_start = since_ts - while chunk_start < now: - chunk_end = min(chunk_start + 14400, now + 1) # 4-hour windows - cur = conn.cursor() - cur.row_factory = sqlite3.Row - cur.execute( - "SELECT task_name, state, " - " (CAST(finished_at AS INTEGER) / 60 * 60) AS minute_ts, " - " COUNT(*) AS cnt, " - " worker, queue, exception_type " - "FROM invocations WHERE finished_at >= ? AND finished_at < ? " - "GROUP BY task_name, state, minute_ts, worker, queue, exception_type " - "ORDER BY task_name", - (chunk_start, chunk_end), - ) - yield from cur - cur.close() - # Explicit commit releases any read snapshot Python's sqlite3 - # module may hold, allowing WAL checkpointing to proceed. - conn.commit() - chunk_start = chunk_end - - def load_recovery_runtimes(self, conn: sqlite3.Connection, since_ts: float) -> Iterator[sqlite3.Row]: - """Stream individual runtime values for t-digest population. - - Chunked in 4-hour windows to release read locks periodically. + def _recovery_inv_tables(self) -> list[str]: + tables = [t for t, _ in self._read_tables()] + return tables + + def load_recovery_counts( + self, conn: sqlite3.Connection, since_ts: float + ) -> Iterator[sqlite3.Row]: + """Per-task/state/minute count rows. Chunked in 4-hour windows so + the read snapshot is released periodically — important on the + unbounded legacy table where a full scan can take minutes.""" + for inv_tbl in self._recovery_inv_tables(): + now = time.time() + chunk_start = since_ts + while chunk_start < now: + chunk_end = min(chunk_start + 14400, now + 1) + cur = conn.cursor() + cur.row_factory = sqlite3.Row + cur.execute( + f"SELECT task_name, state, " + f" (CAST(finished_at AS INTEGER) / 60 * 60) AS minute_ts, " + f" COUNT(*) AS cnt, " + f" worker, queue, exception_type " + f"FROM {inv_tbl} WHERE finished_at >= ? AND finished_at < ? " + f"GROUP BY task_name, state, minute_ts, worker, queue, exception_type " + f"ORDER BY task_name", + (chunk_start, chunk_end), + ) + yield from cur + cur.close() + conn.commit() + chunk_start = chunk_end + + def load_recovery_runtimes( + self, conn: sqlite3.Connection, since_ts: float + ) -> Iterator[sqlite3.Row]: + """Stream individual runtime values for t-digest population.""" + for inv_tbl in self._recovery_inv_tables(): + now = time.time() + chunk_start = since_ts + while chunk_start < now: + chunk_end = min(chunk_start + 14400, now + 1) + cur = conn.cursor() + cur.row_factory = sqlite3.Row + cur.execute( + f"SELECT task_name, " + f" (CAST(finished_at AS INTEGER) / 60 * 60) AS minute_ts, " + f" runtime_ms " + f"FROM {inv_tbl} " + f"WHERE finished_at >= ? AND finished_at < ? " + f" AND runtime_ms IS NOT NULL " + f"ORDER BY task_name", + (chunk_start, chunk_end), + ) + yield from cur + cur.close() + conn.commit() + chunk_start = chunk_end + + def load_recovery_pickup( + self, conn: sqlite3.Connection, since_ts: float + ) -> Iterator[sqlite3.Row]: + """Stream received_at/started_at pairs for pickup latency rebuild. + + Only reads the newest two partitions — pickup latency is a recent- + traffic signal, no value in pulling 5 days of history. """ - now = time.time() - chunk_start = since_ts - while chunk_start < now: - chunk_end = min(chunk_start + 14400, now + 1) + tables = self._recovery_inv_tables()[:2] + for inv_tbl in tables: cur = conn.cursor() cur.row_factory = sqlite3.Row cur.execute( - "SELECT task_name, " - " (CAST(finished_at AS INTEGER) / 60 * 60) AS minute_ts, " - " runtime_ms " - "FROM invocations " - "WHERE finished_at >= ? AND finished_at < ? AND runtime_ms IS NOT NULL " - "ORDER BY task_name", - (chunk_start, chunk_end), + f"SELECT queue, (started_at - received_at) * 1000 AS wait_ms " + f"FROM {inv_tbl} " + f"WHERE finished_at >= ? " + f" AND received_at IS NOT NULL AND started_at IS NOT NULL " + f" AND started_at > received_at " + f"ORDER BY finished_at DESC LIMIT 5000", + (since_ts,), ) yield from cur cur.close() - conn.commit() - chunk_start = chunk_end - - def load_recovery_pickup(self, conn: sqlite3.Connection, since_ts: float) -> Iterator[sqlite3.Row]: - """Stream received_at/started_at pairs for pickup latency rebuild.""" - cur = conn.cursor() - cur.row_factory = sqlite3.Row - cur.execute( - "SELECT queue, (started_at - received_at) * 1000 AS wait_ms " - "FROM invocations " - "WHERE finished_at >= ? " - " AND received_at IS NOT NULL AND started_at IS NOT NULL " - " AND started_at > received_at " - "ORDER BY finished_at DESC LIMIT 5000", - (since_ts,), - ) - yield from cur @_serialized def refresh_cached_stats(self) -> None: - """Update cached stats for healthz — called from purge loop.""" - row = self._conn.execute("SELECT count(*) FROM invocations").fetchone() - self._cached_row_count = row[0] if row else 0 - row = self._conn.execute("SELECT count(*) FROM invocation_details").fetchone() - self._cached_detail_row_count = row[0] if row else 0 - row = self._conn.execute("SELECT MIN(finished_at) FROM invocations").fetchone() - self._cached_oldest_at = row[0] if row and row[0] is not None else None + """Update cached stats for healthz — called from purge loop. + + Daily partitions are size-bounded (~1 day of data) so ``COUNT(*)`` + is cheap. The legacy table is multi-GB, so we use ``MAX(rowid)`` — + index-fast and accurate enough for an approximate healthz number + that disappears once legacy is dropped. + """ + total_inv = 0 + total_det = 0 + oldest: float | None = None + for inv_tbl, det_tbl in self._read_tables(): + if inv_tbl == LEGACY_INV: + row = self._conn.execute( + f"SELECT MAX(rowid) FROM {inv_tbl}" + ).fetchone() + total_inv += row[0] if row and row[0] is not None else 0 + else: + row = self._conn.execute(f"SELECT count(*) FROM {inv_tbl}").fetchone() + total_inv += row[0] if row else 0 + if det_tbl: + if det_tbl == LEGACY_DETAILS: + row = self._conn.execute( + f"SELECT MAX(rowid) FROM {det_tbl}" + ).fetchone() + total_det += row[0] if row and row[0] is not None else 0 + else: + row = self._conn.execute( + f"SELECT count(*) FROM {det_tbl}" + ).fetchone() + total_det += row[0] if row else 0 + # MIN(finished_at) uses idx_inv_finished — fast on legacy too. + row = self._conn.execute( + f"SELECT MIN(finished_at) FROM {inv_tbl}" + ).fetchone() + if row and row[0] is not None: + oldest = row[0] if oldest is None else min(oldest, row[0]) + self._cached_row_count = total_inv + self._cached_detail_row_count = total_det + self._cached_oldest_at = oldest def db_size_mb(self) -> float: """Approximate DB file size in MB.""" @@ -415,7 +697,7 @@ def db_size_mb(self) -> float: page_size = row[0] if row else 4096 return (pages * page_size) / (1024 * 1024) - # -- metadata persistence ----------------------------------------------- + # -- metadata persistence --------------------------------------------- @_serialized def save_metadata(self, key: str, values: list[str]) -> None: @@ -436,7 +718,7 @@ def load_metadata(self, key: str) -> list[str]: ).fetchall() return [r[0] for r in rows] - # -- aggregate snapshots -------------------------------------------------- + # -- aggregate snapshots ---------------------------------------------- @_serialized def save_snapshots(self, snapshots: list[tuple[str, float, bytes]]) -> int: @@ -444,21 +726,20 @@ def save_snapshots(self, snapshots: list[tuple[str, float, bytes]]) -> int: if not snapshots: return 0 self._conn.executemany( - "INSERT OR REPLACE INTO aggregate_snapshots (task_name, snapshot_ts, data) VALUES (?, ?, ?)", + "INSERT OR REPLACE INTO aggregate_snapshots " + "(task_name, snapshot_ts, data) VALUES (?, ?, ?)", snapshots, ) self._conn.commit() return len(snapshots) def load_snapshots(self) -> list[tuple[str, float, bytes]]: - """Load all aggregate snapshots.""" rows = self._conn.execute( "SELECT task_name, snapshot_ts, data FROM aggregate_snapshots" ).fetchall() return rows def min_snapshot_ts(self) -> float | None: - """Oldest snapshot timestamp, or None if table is empty.""" row = self._conn.execute( "SELECT MIN(snapshot_ts) FROM aggregate_snapshots" ).fetchone() @@ -483,7 +764,7 @@ def close(self) -> None: self.checkpoint(truncate=True) self._conn.close() - # -- WAL management ----------------------------------------------------- + # -- WAL management --------------------------------------------------- @_serialized def checkpoint(self, *, truncate: bool = False) -> None: @@ -535,7 +816,6 @@ def disk_usage_pct(self) -> float: return 0.0 def disk_free_mb(self) -> float: - """Free disk space in MB on the partition hosting the DB file.""" try: stat = os.statvfs(self.db_path) return (stat.f_bavail * stat.f_frsize) / (1024 * 1024) @@ -543,8 +823,6 @@ def disk_free_mb(self) -> float: return 0.0 def _row_to_record(self, row: tuple) -> InvocationRecord: - # Core columns: 0-9 (task_id..exception_type) - # Detail columns from LEFT JOIN: 10-12 (args_preview, kwargs_preview, traceback_snippet) return InvocationRecord( task_id=row[0], task_name=row[1], diff --git a/src/phlower/store.py b/src/phlower/store.py index 8f433c3..e021344 100644 --- a/src/phlower/store.py +++ b/src/phlower/store.py @@ -349,8 +349,18 @@ def __init__(self, config: Config, sqlite_store: SQLiteStore | None = None) -> N # rolling event counter for tasks/sec display self._event_timestamps: deque[float] = deque(maxlen=2000) - # SQLite write-behind buffer (CompletedRecords, snapshotted at completion) + # SQLite write-behind buffer (CompletedRecords, snapshotted at completion). + # Capped — if the flush loop falls behind (slow purge, big checkpoint, + # bursty ingest), we drop oldest. RAM staying bounded matters more than + # capturing every record during overload; the alternative is RSS + # runaway → liveness probe → SIGKILL, and ALL pending records lost. self._sqlite_pending: list[CompletedRecord] = [] + self._sqlite_pending_cap: int = config.sqlite_pending_buffer_cap + self._dropped_invocations_total: int = 0 + # Cap _new_invocation_ids at the same proportion — it's drained on a + # 600ms SSE cadence, far faster than SQLite flush, but still worth + # bounding so a stalled SSE loop can't grow it without bound. + self._new_invocation_ids_cap: int = max(10_000, config.sqlite_pending_buffer_cap // 4) # pickup latency (received→started) per queue, rolling buffer self._pickup_latencies: dict[str, deque[float]] = defaultdict( @@ -384,6 +394,33 @@ def _snapshot(rec: InvocationRecord, *, include_detail: bool = True) -> Complete ) + def _append_pending(self, snapshot: CompletedRecord, task_id: str) -> None: + """Append to write-behind buffer, dropping oldest if over cap. + + Dropping happens in chunks (10% of cap) to amortize the list-slice + cost — single-element drops on a long list copy the whole tail on + every append. + """ + self._sqlite_pending.append(snapshot) + if len(self._sqlite_pending) > self._sqlite_pending_cap: + drop_n = max(1, self._sqlite_pending_cap // 10) + self._sqlite_pending = self._sqlite_pending[drop_n:] + self._dropped_invocations_total += drop_n + if self._dropped_invocations_total % drop_n == 0: + logger.warning( + "SQLite write-behind buffer over cap — dropped %d records " + "(total dropped: %d). Flush loop is falling behind.", + drop_n, self._dropped_invocations_total, + ) + self._new_invocation_ids.append(task_id) + if len(self._new_invocation_ids) > self._new_invocation_ids_cap: + drop_n = max(1, self._new_invocation_ids_cap // 10) + self._new_invocation_ids = self._new_invocation_ids[drop_n:] + + def snapshot_dropped_invocations(self) -> int: + """Counter of write-behind records dropped since startup.""" + return self._dropped_invocations_total + def _get_or_create_task(self, task_name: str) -> TaskAggregate: agg = self.tasks.get(task_name) if agg is None: @@ -532,13 +569,16 @@ def process_succeeded( rec.runtime_ms = runtime_ms rec.updated_at = ts rec.transitions.append((TaskState.SUCCESS, ts)) - self._new_invocation_ids.append(task_id) self._dirty_tasks.add(name) self._snapshot_dirty.add(name) if self.sqlite_store is not None: threshold = self.config.detail_rate_threshold include_detail = threshold <= 0 or agg._recent_rate() <= threshold - self._sqlite_pending.append(self._snapshot(rec, include_detail=include_detail)) + self._append_pending( + self._snapshot(rec, include_detail=include_detail), task_id, + ) + else: + self._new_invocation_ids.append(task_id) def process_failed( self, @@ -580,13 +620,14 @@ def process_failed( rec.traceback_snippet = traceback_snippet rec.updated_at = ts rec.transitions.append((TaskState.FAILURE, ts)) - self._new_invocation_ids.append(task_id) self._dirty_tasks.add(name) self._snapshot_dirty.add(name) if self.sqlite_store is not None: - self._sqlite_pending.append(self._snapshot(rec)) + self._append_pending(self._snapshot(rec), task_id) rec.traceback_snippet = None rec.exception_message = None + else: + self._new_invocation_ids.append(task_id) def process_retried( self, @@ -616,13 +657,14 @@ def process_retried( rec.traceback_snippet = traceback_snippet rec.updated_at = ts rec.transitions.append((TaskState.RETRY, ts)) - self._new_invocation_ids.append(task_id) self._dirty_tasks.add(name) self._snapshot_dirty.add(name) if self.sqlite_store is not None: - self._sqlite_pending.append(self._snapshot(rec)) + self._append_pending(self._snapshot(rec), task_id) rec.traceback_snippet = None rec.exception_message = None + else: + self._new_invocation_ids.append(task_id) # -- periodic maintenance ---------------------------------------------