From 718aae6a8747388bffd6a659a3c7f01fe3257f35 Mon Sep 17 00:00:00 2001 From: Tigran Soghbatyan Date: Thu, 9 Apr 2026 20:06:54 +0200 Subject: [PATCH] ws_dataset: add parallel_collect option to SQL query execution Adds a `parallel_collect` parameter to `_parse_sql_queries_polars` and `_execute_sql_queries_polars` that uses `pl.collect_all()` to collect all lazy frames in parallel before concatenation. --- wsds/ws_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index 35dd5cd..278a0ef 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -234,7 +234,7 @@ def __len__(self): # # SQL support, using Polars # - def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard_pipe=None): + def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard_pipe=None, parallel_collect=False): """Parses SQL queries via Polars to: - extract the Polars expressions for each query - use the expressions to build a list of column dirs to load shards from""" @@ -336,6 +336,8 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard f"No usable shards found (columns: {', '.join(column_dirs)}) for dataset in: {str(self.dataset_root)}" ) + if parallel_collect: + return exprs, pl.concat(pl.collect_all(row_merge)).lazy() return exprs, pl.concat(row_merge) def _check_for_subsampling(self, shard_subsample): @@ -375,6 +377,7 @@ def sql_select( shard_subsample=None, rng=42, shard_pipe=None, + parallel_collect: bool = False, ) -> pl.DataFrame | pl.LazyFrame: """Given a list of SQL expressions, returns a Polars DataFrame/ LazyFrame with the results.""" if isinstance(rng, int): @@ -384,6 +387,7 @@ def sql_select( shard_subsample=self._check_for_subsampling(shard_subsample), rng=rng, shard_pipe=shard_pipe, + parallel_collect=parallel_collect, ) if return_as_lazyframe: