Source code for genevector.metrics

"""genevector/metrics.py — co-expression target functions."""

import numpy as np
from scipy.sparse import issparse, csr_matrix
from scipy.stats import spearmanr
import collections
import itertools
import tqdm

try:
    from numba import njit, prange
    HAS_NUMBA = True
except ImportError:
    HAS_NUMBA = False

try:
    from ._rust import compute_mi_pairs as _rust_mi_pairs
    HAS_RUST = True
except ImportError:
    HAS_RUST = False


# ─── Discretization ────────────────────────────────────────────

[docs] def discretize_genes(X, n_bins=10): """ Discretize each gene's expression into integer bin indices. Parameters ---------- X : scipy.sparse.csr_matrix or np.ndarray Cells x genes expression matrix. n_bins : int Number of bins per gene (excluding the zero bin). Returns ------- X_disc : np.ndarray, shape (n_cells, n_genes), dtype=np.int32 Discretized expression. 0 = zero expression, 1..n_bins = quantile bins of nonzero expression. n_bins_per_gene : np.ndarray, shape (n_genes,), dtype=np.int32 Actual number of bins used per gene (may be < n_bins if a gene has fewer unique nonzero values). """ if issparse(X): X = np.asarray(X.todense()) n_cells, n_genes = X.shape X_disc = np.zeros((n_cells, n_genes), dtype=np.int32) n_bins_per_gene = np.zeros(n_genes, dtype=np.int32) for g in range(n_genes): col = X[:, g].ravel() nonzero_mask = col > 0 nonzero_vals = col[nonzero_mask] if len(nonzero_vals) == 0: n_bins_per_gene[g] = 1 # just the zero bin continue unique_vals = np.unique(nonzero_vals) actual_bins = min(n_bins, len(unique_vals)) if actual_bins <= 1: X_disc[nonzero_mask, g] = 1 n_bins_per_gene[g] = 2 # 0 and 1 else: quantiles = np.linspace(0, 100, actual_bins + 1)[1:-1] edges = np.percentile(nonzero_vals, quantiles) X_disc[nonzero_mask, g] = np.searchsorted(edges, nonzero_vals, side='right') + 1 n_bins_per_gene[g] = actual_bins + 1 # including zero bin return X_disc, n_bins_per_gene
# ─── MI helper ───────────────────────────────────────────────── def _mi_from_joint(pxy): """Compute MI from a joint probability distribution.""" pxy = pxy / pxy.sum() px = pxy.sum(axis=1) py = pxy.sum(axis=0) px_py = np.outer(px, py) nz = (pxy > 0) & (px_py > 0) return np.sum(pxy[nz] * np.log2(pxy[nz] / px_py[nz])) # ─── Tier A: Vectorized NumPy MI ───────────────────────────────
[docs] def compute_mi_vectorized(X, gene_names, n_bins=10, signed=False): """ Compute MI for all gene pairs using vectorized discretization. Parameters ---------- X : sparse or dense matrix, shape (n_cells, n_genes) gene_names : list of str n_bins : int signed : bool If True, multiply MI by sign of Pearson correlation. Returns ------- mi_scores : dict of dict mi_scores[gene_a][gene_b] = float """ X_disc, n_bins_per_gene = discretize_genes(X, n_bins=n_bins) n_genes = len(gene_names) if signed: if issparse(X): X_dense = np.asarray(X.todense()) else: X_dense = X corr_matrix = np.nan_to_num(np.corrcoef(X_dense.T), nan=0.0) mi_scores = collections.defaultdict(lambda: collections.defaultdict(float)) pairs = list(itertools.combinations(range(n_genes), 2)) for idx_a, idx_b in tqdm.tqdm(pairs, desc="Computing MI"): na = n_bins_per_gene[idx_a] nb = n_bins_per_gene[idx_b] if na <= 1 or nb <= 1: continue col_a = X_disc[:, idx_a] col_b = X_disc[:, idx_b] mask = (col_a > 0) | (col_b > 0) if mask.sum() == 0: continue joint = np.zeros((na, nb), dtype=np.float64) np.add.at(joint, (col_a[mask], col_b[mask]), 1) mi = _mi_from_joint(joint) if signed: sign = np.sign(corr_matrix[idx_a, idx_b]) mi = sign * mi gene_a = gene_names[idx_a] gene_b = gene_names[idx_b] mi_scores[gene_a][gene_b] = round(mi, 5) mi_scores[gene_b][gene_a] = round(mi, 5) return mi_scores
# ─── Tier B: Numba JIT MI ────────────────────────────────────── if HAS_NUMBA: @njit(parallel=True) def _compute_all_mi_numba(X_disc, n_bins_per_gene, n_genes): """ Returns a flat array of MI values for upper-triangle pairs. """ n_pairs = n_genes * (n_genes - 1) // 2 mi_values = np.zeros(n_pairs, dtype=np.float64) for flat_idx in prange(n_pairs): # recover (i, j) from flat index i = int(n_genes - 1 - int(np.sqrt(-8 * flat_idx + 4 * n_genes * (n_genes - 1) - 7) / 2 - 0.5)) j = flat_idx + i * (i + 1) // 2 - n_genes * i + n_genes na = n_bins_per_gene[i] nb = n_bins_per_gene[j] if na <= 1 or nb <= 1: continue n_cells = X_disc.shape[0] joint = np.zeros((na, nb), dtype=np.float64) count = 0 for c in range(n_cells): a = X_disc[c, i] b = X_disc[c, j] if a > 0 or b > 0: joint[a, b] += 1.0 count += 1 if count == 0: continue total = 0.0 for ai in range(na): for bi in range(nb): total += joint[ai, bi] px = np.zeros(na, dtype=np.float64) py = np.zeros(nb, dtype=np.float64) for ai in range(na): for bi in range(nb): joint[ai, bi] /= total px[ai] += joint[ai, bi] py[bi] += joint[ai, bi] mi = 0.0 for ai in range(na): for bi in range(nb): if joint[ai, bi] > 0 and px[ai] > 0 and py[bi] > 0: mi += joint[ai, bi] * np.log2(joint[ai, bi] / (px[ai] * py[bi])) mi_values[flat_idx] = mi return mi_values
[docs] def compute_mi_numba(X, gene_names, n_bins=10, signed=False): """Numba-accelerated MI computation.""" if not HAS_NUMBA: raise ImportError("numba is required for compute_mi_numba. " "Install with: pip install numba") X_disc, n_bins_per_gene = discretize_genes(X, n_bins=n_bins) n_genes = len(gene_names) mi_flat = _compute_all_mi_numba(X_disc, n_bins_per_gene, n_genes) if signed: if issparse(X): X_dense = np.asarray(X.todense()) else: X_dense = X corr_matrix = np.nan_to_num(np.corrcoef(X_dense.T), nan=0.0) mi_scores = collections.defaultdict(lambda: collections.defaultdict(float)) pair_idx = 0 for i in range(n_genes): for j in range(i + 1, n_genes): mi = mi_flat[pair_idx] if signed: mi = np.sign(corr_matrix[i, j]) * mi gene_a = gene_names[i] gene_b = gene_names[j] mi_scores[gene_a][gene_b] = round(float(mi), 5) mi_scores[gene_b][gene_a] = round(float(mi), 5) pair_idx += 1 return mi_scores
# ─── Tier C: GPU MI ────────────────────────────────────────────
[docs] def compute_mi_gpu(X, gene_names, n_bins=10, signed=False, device="cuda"): """ GPU-accelerated MI using PyTorch scatter_add for joint histograms. """ import torch X_disc, n_bins_per_gene = discretize_genes(X, n_bins=n_bins) n_cells, n_genes = X_disc.shape X_disc_t = torch.tensor(X_disc, dtype=torch.long, device=device) if signed: if issparse(X): X_dense = np.asarray(X.todense()) else: X_dense = X corr_matrix = np.nan_to_num(np.corrcoef(X_dense.T), nan=0.0) mi_scores = collections.defaultdict(lambda: collections.defaultdict(float)) for i in tqdm.tqdm(range(n_genes), desc="GPU MI"): na = int(n_bins_per_gene[i]) if na <= 1: continue col_a = X_disc_t[:, i] j_indices = list(range(i + 1, n_genes)) if not j_indices: continue cols_b = X_disc_t[:, j_indices] mask = (col_a.unsqueeze(1) > 0) | (cols_b > 0) for local_j, j in enumerate(j_indices): nb = int(n_bins_per_gene[j]) if nb <= 1: continue m = mask[:, local_j] a_vals = col_a[m] b_vals = cols_b[m, local_j] if a_vals.numel() == 0: continue flat_idx = a_vals * nb + b_vals joint = torch.zeros(na * nb, dtype=torch.float32, device=device) joint.scatter_add_(0, flat_idx.long(), torch.ones_like(flat_idx, dtype=torch.float32)) joint = joint.reshape(na, nb) total = joint.sum() if total == 0: continue joint /= total px = joint.sum(dim=1) py = joint.sum(dim=0) px_py = torch.outer(px, py) nz = (joint > 0) & (px_py > 0) mi = (joint[nz] * torch.log2(joint[nz] / px_py[nz])).sum().item() if signed: mi = float(np.sign(corr_matrix[i, j])) * mi gene_a = gene_names[i] gene_b = gene_names[j] mi_scores[gene_a][gene_b] = round(mi, 5) mi_scores[gene_b][gene_a] = round(mi, 5) return mi_scores
# ─── Tier D: Rust MI ───────────────────────────────────────────
[docs] def compute_mi_rust(X, gene_names, n_bins=10, signed=False): """Rust-accelerated MI computation via PyO3.""" if not HAS_RUST: raise ImportError("Rust backend not available. " "Build with: maturin develop --release") X_disc, n_bins_per_gene = discretize_genes(X, n_bins=n_bins) corr_signs = None if signed: if issparse(X): X_dense = np.asarray(X.todense()) else: X_dense = X corr_signs = np.nan_to_num(np.corrcoef(X_dense.T), nan=0.0).astype(np.float32) triples = _rust_mi_pairs(X_disc, n_bins_per_gene, corr_signs) mi_scores = collections.defaultdict(lambda: collections.defaultdict(float)) for i, j, mi in triples: gene_a = gene_names[i] gene_b = gene_names[j] mi_scores[gene_a][gene_b] = round(mi, 5) mi_scores[gene_b][gene_a] = round(mi, 5) return mi_scores
# ─── Target function registry ────────────────────────────────── TARGETS = {}
[docs] def register_target(name): """Decorator to register a target function.""" def wrapper(fn): TARGETS[name] = fn return fn return wrapper
[docs] def get_target_function(name): """Look up a registered target function by name. Parameters ---------- name : str Name of the registered target. Returns ------- callable The target function. Raises ------ ValueError If name is not registered. """ if name not in TARGETS: available = ", ".join(sorted(TARGETS.keys())) raise ValueError(f"Unknown target '{name}'. Available: {available}") return TARGETS[name]
# ─── Built-in targets ──────────────────────────────────────────
[docs] @register_target("mi") def target_mi(X, gene_names, signed=False, backend="auto", device="cpu", n_bins=10, **kwargs): """Mutual information (optionally signed by Pearson correlation).""" if backend == "auto": if device == "cuda": return compute_mi_gpu(X, gene_names, n_bins=n_bins, signed=signed, device=device) elif HAS_RUST: return compute_mi_rust(X, gene_names, n_bins=n_bins, signed=signed) elif HAS_NUMBA: return compute_mi_numba(X, gene_names, n_bins=n_bins, signed=signed) else: return compute_mi_vectorized(X, gene_names, n_bins=n_bins, signed=signed) elif backend == "rust": return compute_mi_rust(X, gene_names, n_bins=n_bins, signed=signed) elif backend == "numpy": return compute_mi_vectorized(X, gene_names, n_bins=n_bins, signed=signed) elif backend == "numba": return compute_mi_numba(X, gene_names, n_bins=n_bins, signed=signed) elif backend == "gpu": return compute_mi_gpu(X, gene_names, n_bins=n_bins, signed=signed, device=device) else: raise ValueError(f"Unknown MI backend: {backend}")
[docs] @register_target("pearson") def target_pearson(X, gene_names, **kwargs): """Pearson correlation between all gene pairs.""" if issparse(X): X = np.asarray(X.todense()) corr = np.corrcoef(X.T) return _matrix_to_score_dict(corr, gene_names)
[docs] @register_target("spearman") def target_spearman(X, gene_names, **kwargs): """Spearman rank correlation between all gene pairs.""" if issparse(X): X = np.asarray(X.todense()) corr, _ = spearmanr(X) if corr.ndim == 0: corr = np.array([[1.0, corr], [corr, 1.0]]) return _matrix_to_score_dict(corr, gene_names)
[docs] @register_target("jaccard") def target_jaccard(X, gene_names, **kwargs): """Jaccard index on binarized expression (gene detected / not detected).""" if issparse(X): binary = (X > 0).astype(np.float32) binary_dense = np.asarray(binary.todense()) else: binary_dense = (X > 0).astype(np.float32) intersection = binary_dense.T @ binary_dense sums = binary_dense.sum(axis=0) union = sums[:, None] + sums[None, :] - intersection union[union == 0] = 1 jaccard = intersection / union np.fill_diagonal(jaccard, 0) return _matrix_to_score_dict(np.array(jaccard), gene_names)
[docs] @register_target("cosine") def target_cosine(X, gene_names, **kwargs): """Cosine similarity between gene expression vectors (each gene is a vector across cells).""" from sklearn.metrics.pairwise import cosine_similarity as cos_sim if issparse(X): sim = cos_sim(X.T) else: sim = cos_sim(X.T) np.fill_diagonal(sim, 0) return _matrix_to_score_dict(sim, gene_names)
def _matrix_to_score_dict(matrix, gene_names): """Convert a symmetric score matrix to nested dict.""" scores = collections.defaultdict(lambda: collections.defaultdict(float)) n = len(gene_names) for i in range(n): for j in range(n): if i != j: scores[gene_names[i]][gene_names[j]] = round(float(matrix[i, j]), 5) return scores # ─── Import graph targets to trigger registration ────────────── from . import _graph_targets # noqa: F401, E402