Run INSPIRE on STARmap PLUS sections from mouse hippocampus

In this tutorial, we show INSPIRE’s ability to perform spatial registration between 2D slices.

The STARmap PLUS slices are publicly available at https://doi.org/10.5281/zenodo.7458952.

Import packages

[1]:
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import umap
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.cm import get_cmap
from matplotlib.lines import Line2D

import INSPIRE

import warnings
warnings.filterwarnings("ignore")

Load data

[2]:
data_dir = "/gpfs/gibbs/pi/zhao/jz874/jiazhao/reference-free_spatial-integration/backup_project_folders/reference-free_spatial-integration/reference-free_integration_3d-alignment_v2/STARmapPLUS"
adata_1 = sc.read_h5ad(data_dir + "/adata_8m_disease_rep1.h5ad")
adata_1.var_names_make_unique()
adata_2 = sc.read_h5ad(data_dir + "/adata_8m_disease_rep2.h5ad")
adata_2.var_names_make_unique()
[3]:
theta = 0.5
R = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
adata_1.obsm["spatial"] = adata_1.obsm["spatial"] @ R.T + np.array([5000,5000]).reshape((1,-1))
[4]:
adata_st_list = [adata_2, adata_1]

Data preprocessing

[5]:
adata_st_list, adata_full = INSPIRE.utils.preprocess(adata_st_list=adata_st_list,
                                                     num_hvgs=2500,
                                                     min_genes_qc=1,
                                                     min_cells_qc=1,
                                                     spot_size=100,
                                                     limit_num_genes=False)
Finding highly variable genes...
shape of adata 0 before quality control:  (8202, 2766)
shape of adata 0 after quality control:  (8202, 2766)
shape of adata 1 before quality control:  (8186, 2766)
shape of adata 1 after quality control:  (8186, 2766)
Find 2382 shared highly variable genes among datasets.
Concatenate datasets as a full anndata for better visualization...
../../_images/tutorials_registration_starmapplus_INSPIRE_registration_starmapplus_rtd_updated_10_1.png
Store counts and library sizes for Poisson modeling...
Normalize data...

Build spatial graph

[6]:
adata_st_list = INSPIRE.utils.build_graph_LGCN(adata_st_list=adata_st_list,
                                               rad_cutoff_list=[300,300],
                                               k_lgcn=1)
Start building graphs...
Build graphs and prepare node features for LGCN networks
Radius for graph connection is 300.0000.
7.5111 neighbors per cell on average.
Node features for slice 0 : (8202, 4764)
Radius for graph connection is 300.0000.
7.6316 neighbors per cell on average.
Node features for slice 1 : (8186, 4764)

Run INSPIRE model

[7]:
model = INSPIRE.model.Model_LGCN(adata_st_list=adata_st_list,
                                 n_spatial_factors=15,
                                 n_training_steps=10000,
                                 batch_size=1024,
                                 different_platforms=False
                                )
[8]:
model.train(adata_st_list)
  0%|          | 2/10000 [00:00<31:10,  5.34it/s]
Step: 0, d_loss: 1.3993, Loss: 1232.6716, recon_loss: 940.7932, fe_loss: 81.9788, geom_loss: 80.1937, beta_loss: 207.4786, gan_loss: 0.8172
  5%|▌         | 502/10000 [00:51<16:26,  9.63it/s]
Step: 500, d_loss: 1.0141, Loss: 1158.4047, recon_loss: 879.2302, fe_loss: 72.0036, geom_loss: 71.5310, beta_loss: 204.6874, gan_loss: 1.0529
 10%|█         | 1003/10000 [01:42<15:07,  9.91it/s]
Step: 1000, d_loss: 1.0486, Loss: 1121.3470, recon_loss: 841.4995, fe_loss: 71.7031, geom_loss: 66.7903, beta_loss: 205.4967, gan_loss: 1.3119
 15%|█▌        | 1503/10000 [02:33<13:57, 10.15it/s]
Step: 1500, d_loss: 1.0207, Loss: 1071.1522, recon_loss: 792.4841, fe_loss: 71.3514, geom_loss: 60.8091, beta_loss: 204.9298, gan_loss: 1.1706
 20%|██        | 2002/10000 [03:23<13:58,  9.53it/s]
Step: 2000, d_loss: 1.0378, Loss: 1040.6970, recon_loss: 762.6006, fe_loss: 70.7537, geom_loss: 55.9099, beta_loss: 205.0133, gan_loss: 1.2113
 25%|██▌       | 2502/10000 [04:14<12:17, 10.16it/s]
Step: 2500, d_loss: 0.9979, Loss: 1031.1418, recon_loss: 753.7010, fe_loss: 70.7868, geom_loss: 52.6111, beta_loss: 204.4102, gan_loss: 1.1915
 30%|███       | 3002/10000 [05:04<12:37,  9.24it/s]
Step: 3000, d_loss: 0.9697, Loss: 1010.8026, recon_loss: 733.3485, fe_loss: 70.2755, geom_loss: 51.0077, beta_loss: 204.9810, gan_loss: 1.1774
 35%|███▌      | 3503/10000 [05:55<11:19,  9.56it/s]
Step: 3500, d_loss: 0.9404, Loss: 999.8834, recon_loss: 723.0525, fe_loss: 70.1577, geom_loss: 50.1352, beta_loss: 204.4325, gan_loss: 1.2381
 40%|████      | 4002/10000 [06:46<10:20,  9.67it/s]
Step: 4000, d_loss: 0.8478, Loss: 1006.8831, recon_loss: 729.8905, fe_loss: 70.3404, geom_loss: 48.9135, beta_loss: 204.6774, gan_loss: 0.9965
 45%|████▌     | 4502/10000 [07:37<09:57,  9.20it/s]
Step: 4500, d_loss: 0.8601, Loss: 1007.2548, recon_loss: 730.0432, fe_loss: 70.3636, geom_loss: 48.5362, beta_loss: 204.3343, gan_loss: 1.5429
 50%|█████     | 5002/10000 [08:28<08:58,  9.28it/s]
Step: 5000, d_loss: 0.8159, Loss: 1002.8859, recon_loss: 725.6901, fe_loss: 70.2587, geom_loss: 48.5158, beta_loss: 204.5747, gan_loss: 1.3920
 55%|█████▌    | 5502/10000 [09:19<07:25, 10.10it/s]
Step: 5500, d_loss: 0.8297, Loss: 991.7452, recon_loss: 715.1700, fe_loss: 70.0473, geom_loss: 47.8334, beta_loss: 204.3306, gan_loss: 1.2407
 60%|██████    | 6003/10000 [10:09<06:33, 10.16it/s]
Step: 6000, d_loss: 0.8175, Loss: 995.9709, recon_loss: 719.3817, fe_loss: 70.1203, geom_loss: 47.4311, beta_loss: 204.3611, gan_loss: 1.1592
 65%|██████▌   | 6502/10000 [11:00<05:43, 10.19it/s]
Step: 6500, d_loss: 0.8358, Loss: 984.2735, recon_loss: 707.9119, fe_loss: 69.6064, geom_loss: 48.0052, beta_loss: 204.2618, gan_loss: 1.5333
 70%|███████   | 7003/10000 [11:51<05:04,  9.86it/s]
Step: 7000, d_loss: 0.8647, Loss: 977.4116, recon_loss: 701.4462, fe_loss: 69.4813, geom_loss: 47.5673, beta_loss: 204.3251, gan_loss: 1.2077
 75%|███████▌  | 7501/10000 [12:42<04:25,  9.41it/s]
Step: 7500, d_loss: 0.8766, Loss: 984.2867, recon_loss: 708.1206, fe_loss: 69.6184, geom_loss: 47.1220, beta_loss: 204.2310, gan_loss: 1.3743
 80%|████████  | 8002/10000 [13:33<03:25,  9.74it/s]
Step: 8000, d_loss: 0.8788, Loss: 987.8925, recon_loss: 711.8409, fe_loss: 69.6863, geom_loss: 47.2998, beta_loss: 204.2821, gan_loss: 1.1371
 85%|████████▌ | 8502/10000 [14:24<02:27, 10.13it/s]
Step: 8500, d_loss: 0.8667, Loss: 987.3364, recon_loss: 710.8536, fe_loss: 69.7705, geom_loss: 47.5104, beta_loss: 204.3777, gan_loss: 1.3844
 90%|█████████ | 9003/10000 [15:15<01:40,  9.88it/s]
Step: 9000, d_loss: 0.8904, Loss: 989.6608, recon_loss: 713.5899, fe_loss: 69.7394, geom_loss: 47.3052, beta_loss: 204.2901, gan_loss: 1.0952
 95%|█████████▌| 9502/10000 [16:05<00:48, 10.21it/s]
Step: 9500, d_loss: 0.8602, Loss: 983.3535, recon_loss: 707.1262, fe_loss: 69.5808, geom_loss: 47.4341, beta_loss: 204.2249, gan_loss: 1.4728
100%|██████████| 10000/10000 [16:56<00:00,  9.84it/s]

Access spot representations, proportions of spatial factors in spots, and gene loading matrix

[9]:
adata_full, basis_df = model.eval(adata_st_list, adata_full)
basis = np.array(basis_df.values)
Add cell/spot proportions of spatial factors into adata_full.obs...
Add cell/spot latent representations into adata_full.obsm['latent']...

Gene loading matrix is saved as basis.

Spatial distributions of spatial factors in tissues

[10]:
sc.pl.spatial(adata_full, color=["Proportion of spatial factor "+str(i+1) for i in range(15)], spot_size=200.)
../../_images/tutorials_registration_starmapplus_INSPIRE_registration_starmapplus_rtd_updated_20_0.png

Spot representations

[11]:
# calculate 2D UMAP coordinate of spots based on INSPIRE's learned cell representations.
reducer = umap.UMAP(n_neighbors=30,
                    n_components=2,
                    metric="correlation",
                    n_epochs=None,
                    learning_rate=1.0,
                    min_dist=0.3,
                    spread=1.0,
                    set_op_mix_ratio=1.0,
                    local_connectivity=1,
                    repulsion_strength=1,
                    negative_sample_rate=5,
                    a=None,
                    b=None,
                    random_state=1234,
                    metric_kwds=None,
                    angular_rp_forest=False,
                    verbose=True)
embedding = reducer.fit_transform(adata_full.obsm['latent'])
adata_full.obsm["X_umap"] = embedding
adata_full.obs["label"] = adata_full.obs["label"].values.astype(str)
UMAP(angular_rp_forest=True, local_connectivity=1, metric='correlation', min_dist=0.3, n_neighbors=30, random_state=1234, repulsion_strength=1, verbose=True)
Wed May 28 21:24:22 2025 Construct fuzzy simplicial set
Wed May 28 21:24:22 2025 Finding Nearest Neighbors
Wed May 28 21:24:22 2025 Building RP forest with 11 trees
Wed May 28 21:24:24 2025 NN descent for 14 iterations
         1  /  14
         2  /  14
         3  /  14
         4  /  14
         5  /  14
        Stopping threshold met -- exiting after 5 iterations
Wed May 28 21:24:33 2025 Finished Nearest Neighbor Search
Wed May 28 21:24:35 2025 Construct embedding
        completed  0  /  200 epochs
        completed  20  /  200 epochs
        completed  40  /  200 epochs
        completed  60  /  200 epochs
        completed  80  /  200 epochs
        completed  100  /  200 epochs
        completed  120  /  200 epochs
        completed  140  /  200 epochs
        completed  160  /  200 epochs
        completed  180  /  200 epochs
Wed May 28 21:24:50 2025 Finished embedding
[12]:
# perform clustering on INSPIRE’s learned cell representations to identify spatial domains across slices.
sc.pp.neighbors(adata_full, use_rep="latent", n_neighbors=30)
sc.tl.louvain(adata_full, resolution=1.5)
[13]:
adata = adata_full

size = 0.05
umap = adata.obsm["X_umap"]
n_cells = umap.shape[0]
np.random.seed(1234)
order = np.arange(n_cells)
np.random.shuffle(order)

adata.obs["slice_color"] = ""
adata.obs["slice_color"][adata.obs["slice"].values.astype(str) == str(0)] = "tab:blue"
adata.obs["slice_color"][adata.obs["slice"].values.astype(str) == str(1)] = "tab:orange"

f = plt.figure(figsize=(5,5))

ax3 = f.add_subplot(1,1,1)
scatter2 = ax3.scatter(umap[order, 0], umap[order, 1], s=size, c=adata.obs["slice_color"][order], rasterized=True, marker='o')
ax3.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)

legend_elements_slice = [Line2D([0], [0], marker='o', color="w", label='STARmap PLUS slice 1', markerfacecolor="tab:blue", markersize=10),
                         Line2D([0], [0], marker='o', color="w", label='STARmap PLUS slice 2', markerfacecolor="tab:orange", markersize=10)]
ax3.legend(handles=legend_elements_slice, loc="upper left", bbox_to_anchor=(1, 1.), frameon=False,
           markerscale=.8, fontsize=10, handletextpad=0., ncol=1)

f.subplots_adjust(hspace=0.02, wspace=0.1)
plt.show()
../../_images/tutorials_registration_starmapplus_INSPIRE_registration_starmapplus_rtd_updated_24_0.png
[14]:
rgb_10 = [i for i in get_cmap('Set3').colors]
rgb_20 = [i for i in get_cmap('tab20').colors]
rgb_20b = [i for i in get_cmap('tab20b').colors]
rgb_dark2 = [i for i in get_cmap('Dark2').colors]
rgb_pst1 = [i for i in get_cmap('Pastel1').colors]
rgb_acc = [i for i in get_cmap('Accent').colors]
rgb2hex_10 = [mpl.colors.rgb2hex(color) for color in rgb_10]
rgb2hex_20 = [mpl.colors.rgb2hex(color) for color in rgb_20]
rgb2hex_20b = [mpl.colors.rgb2hex(color) for color in rgb_20b]
rgb2hex_20b_new = [rgb2hex_20b[i] for i in [0, 3, 4, 7, 8, 11, 12, 15, 16, 19]]
rgb2hex_dark2 = [mpl.colors.rgb2hex(color) for color in rgb_dark2]
rgb2hex_pst1 = [mpl.colors.rgb2hex(color) for color in rgb_pst1]
rgb2hex_acc = [mpl.colors.rgb2hex(color) for color in rgb_acc]
rgb2hex = rgb2hex_20 + rgb2hex_20b_new + rgb2hex_dark2 + rgb2hex_pst1 + rgb2hex_acc

adata.obs["c_hex"] = ""
for i in range(len(set(adata.obs["louvain"]))):
    adata.obs["c_hex"][adata.obs["louvain"].values.astype(str) == str(i)] = rgb2hex[i]
[15]:
f = plt.figure(figsize=(5,5))
ax1 = f.add_subplot(1,1,1)
scatter1 = ax1.scatter(umap[order, 0], umap[order, 1], s=size, c=adata.obs["c_hex"][order], rasterized=True)
ax1.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
plt.show()
../../_images/tutorials_registration_starmapplus_INSPIRE_registration_starmapplus_rtd_updated_26_0.png
[16]:
size = 1

# louvain
f = plt.figure(figsize=(12,5))
ax = f.add_subplot(1,1,1)
ax.axis('equal')
colors = rgb2hex
for i in range(len(set(adata.obs["louvain"]))):
    ax.scatter(adata.obsm["spatial"][adata.obs["louvain"].values.astype(str)==str(i), 0],
               -adata.obsm["spatial"][adata.obs["louvain"].values.astype(str)==str(i), 1],
               s=size, c=colors[i], label="cluster "+str(i), rasterized=True)
ax.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
plt.show()
../../_images/tutorials_registration_starmapplus_INSPIRE_registration_starmapplus_rtd_updated_27_0.png

Save results

[17]:
res_path = "/gpfs/gibbs/pi/zhao/jz874/project/jiazhao/inspire_revision/tutorials/spatial_alignment/STARmapPLUS"
adata_full.write(res_path + "/adata_inspire.h5ad")
basis_df.to_csv(res_path + "/basis_df_inspire.csv")
[ ]: