diff --git a/pf2rnaseq/Data/DN11Doublets.csv.gz b/pf2rnaseq/Data/DN11Doublets.csv.gz new file mode 100644 index 0000000..fc234c3 Binary files /dev/null and b/pf2rnaseq/Data/DN11Doublets.csv.gz differ diff --git a/pf2rnaseq/factorization.py b/pf2rnaseq/factorization.py index dd81f10..0aff1b5 100644 --- a/pf2rnaseq/factorization.py +++ b/pf2rnaseq/factorization.py @@ -16,10 +16,15 @@ def correct_conditions(X: anndata.AnnData): """Correct the conditions factors by overall read depth. Ensures that weighting is not affected by cell count difference""" - # sgIndex = X.obs["condition_unique_idxs"] - sgIndex = X.obs["condition_unique_idxs"].cat.codes + sgIndex = X.obs["condition_unique_idxs"] + #sgIndex = X.obs["condition_unique_idxs"].cat.codes counts = np.zeros((np.amax(sgIndex) + 1, 1)) - + min_val = np.min(X.uns["Pf2_A"]) + if min_val < 0: + # Add the absolute value of the minimum (plus a small epsilon) to make all values positive + X.uns["Pf2_A"] = X.uns["Pf2_A"] + abs(min_val) + 1e-10 + print(f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values.") + cond_mean = gmean(X.uns["Pf2_A"], axis=1) x_count = X.X.sum(axis=1) @@ -43,7 +48,7 @@ def pf2( tolerance=1e-9, r2x=False, ): - cupy.cuda.Device(1).use() + cupy.cuda.Device(0).use() pf_out, R2X = parafac2_nd( X, @@ -72,7 +77,7 @@ def pf2_pca_r2x(X: anndata.AnnData, ranks): r2x_pf2 = np.zeros(len(ranks)) for i in tqdm(range(len(r2x_pf2)), total=len(r2x_pf2)): - _, R2X = parafac2_nd(X, rank=i + 1) + _, R2X = parafac2_nd(X, rank=ranks[i]) r2x_pf2[i] = R2X pca = PCA(n_components=ranks[-1], svd_solver="arpack") diff --git a/pf2rnaseq/figures/commonFuncs/plotFactors.py b/pf2rnaseq/figures/commonFuncs/plotFactors.py index 31604df..ccacc0d 100644 --- a/pf2rnaseq/figures/commonFuncs/plotFactors.py +++ b/pf2rnaseq/figures/commonFuncs/plotFactors.py @@ -18,13 +18,14 @@ def plot_condition_factors( cond_group_labels: pd.Series | None = None, groupConditions=False, cond="Condition", + log_scale=True, ): """Plots Pf2 condition factors""" pd.set_option("display.max_rows", None) yt = pd.Series(np.unique(data.obs[cond])) X = np.array(data.uns["Pf2_A"]) - - X = np.log10(X) + if log_scale: + X = np.log10(X) X -= np.median(X, axis=0) X /= np.std(X, axis=0) @@ -41,6 +42,7 @@ def plot_condition_factors( X = X[ind] yt = yt.iloc[ind] ax.tick_params(axis="y", which="major", pad=20, length=0) + # extra padding to leave room for the row colors # get list of colors for each label: colors = sns.color_palette( @@ -68,6 +70,7 @@ def plot_condition_factors( ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3)) xticks = np.arange(1, X.shape[1] + 1) + sns.heatmap( data=X, xticklabels=xticks, @@ -170,7 +173,7 @@ def plot_condition_factors_groups( ax.tick_params( axis="y", which="major", - pad=40 if subgroup_labels is not None else 20, + pad=30 if subgroup_labels is not None else 30, length=0, ) @@ -190,8 +193,8 @@ def plot_condition_factors_groups( for i, color in enumerate(main_row_colors): ax.add_patch( plt.Rectangle( - xy=(-0.05, i), - width=0.05, + xy=(-0.02, i), + width=0.02, height=1, color=color, lw=0, @@ -218,8 +221,8 @@ def plot_condition_factors_groups( for i, color in enumerate(sub_row_colors): ax.add_patch( plt.Rectangle( - xy=(-0.10, i), # Position to left of main group colors - width=0.05, + xy=(-0.04, i), # Position to left of main group colors + width=0.02, height=1, color=color, lw=0, @@ -317,6 +320,7 @@ def plot_gene_factors( vmin=-1, vmax=1, ) + ax.set(xlabel="Component") @@ -332,11 +336,8 @@ def plot_geneSet_factors( X = X[kept_idxs] yt = yt[kept_idxs] - # index for genes - # ind = reorder_table(X) - # X = X[ind] X = X / np.max(np.abs(X)) - # yt = [yt[ii] for ii in ind] + xticks = np.arange(1, rank + 1) sns.heatmap( @@ -422,6 +423,17 @@ def plot_geneSetScore( # Calculate the sum of X values for each component component_sums = np.sum(X, axis=0) + # Find the top 3 components with highest absolute scores + top_3_indices = np.argsort(np.abs(component_sums))[-3:] + + # Create colors array - highlight the top 3 components + colors = [] + for i in range(len(component_sums)): + if i in top_3_indices: + colors.append("darkred") + else: + colors.append("steelblue") + # Create the bar plot xticks = np.arange(1, rank + 1) sns.barplot(x=xticks, y=component_sums, ax=ax) @@ -429,93 +441,28 @@ def plot_geneSetScore( ax.set_xlabel("Component", fontsize=12) ax.set_ylabel("Sum of Weights", fontsize=12) ax.set_title("Sum of Gene Factors per Component", fontsize=12) - - -def plot_geneSetScoreDot( - data: AnnData, - ax: Axes, - genes: np.array, - trim=True, - size_scale=500, - cmap="coolwarm", - size_norm=None, -): - """ - Plots Pf2 gene set score as a dot plot. - - Parameters: - data: AnnData object containing the Pf2 results - ax: Matplotlib axes to plot on - genes: Array of gene names to include in the score - trim: Whether to trim genes (not used in this version) - size_scale: Scaling factor for dot sizes - cmap: Colormap for the dot colors - size_norm: Optional normalization for dot sizes - """ - # Get dimensions and data - rank = data.varm["Pf2_C"].shape[1] - X = np.array(data.varm["Pf2_C"]) - yt = data.var.index.values - - # Filter the genes - kept_idxs = np.where(np.in1d(yt, genes))[0] - if len(kept_idxs) == 0: + # Add labels to the top 3 components + for idx in top_3_indices: + component_num = idx + 1 + y_pos = component_sums[idx] + 0.01 * np.sign(component_sums[idx]) * np.max( + np.abs(component_sums) + ) ax.text( - 0.5, - 0.5, - "No matching genes found", + component_num, + y_pos, + f"Comp. {component_num}", ha="center", - va="center", - transform=ax.transAxes, + va="bottom" if component_sums[idx] > 0 else "top", + fontsize=12, + fontweight="bold", + color="darkred", ) - return - X = X[kept_idxs] - - # Calculate the sum of X values for each component - component_sums = np.sum(X, axis=0) - - # Create DataFrame for plotting - df = pd.DataFrame( - { - "Component": np.arange(1, rank + 1), - "Sum": component_sums, - "Magnitude": np.abs(component_sums), - } - ) - - # Create a diverging colormap based on value - min_val = df["Sum"].min() - max_val = df["Sum"].max() - abs_max = max(abs(min_val), abs(max_val)) - - # Create the dot plot - scatter = sns.scatterplot( - data=df, - x="Component", - y=[0] * rank, # All dots on same horizontal line - size="Magnitude", - hue="Sum", - palette=cmap, - sizes=(20, size_scale), - size_norm=size_norm, - ax=ax, - legend="brief", - ) - - # Customize the plot - ax.set(xlabel="Component", title="Gene Set Score by Component") - ax.set_yticks([]) # Remove y-axis ticks - ax.set_ylim([-1, 1]) # Set y-axis limits - - # Set x-axis ticks to integers - ax.set_xticks(np.arange(1, rank + 1)) - - # Move the legend outside the plot - plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left") - - # Add a horizontal line at y=0 - ax.axhline(y=0, color="gray", linestyle="-", lw=0.5, alpha=0.7) + ax.set_xlabel("Component", fontsize=20) + ax.set_ylabel("Sum of Weights", fontsize=20) + ax.set_title("Signature Score", fontsize=25) + ax.tick_params(axis="x", rotation=90, labelsize=16) + ax.tick_params(axis="y", labelsize=16) def plot_ttest(X: AnnData, ax: Axes): @@ -527,7 +474,7 @@ def plot_ttest(X: AnnData, ax: Axes): # Get all cytokines directly without separate function all_cytokines = X.obs["cyt"].unique() - #Get highly weighted cytokines per component + # Get highly weighted cytokines per component results_df = highly_weighted_cytokines(X) # Create pivot table for all components @@ -565,3 +512,84 @@ def plot_ttest(X: AnnData, ax: Axes): ax.set_title("Highly weighted Cytokines Across Components (ANOVA + Post-hoc)") ax.set_xlabel("Component") ax.set_ylabel("Cytokine") + + +def plot_comp_weights( + data: AnnData, + ax: Axes, + comp: int, + cond="Condition", + sort_bars=True, + top_n=3, + include_lowest=True, +): + """Plots component weights for each condition as a bar chart""" + + # Get condition names and factor matrix + cond_df = ( + data.obs[[cond, "condition_unique_idxs"]] + .drop_duplicates() + .sort_values("condition_unique_idxs") + ) + yt = cond_df[cond].to_numpy() + X = np.array(data.uns["Pf2_A"]) + cond_mapping = data.obs.groupby("condition_unique_idxs", sort=True)[cond].first() + + # Extract condition names and indices + condition_indices = cond_mapping.index.to_numpy() + yt = cond_mapping.values + + # Extract weights + component_weights = X[condition_indices, comp - 1] + # Create DataFrame for plotting + df = pd.DataFrame({"Condition": yt, "Weight": component_weights}) + + # Get top N highest weighted conditions + top_n_highest = df.nlargest(top_n, "Weight") + + # Conditionally get lowest weighted conditions + if include_lowest: + top_n_lowest = df.nsmallest(top_n, "Weight") + # Combine and keep only top conditions + df_filtered = pd.concat([top_n_highest, top_n_lowest]).drop_duplicates() + else: + df_filtered = top_n_highest + + # Sort by weight if requested + if sort_bars: + df_filtered = df_filtered.sort_values("Weight", ascending=False) + + # Create color mapping - highest in red, lowest in blue (if included) + colors = [] + for condition in df_filtered["Condition"]: + if condition in top_n_highest["Condition"].values: + colors.append("darkred") # Highest weights + else: + colors.append("darkblue") # Lowest weights + + # Create bar plot with custom colors + bars = ax.bar(df_filtered["Condition"], df_filtered["Weight"], color=colors) + + # Customize the plot title based on whether lowest are included + if include_lowest: + title = f"Component {comp} Weights by Condition (Top {top_n} Highest/Lowest)" + else: + title = f"Component {comp} Weights by Condition (Top {top_n} Highest)" + + ax.set_title(title, fontsize=20) + ax.set_xlabel("Condition", fontsize=18) + ax.set_ylabel("Weight", fontsize=18) + ax.tick_params(axis="x", rotation=90, labelsize=18) + ax.tick_params(axis="y", labelsize=18) + + # Add legend for color coding (only if lowest are included) + if include_lowest: + + legend_elements = [ + Patch(facecolor="darkred", label=f"Top {top_n} Highest"), + Patch(facecolor="darkblue", label=f"Top {top_n} Lowest"), + ] + ax.legend(handles=legend_elements, loc="upper right") + + # Add horizontal line at y=0 for reference + ax.axhline(y=0, color="gray", linestyle="-", alpha=0.3) diff --git a/pf2rnaseq/figures/commonFuncs/plotGeneral.py b/pf2rnaseq/figures/commonFuncs/plotGeneral.py index db0f833..a9da204 100644 --- a/pf2rnaseq/figures/commonFuncs/plotGeneral.py +++ b/pf2rnaseq/figures/commonFuncs/plotGeneral.py @@ -6,7 +6,7 @@ import seaborn as sns from matplotlib.axes import Axes -from ...factorization import fms_percent_drop, pf2_pca_r2x, fms_diff_ranks +from ...factorization import fms_diff_ranks, fms_percent_drop, pf2_pca_r2x def plot_r2x(data, rank_vec, ax: Axes): @@ -32,65 +32,120 @@ def plot_r2x(data, rank_vec, ax: Axes): 0, np.max(np.append(r2xError[0], r2xError[1])) + 0.01, num=5 ), ) + # Increase font sizes + ax.set_xlabel("Number of Components", fontsize=18) + ax.set_ylabel("Variance Explained", fontsize=18) + ax.tick_params(axis="both", which="major", labelsize=16) ax.legend() -def plot_avegene_per_celltype(adata, genes, ax, cellType="Cell Type"): +def plot_avegene_per_celltype( + adata, + genes, + ax, + cellType="Cell Type", + condition="cytokine", + mean=False, + center_data=False, +): """Plots average gene expression across cell types for all conditions""" genesV = adata[:, genes] dataDF = genesV.to_df() - dataDF = dataDF.subtract(genesV.var["means"].values) - dataDF["Condition"] = genesV.obs["Condition"].values + if center_data: + dataDF = dataDF.subtract(genesV.var["means"].values) + gene_std = dataDF.std(axis=0) # Standard deviation for each gene + # Z-score: data is already mean-centered, just divide by std + dataDF = dataDF / gene_std + dataDF["Condition"] = genesV.obs[condition].values dataDF["Cell Type"] = genesV.obs[cellType].values + # Melt the data data = pd.melt(dataDF, id_vars=["Condition", "Cell Type"], value_vars=genes).rename( - columns={"variable": "Gene", "value": "Value"} + columns={"variable": "Gene", "value": "Gene Expression"} ) - df = data.groupby(["Condition", "Cell Type", "Gene"], observed=False).mean() - df = df.rename(columns={"Value": "Average Gene Expression"}) + + # Apply grouping if mean=True + if mean is True: + data = ( + data.groupby(["Condition", "Cell Type", "Gene"], observed=False) + .mean() + .reset_index() + ) + sns.boxplot( - data=df, - x="Gene", - y="Average Gene Expression", - hue="Cell Type", + data=data, + x="Cell Type", + y="Gene Expression", + hue="Condition", ax=ax, fliersize=0, ) - - -def plot_avegene_per_category(conds, gene, adata, ax, mean=True, cellType="Cell Type"): - """Plots average gene expression across cell types for a category of drugs""" + ax.set_title(genes, fontsize=25) + ax.set_xlabel("Cell Type", fontsize=22) + ax.set_ylabel("Gene Expression", fontsize=22) + ax.tick_params(axis="x", rotation=45, labelsize=20) + ax.tick_params(axis="y", labelsize=20) + + +def plot_avegene_per_category( + conds, + gene, + adata, + ax, + mean=True, + cellType="Cell Type", + swarm=False, + center_data=False, + condition="cytokine", +): + """Plots average gene expression across cell types for a specified condition""" genesV = adata[:, gene] dataDF = genesV.to_df() - dataDF = dataDF.subtract(genesV.var["means"].values) - dataDF["Condition"] = genesV.obs["Condition"].values + + if center_data: + gene_means = dataDF.mean(axis=0) + dataDF = dataDF - gene_means + gene_std = dataDF.std(axis=0) # Standard deviation for each gene + # Z-score: data is already mean-centered, just divide by std + dataDF = dataDF / gene_std + + dataDF["Condition"] = genesV.obs[condition].values dataDF["Cell Type"] = genesV.obs[cellType].values df = pd.melt(dataDF, id_vars=["Condition", "Cell Type"], value_vars=gene).rename( - columns={"variable": "Gene", "value": "Value"} + columns={"variable": "Gene", "value": "Gene Expression"} ) + if mean is True: - df = df.groupby(["Condition", "Cell Type", "Gene"], observed=False).mean() + df = ( + df.groupby(["Condition", "Cell Type", "Gene"], observed=False) + .mean() + .reset_index() + ) - df = df.rename(columns={"Value": "Average Gene Expression For Drugs"}).reset_index() - df = df[df["Condition"].isin(conds)] + df["Condition"] = np.where(df["Condition"].isin(conds), df["Condition"], "Other") - # df["Condition"] = np.where(df["Condition"].isin(conds), df["Condition"], "Other") - # df["Condition"] = df[df["Condition"]==conds] - # for i in conds: - # df = df.replace({"Condition": {i: categoryCond}}) + if swarm is False: + sns.boxplot( + data=df, + x="Cell Type", + y="Gene Expression", + hue="Condition", + ax=ax, + showfliers=False, + ) + else: + sns.stripplot( + data=df, + x="Condition", + y="Gene Expression", + hue="Condition", + ax=ax, + alpha=0.6, + ) - sns.boxplot( - data=df.loc[df["Gene"] == gene], - x="Cell Type", - y="Average Gene Expression For Drugs", - hue="Condition", - ax=ax, - showfliers=False, - ) ax.set(title=gene) - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45) + ax.tick_params(axis="x", rotation=45) def heatmapGeneFactors( @@ -308,29 +363,38 @@ def plot_cell_gene_corr( ) -def cell_count_perc_df(X, celltype="Cell Type", status=False, grouping="Condition"): +def cell_count_perc_df(X, celltype="Cell Type", condition="cytokine"): """Returns DF with cell counts and percentages for experiment""" - if status is False: - grouping = [celltype, grouping] - else: - grouping = [celltype, "Condition", "SLE_status"] + grouping_all = [ + celltype, + condition, + "condition_unique_idxs", + ] + grouping = [celltype, condition] - df = X.obs[grouping].reset_index(drop=True) + df = X.obs[grouping_all].reset_index(drop=True) - dfCond = df.groupby([grouping], observed=True).size().reset_index(name="Cell Count") + idx_mapping = X.obs.groupby(condition, observed=False)[ + "condition_unique_idxs" + ].first() + + dfCond = ( + df.groupby([condition], observed=True).size().reset_index(name="Cell Count") + ) dfCellType = ( df.groupby(grouping, observed=True).size().reset_index(name="Cell Count") ) dfCellType["Cell Count"] = dfCellType["Cell Count"].astype("float") dfCellType["Cell Type Percentage"] = 0.0 - for cond in np.unique(df[grouping]): - dfCellType.loc[dfCellType[grouping] == cond, "Cell Type Percentage"] = ( + for cond in np.unique(df[condition]): + dfCellType.loc[dfCellType[condition] == cond, "Cell Type Percentage"] = ( 100 - * dfCellType.loc[dfCellType[grouping] == cond, "Cell Count"].to_numpy() - / dfCond.loc[dfCond[grouping] == cond]["Cell Count"].to_numpy() + * dfCellType.loc[dfCellType[condition] == cond, "Cell Count"].to_numpy() + / dfCond.loc[dfCond[condition] == cond]["Cell Count"].to_numpy() ) + dfCellType["condition_unique_idxs"] = dfCellType[condition].map(idx_mapping) dfCellType.rename(columns={celltype: "Cell Type"}, inplace=True) return dfCellType diff --git a/pf2rnaseq/figures/commonFuncs/plotPaCMAP.py b/pf2rnaseq/figures/commonFuncs/plotPaCMAP.py index e69846c..56e9813 100644 --- a/pf2rnaseq/figures/commonFuncs/plotPaCMAP.py +++ b/pf2rnaseq/figures/commonFuncs/plotPaCMAP.py @@ -86,8 +86,7 @@ def plot_wp_pacmap(X: anndata.AnnData, cmp: int, ax: Axes, cbarMax: float = 1.0) projections for a component and eigenstate""" values = X.obsm["weighted_projections"][:, cmp - 1] points = X.obsm["X_pf2_PaCMAP"] - - cmap = sns.diverging_palette(250, 30, l=65, center="dark", as_cmap=True) + cmap = sns.diverging_palette(240, 10, as_cmap=True) canvas = _get_canvas(points) data = pd.DataFrame(points, columns=("x", "y")) @@ -101,15 +100,17 @@ def plot_wp_pacmap(X: anndata.AnnData, cmp: int, ax: Axes, cbarMax: float = 1.0) cmap=cmap, span=(-cbarMax, cbarMax), how="linear", - alpha=255, - min_alpha=255, + alpha=220, + min_alpha=220, ) ds_show(result, ax) psm = plt.pcolormesh([[-cbarMax, cbarMax], [-cbarMax, cbarMax]], cmap=cmap) - plt.colorbar(psm, ax=ax) - ax.set(title="Cmp. " + str(cmp)) + + cbar = plt.colorbar(psm, ax=ax) + + ax.set_title("Cmp. " + str(cmp)) ax = assign_labels(ax) @@ -118,14 +119,14 @@ def plot_labels_pacmap( labelType: str, ax: Axes, condition=None, - cmap="tab20", + cmap: str = "tab20", color_key=None, ): """Scatterplot of UMAP visualization weighted by condition or cell type""" labels = X.obs[labelType] if condition is not None: - labels = pd.Series([c if c in condition else "Z Other" for c in labels]) + labels = pd.Series([c if c in condition else "Other" for c in labels]) if labels.dtype == "category": labels = labels.cat.set_categories( np.sort(labels.cat.categories.values), ordered=True @@ -156,7 +157,10 @@ def plot_labels_pacmap( ) ds_show(result, ax) - ax.legend(handles=legend_elements) + + ax.legend( + handles=legend_elements, fontsize=25, bbox_to_anchor=(1.05, 1), loc="upper left" + ) ax = assign_labels(ax) @@ -180,7 +184,7 @@ def plot_wp_per_celltype( ax.set( xticks=np.linspace(-maxvalue, maxvalue, num=5), xlabel="Cell Specific Weight" ) - ax.set_title(cmpName) + ax.set_title(cmpName, fontsize=15) def assign_labels(ax): diff --git a/pf2rnaseq/imports.py b/pf2rnaseq/imports.py index 6b4a204..9eaaf81 100644 --- a/pf2rnaseq/imports.py +++ b/pf2rnaseq/imports.py @@ -1,8 +1,12 @@ from concurrent.futures import ProcessPoolExecutor import anndata +import pandas as pd import scanpy as sc from parafac2.normalize import prepare_dataset +from pathlib import Path + +path_here = Path(__file__).parent.parent def import_citeseq() -> anndata.AnnData: @@ -35,7 +39,7 @@ def import_cytokine() -> anndata.AnnData: X = anndata.read_h5ad("/opt/extra-storage/Treg_h5ads/Treg_raw.h5ad") # Remove multiplexing identifiers - X = X[:, ~X.var_names.str.match("^CMO3[0-9]{2}$")] # type: ignore + X = X[:, ~X.var_names.str.match("^CMO3[0-9]{2}$")].copy() # type: ignore return prepare_dataset(X, "Condition", geneThreshold=0.002) # 0.1 @@ -50,21 +54,17 @@ def import_pf2Cytokine30() -> anndata.AnnData: return X -def import_Heiser(deviance=False) -> anndata.AnnData: +def import_Heiser() -> anndata.AnnData: """Import Heiser C3TAg dataset. anndata.X is the raw counts """ data = anndata.read_h5ad("/home/nicoleb/C3TAg.h5ad") - if deviance: - # Apply deviance transformation - return prepare_dataset(data, "sample_id", geneThreshold=0.1, deviance=True) - else: - # Apply standard normalization and scaling - return prepare_dataset(data, "sample_id", geneThreshold=0.1) + + return prepare_dataset(data, "sample_id", geneThreshold=0.1) -def import_MouseImmune() -> anndata.AnnData: +def import_MouseImmune(geneThreshold=0.1) -> anndata.AnnData: """Import Mouse Immune Dictionary cytokine data. -- columns from observation data: {'biosample_id': cytokine and replicate info, @@ -78,8 +78,26 @@ def import_MouseImmune() -> anndata.AnnData: ...}""" X = anndata.read_h5ad("/home/nicoleb/MouseCytok.h5ad") # Filter out doublets - X = X[X.obs["celltype"] != "doublet", :] + X = X[X.obs["celltype"] != "doublet", :].copy() - return prepare_dataset(X, "biosample_id", geneThreshold=0.1) # 0.01 + return prepare_dataset(X, "biosample_id", geneThreshold=geneThreshold) # 0.01 +def import_Parse(geneThreshold=0.1, doublet=False) -> anndata.AnnData: + """Import Parse data . + cytokine: cytokine treatment + donor: donor identifier + + """ + X = anndata.read_h5ad("/home/nicoleb/Pf2-scRNAseq-1/pf2rnaseq/Parse_Donor11.h5ad") + if doublet: + doubletDF = pd.read_csv( + path_here / "pf2rnaseq/Data/DN11Doublets.csv.gz", + index_col=0 + ) + X.obs = X.obs.join(doubletDF.reindex(X.obs.index)) + singlet_mask = X.obs["doublet"] == 0 + X = X[singlet_mask, :].copy() + print(f"Kept {X.n_obs} singlet cells, removed {(~singlet_mask).sum()} doublets") + + return prepare_dataset(X, "cytokine", geneThreshold=geneThreshold) diff --git a/pyproject.toml b/pyproject.toml index e5e80ac..db02f1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "ipykernel>=6.29.5", "parafac2 @ git+https://github.com/meyer-lab/parafac2.git", "wandb>=0.19.9", + "doubletdetection>=4.3.0.post1", ] diff --git a/requirements-dev.lock b/requirements-dev.lock index 7137c48..6382896 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -11,6 +11,7 @@ -e file:. anndata==0.11.3 + # via doubletdetection # via parafac2 # via pf2rnaseq # via scanpy @@ -54,6 +55,8 @@ decorator==5.1.1 # via ipython docker-pycreds==0.4.0 # via wandb +doubletdetection==4.3.0.post1 + # via pf2rnaseq executing==2.1.0 # via stack-data fast-array-utils==1.2.1 @@ -103,12 +106,14 @@ legacy-api-wrap==1.4 # via scanpy leidenalg==0.10.2 # via pf2rnaseq + # via phenograph llvmlite==0.44.0 # via numba # via pynndescent locket==1.0.0 # via partd matplotlib==3.9.2 + # via doubletdetection # via gseapy # via scanpy # via seaborn @@ -140,6 +145,7 @@ numpy==2.2.6 # via cupy-cuda12x # via dask # via datashader + # via doubletdetection # via fast-array-utils # via gseapy # via h5py @@ -150,6 +156,7 @@ numpy==2.2.6 # via parafac2 # via patsy # via pf2rnaseq + # via phenograph # via pyarrow # via scanpy # via scikit-learn @@ -184,7 +191,7 @@ pandas==2.2.2 # via statsmodels # via tlviz # via xarray -parafac2 @ git+https://github.com/meyer-lab/parafac2.git@ccf708fe71c78a106151643e893ea006bba865d9 +parafac2 @ git+https://github.com/meyer-lab/parafac2.git@1c214868031d61297505d95bdf056dfce33262a5 # via pf2rnaseq param==2.1.1 # via datashader @@ -198,6 +205,8 @@ patsy==0.5.6 # via statsmodels pexpect==4.9.0 # via ipython +phenograph==1.5.7 + # via doubletdetection pillow==10.4.0 # via matplotlib platformdirs==4.3.6 @@ -211,6 +220,7 @@ protobuf==5.29.4 # via wandb psutil==6.0.0 # via ipykernel + # via phenograph # via wandb ptyprocess==0.7.0 # via pexpect @@ -253,20 +263,24 @@ requests==2.32.3 # via tlviz # via wandb scanpy @ git+https://github.com/scverse/scanpy.git@c2a7a4b7ec3203121a8d75aa05fbeb602ceecbd4 + # via doubletdetection # via pf2rnaseq scikit-learn==1.6.1 # via pacmap # via pf2rnaseq + # via phenograph # via pynndescent # via scanpy # via umap-learn scipy==1.15.2 # via anndata # via datashader + # via doubletdetection # via fast-array-utils # via gseapy # via parafac2 # via pf2rnaseq + # via phenograph # via pynndescent # via scanpy # via scikit-learn @@ -284,6 +298,7 @@ session-info2==0.1.2 setproctitle==1.3.5 # via wandb setuptools==74.1.2 + # via phenograph # via wandb six==1.16.0 # via asttokens @@ -315,6 +330,7 @@ tornado==6.4.1 # via ipykernel # via jupyter-client tqdm==4.66.5 + # via doubletdetection # via parafac2 # via pf2rnaseq # via scanpy diff --git a/requirements.lock b/requirements.lock index 3c19ef6..ff16107 100644 --- a/requirements.lock +++ b/requirements.lock @@ -11,6 +11,7 @@ -e file:. anndata==0.11.3 + # via doubletdetection # via parafac2 # via pf2rnaseq # via scanpy @@ -52,6 +53,8 @@ decorator==5.1.1 # via ipython docker-pycreds==0.4.0 # via wandb +doubletdetection==4.3.0.post1 + # via pf2rnaseq executing==2.1.0 # via stack-data fast-array-utils==1.2.1 @@ -99,12 +102,14 @@ legacy-api-wrap==1.4 # via scanpy leidenalg==0.10.2 # via pf2rnaseq + # via phenograph llvmlite==0.44.0 # via numba # via pynndescent locket==1.0.0 # via partd matplotlib==3.9.2 + # via doubletdetection # via gseapy # via scanpy # via seaborn @@ -134,6 +139,7 @@ numpy==2.2.6 # via cupy-cuda12x # via dask # via datashader + # via doubletdetection # via fast-array-utils # via gseapy # via h5py @@ -144,6 +150,7 @@ numpy==2.2.6 # via parafac2 # via patsy # via pf2rnaseq + # via phenograph # via pyarrow # via scanpy # via scikit-learn @@ -177,7 +184,7 @@ pandas==2.2.2 # via statsmodels # via tlviz # via xarray -parafac2 @ git+https://github.com/meyer-lab/parafac2.git@ccf708fe71c78a106151643e893ea006bba865d9 +parafac2 @ git+https://github.com/meyer-lab/parafac2.git@1c214868031d61297505d95bdf056dfce33262a5 # via pf2rnaseq param==2.1.1 # via datashader @@ -191,6 +198,8 @@ patsy==0.5.6 # via statsmodels pexpect==4.9.0 # via ipython +phenograph==1.5.7 + # via doubletdetection pillow==10.4.0 # via matplotlib platformdirs==4.3.6 @@ -202,6 +211,7 @@ protobuf==5.29.4 # via wandb psutil==6.0.0 # via ipykernel + # via phenograph # via wandb ptyprocess==0.7.0 # via pexpect @@ -240,20 +250,24 @@ requests==2.32.3 # via tlviz # via wandb scanpy @ git+https://github.com/scverse/scanpy.git@c2a7a4b7ec3203121a8d75aa05fbeb602ceecbd4 + # via doubletdetection # via pf2rnaseq scikit-learn==1.6.1 # via pacmap # via pf2rnaseq + # via phenograph # via pynndescent # via scanpy # via umap-learn scipy==1.15.2 # via anndata # via datashader + # via doubletdetection # via fast-array-utils # via gseapy # via parafac2 # via pf2rnaseq + # via phenograph # via pynndescent # via scanpy # via scikit-learn @@ -271,6 +285,7 @@ session-info2==0.1.2 setproctitle==1.3.5 # via wandb setuptools==74.1.2 + # via phenograph # via wandb six==1.16.0 # via asttokens @@ -302,6 +317,7 @@ tornado==6.4.1 # via ipykernel # via jupyter-client tqdm==4.66.5 + # via doubletdetection # via parafac2 # via pf2rnaseq # via scanpy