Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __len__(self):
#
# SQL support, using Polars
#
def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard_pipe=None):
def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard_pipe=None, parallel_collect=False):
"""Parses SQL queries via Polars to:
- extract the Polars expressions for each query
- use the expressions to build a list of column dirs to load shards from"""
Expand Down Expand Up @@ -336,6 +336,8 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard
f"No usable shards found (columns: {', '.join(column_dirs)}) for dataset in: {str(self.dataset_root)}"
)

if parallel_collect:
return exprs, pl.concat(pl.collect_all(row_merge)).lazy()
return exprs, pl.concat(row_merge)

def _check_for_subsampling(self, shard_subsample):
Expand Down Expand Up @@ -375,6 +377,7 @@ def sql_select(
shard_subsample=None,
rng=42,
shard_pipe=None,
parallel_collect: bool = False,
) -> pl.DataFrame | pl.LazyFrame:
"""Given a list of SQL expressions, returns a Polars DataFrame/ LazyFrame with the results."""
if isinstance(rng, int):
Expand All @@ -384,6 +387,7 @@ def sql_select(
shard_subsample=self._check_for_subsampling(shard_subsample),
rng=rng,
shard_pipe=shard_pipe,
parallel_collect=parallel_collect,
)

if return_as_lazyframe:
Expand Down