diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a01d1f..c29aa08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,10 +4,22 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +## [0.8.0] - 2026-05-28 + +### Behavior changes + +These three changes correct discrepancies between the implementation and the manuscript that describes Kompot's statistics. Two are numerical scale shifts that preserve relative rankings; the third is a default-value harmonization. Re-tune any absolute thresholds calibrated against 0.7.0. + + - **Mahalanobis denominator now sums covariances**: the gene-wise Mahalanobis distance used by `DifferentialExpression.predict(compute_mahalanobis=True)` now computes the posterior combined covariance as `Σ_a + Σ_b` (the variance of the difference of two independent posterior estimators) instead of `(Σ_a + Σ_b) / 2`. Matches the manuscript definition `D(a,b) = sqrt((μ_a − μ_b)^T (Σ_a + Σ_b)^(-1) (μ_a − μ_b))`. **Effect**: absolute Mahalanobis distances in the GP-only regime contract by a factor of `√2` (and `D²` by 2). Relative rankings of genes are unchanged because the same scale factor applies everywhere, and FDR thresholds re-calibrate against the null. The sample-variance branch was already correctly summed. + - **Differential-abundance posterior tail probability is now one-sided**: `DifferentialAbundance.predict()` returns `PTP = Φ(−|z|)` (one-sided), matching the manuscript definition `PTP(x_i) = Φ(−|Δ(x_i)|/√(σ_a² + σ_b²))`. Previous releases returned `2·Φ(−|z|)` (two-sided). **Effect**: numeric PTP values are halved relative to 0.7.0; equivalently, `neg_log10_fold_change_ptp` increases by `log10(2) ≈ 0.301`. The threshold `PTP < 1e-3` previously corresponded to `|z| ≥ 3.29` and now corresponds to `|z| ≥ 3.09`. Re-tune any hard-coded `ptp_threshold` chosen against 0.7.0 if you want to preserve the old call-rate. + - **`use_empirical_variance` default is now `False` everywhere**: harmonized across `kompot.de()` (already False), `DifferentialExpression.__init__`, `ExpressionModel.__init__`, `kompot.smooth_expression()`, the deprecated `compute_differential_expression()` and `compute_smoothed_expression()` wrappers, and the CLI `smooth_config_template.yaml`. Previously these four entry points defaulted to `True`, inconsistent with both the recommended `kompot.de()` path and the manuscript's "empirical variance is disabled by default" statement. Code that relies on empirical variance must now pass `use_empirical_variance=True` explicitly. + - **Differential-expression posterior tail probability is now stored in log space**: `DifferentialExpression.predict(compute_mahalanobis=True)` and `kompot.de(..., store_additional_stats=True)` now compute the Mahalanobis posterior tail probability with `scipy.stats.chi2.logsf` in float64 and store `-log10(PTP)` in a renamed field `__to__neg_log10_ptp` (previously the linear tail probability `chi2.sf` in `__to__ptp`). The PTP is a strictly monotone transform of the Mahalanobis distance, but for an embedding with df on the order of tens the linear `chi2.sf` evaluates to values numerically indistinguishable from `1.0` for the majority of genes (every gene with `D²` below the chi-squared mean), saturating the stored statistic and destroying gene-ranking resolution at the head of the distribution. Computing `logsf` directly (never forming `1 − cdf`) keeps every value distinct and recovers the Mahalanobis ranking exactly. This mirrors the differential-abundance path, whose `neg_log10_lfc_ptp` field is already stored this way. **Effect**: the stored values change scale and orientation — larger now means more significant (a probability `p` becomes `−log10(p) ∈ [0, ∞)`), and the field is renamed. `volcano_de(y_axis_type="ptp")` reads the column directly with no additional `-log10` transform, and its `significance_threshold` is still supplied as a probability. Update any code that reads the old `_ptp` column. + ### New features - **`--dry-run` flag for `kompot de` CLI**: estimates memory, disk, and output field requirements without running the analysis. Outputs machine-parseable JSON to stdout and a human-readable report to stderr. Exit code reflects feasibility. - **`kompot.configure_logging(stream)`**: reconfigure the kompot logger output stream. The CLI now logs to stderr by default, keeping stdout clean for machine-parseable output (dry-run JSON, table output). + - **`kompot.plot.lollipop`**: ax-embeddable gene-set-enrichment lollipop plot. One row per enriched term; a stem runs to a dot whose x-position encodes significance (`-log10(FDR)` by default, or any score column such as StringDB `signal` / enrichr `Combined Score`) and whose area encodes the matched-gene count, with a dashed `FDR = 0.05` guide and an in-axes aesthetic key. The headline feature is input flexibility: pass a `kompot.plot.StringDBReport` (its `get_functional_enrichment()` is called for you), the `signal`-sorted DataFrame that method returns, **or** a generic enrichment table from another tool (gseapy/enrichr, GOATOOLS, clusterProfiler). Column-name mapping params (`term_col`, `score_col`, `count_col`, `fdr_col`) with case-insensitive autodetection — including the gseapy `"k/K"` `Overlap`-string parser — bridge the schema differences. The fig-3 manuscript specifics (direction-red accent, reserved title band, GO-Process category) are now parameters with manuscript-matching defaults. Like `dotplot`, it composes into an externally-provided `ax=` instead of building its own `GridSpec`. - **`kompot.plot.dotplot`**: ax-embeddable fold-change-per-group dotplot. Color = mean of a per-cell LFC layer within each `groupby` category; size = fraction of cells expressing. Gene selection is either an explicit list or auto-picked top-N by Mahalanobis from run history (with optional `filter_key`, e.g. restricting to `is_de=True`). Pass `axes=(main, cbar, size_legend)` to compose into a larger figure, or leave `axes=None` for a standalone figure. Unlike `scanpy.pl.DotPlot`, this function does not build its own `GridSpec` and does not fight externally-provided axes, which is the whole reason it exists. Shares gene-selection, layer-fetch, and colormap-normalization primitives with `kompot.plot.heatmap` via the existing `heatmap.utils` helpers. ### Improvements diff --git a/docs/source/plotting.rst b/docs/source/plotting.rst index 5cec4f2..376ffb8 100644 --- a/docs/source/plotting.rst +++ b/docs/source/plotting.rst @@ -15,6 +15,16 @@ Expression Plots .. autofunction:: kompot.plot.plot_gene_expression +Dotplots +-------- + +.. autofunction:: kompot.plot.dotplot + +Enrichment Lollipop +------------------- + +.. autofunction:: kompot.plot.lollipop + Heatmaps -------- diff --git a/kompot/anndata/_de_helpers.py b/kompot/anndata/_de_helpers.py index 88b633b..042c3b0 100644 --- a/kompot/anndata/_de_helpers.py +++ b/kompot/anndata/_de_helpers.py @@ -771,8 +771,10 @@ def _compute_fdr( "mahalanobis": internal_null_mahalanobis, "mean_lfc": expression_results["mean_log_fold_change"][n_real:], } - if "ptp" in expression_results: - null_table_data["ptp"] = expression_results["ptp"][n_real:] + if "neg_log10_ptp" in expression_results: + null_table_data["neg_log10_ptp"] = expression_results["neg_log10_ptp"][ + n_real: + ] null_data["table"] = pd.DataFrame( null_table_data, index=null_gene_names, @@ -823,8 +825,10 @@ def _compute_fdr( if "mahalanobis_distances" in expression_results: expression_results["mahalanobis_distances"] = real_mahalanobis - if "ptp" in expression_results: - expression_results["ptp"] = expression_results["ptp"][:n_real] + if "neg_log10_ptp" in expression_results: + expression_results["neg_log10_ptp"] = expression_results["neg_log10_ptp"][ + :n_real + ] return fdr_results @@ -909,17 +913,21 @@ def _store_de_results( adata, ) - if compute_mahalanobis and "ptp" in expression_results and store_additional_stats: - ptp = _ensure_1d( - expression_results["ptp"], - "ptp", + if ( + compute_mahalanobis + and "neg_log10_ptp" in expression_results + and store_additional_stats + ): + neg_log10_ptp = _ensure_1d( + expression_results["neg_log10_ptp"], + "neg_log10_ptp", n_selected, logger, ) _add_var_column( new_var_columns, field_names["ptp_key"], - ptp, + neg_log10_ptp, selected_genes, adata, ) @@ -1426,9 +1434,13 @@ def _compute_group_results( f"significantly DE at FDR < {fdr_threshold}" ) - # Group-wise ptp - if compute_mahalanobis and store_additional_stats and "ptp" in subset_results: - sub_ptp = subset_results["ptp"] + # Group-wise neg_log10_ptp + if ( + compute_mahalanobis + and store_additional_stats + and "neg_log10_ptp" in subset_results + ): + sub_ptp = subset_results["neg_log10_ptp"] if len(sub_ptp) == len(expanded_genes): sub_ptp = sub_ptp[:n_real] elif len(sub_ptp) != n_real: @@ -1641,7 +1653,8 @@ def _build_field_mapping( "location": "var", "type": "ptp", "description": ( - "Posterior tail probability from chi-squared distribution" + "Negative log10 posterior tail probability (-log10 PTP) from " + "the chi-squared distribution, computed in log space" ), } @@ -1728,7 +1741,10 @@ def _add_group_field_mapping( field_mapping[ptp_k] = { "location": "varm", "type": "ptp", - "description": "Peak-to-peak values for all subsets", + "description": ( + "Negative log10 posterior tail probability (-log10 PTP) for " + "all subsets" + ), "contains_subsets": subset_names, } diff --git a/kompot/anndata/cleanup.py b/kompot/anndata/cleanup.py index 7d4e3b1..efd4ad8 100644 --- a/kompot/anndata/cleanup.py +++ b/kompot/anndata/cleanup.py @@ -89,7 +89,7 @@ def cleanup( - ``'mean_log_fold_change'``: Mean log fold change values - ``'mahalanobis'``: Mahalanobis distances - - ``'ptp'``: Posterior tail probability + - ``'ptp'``: Negative log10 posterior tail probability (-log10 PTP) - ``'mahalanobis_pvalue'``: P-values from empirical null - ``'mahalanobis_local_fdr'``: Local FDR values - ``'mahalanobis_tail_fdr'``: Tail-based FDR values diff --git a/kompot/anndata/differential_expression.py b/kompot/anndata/differential_expression.py index 50d6fc8..4315a0d 100644 --- a/kompot/anndata/differential_expression.py +++ b/kompot/anndata/differential_expression.py @@ -530,8 +530,8 @@ def de( if compute_mahalanobis and "mahalanobis_distances" in expression_results: results_data["mahalanobis"] = expression_results["mahalanobis_distances"] - if "ptp" in expression_results: - results_data["ptp"] = expression_results["ptp"] + if "neg_log10_ptp" in expression_results: + results_data["neg_log10_ptp"] = expression_results["neg_log10_ptp"] result_dict["table"] = pd.DataFrame(results_data, index=selected_genes) result_dict["underrepresentation"] = underrep @@ -751,7 +751,7 @@ def compute_differential_expression( return_full_results: bool = False, store_posterior_covariance: bool = False, allow_single_condition_variance: bool = False, - use_empirical_variance: bool = True, + use_empirical_variance: bool = False, progress: bool = True, null_genes="auto", null_seed=42, diff --git a/kompot/anndata/smooth.py b/kompot/anndata/smooth.py index 4e52ae3..8e07946 100644 --- a/kompot/anndata/smooth.py +++ b/kompot/anndata/smooth.py @@ -109,7 +109,7 @@ def smooth_expression( ls = gp.ls if gp is not None else None ls_factor = gp.ls_factor if gp is not None else 10.0 n_landmarks = gp.n_landmarks if gp is not None else 5000 - use_empirical_variance = gp.use_empirical_variance if gp is not None else True + use_empirical_variance = gp.use_empirical_variance if gp is not None else False eps = gp.eps if gp is not None else 1e-8 random_state = gp.random_state if gp is not None else None batch_size = gp.batch_size if gp is not None else 500 @@ -393,7 +393,7 @@ def compute_smoothed_expression( sigma: float = 1.0, ls: Optional[float] = None, ls_factor: float = 10.0, - use_empirical_variance: bool = True, + use_empirical_variance: bool = False, eps: float = 1e-8, random_state: Optional[int] = None, batch_size: int = 500, diff --git a/kompot/anndata/utils/field_tracking.py b/kompot/anndata/utils/field_tracking.py index 8b2e4ec..0ee0af2 100644 --- a/kompot/anndata/utils/field_tracking.py +++ b/kompot/anndata/utils/field_tracking.py @@ -226,7 +226,7 @@ def generate_output_field_names( field_names.update( { "mahalanobis_key": f"{result_key}_{cond1_safe}_to_{cond2_safe}_mahalanobis{suffix}", - "ptp_key": f"{result_key}_{cond1_safe}_to_{cond2_safe}_ptp{suffix}", + "ptp_key": f"{result_key}_{cond1_safe}_to_{cond2_safe}_neg_log10_ptp{suffix}", "mean_lfc_key": f"{result_key}_{cond1_safe}_to_{cond2_safe}_mean_lfc", "smoothed_key_1": f"{result_key}_{cond1_safe}_smoothed", "smoothed_key_2": f"{result_key}_{cond2_safe}_smoothed", diff --git a/kompot/cli/templates/smooth_config_template.yaml b/kompot/cli/templates/smooth_config_template.yaml index 7561499..a9cfd53 100644 --- a/kompot/cli/templates/smooth_config_template.yaml +++ b/kompot/cli/templates/smooth_config_template.yaml @@ -26,7 +26,7 @@ n_landmarks: 5000 # Number of landmarks for Nystrom approximation sample_col: null # Column in adata.obs with sample IDs # Empirical variance (heteroscedastic noise): -use_empirical_variance: true # Estimate per-gene noise from GP residuals +use_empirical_variance: false # Estimate per-gene noise from GP residuals # GP kernel parameters: sigma: 1.0 # Noise level for function estimator diff --git a/kompot/differential/differential_abundance.py b/kompot/differential/differential_abundance.py index bd9e0fe..17ab1ec 100644 --- a/kompot/differential/differential_abundance.py +++ b/kompot/differential/differential_abundance.py @@ -603,11 +603,12 @@ def compute_sample_variance2(X_batch): sd = np.sqrt(log_fold_change_uncertainty + self.eps) log_fold_change_zscore = log_fold_change / sd - # Compute PTP (Posterior Tail Probability) in natural log (base e) + # Compute PTP (Posterior Tail Probability) in natural log (base e). + # One-sided per manuscript: PTP = Φ(−|z|) = min(Φ(z), Φ(−z)) for real z. ln_ptp = np.minimum( normal.logcdf(log_fold_change_zscore), normal.logcdf(-log_fold_change_zscore), - ) + np.log(2) + ) # Convert from natural log to negative log10 (for better volcano plot visualization) # ln_ptp is a log of a small value (typically < 1), so it's negative diff --git a/kompot/differential/differential_expression.py b/kompot/differential/differential_expression.py index 6e20b21..6af2bee 100644 --- a/kompot/differential/differential_expression.py +++ b/kompot/differential/differential_expression.py @@ -2,8 +2,7 @@ import numpy as np import jax -import jax.numpy as jnp -import jax.scipy.stats as jax_stats +from scipy.stats import chi2 as scipy_chi2 from typing import Optional, Dict, Any import logging from mellon.parameters import compute_landmarks @@ -42,7 +41,7 @@ def __init__( self, n_landmarks: Optional[int] = None, use_sample_variance: Optional[bool] = None, - use_empirical_variance: bool = True, + use_empirical_variance: bool = False, eps: float = 1e-8, # Increased default epsilon for better numerical stability jit_compile: bool = False, function_predictor1: Optional[Any] = None, @@ -625,8 +624,10 @@ def compute_mahalanobis_distances( # Points for sample variance computation variance_points = X - # Average the covariance matrices - combined_cov = (cov1 + cov2) / 2 + # Sum the covariance matrices: Σ_a + Σ_b is the variance of the + # difference of independent posterior estimators, matching the + # Mahalanobis denominator defined in the manuscript. + combined_cov = cov1 + cov2 del cov1, cov2 # For sample variance, use diag=False to get full covariance matrices @@ -976,13 +977,38 @@ def predict( if hasattr(self, "_last_mahalanobis_dof"): logger.debug( - f"Computing ptp with {self._last_mahalanobis_dof} degrees of freedom..." + f"Computing neg_log10_ptp with {self._last_mahalanobis_dof} " + "degrees of freedom..." ) - mahalanobis_squared = jnp.array(mahalanobis_distances) ** 2 - ptp = jax_stats.chi2.sf( + # Posterior tail probability (PTP) of the Mahalanobis distance under + # the chi-squared null. Computed in LOG space and stored as + # -log10(PTP), mirroring the DA path's neg_log10_lfc_ptp convention. + # + # The PTP is chi2.sf(D^2, df). For an embedding with df on the order + # of tens, the vast majority of genes have D^2 well below the chi2 + # mean, where the linear sf evaluates to values numerically + # indistinguishable from 1.0 in float64 (1 - epsilon rounds to 1.0). + # Storing the linear sf therefore collapses most genes onto a single + # saturated value and destroys gene-ranking resolution at the head of + # the distribution. chi2.logsf returns the log of the same quantity + # directly (never forming 1 - cdf), so log(1 - epsilon) ~= -epsilon + # remains representable and every gene keeps a distinct value. + # + # scipy in float64 is used deliberately: jax runs in float32 unless + # x64 is explicitly enabled, and float32 logsf re-collapses the + # dynamic range (the precision is in the mantissa we are trying to + # preserve). + mahalanobis_squared = np.asarray( + mahalanobis_distances, dtype=np.float64 + ) ** 2 + ln_ptp = scipy_chi2.logsf( mahalanobis_squared, df=self._last_mahalanobis_dof ) - result["ptp"] = np.array(ptp) + # Convert natural-log tail probability to -log10(PTP): positive and + # larger for more significant genes, matching the DA convention. + result["neg_log10_ptp"] = np.asarray( + -(ln_ptp / np.log(10)), dtype=np.float64 + ) return result diff --git a/kompot/differential/expression_model.py b/kompot/differential/expression_model.py index af8e061..b25d16d 100644 --- a/kompot/differential/expression_model.py +++ b/kompot/differential/expression_model.py @@ -112,6 +112,7 @@ class ExpressionModel: Number of landmarks for Nystrom approximation. use_empirical_variance : bool Whether to estimate per-gene empirical variance from GP residuals. + By default False. eps : float Small constant for numerical stability. random_state : int, optional @@ -135,7 +136,7 @@ class ExpressionModel: def __init__( self, n_landmarks: Optional[int] = None, - use_empirical_variance: bool = True, + use_empirical_variance: bool = False, eps: float = 1e-8, random_state: Optional[int] = None, batch_size: int = 500, diff --git a/kompot/plot/__init__.py b/kompot/plot/__init__.py index d099c66..39e1124 100644 --- a/kompot/plot/__init__.py +++ b/kompot/plot/__init__.py @@ -166,6 +166,20 @@ def dotplot(*args, **kwargs): ) +try: + from .lollipop import lollipop + + __all__.append("lollipop") +except ImportError as e: + logger.warning(f"Could not import lollipop function due to: {e}") + + def lollipop(*args, **kwargs): + raise ImportError( + "Lollipop plot unavailable due to missing dependencies. " + "matplotlib is required." + ) + + # Import StringDB report class try: from .stringdb import StringDBReport diff --git a/kompot/plot/field_inference.py b/kompot/plot/field_inference.py index 1bd33a9..f32f5ae 100644 --- a/kompot/plot/field_inference.py +++ b/kompot/plot/field_inference.py @@ -230,7 +230,7 @@ def _fallback_field_inference( "direction_key": ["direction"], "mean_lfc_key": ["mean_lfc", "lfc", "log_fold_change", "fold_change"], "mahalanobis_key": ["mahalanobis", "score"], - "ptp_key": ["ptp"], + "ptp_key": ["neg_log10_ptp", "ptp"], "is_de_key": ["is_de", "significant"], "zscore_key": ["zscore", "z_score"], "density_key_1": ["log_density"], diff --git a/kompot/plot/lollipop.py b/kompot/plot/lollipop.py new file mode 100644 index 0000000..f6ea363 --- /dev/null +++ b/kompot/plot/lollipop.py @@ -0,0 +1,697 @@ +"""Gene-set-enrichment lollipop plot. + +Ax-embeddable lollipop for functional-enrichment tables: one row per +enriched term, a stem from the axis baseline to a dot whose x-position +encodes significance (``-log10(FDR)`` by default, or any score column) +and whose area encodes the gene count. A dashed ``FDR = 0.05`` guide and +an optional in-axes aesthetic key round out the figure-grade rendering. + +The renderer was built for the kompot manuscript (Fig 3 panels G/L) and +is generalized here so it feeds any enrichment result with minimal fuss: + +* a :class:`kompot.plot.StringDBReport` instance (its + :meth:`~kompot.plot.StringDBReport.get_functional_enrichment` is + called for you), **or** +* the ``signal``-sorted DataFrame that method already returns, **or** +* a generic enrichment table from another tool (gseapy / enrichr, + GOATOOLS, clusterProfiler exports, …). Column-name mapping params plus + autodetection bridge the schema differences — see :func:`lollipop`. + +Like :func:`kompot.plot.dotplot`, this composes cleanly into an +externally-provided ``ax`` instead of building its own ``GridSpec``, so +it drops into a composite figure without fighting the surrounding layout. +""" + +from __future__ import annotations + +import logging +from typing import Mapping, Optional, Sequence, Tuple, Union + +import numpy as np +import pandas as pd + +logger = logging.getLogger("kompot") + +try: + import matplotlib.pyplot as plt + from matplotlib.axes import Axes + from matplotlib.colors import Normalize, to_rgb + from matplotlib.figure import Figure + from matplotlib.lines import Line2D +except ImportError as e: # pragma: no cover - exercised via import facade + raise ImportError( + "matplotlib is required for plotting: pip install matplotlib" + ) from e + + +# --------------------------------------------------------------------------- +# Column autodetection +# +# Ordered candidate lists per logical field. The first column present in +# the frame wins. Names cover StringDB (kompot's StringDBReport), gseapy / +# enrichr, GOATOOLS, and clusterProfiler export conventions. Matching is +# case-insensitive on a normalized (stripped) header. +# --------------------------------------------------------------------------- +_TERM_CANDIDATES = ( + "description", # StringDB, GOATOOLS go-term name + "term", # generic / StringDB term id + "name", # gseapy "Name" + "go_name", + "pathway", + "term_name", + "go_term", + "id", # last-ditch identifier +) +_SCORE_CANDIDATES = ( + "signal", # StringDB balanced metric (its default sort key) + "combined score", # enrichr / gseapy + "nes", # GSEA normalized enrichment score + "enrichment_score", + "score", + "strength", # StringDB log10(obs/exp) + "odds ratio", +) +_COUNT_CANDIDATES = ( + "number_of_genes", # StringDB + "count", # clusterProfiler + "intersection_size", # g:Profiler + "overlap", # gseapy / enrichr ("k/K" string — numerator is taken) + "n_genes", + "gene_count", + "genes", # sometimes a list/str of members + "study_count", # GOATOOLS +) +_FDR_CANDIDATES = ( + "fdr", # StringDB + "adjusted p-value", # enrichr / gseapy + "p.adjust", # clusterProfiler + "padj", + "p_fdr_bh", # GOATOOLS + "qvalue", + "q_value", + "q-value", + "fdr_q-val", # gseapy preranked + "benjamini", + "p_value", # fall back to raw p if no adjusted column exists + "pvalue", + "p-value", +) + + +def _find_column( + df: pd.DataFrame, candidates: Sequence[str], explicit: Optional[str] +) -> Optional[str]: + """Resolve a logical field to an actual column name. + + ``explicit`` (if given) is honored verbatim and validated. Otherwise + the first candidate present in ``df`` (case-insensitive) is returned, + or ``None`` when nothing matches. + """ + if explicit is not None: + if explicit not in df.columns: + raise KeyError( + f"column '{explicit}' not found in enrichment frame " + f"(available: {list(df.columns)})" + ) + return explicit + lower = {str(c).strip().lower(): c for c in df.columns} + for cand in candidates: + if cand in lower: + return lower[cand] + return None + + +def _coerce_counts(values: pd.Series) -> np.ndarray: + """Coerce a gene-count column to a float array. + + Handles the three shapes seen in the wild: a plain numeric column + (StringDB ``number_of_genes``), a ``"k/K"`` overlap string (gseapy / + enrichr ``Overlap`` — the numerator is the observed count), and a + delimited gene list (the member count is its length). Unparseable + entries become ``NaN``. + """ + if pd.api.types.is_numeric_dtype(values): + return values.to_numpy(dtype=float) + + out = np.full(len(values), np.nan, dtype=float) + for i, v in enumerate(values.to_numpy()): + if v is None or (isinstance(v, float) and np.isnan(v)): + continue + s = str(v).strip() + if not s: + continue + if "/" in s: # "12/350" overlap → observed numerator + head = s.split("/", 1)[0].strip() + try: + out[i] = float(head) + continue + except ValueError: + pass + try: # plain number stored as text + out[i] = float(s) + continue + except ValueError: + pass + # delimited member list → count the members + for sep in (",", ";", " "): + if sep in s: + out[i] = float(len([t for t in s.split(sep) if t.strip()])) + break + return out + + +def _darken(color: str, factor: float = 0.55) -> Tuple[float, float, float]: + """Return a darker companion of ``color`` for dot outlines.""" + r, g, b = to_rgb(color) + return (r * factor, g * factor, b * factor) + + +def _wrap(text: str, width: int) -> str: + """Soft-wrap a long term description across at most two lines.""" + text = str(text) + if width <= 0 or len(text) <= width: + return text + cut = text.rfind(" ", 0, width + 1) + if cut <= 0: + cut = width + head = text[:cut].rstrip() + tail = text[cut:].lstrip() + if len(tail) > width: + tail = tail[: max(width - 1, 0)].rstrip() + "…" + return f"{head}\n{tail}" + + +def _resolve_enrichment( + data, + *, + category: str, + fdr_threshold: float, +) -> pd.DataFrame: + """Coerce the ``data`` argument into an enrichment DataFrame. + + Accepts a :class:`~kompot.plot.StringDBReport` (its + ``get_functional_enrichment`` is called), a DataFrame (used as-is), or + anything DataFrame-constructible (e.g. a list of record dicts). + """ + if isinstance(data, pd.DataFrame): + return data + # Duck-type the StringDBReport so we don't hard-import it (keeps the + # optional-dependency story intact) and so any work-alike with the + # same method works too. + if hasattr(data, "get_functional_enrichment"): + enr = data.get_functional_enrichment( + category=category, fdr_threshold=fdr_threshold + ) + if enr is None or len(enr) == 0: + raise ValueError( + "StringDBReport.get_functional_enrichment returned no terms " + f"for category '{category}' at FDR ≤ {fdr_threshold}. " + "The StringDB service may be unavailable, or the gene set " + "may carry no enrichment; pass a precomputed DataFrame to " + "plot offline." + ) + return enr + try: + return pd.DataFrame(data) + except Exception as exc: # pragma: no cover - defensive + raise TypeError( + "`data` must be a StringDBReport, a pandas DataFrame, or a " + f"DataFrame-constructible object; got {type(data)!r}" + ) from exc + + +def lollipop( + data: Union["pd.DataFrame", object], + *, + n_terms: int = 12, + term_col: Optional[str] = None, + score_col: Optional[str] = None, + count_col: Optional[str] = None, + fdr_col: Optional[str] = None, + x_metric: str = "neg_log10_fdr", + sort_by: Optional[str] = "x", + ascending: Optional[bool] = None, + category: str = "Process", + fdr_threshold: float = 0.05, + color: str = "#d73027", + edge_color: Optional[str] = None, + cmap: Optional[str] = None, + color_by: Optional[str] = None, + stem_lw: float = 1.8, + stem_alpha: float = 0.65, + dot_min: float = 40.0, + dot_max: float = 320.0, + dot_scale: float = 22.0, + dot_const: float = 80.0, + fdr_line: Optional[float] = 0.05, + annotate: bool = True, + annotate_fmt: Optional[str] = None, + legend: bool = True, + legend_label: str = "gene set", + label_width: int = 55, + label_fontsize: float = 6.5, + annotate_fontsize: float = 6.0, + fdr_floor: float = 1e-50, + title: Optional[str] = None, + subtitle: Optional[str] = None, + title_space: float = 0.18, + xlabel: Optional[str] = None, + ax: Optional[Axes] = None, + figsize: Tuple[float, float] = (7.0, 5.0), + return_fig: bool = False, + save: Optional[str] = None, + **kwargs, +) -> Optional[Union[Figure, Axes]]: + r"""Gene-set-enrichment lollipop plot. + + Each row is an enriched term. A stem runs from the x-axis baseline to + a dot whose x-position encodes significance (``x_metric``) and whose + area encodes the matched-gene count. + + Parameters + ---------- + data : StringDBReport, DataFrame, or records + Enrichment source. Three forms are accepted: + + * a :class:`kompot.plot.StringDBReport` instance — its + :meth:`~kompot.plot.StringDBReport.get_functional_enrichment` + is called with ``category`` / ``fdr_threshold``; + * the ``signal``-sorted DataFrame that method returns; + * any other enrichment-result DataFrame (gseapy / enrichr, + GOATOOLS, clusterProfiler, …). Use the ``*_col`` params to map + its columns, or rely on autodetection. + + **Expected schema** (logical field → autodetected column names): + + =============== ==================================================== + Field Candidate columns (case-insensitive) + =============== ==================================================== + term label ``description``, ``term``, ``name``, ``pathway``, … + score ``signal``, ``Combined Score``, ``NES``, ``score``, … + gene count ``number_of_genes``, ``Count``, ``Overlap`` (``k/K``), … + FDR ``fdr``, ``Adjusted P-value``, ``p.adjust``, ``padj``, … + =============== ==================================================== + + n_terms : int, default 12 + Number of top terms to display (after sorting). + term_col, score_col, count_col, fdr_col : str, optional + Explicit column names overriding autodetection for the term + label, the score (used when ``x_metric="score"``), the gene count + (dot size), and the FDR (x-axis when ``x_metric="neg_log10_fdr"``, + plus the guide line and annotations). + x_metric : {"neg_log10_fdr", "score"} or column name, default "neg_log10_fdr" + What the dot's x-position encodes. ``"neg_log10_fdr"`` plots + ``-log10(FDR)`` (manuscript default); ``"score"`` plots + ``score_col`` directly; any other value is treated as a literal + column name to plot. + sort_by : str or None, default "x" + How to order rows before taking the top ``n_terms``. ``"x"`` sorts + by the plotted value (most significant / highest score on top); + any column name sorts by that column; ``None`` preserves input + order (StringDB frames already arrive ``signal``-sorted). + ascending : bool, optional + Sort direction override. By default sorting is descending for + ``"x"`` / score columns and ascending for FDR-like columns. + category : str, default "Process" + StringDB enrichment category, used only when ``data`` is a + StringDBReport. See + :meth:`~kompot.plot.StringDBReport.get_functional_enrichment`. + fdr_threshold : float, default 0.05 + FDR cutoff passed through to StringDBReport (StringDBReport path + only). + color : str, default ``"#d73027"`` + Lollipop fill color. The default is kompot's "up" direction red + (:data:`kompot.utils.KOMPOT_COLORS`), matching the manuscript. + edge_color : str, optional + Dot outline / stem color. Defaults to a darkened ``color``. + cmap : str, optional + If given, dots are colored by ``color_by`` through this colormap + (a colorbar is added on standalone figures) instead of the solid + ``color``. + color_by : str, optional + Column whose values drive the ``cmap`` coloring. Defaults to the + resolved score column when ``cmap`` is set. + stem_lw, stem_alpha : float + Line width and alpha of the lollipop stems. + dot_min, dot_max : float, default 40, 320 + Clip bounds (area in pt²) for the gene-count dot sizer + ``clip(dot_min + dot_scale * sqrt(count), dot_min, dot_max)``. + dot_scale : float, default 22 + Multiplier in the dot sizer above. + dot_const : float, default 80 + Constant dot area used when no gene-count column is available. + fdr_line : float or None, default 0.05 + Draw a dashed vertical guide at this FDR (rendered at + ``-log10(fdr_line)`` when ``x_metric="neg_log10_fdr"``). ``None`` + disables it. Ignored for non-FDR x metrics. + annotate : bool, default True + Annotate each dot with ``n= FDR=`` to its right. + annotate_fmt : str, optional + Custom format string for the annotation, receiving ``count`` and + ``fdr`` as keyword fields, e.g. ``"{count} genes (q={fdr:.1e})"``. + legend : bool, default True + Draw the aesthetic key (set swatch, dot-size cue, FDR guide). + legend_label : str, default ``"gene set"`` + Label for the set swatch in the legend. + label_width : int, default 55 + Soft-wrap width for term descriptions (two lines max). ``0`` + disables wrapping. + label_fontsize, annotate_fontsize : float + Font sizes for the y-axis term labels and the per-dot annotation. + fdr_floor : float, default 1e-50 + FDRs are clipped to this floor before ``-log10`` to keep the + x-axis finite. + title, subtitle : str, optional + Title (bold) and subtitle. On a standalone figure these sit in a + reserved band above the axes (so the top row is never covered); + when embedding into ``ax`` the title becomes the axes title. + title_space : float, default 0.18 + Fraction of the standalone figure height reserved at the top for + the title / subtitle / legend band. + xlabel : str, optional + X-axis label. Defaults to ``$-\log_{10}(\mathrm{FDR})$`` for the + FDR metric, otherwise the score/column name. + ax : matplotlib.axes.Axes, optional + Embed into this axis. If ``None`` a standalone figure is built. + figsize : tuple, default ``(7.0, 5.0)`` + Standalone figure size; ignored when ``ax`` is given. + return_fig : bool, default False + If ``True``, return the ``Figure`` (standalone) or the ``Axes`` + (embedded) instead of ``None``. + save : str, optional + If given, ``fig.savefig(save, bbox_inches="tight")`` is called. + **kwargs + Forwarded to the dot :meth:`~matplotlib.axes.Axes.scatter` call. + + Returns + ------- + matplotlib.figure.Figure or matplotlib.axes.Axes or None + The figure (standalone) or axis (embedded) when ``return_fig`` is + ``True``, else ``None``. + + Examples + -------- + From a StringDBReport (queries StringDB live):: + + import kompot + report = kompot.plot.StringDBReport( + ["TP53", "BRCA1", "KRAS", "EGFR", "PTEN"], species_id=9606, + ) + kompot.plot.lollipop(report, category="Process", n_terms=10, + return_fig=True) + + From a precomputed enrichment table (offline, any tool):: + + import pandas as pd + df = pd.DataFrame({ + "description": ["immune response", "cell cycle", "apoptosis"], + "fdr": [1e-8, 3e-5, 2e-3], + "number_of_genes": [42, 18, 9], + "signal": [3.1, 2.0, 1.2], + }) + kompot.plot.lollipop(df, n_terms=3, return_fig=True) + + A gseapy/enrichr frame, scored by Combined Score, mapped explicitly:: + + kompot.plot.lollipop( + enrichr_df, x_metric="score", + term_col="Term", score_col="Combined Score", + count_col="Overlap", fdr_col="Adjusted P-value", + ) + + Embed into a composite figure:: + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + kompot.plot.lollipop(df_a, ax=axes[0], title="Condition A") + kompot.plot.lollipop(df_b, ax=axes[1], title="Condition B") + """ + df = _resolve_enrichment( + data, category=category, fdr_threshold=fdr_threshold + ) + if not isinstance(df, pd.DataFrame): # pragma: no cover - defensive + df = pd.DataFrame(df) + if len(df) == 0: + raise ValueError("enrichment frame is empty; nothing to plot") + df = df.reset_index(drop=True) + + # ---- resolve columns ---------------------------------------- + term_c = _find_column(df, _TERM_CANDIDATES, term_col) + if term_c is None: + raise KeyError( + "could not find a term/description column; pass `term_col=` " + f"(looked for {list(_TERM_CANDIDATES)}; have {list(df.columns)})" + ) + fdr_c = _find_column(df, _FDR_CANDIDATES, fdr_col) + score_c = _find_column(df, _SCORE_CANDIDATES, score_col) + count_c = _find_column(df, _COUNT_CANDIDATES, count_col) + + # ---- x values ----------------------------------------------- + if x_metric == "neg_log10_fdr": + if fdr_c is None: + raise KeyError( + "x_metric='neg_log10_fdr' needs an FDR column but none was " + "found; pass `fdr_col=` or choose x_metric='score'." + ) + fdr_vals = pd.to_numeric(df[fdr_c], errors="coerce") + x = -np.log10(fdr_vals.clip(lower=fdr_floor).astype(float)) + default_xlabel = r"$-\log_{10}(\mathrm{FDR})$" + elif x_metric == "score": + if score_c is None: + raise KeyError( + "x_metric='score' needs a score column but none was found; " + "pass `score_col=` or choose x_metric='neg_log10_fdr'." + ) + x = pd.to_numeric(df[score_c], errors="coerce") + default_xlabel = str(score_c) + else: + # literal column name + if x_metric not in df.columns: + raise KeyError( + f"x_metric '{x_metric}' is neither a known metric " + "('neg_log10_fdr', 'score') nor a column in the frame " + f"({list(df.columns)})" + ) + x = pd.to_numeric(df[x_metric], errors="coerce") + default_xlabel = str(x_metric) + x = np.asarray(x, dtype=float) + df = df.assign(_x=x) + df = df[np.isfinite(df["_x"])] + if len(df) == 0: + raise ValueError( + f"no finite x values from x_metric='{x_metric}'; check the " + "selected column for non-numeric / missing entries" + ) + + # ---- sort + top-N ------------------------------------------- + if sort_by is not None: + if sort_by == "x": + sort_key, asc_default = "_x", False + elif sort_by in df.columns: + sort_key = sort_by + # FDR-like → ascending (smaller is better); else descending. + asc_default = sort_by == fdr_c + else: + raise KeyError( + f"sort_by '{sort_by}' is not 'x' and not a column " + f"({list(df.columns)})" + ) + asc = asc_default if ascending is None else bool(ascending) + df = df.sort_values(sort_key, ascending=asc, kind="stable") + df = df.head(n_terms).reset_index(drop=True) + if len(df) == 0: + raise ValueError("no rows remain after sorting / top-N selection") + + xv = df["_x"].to_numpy(dtype=float) + + # ---- dot sizes ---------------------------------------------- + if count_c is not None: + counts = _coerce_counts(df[count_c]) + with np.errstate(invalid="ignore"): + sizes = np.clip( + dot_min + dot_scale * np.sqrt(np.clip(counts, 0, None)), + dot_min, + dot_max, + ) + sizes = np.where(np.isfinite(sizes), sizes, dot_const) + else: + counts = np.full(len(df), np.nan) + sizes = np.full(len(df), dot_const, dtype=float) + + # ---- fdr (for guide + annotation) --------------------------- + if fdr_c is not None: + fdr_series = pd.to_numeric(df[fdr_c], errors="coerce").to_numpy(dtype=float) + else: + fdr_series = np.full(len(df), np.nan) + + # ---- colors ------------------------------------------------- + if edge_color is None: + edge_color = _darken(color) + use_cmap = cmap is not None + if use_cmap: + cb_col = color_by if color_by is not None else score_c + if cb_col is None or cb_col not in df.columns: + raise KeyError( + "cmap was given but no `color_by` column is available; pass " + "`color_by=` (a numeric column) explicitly." + ) + cvals = pd.to_numeric(df[cb_col], errors="coerce").to_numpy(dtype=float) + norm = Normalize( + vmin=float(np.nanmin(cvals)), vmax=float(np.nanmax(cvals)) + ) + + # ---- axes --------------------------------------------------- + standalone = ax is None + if standalone: + fig = plt.figure(figsize=figsize) + top = max(0.55, 1.0 - float(title_space)) + ax = fig.add_axes([0.48, 0.14, 0.49, top - 0.14]) + else: + fig = ax.figure + + # ---- stems + dots ------------------------------------------- + y = np.arange(len(df)) + for yi, xi in zip(y, xv): + ax.plot( + [0, xi], [yi, yi], + color=edge_color if not use_cmap else color, + lw=stem_lw, alpha=stem_alpha, zorder=2, + ) + scatter_kwargs = dict( + s=sizes, edgecolors=edge_color, linewidths=0.6, zorder=3, + ) + scatter_kwargs.update(kwargs) + if use_cmap: + sc = ax.scatter(xv, y, c=cvals, cmap=cmap, norm=norm, **scatter_kwargs) + else: + sc = ax.scatter(xv, y, color=color, **scatter_kwargs) + + # ---- annotations -------------------------------------------- + if annotate: + for yi, xi in zip(y, xv): + cnt = counts[yi] + fdr_val = fdr_series[yi] + if annotate_fmt is not None: + txt = annotate_fmt.format( + count=int(cnt) if np.isfinite(cnt) else "?", + fdr=fdr_val if np.isfinite(fdr_val) else float("nan"), + ) + else: + parts = [] + if np.isfinite(cnt): + parts.append(f"n={int(cnt)}") + if np.isfinite(fdr_val): + if fdr_val < 1e-6: + parts.append(f"FDR={fdr_val:.1e}") + else: + parts.append(f"FDR={fdr_val:.2g}") + txt = " ".join(parts) + if txt: + ax.annotate( + txt, xy=(xi, yi), xytext=(8, 0), + textcoords="offset points", + fontsize=annotate_fontsize, color=edge_color, + ha="left", va="center", zorder=4, + ) + + # ---- term labels + axes cosmetics --------------------------- + labels = [_wrap(t, label_width) for t in df[term_c].tolist()] + ax.set_yticks(y) + ax.set_yticklabels(labels, fontsize=label_fontsize) + ax.set_xlabel(xlabel if xlabel is not None else default_xlabel) + + draw_guide = ( + fdr_line is not None and x_metric == "neg_log10_fdr" and fdr_line > 0 + ) + if draw_guide: + ax.axvline( + -np.log10(fdr_line), color="0.4", ls="--", lw=0.7, + alpha=0.8, zorder=1, + ) + + # Right-pad so annotations never run off the axes. + xmax = float(np.nanmax(xv)) if len(xv) else 1.0 + if not np.isfinite(xmax) or xmax <= 0: + xmax = 1.0 + x_hi = max(xmax * 1.35, xmax + 4) + x_lo = min(0.0, float(np.nanmin(xv))) + ax.set_xlim(x_lo, x_hi) + ax.set_ylim(len(df) - 0.5, -0.8) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.tick_params(axis="y", length=0) + + # ---- colorbar (cmap path, standalone only) ------------------ + if use_cmap and standalone: + cb = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.02) + cb.set_label(str(cb_col), fontsize=annotate_fontsize) + cb.ax.tick_params(labelsize=annotate_fontsize) + + # ---- legend ------------------------------------------------- + if legend: + handles = [] + if not use_cmap: + handles.append( + Line2D( + [0], [0], marker="o", color="w", + markerfacecolor=color, markeredgecolor=edge_color, + markersize=7, label=legend_label, + ) + ) + if count_c is not None: + handles.append( + Line2D( + [0], [0], marker="o", color="w", + markerfacecolor=color if not use_cmap else "0.5", + markeredgecolor=edge_color, markersize=4, + label="dot size ∝ gene count", + ) + ) + if draw_guide: + handles.append( + Line2D( + [0], [0], color="0.4", ls="--", lw=0.8, + label=f"FDR = {fdr_line:g}", + ) + ) + if handles: + if standalone: + fig.legend( + handles=handles, loc="upper right", + bbox_to_anchor=(0.97, 0.97), fontsize=5.5, + frameon=True, framealpha=0.9, borderpad=0.4, + handlelength=1.8, + ) + else: + ax.legend( + handles=handles, loc="lower right", fontsize=5.5, + frameon=True, framealpha=0.9, + ) + + # ---- title / subtitle --------------------------------------- + if standalone: + if title: + fig.text( + 0.02, 0.955, title, fontsize=10, fontweight="bold", + ha="left", va="top", + ) + if subtitle: + fig.text( + 0.02, 1.0 - title_space * 0.5, subtitle, fontsize=6.5, + color="0.35", ha="left", va="top", + ) + else: + if title: + ax.set_title(title, fontsize=10, fontweight="bold", pad=4) + if subtitle: + ax.text( + 0.0, 1.01, subtitle, transform=ax.transAxes, fontsize=6.5, + color="0.35", ha="left", va="bottom", + ) + + # ---- save / return ------------------------------------------ + if save is not None: + fig.savefig(save, bbox_inches="tight") + + if return_fig: + return fig if standalone else ax + return None diff --git a/kompot/plot/volcano/de.py b/kompot/plot/volcano/de.py index eec9d1b..6e6b427 100644 --- a/kompot/plot/volcano/de.py +++ b/kompot/plot/volcano/de.py @@ -73,7 +73,7 @@ def volcano_de( legend_ncol: Optional[int] = None, group: Optional[str] = None, # New significance-related parameters - y_axis_type: str = "mahalanobis", # "mahalanobis", "local_fdr", "tail_fdr", "log10_ptp", or custom column name + y_axis_type: str = "mahalanobis", # "mahalanobis", "local_fdr", "tail_fdr", "ptp", or custom column name significance_threshold: Optional[Union[float, Dict[str, float]]] = None, update_de_classification: bool = False, direction_column: Optional[str] = None, @@ -188,8 +188,9 @@ def volcano_de( adata.var for Mahalanobis distances, and mean fold changes. y_axis_type : str, optional Type of values to use for the y-axis: "mahalanobis" (default), "local_fdr", "tail_fdr", - "ptp", or a custom column name from adata.var. When using FDR or ptp values, they are - -log10 transformed for display. + "ptp", or a custom column name from adata.var. FDR values are -log10 transformed for + display; the "ptp" column is already stored as -log10(PTP) (the neg_log10_ptp field) + and is plotted directly. In both cases higher on the axis means more significant. significance_threshold : float or dict, optional Significance threshold for the y-axis values. A float sets a single threshold shown as a horizontal line. A dict maps y-axis types to @@ -298,14 +299,20 @@ def fdr_y_transform(y): score_key = original_score_key elif y_axis_type == "ptp": - # Posterior tail probability (will be -log10 transformed for display) + # Posterior tail probability, stored as -log10(PTP) in the neg_log10_ptp + # field. The column is already log-transformed (higher = more + # significant), so NO additional -log10 transform is applied for display. + # This mirrors the DA path, whose neg_log10_lfc_ptp field is likewise + # pre-transformed. if run_info and "ptp_key" in run_info and run_info["ptp_key"]: significance_key = run_info["ptp_key"] if significance_key and significance_key in adata.var.columns: score_key = significance_key - y_transform = fdr_y_transform # Same -log10 transform as FDR - logger.info(f"Using ptp values for y-axis: {significance_key}") + y_transform = None # already -log10(PTP); plot directly + logger.info( + f"Using neg_log10_ptp values for y-axis: {significance_key}" + ) else: logger.warning( f"ptp key '{significance_key}' from run info not found in adata.var" @@ -317,15 +324,17 @@ def fdr_y_transform(y): logger.warning( "No ptp key in run_info; attempting fallback ptp key inference from score key..." ) - fallback_key = score_key.replace("mahalanobis", "ptp") + fallback_key = score_key.replace("mahalanobis", "neg_log10_ptp") if fallback_key in adata.var.columns: score_key = fallback_key significance_key = fallback_key - y_transform = fdr_y_transform # Same -log10 transform as FDR - logger.warning(f"Using fallback ptp key: {fallback_key}") + y_transform = None # already -log10(PTP); plot directly + logger.warning(f"Using fallback neg_log10_ptp key: {fallback_key}") else: - logger.warning(f"Fallback ptp key '{fallback_key}' not found either") + logger.warning( + f"Fallback neg_log10_ptp key '{fallback_key}' not found either" + ) # Final fallback to original score key if nothing worked if significance_key is None: @@ -351,7 +360,8 @@ def fdr_y_transform(y): ylabel = "-log10(Local FDR)" elif y_axis_type == "tail_fdr" and y_transform is not None: ylabel = "-log10(Tail FDR)" - elif y_axis_type == "ptp" and y_transform is not None: + elif y_axis_type == "ptp" and significance_key is not None: + # neg_log10_ptp column is already -log10(PTP); no transform applied ylabel = "-log10(Posterior Tail Probability)" elif y_axis_type == "mahalanobis" or ( score_key and "mahalanobis" in score_key.lower() @@ -942,7 +952,8 @@ def fdr_y_transform(y): comparison = "<" elif axis_type == "ptp": col_key = run_info.get("ptp_key") if run_info else None - comparison = "<" + # neg_log10_ptp column: higher = more significant. + comparison = ">" elif axis_type == "mahalanobis": col_key = ( score_key @@ -959,6 +970,12 @@ def fdr_y_transform(y): if col_key and col_key in adata.var.columns: col_values = adata.var[col_key] + # ptp threshold is a probability; the neg_log10_ptp + # column is on the -log10 scale, so convert. + if axis_type == "ptp": + threshold_val = fdr_y_transform( + np.array([threshold_val]) + )[0] if comparison == "<": axis_mask = col_values < threshold_val else: @@ -1010,7 +1027,9 @@ def fdr_y_transform(y): significance_values_key = ( run_info.get("ptp_key") if run_info else None ) - threshold_comparison = "<" + # neg_log10_ptp column: higher = more significant, so compare + # the -log10 of the (probability) threshold with '>'. + threshold_comparison = ">" elif y_axis_type == "mahalanobis": significance_values_key = score_key # Use the current score key threshold_comparison = ">" @@ -1035,17 +1054,25 @@ def fdr_y_transform(y): ): # Select genes based on significance threshold sig_values = adata.var[significance_values_key] + # ptp threshold is a probability (max PTP); the stored + # column is -log10(PTP), so convert the threshold to the + # same scale for comparison. + effective_threshold = ( + fdr_y_transform(np.array([significance_threshold]))[0] + if y_axis_type == "ptp" + else significance_threshold + ) logger.info( - f"Significance threshold selection: using column '{significance_values_key}' with threshold {threshold_comparison} {significance_threshold}" + f"Significance threshold selection: using column '{significance_values_key}' with threshold {threshold_comparison} {effective_threshold}" ) logger.info( f"Values range: {sig_values.min():.6f} - {sig_values.max():.6f}" ) if threshold_comparison == "<": - significant_mask = sig_values < significance_threshold + significant_mask = sig_values < effective_threshold else: # '>' - significant_mask = sig_values > significance_threshold + significant_mask = sig_values > effective_threshold significant_genes = adata.var_names[significant_mask].tolist() logger.info( @@ -1469,9 +1496,20 @@ def fdr_y_transform(y): and significance_threshold is not None and not isinstance(significance_threshold, dict) ): - if y_axis_type in ["local_fdr", "tail_fdr", "ptp"] and y_transform is not None: + # ptp stores -log10(PTP) directly (y_transform is None), but the user + # passes a probability threshold, so it must still be mapped onto the + # -log10 axis. FDR axes carry an explicit y_transform that does the same. + threshold_axis_transform = None + if y_axis_type == "ptp": + threshold_axis_transform = fdr_y_transform + elif y_axis_type in ["local_fdr", "tail_fdr"] and y_transform is not None: + threshold_axis_transform = y_transform + + if threshold_axis_transform is not None: # Transform the threshold for display - threshold_y = y_transform(np.array([significance_threshold]))[0] + threshold_y = threshold_axis_transform( + np.array([significance_threshold]) + )[0] ax.axhline( y=threshold_y, color="red", diff --git a/kompot/resource_estimation.py b/kompot/resource_estimation.py index f4a36fa..e955010 100644 --- a/kompot/resource_estimation.py +++ b/kompot/resource_estimation.py @@ -898,10 +898,10 @@ def estimate_differential_expression_resources( shape=cov_matrix_shape, ) - # Combined covariance matrix (averaged) + # Combined covariance matrix (sum: Σ_a + Σ_b) plan.add_requirement( "Combined covariance matrix", - cov_size, # (cov1 + cov2) / 2 + cov_size, # cov1 + cov2 "memory", shape=cov_matrix_shape, ) diff --git a/kompot/version.py b/kompot/version.py index 26a803c..028b8b0 100644 --- a/kompot/version.py +++ b/kompot/version.py @@ -1,3 +1,3 @@ """Version information.""" -__version__ = "0.7.0" +__version__ = "0.8.0" diff --git a/pyproject.toml b/pyproject.toml index fc94426..fc8d333 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ ignore = ["E203", "W503"] [project] name = "kompot" -version = "0.7.0" +version = "0.8.0" description = "Differential abundance and gene expression analysis using Mahalanobis distance with JAX backend" readme = "README.md" authors = [ diff --git a/tests/test_audit_fixes.py b/tests/test_audit_fixes.py new file mode 100644 index 0000000..eb74a25 --- /dev/null +++ b/tests/test_audit_fixes.py @@ -0,0 +1,275 @@ +"""Regression tests pinning manuscript-aligned statistical behavior. + +These tests guard the v0.8.0 corrections so the implementation cannot +silently drift away from the manuscript's definitions: + +* Mahalanobis denominator must SUM covariances (not average) so that + ``D(a,b) = sqrt((mu_a - mu_b)^T (Sigma_a + Sigma_b)^(-1) (mu_a - mu_b))``. +* DA posterior tail probability must be ONE-sided: ``Phi(-|z|)``. +* ``use_empirical_variance`` must default to ``False`` at every + publicly-exposed entry point (the manuscript states that empirical + variance is disabled by default). +""" + +import inspect + +import numpy as np +import pytest + + +# ----------------------------------------------------------------------------- +# Mahalanobis denominator is Σ_a + Σ_b (sum), not the average +# ----------------------------------------------------------------------------- + + +class TestMahalanobisDenominatorIsSum: + """The covariance denominator in the gene-wise Mahalanobis distance + is the *sum* of the two posterior covariance matrices. + """ + + def test_combined_cov_equals_sum_via_compute_mahalanobis_distances( + self, monkeypatch + ): + """Capture the ``combined_cov`` argument that + ``DifferentialExpression.compute_mahalanobis_distances`` passes + into the underlying ``compute_mahalanobis_distances`` utility + and assert it equals ``cov1 + cov2`` (not ``(cov1+cov2)/2``). + """ + from kompot.differential import DifferentialExpression + from kompot.differential import differential_expression as de_module + + captured = {} + + def fake_compute( + diff_values, + covariance=None, + batch_size=500, + jit_compile=False, + progress=False, + eps=1e-10, + diagonal_variance=None, + **_kwargs, + ): + captured["combined_cov"] = np.asarray(covariance) + n_genes = np.asarray(diff_values).shape[0] + return np.zeros(n_genes, dtype=float) + + monkeypatch.setattr( + de_module, "compute_mahalanobis_distances", fake_compute + ) + + # Synthetic predictors with controllable covariance kernels: + # cov1 returns 2*I, cov2 returns 3*I, so cov1+cov2 = 5*I and the + # (buggy) average would be 2.5*I. + class _Pred: + def __init__(self, scale): + self.scale = scale + + def covariance(self, X, diag=False): + k = X.shape[0] + return self.scale * np.eye(k) + + def __call__(self, X): + # Return an (n_cells, n_genes) zero-mean expression so + # downstream `fold_change_subset` is well-defined. + return np.zeros((X.shape[0], 3), dtype=float) + + de = DifferentialExpression( + n_landmarks=None, + use_sample_variance=False, + use_empirical_variance=False, + function_predictor1=_Pred(2.0), + function_predictor2=_Pred(3.0), + ) + + X_new = np.random.RandomState(0).randn(8, 4) + de.compute_mahalanobis_distances(X_new, use_landmarks=False, progress=False) + + combined_cov = captured["combined_cov"] + expected_sum = 5.0 * np.eye(X_new.shape[0]) + np.testing.assert_allclose( + combined_cov, + expected_sum, + rtol=1e-12, + atol=0, + err_msg=( + "Regression: combined posterior covariance should be " + "cov1 + cov2 (= 5*I here), got something else. The pre-" + "0.8.0 (buggy) value would have been 2.5*I (= " + "(cov1 + cov2) / 2)." + ), + ) + + +# ----------------------------------------------------------------------------- +# DA PTP is one-sided: Phi(-|z|), not 2*Phi(-|z|) +# ----------------------------------------------------------------------------- + + +class TestDifferentialAbundancePTPOneSided: + """The differential-abundance posterior tail probability matches + the one-sided manuscript definition ``PTP = Phi(-|z|)``. + """ + + def test_ptp_one_sided_synthetic_z(self): + from scipy.stats import norm + + # Replicate the exact ln_ptp computation from + # kompot.differential.differential_abundance, fed with controlled + # z-scores so we can compare against the closed-form one-sided + # tail probability. + import jax.scipy.stats.norm as normal + + z = np.array([-3.0, -1.5, -0.5, 0.0, 0.5, 1.5, 3.0]) + + ln_ptp = np.minimum( + np.asarray(normal.logcdf(z)), + np.asarray(normal.logcdf(-z)), + ) + ptp = np.exp(ln_ptp) + + expected_one_sided = norm.cdf(-np.abs(z)) + np.testing.assert_allclose( + ptp, + expected_one_sided, + rtol=1e-10, + atol=1e-12, + err_msg=( + "Regression: PTP should be the one-sided tail Phi(-|z|). " + "Pre-0.8.0 code emitted 2*Phi(-|z|) (two-sided)." + ), + ) + + # And explicitly that it is NOT the two-sided variant + two_sided = 2.0 * norm.cdf(-np.abs(z)) + # Allow the symmetric `z == 0` boundary case (where both sides + # collapse to 0.5 and 1.0 respectively) by checking the strict + # off-axis values. + nonzero = z != 0 + assert np.all( + np.abs(ptp[nonzero] - two_sided[nonzero]) > 1e-3 + ), "PTP unexpectedly equals 2*Phi(-|z|) (two-sided)." + + def test_da_predict_emits_one_sided_ptp(self): + """End-to-end: fit DA on a clearly-separated synthetic pair and + verify the recovered PTP at each evaluation point equals + ``Phi(-|z|)`` computed from the same fit's z-score, not twice + that value. + """ + from scipy.stats import norm + from kompot.differential import DifferentialAbundance + + rng = np.random.RandomState(42) + X1 = rng.randn(80, 3) + X2 = rng.randn(80, 3) + 0.4 + + da = DifferentialAbundance() + da.fit(X1, X2) + + X_eval = np.vstack([X1[:20], X2[:20]]) + out = da.predict(X_eval, progress=False) + + z = np.asarray(out["log_fold_change_zscore"]) + neg_log10_ptp = np.asarray(out["neg_log10_fold_change_ptp"]) + ptp = 10.0 ** (-neg_log10_ptp) + + expected = norm.cdf(-np.abs(z)) + np.testing.assert_allclose( + ptp, + expected, + rtol=1e-4, + atol=1e-6, + err_msg=( + "Regression: PTP returned by DifferentialAbundance." + "predict() does not match the one-sided Phi(-|z|)." + ), + ) + + +# ----------------------------------------------------------------------------- +# use_empirical_variance defaults to False at every public entry point +# ----------------------------------------------------------------------------- + + +class TestUseEmpiricalVarianceDefaultIsFalse: + """Every publicly-exposed entry point that accepts + ``use_empirical_variance`` must default to ``False`` (matching the + manuscript's "empirical variance is disabled by default" statement). + """ + + def _default_for(self, callable_obj, param_name="use_empirical_variance"): + sig = inspect.signature(callable_obj) + assert param_name in sig.parameters, ( + f"{callable_obj.__qualname__} does not expose {param_name}" + ) + param = sig.parameters[param_name] + assert param.default is not inspect.Parameter.empty, ( + f"{callable_obj.__qualname__} parameter {param_name} has " + f"no default value" + ) + return param.default + + def test_gpsettings_default_is_false(self): + from kompot.settings import GPSettings + + assert GPSettings().use_empirical_variance is False + + def test_differential_expression_init_default_is_false(self): + from kompot.differential import DifferentialExpression + + assert ( + self._default_for(DifferentialExpression.__init__) is False + ) + + def test_expression_model_init_default_is_false(self): + from kompot.differential.expression_model import ExpressionModel + + assert self._default_for(ExpressionModel.__init__) is False + + def test_deprecated_compute_differential_expression_default_is_false(self): + from kompot.anndata.differential_expression import ( + compute_differential_expression, + ) + + assert self._default_for(compute_differential_expression) is False + + def test_deprecated_compute_smoothed_expression_default_is_false(self): + from kompot.anndata.smooth import compute_smoothed_expression + + assert self._default_for(compute_smoothed_expression) is False + + def test_smooth_config_template_default_is_false(self): + import pathlib + + import yaml + + import kompot + + template = ( + pathlib.Path(kompot.__file__).parent + / "cli" + / "templates" + / "smooth_config_template.yaml" + ) + cfg = yaml.safe_load(template.read_text()) + assert cfg["use_empirical_variance"] is False + + def test_de_config_template_default_is_false(self): + import pathlib + + import yaml + + import kompot + + template = ( + pathlib.Path(kompot.__file__).parent + / "cli" + / "templates" + / "de_config_template.yaml" + ) + cfg = yaml.safe_load(template.read_text()) + assert cfg["use_empirical_variance"] is False + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_cleanup.py b/tests/test_cleanup.py index 6fc4f1e..5ac2297 100644 --- a/tests/test_cleanup.py +++ b/tests/test_cleanup.py @@ -224,7 +224,7 @@ def test_cleanup_keep_specific_var_fields(self): # Additional stats should be removed assert "test_keep_var_A_to_B_mahalanobis_pvalue" not in adata.var.columns assert "test_keep_var_A_to_B_mahalanobis_tail_fdr" not in adata.var.columns - assert "test_keep_var_A_to_B_ptp" not in adata.var.columns + assert "test_keep_var_A_to_B_neg_log10_ptp" not in adata.var.columns def test_cleanup_not_inplace(self): """Test that cleanup returns a copy when inplace=False.""" diff --git a/tests/test_differential_expression_core.py b/tests/test_differential_expression_core.py index 260134b..b3d6827 100644 --- a/tests/test_differential_expression_core.py +++ b/tests/test_differential_expression_core.py @@ -333,11 +333,13 @@ def test_differential_expression_predict_basic(self): "kompot.differential.differential_expression.compute_mahalanobis_distances" ) as mock_mahal: with patch( - "kompot.differential.differential_expression.jax_stats.chi2.sf" - ) as mock_chi2: + "kompot.differential.differential_expression.scipy_chi2.logsf" + ) as mock_logsf: mock_batch.side_effect = lambda func, X, **kwargs: func(X) mock_mahal.return_value = np.array([0.5, 0.8]) # 2 genes - mock_chi2.return_value = np.array([0.3, 0.1]) # Mock PTP values + # PTP is now computed in log space via scipy chi2.logsf; mock + # the natural-log tail probability for the 2 genes. + mock_logsf.return_value = np.array([-0.3, -0.1]) results = de.predict(X_test, compute_mahalanobis=True) @@ -349,12 +351,12 @@ def test_differential_expression_predict_basic(self): assert "fold_change_zscores" in results assert "mean_log_fold_change" in results assert "mahalanobis_distances" in results - assert "ptp" in results # New PTP column + assert "neg_log10_ptp" in results # -log10(PTP) column # Check shapes assert results["fold_change"].shape == (3, 2) # 3 cells, 2 genes assert results["mahalanobis_distances"].shape == (2,) # 2 genes - assert results["ptp"].shape == (2,) # 2 genes + assert results["neg_log10_ptp"].shape == (2,) # 2 genes def test_differential_expression_predict_with_sample_variance(self): """Test DifferentialExpression prediction with sample variance.""" @@ -650,13 +652,13 @@ def side_effect_unc2(X, diag=False): "kompot.differential.differential_expression.compute_mahalanobis_distances" ) as mock_mahal: with patch( - "kompot.differential.differential_expression.jax_stats.chi2.sf" - ) as mock_chi2: + "kompot.differential.differential_expression.scipy_chi2.logsf" + ) as mock_logsf: mock_batch.side_effect = lambda func, X, **kwargs: func(X) mock_mahal.return_value = np.array([0.2, 0.4, 0.6]) # 3 genes - mock_chi2.return_value = np.array( - [0.4, 0.2, 0.1] - ) # Mock PTP values + # PTP is now computed in log space via scipy chi2.logsf; mock + # the natural-log tail probability for the 3 genes. + mock_logsf.return_value = np.array([-0.4, -0.2, -0.1]) results = de.predict(X_test, progress=False) @@ -674,7 +676,7 @@ def side_effect_unc2(X, diag=False): "mahalanobis_distances" not in results ) # Should not be present when compute_mahalanobis=False assert ( - "ptp" not in results + "neg_log10_ptp" not in results ) # Should not be present when compute_mahalanobis=False # Test with mahalanobis computation enabled @@ -682,8 +684,10 @@ def side_effect_unc2(X, diag=False): X_test, compute_mahalanobis=True, progress=False ) assert "mahalanobis_distances" in results_with_mahal - assert "ptp" in results_with_mahal # PTP should be present + assert ( + "neg_log10_ptp" in results_with_mahal + ) # -log10(PTP) should be present assert results_with_mahal["mahalanobis_distances"].shape == ( 3, ) # 3 genes - assert results_with_mahal["ptp"].shape == (3,) # 3 genes + assert results_with_mahal["neg_log10_ptp"].shape == (3,) # 3 genes diff --git a/tests/test_empirical_variance.py b/tests/test_empirical_variance.py index f8738ab..903e131 100644 --- a/tests/test_empirical_variance.py +++ b/tests/test_empirical_variance.py @@ -589,8 +589,15 @@ def test_results_with_empirical_variance(self, small_adata, fast_de_params): assert model.empirical_variance_predictor1 is not None assert model.empirical_variance_predictor2 is not None - def test_default_is_on(self, tiny_adata, fast_de_params): - """Default should be use_empirical_variance=True.""" + def test_default_is_off(self, tiny_adata, fast_de_params): + """Default should be ``use_empirical_variance=False`` everywhere. + + The deprecated ``compute_differential_expression`` wrapper + previously defaulted to ``True``, disagreeing with the + manuscript ("empirical variance is disabled by default") and + with the recommended ``kompot.de()`` path. v0.8.0 harmonizes + all public entry points to ``False``. + """ from kompot.anndata.differential_expression import ( compute_differential_expression, ) @@ -606,8 +613,8 @@ def test_default_is_on(self, tiny_adata, fast_de_params): ) model = result["model"] - assert model.use_empirical_variance is True - assert model.empirical_variance_predictor1 is not None + assert model.use_empirical_variance is False + assert model.empirical_variance_predictor1 is None # ===== Leverage correction ===== diff --git a/tests/test_fdr_integration.py b/tests/test_fdr_integration.py index 0630170..111b31e 100644 --- a/tests/test_fdr_integration.py +++ b/tests/test_fdr_integration.py @@ -92,13 +92,13 @@ def test_fdr_enabled_basic(self): assert col in results["table"].columns, f"Missing column: {col}" assert len(results["table"][col]) == adata.n_vars - # Check AnnData columns (including new ptp column) + # Check AnnData columns (including the neg_log10_ptp column) fdr_columns = [ "test_fdr_Ctrl_to_Treat_mahalanobis_pvalue", "test_fdr_Ctrl_to_Treat_mahalanobis_local_fdr", "test_fdr_Ctrl_to_Treat_mahalanobis_tail_fdr", "test_fdr_Ctrl_to_Treat_is_de", - "test_fdr_Ctrl_to_Treat_ptp", + "test_fdr_Ctrl_to_Treat_neg_log10_ptp", ] for col in fdr_columns: assert col in adata.var.columns, f"Missing column: {col}" @@ -110,9 +110,9 @@ def test_fdr_enabled_basic(self): assert np.all(adata.var["test_fdr_Ctrl_to_Treat_mahalanobis_local_fdr"] <= 1) assert adata.var["test_fdr_Ctrl_to_Treat_is_de"].dtype == bool - # Check ptp values (should be probabilities between 0 and 1) - assert np.all(adata.var["test_fdr_Ctrl_to_Treat_ptp"] >= 0) - assert np.all(adata.var["test_fdr_Ctrl_to_Treat_ptp"] <= 1) + # Check neg_log10_ptp values: -log10(PTP) is non-negative and unbounded + # above (a probability PTP <= 1 maps to -log10(PTP) >= 0). + assert np.all(adata.var["test_fdr_Ctrl_to_Treat_neg_log10_ptp"] >= 0) # FDR pipeline should run without error and produce valid results n_significant = np.sum(adata.var["test_fdr_Ctrl_to_Treat_is_de"]) diff --git a/tests/test_plot_lollipop.py b/tests/test_plot_lollipop.py new file mode 100644 index 0000000..fd92315 --- /dev/null +++ b/tests/test_plot_lollipop.py @@ -0,0 +1,328 @@ +"""Tests for kompot.plot.lollipop.""" + +from __future__ import annotations + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Synthetic enrichment frames +# --------------------------------------------------------------------------- +def _stringdb_frame(n: int = 8, seed: int = 0) -> pd.DataFrame: + """A frame shaped like StringDBReport.get_functional_enrichment output: + ``term``, ``description``, ``signal``, ``strength``, ``fdr``, + ``number_of_genes`` — sorted by ``signal`` descending.""" + rng = np.random.default_rng(seed) + fdr = np.sort(rng.uniform(1e-9, 4e-2, size=n)) + df = pd.DataFrame( + { + "term": [f"GO:{i:07d}" for i in range(n)], + "description": [f"biological process number {i}" for i in range(n)], + "signal": np.sort(rng.uniform(0.5, 3.5, size=n))[::-1], + "strength": rng.uniform(0.3, 1.5, size=n), + "fdr": fdr, + "number_of_genes": rng.integers(3, 60, size=n), + } + ) + return df.sort_values("signal", ascending=False).reset_index(drop=True) + + +def _generic_frame(n: int = 6, seed: int = 1) -> pd.DataFrame: + """A gseapy/enrichr-style frame: ``Term``, ``Overlap`` ("k/K" string), + ``Adjusted P-value``, ``Combined Score`` — different header names so + autodetection + the ``Overlap`` string parser are exercised.""" + rng = np.random.default_rng(seed) + return pd.DataFrame( + { + "Term": [f"pathway {i}" for i in range(n)], + "Overlap": [f"{k}/300" for k in rng.integers(4, 40, size=n)], + "Adjusted P-value": rng.uniform(1e-7, 3e-2, size=n), + "Combined Score": rng.uniform(5, 120, size=n), + } + ) + + +# --------------------------------------------------------------------------- +# Import / export +# --------------------------------------------------------------------------- +def test_lollipop_import_and_exported(): + import kompot.plot as kp + + assert hasattr(kp, "lollipop") + assert "lollipop" in kp.__all__ + + +# --------------------------------------------------------------------------- +# StringDBReport-DataFrame path +# --------------------------------------------------------------------------- +def test_lollipop_stringdb_frame_standalone(): + from kompot.plot import lollipop + + df = _stringdb_frame() + fig = lollipop(df, n_terms=5, return_fig=True) + assert isinstance(fig, plt.Figure) + ax = fig.axes[0] + # 5 dots in the scatter collection + assert ax.collections[-1].get_offsets().shape[0] == 5 + # x label is the -log10(FDR) default + assert "log" in ax.get_xlabel().lower() + plt.close(fig) + + +def test_lollipop_default_returns_none(): + from kompot.plot import lollipop + + out = lollipop(_stringdb_frame(), n_terms=4) + assert out is None + plt.close("all") + + +def test_lollipop_top_n_sorted_by_significance(): + from kompot.plot import lollipop + + df = _stringdb_frame(n=10) + fig = lollipop(df, n_terms=3, return_fig=True) + ax = fig.axes[0] + labels = [t.get_text() for t in ax.get_yticklabels()] + # the three smallest-FDR terms, most significant on top + expected = list( + df.sort_values("fdr").head(3)["description"] + ) + # labels may be soft-wrapped (newline) — compare on the un-wrapped text + got = [lbl.replace("\n", " ") for lbl in labels] + assert got == expected + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Generic-format path + autodetection + Overlap parsing +# --------------------------------------------------------------------------- +def test_lollipop_generic_frame_autodetect(): + from kompot.plot import lollipop + + df = _generic_frame() + fig = lollipop(df, n_terms=5, return_fig=True) + assert isinstance(fig, plt.Figure) + ax = fig.axes[0] + assert ax.collections[-1].get_offsets().shape[0] == 5 + plt.close(fig) + + +def test_lollipop_explicit_column_mapping_and_score_metric(): + from kompot.plot import lollipop + + df = _generic_frame() + fig = lollipop( + df, + x_metric="score", + term_col="Term", + score_col="Combined Score", + count_col="Overlap", + fdr_col="Adjusted P-value", + n_terms=4, + return_fig=True, + ) + ax = fig.axes[0] + assert ax.get_xlabel() == "Combined Score" + # score-metric → sorted by score descending; top dot has the max score + xs = ax.collections[-1].get_offsets()[:, 0] + assert xs[0] == pytest.approx(df["Combined Score"].max()) + plt.close(fig) + + +def test_lollipop_overlap_string_drives_dot_size(): + from kompot.plot import lollipop + from kompot.plot.lollipop import _coerce_counts + + counts = _coerce_counts(pd.Series(["12/300", "4/300", "40/300"])) + assert list(counts) == [12.0, 4.0, 40.0] + + +# --------------------------------------------------------------------------- +# StringDBReport instance path (mocked — no network) +# --------------------------------------------------------------------------- +def test_lollipop_accepts_stringdb_report_instance(): + from kompot.plot import lollipop + + frame = _stringdb_frame() + + class _FakeReport: + def get_functional_enrichment(self, category="Process", fdr_threshold=0.05): + assert category == "Process" + return frame + + fig = lollipop(_FakeReport(), n_terms=4, return_fig=True) + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_lollipop_stringdb_report_empty_raises(): + from kompot.plot import lollipop + + class _EmptyReport: + def get_functional_enrichment(self, category="Process", fdr_threshold=0.05): + return None + + with pytest.raises(ValueError, match="no terms"): + lollipop(_EmptyReport()) + plt.close("all") + + +# --------------------------------------------------------------------------- +# Embedding into a provided ax +# --------------------------------------------------------------------------- +def test_lollipop_embeds_into_provided_ax(): + from kompot.plot import lollipop + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + figs_before = list(plt.get_fignums()) + out = lollipop(_stringdb_frame(), ax=axes[0], n_terms=4) + assert out is None + # no new figure created + assert plt.get_fignums() == figs_before + assert len(axes[0].collections) >= 1 + plt.close(fig) + + +def test_lollipop_embedded_return_fig_returns_ax(): + from kompot.plot import lollipop + + fig, ax = plt.subplots() + out = lollipop(_stringdb_frame(), ax=ax, n_terms=3, return_fig=True) + assert out is ax + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Toggles: FDR guide, legend, annotations +# --------------------------------------------------------------------------- +def test_lollipop_fdr_guide_toggle(): + from kompot.plot import lollipop + + df = _stringdb_frame() + fig_on = lollipop(df, fdr_line=0.05, return_fig=True) + n_on = len(fig_on.axes[0].lines) + plt.close(fig_on) + + fig_off = lollipop(df, fdr_line=None, return_fig=True) + n_off = len(fig_off.axes[0].lines) + plt.close(fig_off) + + # one fewer line (the guide) when disabled + assert n_on == n_off + 1 + + +def test_lollipop_legend_toggle(): + from kompot.plot import lollipop + + fig = lollipop(_stringdb_frame(), legend=False, return_fig=True) + assert fig.legends == [] and fig.axes[0].get_legend() is None + plt.close(fig) + + +def test_lollipop_annotate_toggle(): + from kompot.plot import lollipop + + df = _stringdb_frame(n=4) + fig_on = lollipop(df, annotate=True, return_fig=True) + ann_on = [c for c in fig_on.axes[0].texts if "FDR" in c.get_text()] + assert len(ann_on) == 4 + plt.close(fig_on) + + fig_off = lollipop(df, annotate=False, return_fig=True) + ann_off = [c for c in fig_off.axes[0].texts if "FDR" in c.get_text()] + assert ann_off == [] + plt.close(fig_off) + + +def test_lollipop_custom_annotate_fmt(): + from kompot.plot import lollipop + + fig = lollipop( + _stringdb_frame(n=3), + annotate_fmt="{count} genes", + return_fig=True, + ) + texts = [t.get_text() for t in fig.axes[0].texts] + assert any("genes" in t for t in texts) + plt.close(fig) + + +# --------------------------------------------------------------------------- +# cmap coloring path +# --------------------------------------------------------------------------- +def test_lollipop_cmap_colors_by_score(): + from kompot.plot import lollipop + + df = _stringdb_frame() + fig = lollipop(df, cmap="viridis", color_by="signal", return_fig=True) + # a colorbar axis is added on the standalone figure + assert len(fig.axes) >= 2 + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- +def test_lollipop_missing_term_column_raises(): + from kompot.plot import lollipop + + df = pd.DataFrame({"fdr": [0.01, 0.02], "number_of_genes": [5, 8]}) + with pytest.raises(KeyError, match="term"): + lollipop(df) + plt.close("all") + + +def test_lollipop_neg_log10_without_fdr_raises(): + from kompot.plot import lollipop + + df = pd.DataFrame( + {"description": ["a", "b"], "signal": [1.0, 2.0]} + ) + with pytest.raises(KeyError, match="FDR"): + lollipop(df, x_metric="neg_log10_fdr") + plt.close("all") + + +def test_lollipop_score_metric_without_score_raises(): + from kompot.plot import lollipop + + df = pd.DataFrame( + {"description": ["a", "b"], "fdr": [0.01, 0.02]} + ) + with pytest.raises(KeyError, match="score"): + lollipop(df, x_metric="score") + plt.close("all") + + +def test_lollipop_empty_frame_raises(): + from kompot.plot import lollipop + + with pytest.raises(ValueError, match="empty"): + lollipop(pd.DataFrame({"description": [], "fdr": []})) + plt.close("all") + + +def test_lollipop_explicit_bad_column_raises(): + from kompot.plot import lollipop + + df = _stringdb_frame() + with pytest.raises(KeyError, match="not found"): + lollipop(df, term_col="NoSuchColumn") + plt.close("all") + + +def test_lollipop_literal_column_x_metric(): + from kompot.plot import lollipop + + df = _stringdb_frame() + fig = lollipop(df, x_metric="strength", n_terms=4, return_fig=True) + assert fig.axes[0].get_xlabel() == "strength" + plt.close(fig) diff --git a/tests/test_ptp_functionality.py b/tests/test_ptp_functionality.py index 91b9550..80c06aa 100644 --- a/tests/test_ptp_functionality.py +++ b/tests/test_ptp_functionality.py @@ -8,6 +8,7 @@ matplotlib.use("Agg") # Use non-interactive backend import matplotlib.pyplot as plt import jax.scipy.stats as jax_stats +from scipy.stats import chi2 as scipy_chi2 import anndata @@ -27,14 +28,20 @@ def create_test_adata_with_ptp(n_cells=60, n_genes=50): # Create realistic Mahalanobis distances mahalanobis_distances = np.abs(np.random.gamma(2, 1, n_genes)) # Positive values - # Compute ptp from Mahalanobis distances using chi2 distribution + # Compute the posterior tail probability (PTP) from Mahalanobis distances + # using the chi2 distribution, stored as -log10(PTP) in log space — the + # convention kompot now uses (mirrors the DA neg_log10_lfc_ptp field). The + # log-space form avoids the linear-space saturation to 1.0 that collapses + # gene-ranking resolution at the head of the distribution. degrees_of_freedom = 10 # Typical number of dimensions - mahalanobis_squared = mahalanobis_distances**2 - ptp_values = np.array(jax_stats.chi2.sf(mahalanobis_squared, df=degrees_of_freedom)) + mahalanobis_squared = mahalanobis_distances.astype(np.float64) ** 2 + neg_log10_ptp_values = -scipy_chi2.logsf( + mahalanobis_squared, df=degrees_of_freedom + ) / np.log(10) # Add differential expression metrics adata.var["kompot_de_mahalanobis_A_to_B"] = mahalanobis_distances - adata.var["kompot_de_ptp_A_to_B"] = ptp_values + adata.var["kompot_de_neg_log10_ptp_A_to_B"] = neg_log10_ptp_values adata.var["kompot_de_mean_lfc_A_to_B"] = np.random.normal(0, 2, n_genes) # Add FDR values for comparison @@ -44,7 +51,8 @@ def create_test_adata_with_ptp(n_cells=60, n_genes=50): adata.var["kompot_de_mahalanobis_tail_fdr_A_to_B"] = np.random.uniform( 0, 0.5, n_genes ) - adata.var["kompot_de_is_de_A_to_B"] = ptp_values < 0.05 # Significant at p < 0.05 + # Significant at PTP < 0.05, i.e. -log10(PTP) > -log10(0.05) + adata.var["kompot_de_is_de_A_to_B"] = neg_log10_ptp_values > -np.log10(0.05) # Add run history for proper testing adata.uns["kompot_de_run_history"] = [ @@ -52,7 +60,7 @@ def create_test_adata_with_ptp(n_cells=60, n_genes=50): "params": {"condition1": "A", "condition2": "B"}, "field_names": { "mahalanobis_key": "kompot_de_mahalanobis_A_to_B", - "ptp_key": "kompot_de_ptp_A_to_B", + "ptp_key": "kompot_de_neg_log10_ptp_A_to_B", "mean_lfc_key": "kompot_de_mean_lfc_A_to_B", }, "fdr_keys": { @@ -60,7 +68,7 @@ def create_test_adata_with_ptp(n_cells=60, n_genes=50): "tail_fdr_key": "kompot_de_mahalanobis_tail_fdr_A_to_B", "is_de_key": "kompot_de_is_de_A_to_B", }, - "ptp_key": "kompot_de_ptp_A_to_B", # This should be in field_names + "ptp_key": "kompot_de_neg_log10_ptp_A_to_B", # in field_names } ] @@ -150,10 +158,17 @@ def test_volcano_de_ptp_gene_selection(self): adata = create_test_adata_with_ptp() - # Set some genes to be clearly significant - adata.var.loc["gene_0", "kompot_de_ptp_A_to_B"] = 0.001 # Very significant - adata.var.loc["gene_1", "kompot_de_ptp_A_to_B"] = 0.005 # Significant - adata.var.loc["gene_2", "kompot_de_ptp_A_to_B"] = 0.1 # Not significant + # Set some genes to be clearly significant. Column stores -log10(PTP), + # so larger = more significant. + adata.var.loc["gene_0", "kompot_de_neg_log10_ptp_A_to_B"] = -np.log10( + 0.001 + ) # Very significant (PTP=0.001) + adata.var.loc["gene_1", "kompot_de_neg_log10_ptp_A_to_B"] = -np.log10( + 0.005 + ) # Significant (PTP=0.005) + adata.var.loc["gene_2", "kompot_de_neg_log10_ptp_A_to_B"] = -np.log10( + 0.1 + ) # Not significant (PTP=0.1) fig = volcano_de( adata, @@ -164,7 +179,7 @@ def test_volcano_de_ptp_gene_selection(self): return_fig=True, ) - # Should highlight genes with ptp < 0.01 + # Should highlight genes with PTP < 0.01, i.e. -log10(PTP) > 2 plt.close(fig) def test_ptp_column_inference(self): @@ -233,7 +248,7 @@ def test_custom_column_name(self): adata, lfc_key="kompot_de_mean_lfc_A_to_B", score_key="kompot_de_mahalanobis_A_to_B", - y_axis_type="kompot_de_ptp_A_to_B", # Custom column name + y_axis_type="kompot_de_neg_log10_ptp_A_to_B", # Custom column name return_fig=True, ) @@ -247,7 +262,7 @@ def test_ptp_error_handling(self): adata = create_test_adata_with_ptp() # Remove ptp column - del adata.var["kompot_de_ptp_A_to_B"] + del adata.var["kompot_de_neg_log10_ptp_A_to_B"] # Should fall back to mahalanobis when ptp not found fig = volcano_de( @@ -292,5 +307,116 @@ def test_significance_threshold_parameter(self): plt.close(fig) +def _create_de_data(n_cells=80, n_genes=60, n_dims=10, seed=0): + """AnnData with a clear DE signal and a moderate embedding dimension so the + chi-squared df (= embedding dim) is large enough for the linear-space + saturation to bite.""" + rng = np.random.RandomState(seed) + n1 = n_cells // 2 + n2 = n_cells - n1 + # Embedding with a real shift between conditions in a few dimensions + shift = np.zeros(n_dims) + shift[:4] = 1.2 + X = np.vstack( + [rng.normal(0, 1, (n1, n_dims)), rng.normal(shift, 1, (n2, n_dims))] + ) + expr = rng.negative_binomial(10, 0.3, (n_cells, n_genes)).astype(float) + gene_names = [f"Gene_{i:04d}" for i in range(n_genes)] + cell_names = [f"Cell_{i:04d}" for i in range(n_cells)] + adata = anndata.AnnData( + expr, + obs=pd.DataFrame( + {"condition": ["A"] * n1 + ["B"] * n2}, index=cell_names + ), + var=pd.DataFrame(index=gene_names), + ) + adata.obsm["X_pca"] = X + return adata + + +class TestNegLog10PTPRegression: + """Regression guards for the linear-space PTP saturation bug. + + The DE posterior tail probability is a strictly monotone transform of the + Mahalanobis distance, so it must preserve the gene ranking. Storing it in + linear space (``chi2.sf``) collapses every gene below the chi-squared mean + onto values numerically indistinguishable from 1.0, destroying that ranking + at the head of the distribution. Storing ``-log10(PTP)`` from ``chi2.logsf`` + in float64 keeps every value distinct. These tests would have failed against + the old linear-space storage. + """ + + def test_linear_space_saturates_log_space_does_not(self): + """Pure-math guard at a realistic df: linear ``sf`` saturates to 1.0 and + loses distinct values; ``-log10(PTP)`` from ``logsf`` does not.""" + from scipy.stats import spearmanr + + rng = np.random.RandomState(0) + df = 40 # realistic embedding dimension + # Most genes are near-null -> D^2 well below the chi2 mean (= df). + d2 = np.r_[rng.chisquare(5, 2000), rng.chisquare(60, 100)] + + linear_sf = scipy_chi2.sf(d2, df=df) + neg_log10 = -scipy_chi2.logsf(d2, df=df) / np.log(10) + + # Linear space: a substantial fraction collapse to EXACTLY 1.0 ... + assert np.mean(linear_sf == 1.0) > 0.1 + # ... so the distinct-value count is destroyed. + assert len(np.unique(linear_sf)) < len(d2) + + # Log space: every gene keeps a distinct value. + assert len(np.unique(neg_log10)) == len(d2) + # And the ranking is exactly the Mahalanobis ranking. + assert spearmanr(neg_log10, d2).correlation == pytest.approx(1.0) + + def test_stored_field_preserves_mahalanobis_ranking(self): + """End-to-end: the stored ``neg_log10_ptp`` field ranks genes identically + to the Mahalanobis distance, has dynamic range beyond [0, 1] (impossible + for the old linear ``sf`` field, whose max was <= 1), and shows no mass + saturation onto a single value.""" + try: + from kompot.anndata import compute_differential_expression + except ImportError: + pytest.skip("anndata not installed") + from scipy.stats import spearmanr + + adata = _create_de_data() + compute_differential_expression( + adata, + groupby="condition", + condition1="A", + condition2="B", + obsm_key="X_pca", + result_key="reg", + null_genes=10, + null_seed=0, + store_additional_stats=True, + progress=False, + n_landmarks=10, + ) + + mahal = adata.var["reg_A_to_B_mahalanobis"].values + ptp = adata.var["reg_A_to_B_neg_log10_ptp"].values + + # Strictly monotone transform of the distance -> identical ranking. + finite = np.isfinite(mahal) & np.isfinite(ptp) + assert finite.sum() >= 2 + assert spearmanr(ptp[finite], mahal[finite]).correlation == pytest.approx( + 1.0 + ) + + # -log10(PTP) is always non-negative. + assert np.all(ptp[finite] >= 0) + + # Dynamic range the old linear-space field could not represent: at least + # one gene exceeds 1.0 (i.e. PTP < 0.1). The old field stored sf in + # [0, 1], so its maximum was structurally <= 1. + assert np.nanmax(ptp) > 1.0 + + # No mass saturation: no single stored value dominates the field. + _, counts = np.unique(ptp[finite], return_counts=True) + assert counts.max() / finite.sum() < 0.5 + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_store_additional_stats.py b/tests/test_store_additional_stats.py index 2251942..9777fb1 100644 --- a/tests/test_store_additional_stats.py +++ b/tests/test_store_additional_stats.py @@ -77,7 +77,7 @@ def test_default_behavior_stores_minimal_fields(self): # Check that additional measures are NOT stored assert "test_default_A_to_B_mahalanobis_pvalue" not in adata.var.columns assert "test_default_A_to_B_mahalanobis_tail_fdr" not in adata.var.columns - assert "test_default_A_to_B_ptp" not in adata.var.columns + assert "test_default_A_to_B_neg_log10_ptp" not in adata.var.columns assert "test_default_A_to_B_fold_change_zscores" not in adata.layers def test_store_additional_stats_true_stores_all_fields(self): @@ -112,7 +112,7 @@ def test_store_additional_stats_true_stores_all_fields(self): # Additional stats should be stored assert "test_all_stats_A_to_B_mahalanobis_pvalue" in adata.var.columns assert "test_all_stats_A_to_B_mahalanobis_tail_fdr" in adata.var.columns - assert "test_all_stats_A_to_B_ptp" in adata.var.columns + assert "test_all_stats_A_to_B_neg_log10_ptp" in adata.var.columns assert "test_all_stats_A_to_B_fold_change_zscores" in adata.layers def test_pvalue_ranges_when_stored(self): @@ -246,7 +246,7 @@ def test_ptp_stored_conditionally(self): n_landmarks=5, ) - assert "test_no_ptp_A_to_B_ptp" not in adata1.var.columns + assert "test_no_ptp_A_to_B_neg_log10_ptp" not in adata1.var.columns # With store_additional_stats=True: SHOULD store PTP compute_differential_expression( @@ -262,9 +262,9 @@ def test_ptp_stored_conditionally(self): n_landmarks=5, ) - assert "test_with_ptp_A_to_B_ptp" in adata2.var.columns - # Check PTP values are non-negative - assert np.all(adata2.var["test_with_ptp_A_to_B_ptp"] >= 0) + assert "test_with_ptp_A_to_B_neg_log10_ptp" in adata2.var.columns + # -log10(PTP) is always non-negative since PTP <= 1 + assert np.all(adata2.var["test_with_ptp_A_to_B_neg_log10_ptp"] >= 0) def test_storage_consistency_between_adata_and_results(self): """Test that what's stored in adata matches what's in results dictionary.""" diff --git a/tests/test_volcano_de_rendering.py b/tests/test_volcano_de_rendering.py index 2574e97..87b3d72 100644 --- a/tests/test_volcano_de_rendering.py +++ b/tests/test_volcano_de_rendering.py @@ -262,10 +262,12 @@ def test_ptp_y_axis_with_key(self, de_adata): """y_axis_type='ptp' with ptp_key in run info (lines 293-301).""" from kompot.plot.volcano.de import volcano_de - # Add ptp data - de_adata.var["kompot_de_A_to_B_ptp"] = np.random.uniform(0, 1, de_adata.n_vars) + # Add ptp data, stored as -log10(PTP) (the neg_log10_ptp convention) + de_adata.var["kompot_de_A_to_B_neg_log10_ptp"] = np.random.uniform( + 0, 5, de_adata.n_vars + ) run_history = json.loads(de_adata.uns["kompot_de"]["run_history"]) - run_history[0]["ptp_key"] = "kompot_de_A_to_B_ptp" + run_history[0]["ptp_key"] = "kompot_de_A_to_B_neg_log10_ptp" de_adata.uns["kompot_de"]["run_history"] = json.dumps(run_history) fig = volcano_de( @@ -489,9 +491,11 @@ def test_dict_threshold_with_ptp(self, de_adata): """Dict threshold including ptp axis type.""" from kompot.plot.volcano.de import volcano_de - de_adata.var["kompot_de_A_to_B_ptp"] = np.random.uniform(0, 1, de_adata.n_vars) + de_adata.var["kompot_de_A_to_B_neg_log10_ptp"] = np.random.uniform( + 0, 5, de_adata.n_vars + ) run_history = json.loads(de_adata.uns["kompot_de"]["run_history"]) - run_history[0]["ptp_key"] = "kompot_de_A_to_B_ptp" + run_history[0]["ptp_key"] = "kompot_de_A_to_B_neg_log10_ptp" de_adata.uns["kompot_de"]["run_history"] = json.dumps(run_history) fig = volcano_de(