Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize cinema-ot #713

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 39 additions & 105 deletions pertpy/tools/_cinemaot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING

import anndata as ad
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -12,7 +14,6 @@
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from scanpy.plotting import _utils
from scipy.sparse import issparse
from sklearn.decomposition import FastICA
from sklearn.linear_model import LinearRegression
Expand All @@ -30,9 +31,6 @@
class Cinemaot:
"""CINEMA-OT is a causal framework for perturbation effect analysis to identify individual treatment effects and synergy."""

def __init__(self):
pass

def causaleffect(
self,
adata: AnnData,
Expand Down Expand Up @@ -103,12 +101,8 @@ def causaleffect(
cf = np.array(X_transformed[:, xi < thres], np.float32)
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float32)
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float32)
if sum(xi < thres) == 1:
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
elif sum(xi < thres) == 0:
if sum(xi < thres) == 0:
raise ValueError("No confounder components identified. Please try a higher threshold.")
else:
sklearn.metrics.pairwise_distances(cf1, cf2)

e = smoothness * sum(xi < thres)
geom = pointcloud.PointCloud(cf1, cf2, epsilon=e, batch_size=batch_size)
Expand Down Expand Up @@ -189,7 +183,7 @@ def causaleffect(
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], adata.obsm[cf_rep][adata.obs[pert_key] == control, :]
)

TE = sc.AnnData(np.array(te2), obs=adata[adata.obs[pert_key] != control, :].obs.copy(), var=adata.var.copy())
TE = ad.AnnData(np.array(te2), obs=adata[adata.obs[pert_key] != control, :].obs.copy(), var=adata.var.copy())
TE.obsm["X_embedding"] = embedding

if return_matching:
Expand Down Expand Up @@ -346,7 +340,7 @@ def generate_pseudobulk(
df = df.groupby(label_list).sum()
new_index = df.index.map(lambda x: "_".join(map(str, x)))
df_ = df.set_index(new_index)
adata_pb = sc.AnnData(df_)
adata_pb = ad.AnnData(df_)
adata_pb.obs = pd.DataFrame(
df.index.to_frame().values,
index=adata_pb.obs_names,
Expand Down Expand Up @@ -579,7 +573,7 @@ def synergy(
**kwargs,
)
ot0 = de0.obsm["ot"]
syn = sc.AnnData(
syn = ad.AnnData(
np.array(-((ot0 / np.sum(ot0, axis=1)[:, None]) @ de2.X - de1.X)), obs=de1.obs.copy(), var=de1.var.copy()
)
syn.obsm["X_embedding"] = (ot0 / np.sum(ot0, axis=1)[:, None]) @ de2.obsm["X_embedding"] - de1.obsm[
Expand Down Expand Up @@ -723,139 +717,79 @@ def plot_vis_matching(


class Xi:
"""
A fast implementation of cross-rank dependence metric used in CINEMA-OT.

"""

def __init__(self, x, y):
self.x = x
self.y = y
self.x = np.asarray(x)
self.y = np.asarray(y)
self._sample_size = len(x)

@property
def sample_size(self):
return len(self.x)
return self._sample_size

@property
@cached_property
def x_ordered_rank(self):
# PI is the rank vector for x, with ties broken at random
# Not mine: source (https://stackoverflow.com/a/47430384/1628971)
# random shuffling of the data - reason to use random.choice is that
# pd.sample(frac=1) uses the same randomizing algorithm
len_x = len(self.x)
rng = np.random.default_rng()
randomized_indices = rng.choice(np.arange(len_x), len_x, replace=False)
randomized = [self.x[idx] for idx in randomized_indices]
# same as pandas rank method 'first'
randomized_indices = rng.permutation(self._sample_size)
randomized = self.x[randomized_indices]
rankdata = ss.rankdata(randomized, method="ordinal")
# Reindexing based on pairs of indices before and after
unrandomized = [rankdata[j] for i, j in sorted(zip(randomized_indices, range(len_x), strict=False))]
return unrandomized
return rankdata[np.argsort(randomized_indices)]

@property
@cached_property
def y_rank_max(self):
# f[i] is number of j s.t. y[j] <= y[i], divided by n.
return ss.rankdata(self.y, method="max") / self.sample_size
return ss.rankdata(self.y, method="max") / self._sample_size

@property
@cached_property
def g(self):
# g[i] is number of j s.t. y[j] >= y[i], divided by n.
return ss.rankdata([-i for i in self.y], method="max") / self.sample_size
return ss.rankdata(-self.y, method="max") / self._sample_size

@property
@cached_property
def x_ordered(self):
# order of the x's, ties broken at random.
return np.argsort(self.x_ordered_rank)

@property
@cached_property
def x_rank_max_ordered(self):
x_ordered_result = self.x_ordered
y_rank_max_result = self.y_rank_max
# Rearrange f according to ord.
return [y_rank_max_result[i] for i in x_ordered_result]
return self.y_rank_max[self.x_ordered]

@property
@cached_property
def mean_absolute(self):
x1 = self.x_rank_max_ordered[0 : (self.sample_size - 1)]
x2 = self.x_rank_max_ordered[1 : self.sample_size]

return (
np.mean(
np.abs(
[
x - y
for x, y in zip(
x1,
x2,
strict=False,
)
]
)
)
* (self.sample_size - 1)
/ (2 * self.sample_size)
)
x1 = self.x_rank_max_ordered[:-1]
x2 = self.x_rank_max_ordered[1:]
return np.mean(np.abs(x1 - x2)) * (self._sample_size - 1) / (2 * self._sample_size)

@property
@cached_property
def inverse_g_mean(self):
gvalue = self.g
return np.mean(gvalue * (1 - gvalue))
return np.mean(self.g * (1 - self.g))

@property
@cached_property
def correlation(self):
"""xi correlation"""
return 1 - self.mean_absolute / self.inverse_g_mean

@classmethod
def xi(cls, x, y):
return cls(x, y)

def pval_asymptotic(self, ties: bool = False):
"""Returns p values of the correlation.

Args:
ties: boolean
If ties is true, the algorithm assumes that the data has ties
and employs the more elaborated theory for calculating
the P-value. Otherwise, it uses the simpler theory. There is
no harm in setting tiles True, even if there are no ties.

Returns:
p value
"""
# If there are no ties, return xi and theoretical P-value:
def pval_asymptotic(self, ties=False):
if ties:
return 1 - ss.norm.cdf(np.sqrt(self.sample_size) * self.correlation / np.sqrt(2 / 5))

# If there are ties, and the theoretical method is to be used for calculation P-values:
# The following steps calculate the theoretical variance in the presence of ties:
sorted_ordered_x_rank = sorted(self.x_rank_max_ordered)

ind = [i + 1 for i in range(self.sample_size)]
ind2 = [2 * self.sample_size - 2 * ind[i - 1] + 1 for i in ind]
return 1 - ss.norm.cdf(np.sqrt(self._sample_size) * self.correlation / np.sqrt(0.4))

a = np.mean([i * j * j for i, j in zip(ind2, sorted_ordered_x_rank, strict=False)]) / self.sample_size
sorted_ordered_x_rank = np.sort(self.x_rank_max_ordered)
ind = np.arange(1, self._sample_size + 1)
ind2 = 2 * self._sample_size - 2 * ind + 1

c = np.mean([i * j for i, j in zip(ind2, sorted_ordered_x_rank, strict=False)]) / self.sample_size
a = np.mean(ind2 * sorted_ordered_x_rank * sorted_ordered_x_rank) / self._sample_size
c = np.mean(ind2 * sorted_ordered_x_rank) / self._sample_size

cq = np.cumsum(sorted_ordered_x_rank)
m = (cq + (self._sample_size - ind) * sorted_ordered_x_rank) / self._sample_size

m = [
(i + (self.sample_size - j) * k) / self.sample_size
for i, j, k in zip(cq, ind, sorted_ordered_x_rank, strict=False)
]

b = np.mean([np.square(i) for i in m])
b = np.mean(np.square(m))
v = (a - 2 * b + np.square(c)) / np.square(self.inverse_g_mean)

return 1 - ss.norm.cdf(np.sqrt(self.sample_size) * self.correlation / np.sqrt(v))
return 1 - ss.norm.cdf(np.sqrt(self._sample_size) * self.correlation / np.sqrt(v))


class SinkhornKnopp:
"""
An simple implementation of Sinkhorn iteration used in the biwhitening approach.

"""
"""An simple implementation of Sinkhorn iteration used in the biwhitening approach."""

def __init__(self, max_iter: float = 1000, setr: int = 0, setc: float = 0, epsilon: float = 1e-3):
if max_iter < 0:
Expand Down
Loading