Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added pf2rnaseq/Data/DN11Doublets.csv.gz
Binary file not shown.
15 changes: 10 additions & 5 deletions pf2rnaseq/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@

def correct_conditions(X: anndata.AnnData):
"""Correct the conditions factors by overall read depth. Ensures that weighting is not affected by cell count difference"""
# sgIndex = X.obs["condition_unique_idxs"]
sgIndex = X.obs["condition_unique_idxs"].cat.codes
sgIndex = X.obs["condition_unique_idxs"]
#sgIndex = X.obs["condition_unique_idxs"].cat.codes
counts = np.zeros((np.amax(sgIndex) + 1, 1))

min_val = np.min(X.uns["Pf2_A"])
if min_val < 0:
# Add the absolute value of the minimum (plus a small epsilon) to make all values positive
X.uns["Pf2_A"] = X.uns["Pf2_A"] + abs(min_val) + 1e-10
print(f"Warning: Found negative values in Pf2_A (min: {min_val:.6f}). Added {abs(min_val) + 1e-10:.6f} to all values.")

cond_mean = gmean(X.uns["Pf2_A"], axis=1)

x_count = X.X.sum(axis=1)
Expand All @@ -43,7 +48,7 @@ def pf2(
tolerance=1e-9,
r2x=False,
):
cupy.cuda.Device(1).use()
cupy.cuda.Device(0).use()
pf_out, R2X = parafac2_nd(

X,
Expand Down Expand Up @@ -72,7 +77,7 @@ def pf2_pca_r2x(X: anndata.AnnData, ranks):
r2x_pf2 = np.zeros(len(ranks))

for i in tqdm(range(len(r2x_pf2)), total=len(r2x_pf2)):
_, R2X = parafac2_nd(X, rank=i + 1)
_, R2X = parafac2_nd(X, rank=ranks[i])
r2x_pf2[i] = R2X

pca = PCA(n_components=ranks[-1], svd_solver="arpack")
Expand Down
218 changes: 123 additions & 95 deletions pf2rnaseq/figures/commonFuncs/plotFactors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def plot_condition_factors(
cond_group_labels: pd.Series | None = None,
groupConditions=False,
cond="Condition",
log_scale=True,
):
"""Plots Pf2 condition factors"""
pd.set_option("display.max_rows", None)
yt = pd.Series(np.unique(data.obs[cond]))
X = np.array(data.uns["Pf2_A"])

X = np.log10(X)
if log_scale:
X = np.log10(X)

X -= np.median(X, axis=0)
X /= np.std(X, axis=0)
Expand All @@ -41,6 +42,7 @@ def plot_condition_factors(
X = X[ind]
yt = yt.iloc[ind]
ax.tick_params(axis="y", which="major", pad=20, length=0)

# extra padding to leave room for the row colors
# get list of colors for each label:
colors = sns.color_palette(
Expand Down Expand Up @@ -68,6 +70,7 @@ def plot_condition_factors(
ax.legend(handles=legend_elements, bbox_to_anchor=(0, 1.3))

xticks = np.arange(1, X.shape[1] + 1)

sns.heatmap(
data=X,
xticklabels=xticks,
Expand Down Expand Up @@ -170,7 +173,7 @@ def plot_condition_factors_groups(
ax.tick_params(
axis="y",
which="major",
pad=40 if subgroup_labels is not None else 20,
pad=30 if subgroup_labels is not None else 30,
length=0,
)

Expand All @@ -190,8 +193,8 @@ def plot_condition_factors_groups(
for i, color in enumerate(main_row_colors):
ax.add_patch(
plt.Rectangle(
xy=(-0.05, i),
width=0.05,
xy=(-0.02, i),
width=0.02,
height=1,
color=color,
lw=0,
Expand All @@ -218,8 +221,8 @@ def plot_condition_factors_groups(
for i, color in enumerate(sub_row_colors):
ax.add_patch(
plt.Rectangle(
xy=(-0.10, i), # Position to left of main group colors
width=0.05,
xy=(-0.04, i), # Position to left of main group colors
width=0.02,
height=1,
color=color,
lw=0,
Expand Down Expand Up @@ -317,6 +320,7 @@ def plot_gene_factors(
vmin=-1,
vmax=1,
)

ax.set(xlabel="Component")


Expand All @@ -332,11 +336,8 @@ def plot_geneSet_factors(
X = X[kept_idxs]
yt = yt[kept_idxs]

# index for genes
# ind = reorder_table(X)
# X = X[ind]
X = X / np.max(np.abs(X))
# yt = [yt[ii] for ii in ind]

xticks = np.arange(1, rank + 1)

sns.heatmap(
Expand Down Expand Up @@ -422,100 +423,46 @@ def plot_geneSetScore(
# Calculate the sum of X values for each component
component_sums = np.sum(X, axis=0)

# Find the top 3 components with highest absolute scores
top_3_indices = np.argsort(np.abs(component_sums))[-3:]

# Create colors array - highlight the top 3 components
colors = []
for i in range(len(component_sums)):
if i in top_3_indices:
colors.append("darkred")
else:
colors.append("steelblue")

# Create the bar plot
xticks = np.arange(1, rank + 1)
sns.barplot(x=xticks, y=component_sums, ax=ax)

ax.set_xlabel("Component", fontsize=12)
ax.set_ylabel("Sum of Weights", fontsize=12)
ax.set_title("Sum of Gene Factors per Component", fontsize=12)


def plot_geneSetScoreDot(
data: AnnData,
ax: Axes,
genes: np.array,
trim=True,
size_scale=500,
cmap="coolwarm",
size_norm=None,
):
"""
Plots Pf2 gene set score as a dot plot.

Parameters:
data: AnnData object containing the Pf2 results
ax: Matplotlib axes to plot on
genes: Array of gene names to include in the score
trim: Whether to trim genes (not used in this version)
size_scale: Scaling factor for dot sizes
cmap: Colormap for the dot colors
size_norm: Optional normalization for dot sizes
"""
# Get dimensions and data
rank = data.varm["Pf2_C"].shape[1]
X = np.array(data.varm["Pf2_C"])
yt = data.var.index.values

# Filter the genes
kept_idxs = np.where(np.in1d(yt, genes))[0]
if len(kept_idxs) == 0:
# Add labels to the top 3 components
for idx in top_3_indices:
component_num = idx + 1
y_pos = component_sums[idx] + 0.01 * np.sign(component_sums[idx]) * np.max(
np.abs(component_sums)
)
ax.text(
0.5,
0.5,
"No matching genes found",
component_num,
y_pos,
f"Comp. {component_num}",
ha="center",
va="center",
transform=ax.transAxes,
va="bottom" if component_sums[idx] > 0 else "top",
fontsize=12,
fontweight="bold",
color="darkred",
)
return

X = X[kept_idxs]

# Calculate the sum of X values for each component
component_sums = np.sum(X, axis=0)

# Create DataFrame for plotting
df = pd.DataFrame(
{
"Component": np.arange(1, rank + 1),
"Sum": component_sums,
"Magnitude": np.abs(component_sums),
}
)

# Create a diverging colormap based on value
min_val = df["Sum"].min()
max_val = df["Sum"].max()
abs_max = max(abs(min_val), abs(max_val))

# Create the dot plot
scatter = sns.scatterplot(
data=df,
x="Component",
y=[0] * rank, # All dots on same horizontal line
size="Magnitude",
hue="Sum",
palette=cmap,
sizes=(20, size_scale),
size_norm=size_norm,
ax=ax,
legend="brief",
)

# Customize the plot
ax.set(xlabel="Component", title="Gene Set Score by Component")
ax.set_yticks([]) # Remove y-axis ticks
ax.set_ylim([-1, 1]) # Set y-axis limits

# Set x-axis ticks to integers
ax.set_xticks(np.arange(1, rank + 1))

# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

# Add a horizontal line at y=0
ax.axhline(y=0, color="gray", linestyle="-", lw=0.5, alpha=0.7)
ax.set_xlabel("Component", fontsize=20)
ax.set_ylabel("Sum of Weights", fontsize=20)
ax.set_title("Signature Score", fontsize=25)
ax.tick_params(axis="x", rotation=90, labelsize=16)
ax.tick_params(axis="y", labelsize=16)


def plot_ttest(X: AnnData, ax: Axes):
Expand All @@ -527,7 +474,7 @@ def plot_ttest(X: AnnData, ax: Axes):
# Get all cytokines directly without separate function
all_cytokines = X.obs["cyt"].unique()

#Get highly weighted cytokines per component
# Get highly weighted cytokines per component
results_df = highly_weighted_cytokines(X)

# Create pivot table for all components
Expand Down Expand Up @@ -565,3 +512,84 @@ def plot_ttest(X: AnnData, ax: Axes):
ax.set_title("Highly weighted Cytokines Across Components (ANOVA + Post-hoc)")
ax.set_xlabel("Component")
ax.set_ylabel("Cytokine")


def plot_comp_weights(
data: AnnData,
ax: Axes,
comp: int,
cond="Condition",
sort_bars=True,
top_n=3,
include_lowest=True,
):
"""Plots component weights for each condition as a bar chart"""

# Get condition names and factor matrix
cond_df = (
data.obs[[cond, "condition_unique_idxs"]]
.drop_duplicates()
.sort_values("condition_unique_idxs")
)
yt = cond_df[cond].to_numpy()
X = np.array(data.uns["Pf2_A"])
cond_mapping = data.obs.groupby("condition_unique_idxs", sort=True)[cond].first()

# Extract condition names and indices
condition_indices = cond_mapping.index.to_numpy()
yt = cond_mapping.values

# Extract weights
component_weights = X[condition_indices, comp - 1]
# Create DataFrame for plotting
df = pd.DataFrame({"Condition": yt, "Weight": component_weights})

# Get top N highest weighted conditions
top_n_highest = df.nlargest(top_n, "Weight")

# Conditionally get lowest weighted conditions
if include_lowest:
top_n_lowest = df.nsmallest(top_n, "Weight")
# Combine and keep only top conditions
df_filtered = pd.concat([top_n_highest, top_n_lowest]).drop_duplicates()
else:
df_filtered = top_n_highest

# Sort by weight if requested
if sort_bars:
df_filtered = df_filtered.sort_values("Weight", ascending=False)

# Create color mapping - highest in red, lowest in blue (if included)
colors = []
for condition in df_filtered["Condition"]:
if condition in top_n_highest["Condition"].values:
colors.append("darkred") # Highest weights
else:
colors.append("darkblue") # Lowest weights

# Create bar plot with custom colors
bars = ax.bar(df_filtered["Condition"], df_filtered["Weight"], color=colors)

# Customize the plot title based on whether lowest are included
if include_lowest:
title = f"Component {comp} Weights by Condition (Top {top_n} Highest/Lowest)"
else:
title = f"Component {comp} Weights by Condition (Top {top_n} Highest)"

ax.set_title(title, fontsize=20)
ax.set_xlabel("Condition", fontsize=18)
ax.set_ylabel("Weight", fontsize=18)
ax.tick_params(axis="x", rotation=90, labelsize=18)
ax.tick_params(axis="y", labelsize=18)

# Add legend for color coding (only if lowest are included)
if include_lowest:

legend_elements = [
Patch(facecolor="darkred", label=f"Top {top_n} Highest"),
Patch(facecolor="darkblue", label=f"Top {top_n} Lowest"),
]
ax.legend(handles=legend_elements, loc="upper right")

# Add horizontal line at y=0 for reference
ax.axhline(y=0, color="gray", linestyle="-", alpha=0.3)
Loading
Loading