Skip to content

Commit 769a6a0

Browse files
authored
Merge pull request #9 from mbruhns/mps
Integrating MPS backend
2 parents 436e53c + cb3065b commit 769a6a0

File tree

7 files changed

+448
-178
lines changed

7 files changed

+448
-178
lines changed

.gitignore

Lines changed: 148 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,148 @@
1-
__pycache__
2-
*.pdf
3-
build
4-
dist
5-
*.egg-info
6-
.eggs
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# Sublime workspace
132+
*.sublime-workspace
133+
.DS_Store
134+
135+
#Custom folders
136+
results/
137+
figures/
138+
139+
*.sublime-workspace
140+
*.sublime-project
141+
142+
# Jupyter notebooks
143+
*.ipynb
144+
145+
.idea/
146+
147+
*.h5ad
148+

harmony/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from importlib_metadata import version, PackageNotFoundError
77

88
try:
9-
__version__ = version('harmony-pytorch')
9+
__version__ = version("harmony-pytorch")
1010
del version
1111
except PackageNotFoundError:
1212
pass

harmony/harmony.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .utils import one_hot_tensor, get_batch_codes
1111

1212

13-
1413
def harmonize(
1514
X: np.array,
1615
batch_mat: pd.DataFrame,
@@ -105,13 +104,16 @@ def harmonize(
105104
>>> X_harmony = harmonize(adata.obsm['X_pca'], adata.obs, ['Channel', 'Lab'])
106105
"""
107106

108-
assert(isinstance(X, np.ndarray))
107+
assert isinstance(X, np.ndarray)
109108

110109
if n_jobs < 0:
111110
import psutil
112-
n_jobs = psutil.cpu_count(logical=False) # get physical cores
111+
112+
n_jobs = psutil.cpu_count(logical=False) # get physical cores
113113
if n_jobs is None:
114-
n_jobs = psutil.cpu_count(logical=True) # if undetermined, use logical cores instead
114+
n_jobs = psutil.cpu_count(
115+
logical=True
116+
) # if undetermined, use logical cores instead
115117
torch.set_num_threads(n_jobs)
116118

117119
device_type = "cpu"
@@ -120,9 +122,14 @@ def harmonize(
120122
device_type = "cuda"
121123
if verbose:
122124
print("Use GPU mode.")
123-
else:
125+
elif torch.backends.mps.is_available():
126+
device_type = "mps"
124127
if verbose:
125-
print("CUDA is not available on your machine. Use CPU mode instead.")
128+
print("Use Metal (MPS) mode.")
129+
elif verbose:
130+
print(
131+
"Neither CUDA nor MPS is available on your machine. Use CPU mode instead."
132+
)
126133

127134
(stride_0, stride_1) = X.strides
128135
if stride_0 < 0 or stride_1 < 0:
@@ -156,7 +163,7 @@ def harmonize(
156163
theta = theta.view(1, -1)
157164

158165
assert block_proportion > 0 and block_proportion <= 1
159-
assert correction_method in ["fast", "original"]
166+
assert correction_method in {"fast", "original"}
160167

161168
np.random.seed(random_state)
162169

@@ -206,13 +213,10 @@ def harmonize(
206213

207214
if is_convergent_harmony(objectives_harmony, tol=tol_harmony):
208215
if verbose:
209-
print("Reach convergence after {} iteration(s).".format(i + 1))
216+
print(f"Reach convergence after {i + 1} iteration(s).")
210217
break
211218

212-
if device_type == "cpu":
213-
return Z_hat.numpy()
214-
else:
215-
return Z_hat.cpu().numpy()
219+
return Z_hat.numpy() if device_type == "cpu" else Z_hat.cpu().numpy()
216220

217221

218222
def initialize_centroids(
@@ -229,17 +233,19 @@ def initialize_centroids(
229233
):
230234
n_cells = Z_norm.shape[0]
231235

232-
kmeans_params = {'n_clusters': n_clusters,
233-
'init': "k-means++",
234-
'n_init': n_init,
235-
'random_state': random_state,
236-
'max_iter': 25,
237-
}
236+
kmeans_params = {
237+
"n_clusters": n_clusters,
238+
"init": "k-means++",
239+
"n_init": n_init,
240+
"random_state": random_state,
241+
"max_iter": 25,
242+
}
238243

239244
kmeans = KMeans(**kmeans_params)
240245

241246
from threadpoolctl import threadpool_limits
242-
with threadpool_limits(limits = n_jobs):
247+
248+
with threadpool_limits(limits=n_jobs):
243249
if device_type == "cpu":
244250
kmeans.fit(Z_norm)
245251
else:
@@ -249,9 +255,7 @@ def initialize_centroids(
249255
Y_norm = normalize(Y, p=2, dim=1)
250256

251257
# Initialize R
252-
R = torch.exp(
253-
-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t()))
254-
)
258+
R = torch.exp(-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t())))
255259
R = normalize(R, p=1, dim=1)
256260

257261
E = torch.matmul(Pr_b, torch.sum(R, dim=0, keepdim=True))
@@ -282,12 +286,11 @@ def clustering(
282286
device_type,
283287
n_init=10,
284288
):
285-
286289
n_cells = Z_norm.shape[0]
287290

288291
objectives_clustering = []
289292

290-
for i in range(max_iter):
293+
for _ in range(max_iter):
291294
# Compute Cluster Centroids
292295
Y = torch.matmul(R.t(), Z_norm)
293296
Y_norm = normalize(Y, p=2, dim=1)
@@ -298,12 +301,8 @@ def clustering(
298301
pos = 0
299302
while pos < len(idx_list):
300303
idx_in = idx_list[pos : (pos + block_size)]
301-
R_in = R[
302-
idx_in,
303-
]
304-
Phi_in = Phi[
305-
idx_in,
306-
]
304+
R_in = R[idx_in,]
305+
Phi_in = Phi[idx_in,]
307306

308307
# Compute O and E on left out data.
309308
O -= torch.matmul(Phi_in.t(), R_in)
@@ -347,14 +346,12 @@ def correction_original(X, R, Phi, ridge_lambda, device_type):
347346
Phi_1 = torch.cat((torch.ones(n_cells, 1, device=device_type), Phi), dim=1)
348347

349348
Z = X.clone()
350-
id_mat = torch.eye(n_batches + 1, n_batches + 1, device = device_type)
349+
id_mat = torch.eye(n_batches + 1, n_batches + 1, device=device_type)
351350
id_mat[0, 0] = 0
352351
Lambda = ridge_lambda * id_mat
353352
for k in range(n_clusters):
354353
Phi_t_diag_R = Phi_1.t() * R[:, k].view(1, -1)
355-
inv_mat = torch.inverse(
356-
torch.matmul(Phi_t_diag_R, Phi_1) + Lambda
357-
)
354+
inv_mat = torch.inverse(torch.matmul(Phi_t_diag_R, Phi_1) + Lambda)
358355
W = torch.matmul(inv_mat, torch.matmul(Phi_t_diag_R, X))
359356
W[0, :] = 0
360357
Z -= torch.matmul(Phi_t_diag_R.t(), W)
@@ -375,7 +372,7 @@ def correction_fast(X, R, Phi, O, ridge_lambda, device_type):
375372
N_k = torch.sum(O_k)
376373

377374
factor = 1 / (O_k + ridge_lambda)
378-
c = N_k + torch.sum(-factor * O_k ** 2)
375+
c = N_k + torch.sum(-factor * O_k**2)
379376
c_inv = 1 / c
380377

381378
P[0, 1:] = -factor * O_k
@@ -401,7 +398,9 @@ def compute_objective(
401398
Y_norm, Z_norm, R, theta, sigma, O, E, objective_arr, device_type
402399
):
403400
kmeans_error = torch.sum(R * 2 * (1 - torch.matmul(Z_norm, Y_norm.t())))
404-
entropy_term = sigma * torch.sum(-torch.distributions.Categorical(probs=R).entropy())
401+
entropy_term = sigma * torch.sum(
402+
-torch.distributions.Categorical(probs=R).entropy()
403+
)
405404
diversity_penalty = sigma * torch.sum(
406405
torch.matmul(theta, O * torch.log(torch.div(O + 1, E + 1)))
407406
)

harmony/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,28 @@
22

33

44
def get_batch_codes(batch_mat, batch_key):
5-
if type(batch_key) is str or len(batch_key) == 1:
6-
if not type(batch_key) is str:
7-
batch_key = batch_key[0]
5+
if type(batch_key) is str:
6+
batch_vec = batch_mat[batch_key]
7+
8+
elif len(batch_key) == 1:
9+
batch_key = batch_key[0]
810

911
batch_vec = batch_mat[batch_key]
1012

1113
else:
12-
df = batch_mat[batch_key].astype('str')
13-
batch_vec = df.apply(lambda row: ','.join(row), axis = 1)
14+
df = batch_mat[batch_key].astype("str")
15+
batch_vec = df.apply(lambda row: ",".join(row), axis=1)
1416

1517
return batch_vec.astype("category")
1618

1719

1820
def one_hot_tensor(X, device_type):
19-
ids = torch.as_tensor(X.cat.codes.values.copy(), dtype = torch.long, device = device_type).view(-1, 1)
21+
ids = torch.as_tensor(
22+
X.cat.codes.values.copy(), dtype=torch.long, device=device_type
23+
).view(-1, 1)
2024
n_row = X.size
2125
n_col = X.cat.categories.size
22-
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device = device_type)
26+
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device=device_type)
2327
Phi.scatter_(dim=1, index=ids, value=1.0)
2428

2529
return Phi

setup.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
long_description = f.read()
88

99
requires = [
10-
"torch",
10+
"torch>=1.12",
1111
"numpy",
1212
"pandas",
1313
"psutil",
1414
"threadpoolctl",
15-
"scikit-learn>=0.23",
16-
"importlib_metadata>=0.7; python_version < '3.8'",
15+
"scikit-learn>=0.23"
1716
]
1817

1918
setup(

0 commit comments

Comments
 (0)