Skip to content

Commit

Permalink
change to SpatialData
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahOuologuem committed Feb 4, 2025
1 parent 7e8bcf1 commit 463ba6f
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 112 deletions.
113 changes: 78 additions & 35 deletions panpipes/python_scripts/collate_mdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,19 @@

L.info("Running with params: %s", args)

L.info("Reading in MuData from '%s'" % args.input_mudata)
mdata = mu.read(args.input_mudata)
#L.info("Reading in MuData from '%s'" % args.input_mudata)
#mdata = mu.read(args.input_mudata)
L.info("Reading in data from '%s'" % args.input_mudata)
if ".zarr" in args.input_mudata:
import spatialdata as sd
L.info("Reading in SpatialData from '%s'" % args.input_mudata)
mdata = sd.read_zarr(args.input_mudata)
else:
L.info("Reading in MuData from '%s'" % args.input_mudata)
mdata = mu.read(args.input_mudata)




L.info("Reading in cluster information")
cf = pd.read_csv(args.clusters_files_csv)
Expand All @@ -55,46 +66,78 @@

# add in the clusters

if isinstance(mdata, MuData):
L.info("Adding cluster information to MuData")
for i in range(cf.shape[0]):
cf_df = pd.read_csv(cf['fpath'][i], sep='\t', index_col=0)
cf_df['clusters'] = cf_df['clusters'].astype('str').astype('category')
cf_df = cf_df.rename(columns={"clusters":cf['new_key'][i]})

if cf['mod'][i] != "multimodal":
mdata[cf['mod'][i]].obs = mdata[cf['mod'][i]].obs.merge(cf_df, left_index=True, right_index=True)
else:
mdata.obs = mdata.obs.merge(cf_df, left_index=True, right_index=True)
elif isinstance(mdata, sd.SpatialData):
L.info("Adding cluster information to SpatialData")
for i in range(cf.shape[0]):
cf_df = pd.read_csv(cf['fpath'][i], sep='\t', index_col=0)
cf_df['clusters'] = cf_df['clusters'].astype('str').astype('category')
cf_df = cf_df.rename(columns={"clusters":cf['new_key'][i]})
mdata["table"].obs = mdata["table"].obs.merge(cf_df, left_index=True, right_index=True)

L.info("Adding cluster information to MuData")
for i in range(cf.shape[0]):
cf_df = pd.read_csv(cf['fpath'][i], sep='\t', index_col=0)
cf_df['clusters'] = cf_df['clusters'].astype('str').astype('category')
cf_df = cf_df.rename(columns={"clusters":cf['new_key'][i]})

if cf['mod'][i] != "multimodal":
mdata[cf['mod'][i]].obs = mdata[cf['mod'][i]].obs.merge(cf_df, left_index=True, right_index=True)
else:
mdata.obs = mdata.obs.merge(cf_df, left_index=True, right_index=True)


L.info("Adding UMAP coordinates to MuData")
uf = pd.read_csv(args.umap_files_csv)

for i in range(uf.shape[0]):
uf_df = pd.read_csv(uf['fpath'][i], sep='\t', index_col=0)
mod = uf['mod'][i]
new_key = uf['new_key'][i]
if uf['mod'][i] != "multimodal":
if all(mdata[mod].obs_names == uf_df.index):
mdata[mod].obsm[new_key] = uf_df.to_numpy()
if isinstance(mdata, MuData):
for i in range(uf.shape[0]):
uf_df = pd.read_csv(uf['fpath'][i], sep='\t', index_col=0)
mod = uf['mod'][i]
new_key = uf['new_key'][i]
if uf['mod'][i] != "multimodal":
if all(mdata[mod].obs_names == uf_df.index):
mdata[mod].obsm[new_key] = uf_df.to_numpy()
else:
L.warn("Cannot integrate %s into mdata as obs_names mismatch" % uf.iloc[i,:] )
else:
L.warn("Cannot integrate %s into mdata as obs_names mismatch" % uf.iloc[i,:] )
else:
# check the observations are the same
if set(mdata.obs_names).difference(uf_df.index) == set():
# put the observations in the same order
uf_df = uf_df.loc[mdata.obs_names,:]
mdata.obsm[new_key] = uf_df.to_numpy()
# check the observations are the same
if set(mdata.obs_names).difference(uf_df.index) == set():
# put the observations in the same order
uf_df = uf_df.loc[mdata.obs_names,:]
mdata.obsm[new_key] = uf_df.to_numpy()
else:
L.warning("Cannot integrate %s into mdata as obs_names mismatch" % uf.iloc[i,:] )
elif isinstance(mdata, sd.SpatialData):
for i in range(uf.shape[0]):
uf_df = pd.read_csv(uf['fpath'][i], sep='\t', index_col=0)
mod = uf['mod'][i]
new_key = uf['new_key'][i]
if uf['mod'][i] != "multimodal":
if all(mdata["table"].obs_names == uf_df.index):
mdata["table"].obsm[new_key] = uf_df.to_numpy()
else:
L.warn("Cannot integrate %s into adata as obs_names mismatch" % uf.iloc[i,:] )
else:
L.warning("Cannot integrate %s into mdata as obs_names mismatch" % uf.iloc[i,:] )


L.info("Saving updated MuData to '%s'" % args.output_mudata)
mdata.write(args.output_mudata)

output_csv = re.sub(".h5mu", "_cell_metdata.tsv", args.output_mudata)
L.info("Saving metadata to '%s'" % output_csv)
mdata.obs.to_csv(output_csv, sep='\t')
# check the observations are the same
if set(mdata["table"].obs_names).difference(uf_df.index) == set():
# put the observations in the same order
uf_df = uf_df.loc[mdata["table"].obs_names,:]
mdata["table"].obsm[new_key] = uf_df.to_numpy()
else:
L.warning("Cannot integrate %s into adata as obs_names mismatch" % uf.iloc[i,:] )

if isinstance(mdata, MuData):
L.info("Saving updated MuData to '%s'" % args.output_mudata)
mdata.write(args.output_mudata)
output_csv = re.sub(".h5mu", "_cell_metdata.tsv", args.output_mudata)
L.info("Saving metadata to '%s'" % output_csv)
mdata.obs.to_csv(output_csv, sep='\t')
elif isinstance(mdata, sd.SpatialData):
L.info("Saving updated SpatialData to '%s'" % args.output_mudata)
mdata.write(args.output_mudata)
output_csv = re.sub(".zarr", "_cell_metdata.tsv", args.output_mudata)
L.info("Saving metadata to '%s'" % output_csv)
mdata.obs.to_csv(output_csv, sep='\t')

L.info("Done")
26 changes: 18 additions & 8 deletions panpipes/python_scripts/plot_cluster_umaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,13 @@ def plot_spatial(adata,figdir):
fig.savefig(os.path.join(figdir, ok + "_clusters.png"))



L.info("Reading in MuData from '%s'" % args.infile)
mdata = read(args.infile)
if ".zarr" in args.infile:
import spatialdata as sd
L.info("Reading in SpatialData from '%s'" % args.infile)
data = sd.read_zarr(args.infile)
else:
L.info("Reading in MuData from '%s'" % args.infile)
data = read(args.infile)

mods = args.modalities.split(',')
# detemin initial figure directory based on object type
Expand All @@ -102,21 +106,27 @@ def plot_spatial(adata,figdir):
if os.path.exists("multimodal/figures") is False:
os.makedirs("multimodal/figures")
L.info("Plotting multimodal figures")
main(mdata, figdir="multimodal/figures")
main(data, figdir="multimodal/figures")


# we also need to plot per modality
if type(mdata) is MuData:
for mod in mdata.mod.keys():
if type(data) is MuData:
for mod in data.mod.keys():
if mod in mods:
L.info("Plotting for modality: %s" % mod)
figdir = os.path.join(mod, "figures")
if os.path.exists(figdir) is False:
os.makedirs(figdir)
if mod == "spatial": # added separate function for spatial
plot_spatial(mdata[mod], figdir)
plot_spatial(data[mod], figdir)
else:
main(mdata[mod], figdir)
main(data[mod], figdir)
elif isinstance(data, sd.SpatialData):
L.info("Plotting for modality: spatial")
figdir = os.path.join("spatial", "figures")
if os.path.exists(figdir) is False:
os.makedirs(figdir)
plot_spatial(data["table"], figdir)



Expand Down
28 changes: 17 additions & 11 deletions panpipes/python_scripts/plot_scanpy_markers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,23 @@ def do_plots(adata, mod, group_col, mf, n=10, layer=None):


# read data
L.info("Reading in MuData from '%s'" % args.infile)
mdata = mu.read(args.infile)

if type(mdata) is AnnData:
adata = mdata
# main function only does rank_gene_groups on X, so
elif type(mdata) is mu.MuData and args.modality is not None:
adata = mdata[args.modality]
else:
L.error("If the input is a MuData object, a modality needs to be specified")
sys.exit('If the input is a MuData object, a modality needs to be specified')
if args.modality != "spatial":
L.info("Reading in MuData from '%s'" % args.infile)
mdata = mu.read(args.infile)

if type(mdata) is AnnData:
adata = mdata
# main function only does rank_gene_groups on X, so
elif type(mdata) is mu.MuData and args.modality is not None:
adata = mdata[args.modality]
else:
L.error("If the input is a MuData object, a modality needs to be specified")
sys.exit('If the input is a MuData object, a modality needs to be specified')
else:
import spatialdata as sd
L.info("Reading in SpatialData from '%s'" % args.infile)
adata = sd.read_zarr(args.infile)["table"]


L.info("Loading marker information from '%s'" % args.marker_file)
mf = pd.read_csv(args.marker_file, sep='\t' )
Expand Down
90 changes: 59 additions & 31 deletions panpipes/python_scripts/rerun_find_neighbors_for_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import scanpy as sc
from muon import MuData, read

from panpipes.funcs.scmethods import run_neighbors_method_choice
from panpipes.funcs.io import read_yaml
from panpipes.funcs.scmethods import lsi
Expand Down Expand Up @@ -37,53 +38,80 @@
sc.settings.n_jobs = int(args.n_threads)

# read data
L.info("Reading in MuData from '%s'" % args.infile)
mdata = read(args.infile)
if ".zarr" in args.infile:
import spatialdata as sd
L.info("Reading in SpatialData from '%s'" % args.infile)
sdata = sd.read_zarr(args.infile)
else:
L.info("Reading in MuData from '%s'" % args.infile)
mdata = read(args.infile)



for mod in neighbor_dict.keys():
if mod in mdata.mod.keys():
if mod != "spatial":
if mod in mdata.mod.keys():
if neighbor_dict[mod]['use_existing']:
L.info('Using existing neighbors graph for %s' % mod)
pass
else:
L.info("Computing new neighbors for modality %s on %s" % (mod, neighbor_dict[mod]['dim_red']))
if type(mdata) is MuData:
adata=mdata[mod]
if (neighbor_dict[mod]['dim_red'] == "X_pca") and ("X_pca" not in adata.obsm.keys()):
L.info("X_pca not found, computing it using default parameters")
sc.tl.pca(adata)
if (mod == "atac") and (neighbor_dict[mod]['dim_remove'] is not None):
dimrem = int(neighbor_dict[mod]['dim_remove'])
adata.obsm['X_pca'] = adata.obsm['X_pca'][:, dimrem:]
adata.varm["PCs"] = adata.varm["PCs"][:, dimrem:]
if mod == "atac":
if (neighbor_dict[mod]['dim_red'] == "X_lsi") and ("X_lsi" not in adata.obsm.keys()):
L.info("X_lsi not found, computing it using default parameters")
lsi(adata=adata, num_components=50)
if neighbor_dict[mod]['dim_remove'] is not None:
L.info("Removing dimension %s from X_lsi" % neighbor_dict[mod]['dim_remove'])
dimrem = int(neighbor_dict[mod]['dim_remove'])
adata.obsm['X_lsi'] = adata.obsm['X_lsi'][:, dimrem:]
adata.varm["LSI"] = adata.varm["LSI"][:, dimrem:]
adata.uns["lsi"]["stdev"] = adata.uns["lsi"]["stdev"][dimrem:]

# run command
opts = dict(method=neighbor_dict[mod]['method'],
n_neighbors=int(neighbor_dict[mod]['k']),
n_pcs=int(neighbor_dict[mod]['n_dim_red']),
metric=neighbor_dict[mod]['metric'],
nthreads=args.n_threads,
use_rep=neighbor_dict[mod]['dim_red'])


run_neighbors_method_choice(adata,**opts)
mdata.mod[mod] = adata
mdata.update()
else:
if neighbor_dict[mod]['use_existing']:
L.info('Using existing neighbors graph for %s' % mod)
pass
else:
L.info("Computing new neighbors for modality %s on %s" % (mod, neighbor_dict[mod]['dim_red']))
if type(mdata) is MuData:
adata=mdata[mod]
if (neighbor_dict[mod]['dim_red'] == "X_pca") and ("X_pca" not in adata.obsm.keys()):
if (neighbor_dict[mod]['dim_red'] == "X_pca") and ("X_pca" not in sdata["table"].obsm.keys()):
L.info("X_pca not found, computing it using default parameters")
sc.tl.pca(adata)
if (mod == "atac") and (neighbor_dict[mod]['dim_remove'] is not None):
dimrem = int(neighbor_dict[mod]['dim_remove'])
adata.obsm['X_pca'] = adata.obsm['X_pca'][:, dimrem:]
adata.varm["PCs"] = adata.varm["PCs"][:, dimrem:]
if mod == "atac":
if (neighbor_dict[mod]['dim_red'] == "X_lsi") and ("X_lsi" not in adata.obsm.keys()):
L.info("X_lsi not found, computing it using default parameters")
lsi(adata=adata, num_components=50)
if neighbor_dict[mod]['dim_remove'] is not None:
L.info("Removing dimension %s from X_lsi" % neighbor_dict[mod]['dim_remove'])
dimrem = int(neighbor_dict[mod]['dim_remove'])
adata.obsm['X_lsi'] = adata.obsm['X_lsi'][:, dimrem:]
adata.varm["LSI"] = adata.varm["LSI"][:, dimrem:]
adata.uns["lsi"]["stdev"] = adata.uns["lsi"]["stdev"][dimrem:]

# run command
sc.tl.pca(sdata["table"])
opts = dict(method=neighbor_dict[mod]['method'],
n_neighbors=int(neighbor_dict[mod]['k']),
n_pcs=int(neighbor_dict[mod]['n_dim_red']),
metric=neighbor_dict[mod]['metric'],
nthreads=args.n_threads,
use_rep=neighbor_dict[mod]['dim_red'])
# run command
run_neighbors_method_choice(sdata["table"],**opts)


run_neighbors_method_choice(adata,**opts)
mdata.mod[mod] = adata
mdata.update()


if ".zarr" in args.infile:
L.info("Saving updated SpatialData to '%s'" % args.outfile)
sdata.write(args.outfile)
else:
L.info("Saving updated MuData to '%s'" % args.outfile)
mdata.write(args.outfile)

L.info("Saving updated MuData to '%s'" % args.outfile)
mdata.write(args.outfile)
L.info("Done")
L.info("Done")
21 changes: 14 additions & 7 deletions panpipes/python_scripts/run_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@

# read data
L.info("Reading in data from '%s'" % args.infile)
mdata = mu.read(args.infile)
if type(mdata) is AnnData:
adata = mdata
elif args.modality is not None:
adata = mdata[args.modality]
else:
adata = mdata
if ".zarr" in args.infile:
import spatialdata as sd
L.info("Reading in SpatialData from '%s'" % args.infile)
sdata = sd.read_zarr(args.infile)
adata = sdata["table"]
else:
mdata = mu.read(args.infile)
if type(mdata) is AnnData:
adata = mdata
elif args.modality is not None:
adata = mdata[args.modality]
else:
adata = mdata


uns_key=args.neighbors_key
# check sc.pp.neihgbours has been run
Expand Down
Loading

0 comments on commit 463ba6f

Please sign in to comment.