Validate INSPIRE’s gene imputation capability

Import packages

[1]:
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import umap
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.cm import get_cmap
from matplotlib.lines import Line2D

import INSPIRE
from sklearn.neighbors import NearestNeighbors
from scipy import stats

import warnings
warnings.filterwarnings("ignore")
[2]:
print("Load seqFISH data...")

data_dir = "data/seqFISH_mouse_embryo"
counts = pd.read_csv(data_dir+"/counts.csv", index_col=0)
metadata = pd.read_csv(data_dir+"/metadata.csv", index_col=0)
metadata = metadata.loc[counts.index, :]
adata_seqfish = ad.AnnData(np.array(counts.values))
adata_seqfish.var.index = counts.columns
adata_seqfish.obs = metadata

adata_seqfish = adata_seqfish[adata_seqfish.obs["embryo"] == "embryo2", ]
adata_seqfish = adata_seqfish[adata_seqfish.obs["celltype_mapped_refined"] != "Low quality", ]
adata_seqfish.obsm["spatial"] = np.array(adata_seqfish.obs[["x_global", "y_global"]])
adata_seqfish.var_names_make_unique()
Load seqFISH data...
[3]:
print("Load Stereo-seq data...")

data_dir = "data/Stereoseq_mouse_embryo"
adata_stereoseq = sc.read_h5ad(os.path.join(data_dir, "E9.5_E1S1.MOSTA.h5ad"))
adata_stereoseq.X = adata_stereoseq.layers['count']
adata_stereoseq.var_names_make_unique()
Load Stereo-seq data...

Manually hold out six marker genes from the seqFISH dataset

[4]:
gene_impu_list = ["Ttn","Popdc2",
                  "Six3","Lhx2",
                  "Foxa1","Cldn4"]
adata_seqfish = adata_seqfish[:, ~(adata_seqfish.var.index).isin(gene_impu_list)]

Run INSPIRE model

[5]:
adata_st_list = [adata_seqfish, adata_stereoseq]
[6]:
adata_st_list, adata_full = INSPIRE.utils.preprocess(adata_st_list=adata_st_list,
                                                     num_hvgs=1000,
                                                     min_genes_qc=2,
                                                     min_cells_qc=2,
                                                     spot_size=1,
                                                     limit_num_genes=True)
Get shared genes among all datasets...
Find 341 shared genes among datasets.
Finding highly variable genes...
shape of adata 0 before quality control:  (14185, 341)
shape of adata 0 after quality control:  (14185, 341)
shape of adata 1 before quality control:  (5913, 341)
shape of adata 1 after quality control:  (5880, 338)
Find 338 shared highly variable genes among datasets.
Concatenate datasets as a full anndata for better visualization...
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_9_1.png
Store counts and library sizes for Poisson modeling...
Normalize data...
[7]:
adata_st_list = INSPIRE.utils.build_graph_LGCN(adata_st_list=adata_st_list,
                                               rad_cutoff_list=[3,1.6])
Start building graphs...
Build graphs and prepare node features for LGCN networks
Radius for graph connection is 3.0000.
26.7748 neighbors per cell on average.
Node features for slice 0 : (14185, 676)
Radius for graph connection is 1.6000.
7.7946 neighbors per cell on average.
Node features for slice 1 : (5880, 676)
[9]:
model = INSPIRE.model.Model_LGCN(adata_st_list=adata_st_list,
                                 n_spatial_factors=40,
                                 n_training_steps=10000,
                                 batch_size=2048,
                                 different_platforms=True
                                )
[10]:
model.train(adata_st_list)
  0%|          | 6/10000 [00:00<07:39, 21.75it/s]
Step: 0, d_loss: 1.3795, Loss: 1321.7115, recon_loss: 542.4070, fe_loss: 44.9862, geom_loss: 90.0650, beta_loss: 731.7576, gan_loss: 0.7594
  5%|▌         | 506/10000 [00:11<03:35, 44.06it/s]
Step: 500, d_loss: 0.8025, Loss: 1156.2062, recon_loss: 421.9512, fe_loss: 27.9413, geom_loss: 87.9538, beta_loss: 702.2742, gan_loss: 2.2805
 10%|█         | 1006/10000 [00:23<03:25, 43.77it/s]
Step: 1000, d_loss: 0.4605, Loss: 1068.0687, recon_loss: 333.6890, fe_loss: 27.3233, geom_loss: 116.7123, beta_loss: 701.7334, gan_loss: 2.9888
 15%|█▌        | 1506/10000 [00:34<03:14, 43.77it/s]
Step: 1500, d_loss: 0.2192, Loss: 1008.2529, recon_loss: 273.7039, fe_loss: 26.8178, geom_loss: 114.6153, beta_loss: 701.7894, gan_loss: 3.6494
 20%|██        | 2006/10000 [00:46<03:05, 43.12it/s]
Step: 2000, d_loss: 0.1613, Loss: 967.1715, recon_loss: 232.4665, fe_loss: 26.5373, geom_loss: 112.0043, beta_loss: 701.9185, gan_loss: 4.0092
 25%|██▌       | 2506/10000 [00:57<02:51, 43.78it/s]
Step: 2500, d_loss: 0.1505, Loss: 944.4811, recon_loss: 209.6136, fe_loss: 26.3871, geom_loss: 108.8623, beta_loss: 702.3289, gan_loss: 3.9742
 30%|███       | 3006/10000 [01:08<02:39, 43.93it/s]
Step: 3000, d_loss: 0.1095, Loss: 923.7773, recon_loss: 189.4411, fe_loss: 26.2072, geom_loss: 102.6299, beta_loss: 701.8252, gan_loss: 4.2512
 35%|███▌      | 3506/10000 [01:20<02:28, 43.63it/s]
Step: 3500, d_loss: 0.1584, Loss: 917.2579, recon_loss: 183.0217, fe_loss: 26.1698, geom_loss: 101.2249, beta_loss: 702.1776, gan_loss: 3.8644
 40%|████      | 4006/10000 [01:31<02:18, 43.42it/s]
Step: 4000, d_loss: 0.2182, Loss: 907.8060, recon_loss: 173.1643, fe_loss: 25.9905, geom_loss: 97.4785, beta_loss: 702.0744, gan_loss: 4.6273
 45%|████▌     | 4506/10000 [01:43<02:06, 43.55it/s]
Step: 4500, d_loss: 0.1375, Loss: 908.1998, recon_loss: 174.0829, fe_loss: 26.0043, geom_loss: 96.0484, beta_loss: 701.9686, gan_loss: 4.2229
 50%|█████     | 5006/10000 [01:54<01:53, 43.91it/s]
Step: 5000, d_loss: 0.1430, Loss: 899.9315, recon_loss: 165.3609, fe_loss: 25.9882, geom_loss: 97.4602, beta_loss: 702.3676, gan_loss: 4.2655
 55%|█████▌    | 5506/10000 [02:06<01:43, 43.60it/s]
Step: 5500, d_loss: 0.1414, Loss: 902.2596, recon_loss: 168.4683, fe_loss: 26.0173, geom_loss: 95.1445, beta_loss: 702.0005, gan_loss: 3.8708
 60%|██████    | 6006/10000 [02:17<01:31, 43.57it/s]
Step: 6000, d_loss: 0.1390, Loss: 896.6946, recon_loss: 163.0165, fe_loss: 26.0981, geom_loss: 95.9430, beta_loss: 701.5847, gan_loss: 4.0764
 65%|██████▌   | 6506/10000 [02:28<01:20, 43.66it/s]
Step: 6500, d_loss: 0.1373, Loss: 900.0043, recon_loss: 166.5816, fe_loss: 26.0880, geom_loss: 90.4565, beta_loss: 701.6263, gan_loss: 3.8992
 70%|███████   | 7006/10000 [02:40<01:08, 43.71it/s]
Step: 7000, d_loss: 0.1537, Loss: 891.6824, recon_loss: 158.5278, fe_loss: 26.0307, geom_loss: 84.9429, beta_loss: 701.6106, gan_loss: 3.8145
 75%|███████▌  | 7506/10000 [02:51<00:56, 44.05it/s]
Step: 7500, d_loss: 0.1679, Loss: 890.1442, recon_loss: 157.1890, fe_loss: 26.0339, geom_loss: 79.0926, beta_loss: 701.7827, gan_loss: 3.5568
 80%|████████  | 8006/10000 [03:03<00:45, 43.62it/s]
Step: 8000, d_loss: 0.1658, Loss: 890.5060, recon_loss: 157.4647, fe_loss: 26.0632, geom_loss: 77.6348, beta_loss: 701.8259, gan_loss: 3.5995
 85%|████████▌ | 8506/10000 [03:14<00:34, 43.89it/s]
Step: 8500, d_loss: 0.1681, Loss: 884.7556, recon_loss: 152.1146, fe_loss: 26.0144, geom_loss: 74.2172, beta_loss: 701.6224, gan_loss: 3.5198
 90%|█████████ | 9006/10000 [03:26<00:22, 43.68it/s]
Step: 9000, d_loss: 0.1876, Loss: 883.9437, recon_loss: 151.1153, fe_loss: 26.0072, geom_loss: 71.4386, beta_loss: 701.7946, gan_loss: 3.5979
 95%|█████████▌| 9506/10000 [03:37<00:11, 43.29it/s]
Step: 9500, d_loss: 0.2141, Loss: 883.6108, recon_loss: 151.2731, fe_loss: 26.0157, geom_loss: 70.0425, beta_loss: 701.3918, gan_loss: 3.5293
100%|██████████| 10000/10000 [03:48<00:00, 43.68it/s]
[11]:
adata_full, basis_df = model.eval(adata_st_list, adata_full)
basis = np.array(basis_df.values)
Add cell/spot proportions of spatial factors into adata_full.obs...
Add cell/spot latent representations into adata_full.obsm['latent']...

Analysis of cell representations and spatial factors

[12]:
sc.pl.spatial(adata_full, color=["Proportion of spatial factor "+str(i+1) for i in range(40)], spot_size=1.)
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_15_0.png
[13]:
reducer = umap.UMAP(n_neighbors=30,
                    n_components=2,
                    metric="correlation",
                    n_epochs=None,
                    learning_rate=1.0,
                    min_dist=0.3,
                    spread=1.0,
                    set_op_mix_ratio=1.0,
                    local_connectivity=1,
                    repulsion_strength=1,
                    negative_sample_rate=5,
                    a=None,
                    b=None,
                    random_state=1234,
                    metric_kwds=None,
                    angular_rp_forest=False,
                    verbose=True)
embedding = reducer.fit_transform(adata_full.obsm['latent'])
adata_full.obsm["X_umap"] = embedding
adata_full.obs["slice"] = adata_full.obs["slice"].values.astype(str)
UMAP(angular_rp_forest=True, local_connectivity=1, metric='correlation', min_dist=0.3, n_neighbors=30, random_state=1234, repulsion_strength=1, verbose=True)
Tue Aug 27 10:57:01 2024 Construct fuzzy simplicial set
Tue Aug 27 10:57:01 2024 Finding Nearest Neighbors
Tue Aug 27 10:57:01 2024 Building RP forest with 12 trees
Tue Aug 27 10:57:05 2024 NN descent for 14 iterations
         1  /  14
         2  /  14
         3  /  14
        Stopping threshold met -- exiting after 3 iterations
Tue Aug 27 10:57:13 2024 Finished Nearest Neighbor Search
Tue Aug 27 10:57:15 2024 Construct embedding
        completed  0  /  200 epochs
        completed  20  /  200 epochs
        completed  40  /  200 epochs
        completed  60  /  200 epochs
        completed  80  /  200 epochs
        completed  100  /  200 epochs
        completed  120  /  200 epochs
        completed  140  /  200 epochs
        completed  160  /  200 epochs
        completed  180  /  200 epochs
Tue Aug 27 10:57:35 2024 Finished embedding
[14]:
# clustering
sc.pp.neighbors(adata_full, use_rep="latent", n_neighbors=30)
sc.tl.louvain(adata_full, resolution=.9)
[15]:
# visualize umaps
size = .1

rgb_10 = [i for i in get_cmap('Set3').colors]
rgb_20 = [i for i in get_cmap('tab20').colors]
rgb_20b = [i for i in get_cmap('tab20b').colors]
rgb_dark2 = [i for i in get_cmap('Dark2').colors]
rgb_pst1 = [i for i in get_cmap('Pastel1').colors]
rgb_acc = [i for i in get_cmap('Accent').colors]
rgb2hex_10 = [mpl.colors.rgb2hex(color) for color in rgb_10]
rgb2hex_20 = [mpl.colors.rgb2hex(color) for color in rgb_20]
rgb2hex_20b = [mpl.colors.rgb2hex(color) for color in rgb_20b]
rgb2hex_20b_new = [rgb2hex_20b[i] for i in [0, 3, 4, 7, 8, 11, 12, 15, 16, 19]]
rgb2hex_dark2 = [mpl.colors.rgb2hex(color) for color in rgb_dark2]
rgb2hex_pst1 = [mpl.colors.rgb2hex(color) for color in rgb_pst1]
rgb2hex_acc = [mpl.colors.rgb2hex(color) for color in rgb_acc]
rgb2hex = rgb2hex_20 + rgb2hex_20b_new + rgb2hex_dark2 + rgb2hex_pst1 + rgb2hex_acc

embedding = adata_full.obsm["X_umap"]

# umap, slice
f = plt.figure(figsize=(5.5,5))
ax = f.add_subplot(1,1,1)
colors = ["tab:blue", "tab:orange"]
for i in range(len(set(adata_full.obs["slice"]))):
    ax.scatter(embedding[adata_full.obs["slice"]==str(i), 0], embedding[adata_full.obs["slice"]==str(i), 1],
               s=size, c=colors[i], label="slice "+str(i))
ax.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
plt.legend(markerscale=3)
plt.show()

# umap, louvain
f = plt.figure(figsize=(5.5,5))
ax = f.add_subplot(1,1,1)
n_louvain = len(set(adata_full.obs["louvain"]))
colors = rgb2hex
for i in range(n_louvain):
    ax.scatter(embedding[adata_full.obs["louvain"].values.astype(str)==str(i), 0],
               embedding[adata_full.obs["louvain"].values.astype(str)==str(i), 1],
               s=size, c=colors[i], label="cluster "+str(i))
ax.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
plt.legend(markerscale=3, ncol=3, bbox_to_anchor=(2,1))
plt.show()
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_18_0.png
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_18_1.png
[16]:
size = 1

# louvain
f = plt.figure(figsize=(10,5))
ax = f.add_subplot(1,1,1)
ax.axis('equal')
colors = rgb2hex
adata_tmp = adata_full[:adata_st_list[0].shape[0], :]
for i in range(n_louvain):
    ax.scatter(adata_tmp.obsm["spatial"][adata_tmp.obs["louvain"].values.astype(str)==str(i), 0],
               -adata_tmp.obsm["spatial"][adata_tmp.obs["louvain"].values.astype(str)==str(i), 1],
               s=size, c=colors[i], label="cluster "+str(i))
ax.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_19_0.png
[17]:
size = 2.

# louvain
f = plt.figure(figsize=(10,5))
ax = f.add_subplot(1,1,1)
ax.axis('equal')
colors = rgb2hex
adata_tmp = adata_full[adata_st_list[0].shape[0]:(adata_st_list[0].shape[0]+adata_st_list[1].shape[0]), :]
for i in range(n_louvain):
    ax.scatter(adata_tmp.obsm["spatial"][adata_tmp.obs["louvain"].values.astype(str)==str(i), 0],
               -adata_tmp.obsm["spatial"][adata_tmp.obs["louvain"].values.astype(str)==str(i), 1],
               s=size, c=colors[i], label="cluster "+str(i))
ax.tick_params(axis='both',bottom=False, top=False, left=False, right=False, labelleft=False, labelbottom=False, grid_alpha=0)
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_20_0.png

Perform gene imputation

[18]:
ad_0 = adata_full[adata_full.obs.slice.values.astype(str) == "0", :] # seqfish
z_0 = ad_0.obsm["latent"]
ad_1 = adata_full[adata_full.obs.slice.values.astype(str) == "1", :] # stereo-seq
z_1 = ad_1.obsm["latent"]

neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(z_1)
nn_idx = neigh.kneighbors(z_0, 1, return_distance=False).reshape(-1)
[19]:
print("Load Stereo-seq data...")
data_dir = "data/Stereoseq_mouse_embryo"
adata_stereoseq = sc.read_h5ad(os.path.join(data_dir, "E9.5_E1S1.MOSTA.h5ad"))
adata_stereoseq.X = adata_stereoseq.layers['count']
adata_stereoseq.var_names_make_unique()

adata_1 = adata_stereoseq.copy()
adata_1.obs.index = adata_1.obs.index + "-1"
adata_1 = adata_1[ad_1.obs.index, :]

print("Load seqFISH data...")

data_dir = "data/seqFISH_mouse_embryo"
counts = pd.read_csv(data_dir+"/counts.csv", index_col=0)
metadata = pd.read_csv(data_dir+"/metadata.csv", index_col=0)
metadata = metadata.loc[counts.index, :]
adata_seqfish = ad.AnnData(np.array(counts.values))
adata_seqfish.var.index = counts.columns
adata_seqfish.obs = metadata
adata_seqfish = adata_seqfish[adata_seqfish.obs["embryo"] == "embryo2", ]
adata_seqfish = adata_seqfish[adata_seqfish.obs["celltype_mapped_refined"] != "Low quality", ]
adata_seqfish.obsm["spatial"] = np.array(adata_seqfish.obs[["x_global", "y_global"]])
adata_seqfish.var_names_make_unique()

adata_0 = adata_seqfish.copy()
adata_0.obs.index = adata_0.obs.index + "-0"
adata_0 = adata_0[ad_0.obs.index, :]
Load Stereo-seq data...
Load seqFISH data...

Validate the gene imputation result

[20]:
corr_list = []
for gene in gene_impu_list:
    print(gene)
    gene_val = adata_1[:, [gene]].X.toarray().reshape(-1)[nn_idx]
    adata_0.obs[gene+"_impu"] = gene_val
    sc.pl.spatial(adata_0, color=[gene, gene+"_impu"], spot_size=1.)
    res = stats.pearsonr(adata_0[:, [gene]].X.toarray().reshape(-1), np.array(adata_0.obs[gene+"_impu"]).reshape(-1))
    print(res)
    corr_list.append(res[0])
Ttn
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_1.png
PearsonRResult(statistic=0.697646232689517, pvalue=0.0)
Popdc2
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_3.png
PearsonRResult(statistic=0.49495325628699277, pvalue=0.0)
Six3
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_5.png
PearsonRResult(statistic=0.45629528085079796, pvalue=0.0)
Lhx2
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_7.png
PearsonRResult(statistic=0.40789148889303517, pvalue=0.0)
Foxa1
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_9.png
PearsonRResult(statistic=0.40106778982268065, pvalue=0.0)
Cldn4
../../_images/tutorials_diff_tech_embryo_difftech_embryo_6genes_impu_rtd_25_11.png
PearsonRResult(statistic=0.3405698050622275, pvalue=0.0)
[ ]: