diff --git a/scripts/create_data_fast_sample.py b/scripts/create_data_fast_sample.py index 76206227..798ec14a 100644 --- a/scripts/create_data_fast_sample.py +++ b/scripts/create_data_fast_sample.py @@ -7,16 +7,13 @@ import numpy as np from segger.data.parquet._utils import get_polygons_from_xy -xenium_data_dir = Path('data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/') -segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_200_final') +xenium_data_dir = Path("data_raw/breast_cancer/Xenium_FFPE_Human_Breast_Cancer_Rep1/outs/") +segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_200_final") -scrnaseq_file = Path('/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad') -celltype_column = 'celltype_minor' -gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding( - sc.read(scrnaseq_file), - celltype_column -) +scrnaseq_file = Path("/omics/groups/OE0606/internal/tangy/tasks/schier/data/atals_filtered.h5ad") +celltype_column = "celltype_minor" +gene_celltype_abundance_embedding = calculate_gene_celltype_abundance_embedding(sc.read(scrnaseq_file), celltype_column) sample = STSampleParquet( base_dir=xenium_data_dir, @@ -43,30 +40,29 @@ sample.save( - data_dir=segger_data_dir, - k_bd=3, - dist_bd=15, - k_tx=3, - dist_tx=5, - tile_width=200, - tile_height=200, - neg_sampling_ratio=5.0, - frac=1.0, - val_prob=0.3, - test_prob=0, + data_dir=segger_data_dir, + k_bd=3, + dist_bd=15, + k_tx=3, + dist_tx=5, + tile_width=200, + tile_height=200, + neg_sampling_ratio=5.0, + frac=1.0, + val_prob=0.3, + test_prob=0, ) -xenium_data_dir = Path('data_tidy/bc_5k') -segger_data_dir = Path('data_tidy/pyg_datasets/bc_5k_emb_new') - +xenium_data_dir = Path("data_tidy/bc_5k") +segger_data_dir = Path("data_tidy/pyg_datasets/bc_5k_emb_new") sample = STSampleParquet( base_dir=xenium_data_dir, n_workers=8, - sample_type='xenium', - weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available + sample_type="xenium", + weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available ) @@ -88,16 +84,14 @@ sample.save( - data_dir=segger_data_dir, - k_bd=3, - dist_bd=15.0, - k_tx=15, - dist_tx=3, - tile_size=50_000, - neg_sampling_ratio=5.0, - frac=0.1, - val_prob=0.1, - test_prob=0.1, + data_dir=segger_data_dir, + k_bd=3, + dist_bd=15.0, + k_tx=15, + dist_tx=3, + tile_size=50_000, + neg_sampling_ratio=5.0, + frac=0.1, + val_prob=0.1, + test_prob=0.1, ) - - diff --git a/scripts/predict_model_sample.py b/scripts/predict_model_sample.py index f6df9b5a..b6969afa 100644 --- a/scripts/predict_model_sample.py +++ b/scripts/predict_model_sample.py @@ -22,8 +22,8 @@ seg_tag = "bc_fast_data_emb_major" model_version = 1 -segger_data_dir = Path('data_tidy/pyg_datasets') / seg_tag -models_dir = Path("./models") / seg_tag +segger_data_dir = Path("data_tidy/pyg_datasets") / seg_tag +models_dir = Path("./models") / seg_tag benchmarks_dir = Path("/dkfz/cluster/gpu/data/OE0606/elihei/segger_experiments/data_tidy/benchmarks/xe_rep1_bc") transcripts_file = "data_raw/xenium/Xenium_FFPE_Human_Breast_Cancer_Rep1/transcripts.parquet" # Initialize the Lightning data module diff --git a/scripts/train_model_sample.py b/scripts/train_model_sample.py index cc092ae6..aafe8053 100644 --- a/scripts/train_model_sample.py +++ b/scripts/train_model_sample.py @@ -15,7 +15,7 @@ import os -segger_data_dir = segger_data_dir = Path('data_tidy/pyg_datasets/bc_rep1_emb_final_200') +segger_data_dir = segger_data_dir = Path("data_tidy/pyg_datasets/bc_rep1_emb_final_200") models_dir = Path("./models/bc_rep1_emb_final_200") # Base directory to store Pytorch Lightning models @@ -35,37 +35,34 @@ # If you use custom gene embeddings, use the following two lines instead: is_token_based = False -num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes +num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes num_bd_features = dm.train[0].x_dict["bd"].shape[1] # Initialize the Lightning model ls = LitSegger( - is_token_based = is_token_based, - num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features}, - init_emb=8, + is_token_based=is_token_based, + num_node_features={"tx": num_tx_tokens, "bd": num_bd_features}, + init_emb=8, hidden_channels=64, out_channels=16, heads=4, num_mid_layers=3, - aggr='sum', - learning_rate=1e-3 + aggr="sum", + learning_rate=1e-3, ) # Initialize the Lightning trainer trainer = Trainer( - accelerator='cuda', - strategy='auto', - precision='16-mixed', - devices=2, # set higher number if more gpus are available + accelerator="cuda", + strategy="auto", + precision="16-mixed", + devices=2, # set higher number if more gpus are available max_epochs=400, default_root_dir=models_dir, logger=CSVLogger(models_dir), ) -trainer.fit( - model=ls, - datamodule=dm -) \ No newline at end of file +trainer.fit(model=ls, datamodule=dm) diff --git a/src/segger/data/utils.py b/src/segger/data/utils.py index 32faca99..2921b4f4 100644 --- a/src/segger/data/utils.py +++ b/src/segger/data/utils.py @@ -43,7 +43,7 @@ def try_import(module_name): from datetime import timedelta -def filter_transcripts( #ONLY FOR XENIUM +def filter_transcripts( # ONLY FOR XENIUM transcripts_df: pd.DataFrame, min_qv: float = 20.0, ) -> pd.DataFrame: @@ -65,14 +65,14 @@ def filter_transcripts( #ONLY FOR XENIUM "DeprecatedCodeword_", "UnassignedCodeword_", ) - - transcripts_df['feature_name'] = transcripts_df['feature_name'].apply( + + transcripts_df["feature_name"] = transcripts_df["feature_name"].apply( lambda x: x.decode("utf-8") if isinstance(x, bytes) else x ) - mask_quality = transcripts_df['qv'] >= min_qv + mask_quality = transcripts_df["qv"] >= min_qv # Apply the filter for unwanted codewords using Dask string functions - mask_codewords = ~transcripts_df['feature_name'].str.startswith(filter_codewords) + mask_codewords = ~transcripts_df["feature_name"].str.startswith(filter_codewords) # Combine the filters and return the filtered Dask DataFrame mask = mask_quality & mask_codewords diff --git a/src/segger/prediction/predict_parquet.py b/src/segger/prediction/predict_parquet.py index f5be68c1..36326505 100644 --- a/src/segger/prediction/predict_parquet.py +++ b/src/segger/prediction/predict_parquet.py @@ -13,13 +13,7 @@ from pathlib import Path from torch_geometric.loader import DataLoader from torch_geometric.data import Batch -from segger.data.utils import ( - get_edge_index, - format_time, - create_anndata, - coo_to_dense_adj, - filter_transcripts -) +from segger.data.utils import get_edge_index, format_time, create_anndata, coo_to_dense_adj, filter_transcripts from segger.training.train import LitSegger from segger.training.segger_data_module import SeggerDataModule from segger.prediction.boundary import generate_boundaries @@ -36,7 +30,7 @@ from cupyx.scipy.sparse import coo_matrix from torch.utils.dlpack import to_dlpack, from_dlpack -from dask.distributed import Client, LocalCluster +from dask.distributed import Client, LocalCluster, Future import cupy as cp import numpy as np import pandas as pd @@ -286,6 +280,7 @@ def sparse_multiply(embeddings, edge_index, shape) -> coo_matrix: def predict_batch( + client: Client, lit_segger: torch.nn.Module, batch: Batch, score_cut: float, @@ -295,12 +290,13 @@ def predict_batch( edge_index_save_path: Union[str, Path] = None, output_ddf_save_path: Union[str, Path] = None, gpu_id: int = 0, # Added argument for GPU ID -): +) -> tuple[Future | None, Future | None]: """ Predict cell assignments for a batch of transcript data using a segmentation model. Writes both the assignments and edge_index directly into Parquet files incrementally. Args: + client (Client): The client to connect to and submit computation to a dask cluster. lit_segger (torch.nn.Module): The lightning module wrapping the segmentation model. batch (Batch): A batch of transcript and cell data. score_cut (float): The threshold for assigning transcripts to cells based on similarity scores. @@ -315,6 +311,8 @@ def predict_batch( gpu_id (int, optional): The GPU ID to use for the computations. Defaults to 0. """ + delayed_write_edge_index_future, delayed_write_output_ddf_future = None, None + def _get_id(): """Generate a random Xenium-style ID.""" return "".join(np.random.choice(list("abcdefghijklmnopqrstuvwxyz"), 8)) + "-nx" @@ -410,7 +408,7 @@ def _get_id(): delayed_write_edge_index = delayed(edge_index_ddf.to_parquet)( edge_index_save_path, append=True, ignore_divisions=True ) - delayed_write_edge_index.persist() # Schedule writing + delayed_write_edge_index_future = client.persist(delayed_write_edge_index) # Schedule writing assignments = { "transcript_id": assignments["transcript_id"].astype("str"), @@ -428,12 +426,14 @@ def _get_id(): delayed_write_output_ddf = delayed(batch_ddf.to_parquet)( output_ddf_save_path, append=True, ignore_divisions=True ) - delayed_write_output_ddf.persist() # Schedule writing + delayed_write_output_ddf_future = client.persist(delayed_write_output_ddf) # Schedule writing # Free memory after computation cp.get_default_memory_pool().free_all_blocks() # Free CuPy memory torch.cuda.empty_cache() + return delayed_write_edge_index_future, delayed_write_output_ddf_future + def segment( model: LitSegger, @@ -482,6 +482,7 @@ def segment( None. Saves the result to disk in various formats and logs the parameter choices. """ + client = Client() start_time = time() # Create a subdirectory with important parameter info (receptive field values) @@ -511,6 +512,9 @@ def segment( val_dataloader = dm.val_dataloader() test_dataloader = dm.test_dataloader() + delayed_write_edge_index_futures = [] + delayed_write_output_ddf_futures = [] + # Loop through the data loaders (train, val, and test) for loader_name, loader in zip( ["Train", "Validation", "Test"], [train_dataloader, val_dataloader, test_dataloader] @@ -522,7 +526,8 @@ def segment( for batch in tqdm(loader, desc=f"Processing {loader_name} batches"): gpu_id = random.choice(gpu_ids) # Call predict_batch for each batch - predict_batch( + delayed_write_edge_index_future, delayed_write_output_ddf_future = predict_batch( + client, model, batch, score_cut, @@ -534,23 +539,30 @@ def segment( gpu_id=gpu_id, ) + if delayed_write_edge_index_future is not None: + delayed_write_edge_index_futures.append(delayed_write_edge_index_future) + + if delayed_write_output_ddf_future is not None: + delayed_write_output_ddf_futures.append(delayed_write_output_ddf_future) + if verbose: elapsed_time = time() - step_start_time print(f"Batch processing completed in {elapsed_time:.2f} seconds.") + client.gather(delayed_write_output_ddf_futures) + assert os.path.exists(output_ddf_save_path) seg_final_dd = pd.read_parquet(output_ddf_save_path) step_start_time = time() if verbose: print(f"Applying max score selection logic...") output_ddf_save_path = save_dir / "transcripts_df.parquet" - - + seg_final_dd = pd.read_parquet(output_ddf_save_path) - - seg_final_filtered = seg_final_dd.sort_values( - "score", ascending=False - ).drop_duplicates(subset="transcript_id", keep="first") + + seg_final_filtered = seg_final_dd.sort_values("score", ascending=False).drop_duplicates( + subset="transcript_id", keep="first" + ) if verbose: elapsed_time = time() - step_start_time @@ -570,7 +582,7 @@ def segment( # Outer merge to include all transcripts, even those without assigned cell ids transcripts_df_filtered = transcripts_df.merge(seg_final_filtered, on="transcript_id", how="outer") - + if verbose: elapsed_time = time() - step_start_time print(f"Merged segmentation results with transcripts in {elapsed_time:.2f} seconds.") @@ -581,6 +593,8 @@ def segment( if verbose: print(f"Computing connected components for unassigned transcripts...") # Load edge indices from saved Parquet + client.gather(delayed_write_edge_index_futures) + assert os.path.exists(edge_index_save_path) edge_index_dd = pd.read_parquet(edge_index_save_path) # Step 2: Get unique transcript_ids from edge_index_dd and their positional indices