Skip to content

Commit 210fa4e

Browse files
Junbum KimJunbum Kim
Junbum Kim
authored and
Junbum Kim
committed
domain colocalization code tidied up
1 parent 1a744f7 commit 210fa4e

File tree

2 files changed

+51
-317
lines changed

2 files changed

+51
-317
lines changed

documentation/IMC Lung Infection.ipynb

Lines changed: 4 additions & 317 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,11 @@
4545
"import seaborn as sns\n",
4646
"from utag import utag\n",
4747
"\n",
48-
"from utag.utils import celltype_connectivity, domain_connectivity\n",
48+
"from utag.utils import celltype_connectivity\n",
4949
"from utag.visualize import (\n",
5050
" add_scale_box_to_fig,\n",
5151
" add_spatial_image,\n",
5252
" adj2chord,\n",
53-
" draw_network,\n",
5453
")\n",
5554
"\n",
5655
"sc.settings.set_figure_params(dpi=200, dpi_save=300, fontsize=6)"
@@ -2715,72 +2714,6 @@
27152714
"adata"
27162715
]
27172716
},
2718-
{
2719-
"cell_type": "code",
2720-
"execution_count": 59,
2721-
"id": "2b8e1ba2-df46-466d-b1c7-8c43b4b72722",
2722-
"metadata": {},
2723-
"outputs": [],
2724-
"source": [
2725-
"import matplotlib.pyplot as plt\n",
2726-
"import networkx as nx\n",
2727-
"import numpy as np\n",
2728-
"import pandas as pd\n",
2729-
"import squidpy as sq\n",
2730-
"from anndata import AnnData\n",
2731-
"from tqdm import tqdm\n",
2732-
"\n",
2733-
"FIG_KWS = dict(bbox_inches=\"tight\", dpi=300)\n",
2734-
"MAX_BETWEEN_CELL_DIST = 4\n",
2735-
"\n",
2736-
"\n",
2737-
"def measure_cell_type_adjacency(\n",
2738-
" adata: AnnData,\n",
2739-
" utag_key: str = \"UTAG Label\",\n",
2740-
" max_dist: int = 40,\n",
2741-
" n_iterations: int = 100,\n",
2742-
"):\n",
2743-
" a_ = adata.copy()\n",
2744-
" sq.gr.spatial_neighbors(a_, radius=max_dist, coord_type=\"generic\")\n",
2745-
"\n",
2746-
" G = nx.from_scipy_sparse_matrix(a_.obsp[\"spatial_connectivities\"])\n",
2747-
"\n",
2748-
" utag_map = {i: x for i, x in enumerate(adata.obs[utag_key])}\n",
2749-
" nx.set_node_attributes(G, utag_map, name=utag_key)\n",
2750-
"\n",
2751-
" adj, order = nx.linalg.attrmatrix.attr_matrix(G, node_attr=utag_key)\n",
2752-
" order = pd.Series(order).astype(adata.obs[utag_key].dtype)\n",
2753-
" freqs = pd.DataFrame(adj, order, order).fillna(0) + 1\n",
2754-
"\n",
2755-
" norm_freqs = correct_interaction_background_random(G, freqs, utag_key, n_iterations)\n",
2756-
" return norm_freqs\n",
2757-
"\n",
2758-
"\n",
2759-
"def correct_interaction_background_random(\n",
2760-
" graph: nx.Graph, freqs: pd.DataFrame, attribute: str, n_iterations: int = 100\n",
2761-
"):\n",
2762-
" values = {x: graph.nodes[x][attribute] for x in graph.nodes}\n",
2763-
" shuffled_freqs = list()\n",
2764-
" for _ in range(n_iterations):\n",
2765-
" g2 = graph.copy()\n",
2766-
" shuffled_attr = pd.Series(values).sample(frac=1)\n",
2767-
" shuffled_attr.index = values\n",
2768-
" nx.set_node_attributes(g2, shuffled_attr.to_dict(), name=attribute)\n",
2769-
" rf, rl = nx.linalg.attrmatrix.attr_matrix(g2, node_attr=attribute)\n",
2770-
" rl = pd.Series(rl, dtype=freqs.index.dtype)\n",
2771-
" shuffled_freqs.append(pd.DataFrame(rf, index=rl, columns=rl))#.fillna(0) + 1)\n",
2772-
" shuffled_freq = pd.concat(shuffled_freqs)\n",
2773-
" shuffled_freq = shuffled_freq.groupby(level=0).sum()\n",
2774-
" shuffled_freq = shuffled_freq.fillna(0) + 1\n",
2775-
" \n",
2776-
" fl = np.log((freqs / freqs.values.sum()))\n",
2777-
" sl = np.log((shuffled_freq / shuffled_freq.values.sum()))\n",
2778-
" # make sure both contain all edges/nodes\n",
2779-
" fl = fl.reindex(sl.index, axis=0).reindex(sl.index, axis=1)\n",
2780-
" sl = sl.reindex(fl.index, axis=0).reindex(fl.index, axis=1)\n",
2781-
" return fl - sl"
2782-
]
2783-
},
27842717
{
27852718
"cell_type": "code",
27862719
"execution_count": 107,
@@ -2796,6 +2729,8 @@
27962729
}
27972730
],
27982731
"source": [
2732+
"from utag.utils import measure_per_domain_cell_type_colocalization\n",
2733+
"\n",
27992734
"image_key = \"roi\"\n",
28002735
"utag_key = \"UTAG Label\"\n",
28012736
"\n",
@@ -2834,107 +2769,6 @@
28342769
"adata.obs[\"phenotypes\"].cat.categories"
28352770
]
28362771
},
2837-
{
2838-
"cell_type": "code",
2839-
"execution_count": 341,
2840-
"id": "d55fb487-8ab9-45e7-9bb3-306404b3dc9e",
2841-
"metadata": {},
2842-
"outputs": [],
2843-
"source": [
2844-
"def draw_network(\n",
2845-
" adata: AnnData,\n",
2846-
" node_key: str = \"UTAG Label\",\n",
2847-
" adjacency_matrix_key: str = \"UTAG Label_domain_adjacency_matrix\",\n",
2848-
" figsize: tuple = (11, 11),\n",
2849-
" dpi: int = 200,\n",
2850-
" font_size: int = 12,\n",
2851-
" node_size_min: int = 1000,\n",
2852-
" node_size_max: int = 3000,\n",
2853-
" edge_weight: float = 5,\n",
2854-
" edge_weight_baseline: float = 1,\n",
2855-
" log_transform: bool = True,\n",
2856-
" ax=None,\n",
2857-
"):\n",
2858-
" import networkx as nx\n",
2859-
"\n",
2860-
" s1 = adata.obs.groupby(node_key).count()\n",
2861-
" s1 = s1[s1.columns[0]]\n",
2862-
" node_size = s1.values\n",
2863-
" node_size = (node_size - node_size.min()) / (node_size.max() - node_size.min()) * (\n",
2864-
" node_size_max - node_size_min\n",
2865-
" ) + node_size_min\n",
2866-
"\n",
2867-
" if ax == None:\n",
2868-
" fig = plt.figure(figsize=figsize, dpi=dpi)\n",
2869-
" G = nx.from_numpy_matrix(\n",
2870-
" np.matrix(adata.uns[adjacency_matrix_key]), create_using=nx.Graph\n",
2871-
" )\n",
2872-
" G = nx.relabel.relabel_nodes(\n",
2873-
" G, {i: label for i, label in enumerate(adata.uns[adjacency_matrix_key].index)}\n",
2874-
" )\n",
2875-
"\n",
2876-
" edges, weights = zip(*nx.get_edge_attributes(G, \"weight\").items())\n",
2877-
"\n",
2878-
" weights = np.array(list(weights))\n",
2879-
" weights = (weights - weights.min()) / (\n",
2880-
" weights.max() - weights.min()\n",
2881-
" ) * edge_weight + edge_weight_baseline\n",
2882-
"\n",
2883-
" if log_transform:\n",
2884-
" weights = np.log(np.array(list(weights)) + 1)\n",
2885-
" else:\n",
2886-
" weights = np.array(list(weights))\n",
2887-
"\n",
2888-
" weights = tuple(weights.tolist())\n",
2889-
"\n",
2890-
" #pos = nx.spectral_layout(G, weight = 'weight')\n",
2891-
" pos = nx.spring_layout(G, weight=\"weight\", seed=42, k=1)\n",
2892-
"\n",
2893-
" if ax:\n",
2894-
" nx.draw(\n",
2895-
" G,\n",
2896-
" pos,\n",
2897-
" node_color=\"w\",\n",
2898-
" edgelist=edges,\n",
2899-
" edge_color=weights,\n",
2900-
" width=weights,\n",
2901-
" edge_cmap=plt.cm.coolwarm,\n",
2902-
" with_labels=True,\n",
2903-
" font_size=font_size,\n",
2904-
" node_size=node_size,\n",
2905-
" ax=ax,\n",
2906-
" )\n",
2907-
" else:\n",
2908-
" nx.draw(\n",
2909-
" G,\n",
2910-
" pos,\n",
2911-
" node_color=\"w\",\n",
2912-
" edgelist=edges,\n",
2913-
" edge_color=weights,\n",
2914-
" width=weights,\n",
2915-
" edge_cmap=plt.cm.coolwarm,\n",
2916-
" with_labels=True,\n",
2917-
" font_size=font_size,\n",
2918-
" node_size=node_size,\n",
2919-
" )\n",
2920-
"\n",
2921-
" if ax == None:\n",
2922-
" ax = plt.gca()\n",
2923-
"\n",
2924-
" color_key = node_key + \"_colors\"\n",
2925-
" if color_key in adata.uns:\n",
2926-
" ax.collections[0].set_edgecolor(adata.uns[color_key])\n",
2927-
" ax.collections[0].set_facecolor(adata.uns[color_key])\n",
2928-
" else:\n",
2929-
" ax.collections[0].set_edgecolor(\"lightgray\")\n",
2930-
" ax.collections[0].set_linewidth(3)\n",
2931-
" ax.set_xlim([1.3 * x for x in ax.get_xlim()])\n",
2932-
" ax.set_ylim([1 * y for y in ax.get_ylim()])\n",
2933-
"\n",
2934-
" if ax == None:\n",
2935-
" return fig"
2936-
]
2937-
},
29382772
{
29392773
"cell_type": "code",
29402774
"execution_count": 154,
@@ -3016,26 +2850,6 @@
30162850
"plt.show()"
30172851
]
30182852
},
3019-
{
3020-
"cell_type": "code",
3021-
"execution_count": 122,
3022-
"id": "14a6556e-e597-459d-ac3f-5d5c432e2db2",
3023-
"metadata": {},
3024-
"outputs": [
3025-
{
3026-
"data": {
3027-
"text/plain": [
3028-
"array(['Epithelial (Airway / AT)', 'Fibroblasts (Airway Wall)',\n",
3029-
" 'Alveolar', 'Immune', 'Vessel', 'Neutrophils'], dtype=object)"
3030-
]
3031-
},
3032-
"execution_count": 122,
3033-
"metadata": {},
3034-
"output_type": "execute_result"
3035-
}
3036-
],
3037-
"source": []
3038-
},
30392853
{
30402854
"cell_type": "code",
30412855
"execution_count": 174,
@@ -3754,133 +3568,6 @@
37543568
"pval_df.sort_values('adj p-val').head(50)"
37553569
]
37563570
},
3757-
{
3758-
"cell_type": "code",
3759-
"execution_count": 67,
3760-
"id": "9aa8c11f-ef17-4332-ba16-ae4c662cb20b",
3761-
"metadata": {},
3762-
"outputs": [],
3763-
"source": [
3764-
"import networkx as nx\n",
3765-
"from anndata import AnnData\n",
3766-
"\n",
3767-
"\n",
3768-
"def domain_connectivity(\n",
3769-
" adata: AnnData,\n",
3770-
" max_dist: int = 40,\n",
3771-
" slide_key: str = \"Slide\",\n",
3772-
" domain_key: str = \"UTAG Label\",\n",
3773-
") -> AnnData:\n",
3774-
" import numpy as np\n",
3775-
" import squidpy as sq\n",
3776-
" from tqdm import tqdm\n",
3777-
"\n",
3778-
" order = pd.Series(adata.obs[domain_key].unique()).astype(\n",
3779-
" adata.obs[domain_key].dtype\n",
3780-
" )\n",
3781-
" global_pairwise_connection = (\n",
3782-
" pd.DataFrame(\n",
3783-
" np.zeros(shape=(len(order), len(order))), index=order, columns=order\n",
3784-
" )\n",
3785-
" .astype(int)\n",
3786-
" .fillna(0)\n",
3787-
" )\n",
3788-
"\n",
3789-
" for slide in tqdm(adata.obs[slide_key].unique()):\n",
3790-
" adata_batch = adata[adata.obs[slide_key] == slide].copy()\n",
3791-
"\n",
3792-
" sq.gr.spatial_neighbors(adata_batch, radius=max_dist, coord_type=\"generic\")\n",
3793-
" G = nx.from_scipy_sparse_matrix(adata_batch.obsp[\"spatial_connectivities\"])\n",
3794-
"\n",
3795-
" utag_map = {i: x for i, x in enumerate(adata.obs[domain_key])}\n",
3796-
" nx.set_node_attributes(G, utag_map, name=domain_key)\n",
3797-
"\n",
3798-
" adj, order = nx.linalg.attrmatrix.attr_matrix(G, node_attr=domain_key)\n",
3799-
" order = pd.Series(order).astype(adata.obs[domain_key].dtype)\n",
3800-
" freqs = pd.DataFrame(adj, order, order).astype(int)\n",
3801-
"\n",
3802-
" freqs = freqs.fillna(0)\n",
3803-
" global_pairwise_connection = global_pairwise_connection + freqs\n",
3804-
" adata.uns[f\"{domain_key}_domain_adjacency_matrix\"] = global_pairwise_connection\n",
3805-
" return adata"
3806-
]
3807-
},
3808-
{
3809-
"cell_type": "code",
3810-
"execution_count": 103,
3811-
"id": "c46298e1-07e3-457a-8428-4457ff449498",
3812-
"metadata": {},
3813-
"outputs": [
3814-
{
3815-
"name": "stderr",
3816-
"output_type": "stream",
3817-
"text": [
3818-
"100%|██████████| 31/31 [00:13<00:00, 2.27it/s]\n",
3819-
"100%|██████████| 45/45 [00:13<00:00, 3.32it/s]\n",
3820-
"100%|██████████| 75/75 [00:31<00:00, 2.39it/s]\n",
3821-
"100%|██████████| 16/16 [00:05<00:00, 2.94it/s]\n",
3822-
"100%|██████████| 37/37 [00:13<00:00, 2.81it/s]\n",
3823-
"100%|██████████| 33/33 [00:15<00:00, 2.18it/s]\n"
3824-
]
3825-
}
3826-
],
3827-
"source": [
3828-
"disease = dict()\n",
3829-
"for disease_ in results.obs[\"phenotypes\"].unique():\n",
3830-
" result_ = results[results.obs[\"phenotypes\"] == disease_].copy()\n",
3831-
" disease[disease_] = domain_connectivity(\n",
3832-
" result_, slide_key=\"roi\", domain_key=\"UTAG Label\"\n",
3833-
" )"
3834-
]
3835-
},
3836-
{
3837-
"cell_type": "code",
3838-
"execution_count": 65,
3839-
"id": "458de81f-5322-42d3-8abd-b5916b9bb635",
3840-
"metadata": {},
3841-
"outputs": [
3842-
{
3843-
"data": {
3844-
"text/plain": [
3845-
"'COVID19_late'"
3846-
]
3847-
},
3848-
"execution_count": 65,
3849-
"metadata": {},
3850-
"output_type": "execute_result"
3851-
}
3852-
],
3853-
"source": [
3854-
"d_"
3855-
]
3856-
},
3857-
{
3858-
"cell_type": "code",
3859-
"execution_count": null,
3860-
"id": "94bc31a4-a0f6-4fc7-a3e3-f20fcb2b6b68",
3861-
"metadata": {},
3862-
"outputs": [],
3863-
"source": [
3864-
"fig, ax = plt.subplots(1, 6, figsize=(50, 8))\n",
3865-
"for i, d_ in enumerate(\n",
3866-
" [\"Healthy\", \"Flu\", \"ARDS\", \"Pneumonia\", \"COVID19_early\", \"COVID19_late\"]\n",
3867-
"):\n",
3868-
"\n",
3869-
" fig = draw_network(\n",
3870-
" adata=disease[d_],\n",
3871-
" font_size=30,\n",
3872-
" edge_weight=2,\n",
3873-
" edge_weight_baseline=1,\n",
3874-
" dpi=50,\n",
3875-
" node_size_max=5000,\n",
3876-
" node_size_min=500,\n",
3877-
" ax=ax[i],\n",
3878-
" )\n",
3879-
" ax[i].set_title(d_, fontsize=50)\n",
3880-
"plt.tight_layout()\n",
3881-
"# plt.savefig(\"figures/infected_lung_Domain_network.pdf\")"
3882-
]
3883-
},
38843571
{
38853572
"cell_type": "markdown",
38863573
"id": "c2457178-d8f8-4663-8a3c-5766994d0931",
@@ -4870,7 +4557,7 @@
48704557
"name": "python",
48714558
"nbconvert_exporter": "python",
48724559
"pygments_lexer": "ipython3",
4873-
"version": "3.10.0"
4560+
"version": "3.7.11"
48744561
},
48754562
"toc-autonumbering": true
48764563
},

0 commit comments

Comments
 (0)