This is the official SpatialFormer codebase. SpatialFormer is the first single-cell spatial foundation model that learns universal representations of subcellular molecular and cellular spatial proximity through multi-task learning.
Spatial transcriptomics quantifies gene expression within its spatial context, making significant advances in biomedical research possible. Understanding the spatial expression of genes and how multicellular systems are organised is vital for diagnosing diseases and studying biological processes. However, existing models often struggle to effectively integrate gene expression data with cellular spatial information. In this study, we introduce SpatialFormer: a hybrid framework that combines convolutional networks and transformers in order to learn single-cell multi-scale information within a niche context. This includes expression data and the subcellular spatial distribution of genes. Pre-trained on 700 million cell pairs from 17 million spatially resolved single cells across 71 Xenium slides, SpatialFormer merges gene spatial expression profiles with cell niche information via a pairwise training strategy. Our findings demonstrate that SpatialFormer can distil biological signals across various tasks, including single-cell batch correction, cell-type annotation, co-localisation detection and the identification of gene pairs critical to immune cell-cell interactions involved in the regulation of lung fibrosis. These advancements enhance our understanding of cellular dynamics and open up new avenues for applications in biomedical research.
- Transcripts: 3.3B → 4.5B
- Cells: 13M → 17M
- Slides: 61 → 71
- Gene vocabulary: 1,922 → 6,036
- Added a new edge-based dataloader; anchor with preselected index with:
- distance-aware sampling
- hard negative pairs
- easy negative pairs cache-pairs
- faiss-based nearest neighbors search cache-faiss
- index-based storage for p/n pairs, which save large amount of memory usage
- Upgraded to GraphSAGE v2, supporting 6,036 spatial embeddings
- Integrated FlashAttention v2 for efficient long-sequence processing
- Aligning everything of prediction with the sp.tl.embed_data function, update sp.tl.embed_data to process variable lengths
- The embeddings can be extracted more efficient with larger batch size and representative sequence length.
For detailed instructions on using SpatialFormer, please refer to our Jupyter notebook tutorials
(some provided as .py files) in the downstream/ directory.
| Tutorial | Colab |
|---|---|
| Dataset Integration |
| Tutorial | Colab |
|---|---|
| Attention Analysis | |
| Perturbation Analysis | |
| Visualisation of Perturbation Results |
The fine-tuning tutorials
- Cell-cell colocalization prediction fine-tune for other platform1
- Cell-cell colocalization prediction fine-tune for other platform2
We provide the GPU and CPU version for users with different device levels. However, if a large scale of cells need to be calculated, the GPUs is mandatory to get the results effeciently. When using GPUs, AMD and NVIDIA GPUs are all supported.
This package is supported for macOS and Linux. The package has been tested on the following systems:
- macOS: Sequoia (15.3.1)
- Linux: Ubuntu 16.04; SUSE 15.6
Create the spatialformer environment by anaconda (python >= 3.10 required)
conda create -n spatialformer python=3.10Then, enter the spatialformer environment
source activate spatialformerPyTorch must be installed before spatialformer to ensure compatibility with your operating system and GPU.
pip install torch==2.3.1+rocm6.0 torchvision==0.18.1+rocm6.0 torchaudio==2.3.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121conda install -c conda-forge numba llvmlite
pip install torch torchvision torchaudio
pip install "numpy<2" #numpy should match the pytorch versionNote: On Mac, only CPUs are currently supported. if you encounter the OMP Error, please try to set
export KMP_DUPLICATE_LIB_OK=TRUEMake sure cmake already installed, otherwise
conda install cmakepip install spatialformerFlashAttention is required to accelerate training and inference while maintaining accuracy.
Before that, CUDA compiler (nvcc) should be detected in your device. nvcc can be installed via
conda install -c "nvidia/label/cuda-12.4.0" cuda-toolkit
#check the installation of nvcc
nvcc --versionWhen compilation is ready, let's install the flash-attention
To get started with the triton backend for AMD, follow the steps below. FlashAttention-2 ROCm CK backend currently supports (reference):
- MI200x, MI250x, MI300x, and MI355x GPUs.
- Datatype fp16 and bf16
- Both forward's and backward's head dimensions up to 256.
pip install triton==3.2.0Then install the FlashAttention(2.X) from the github
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
git checkout 35e5f00
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
pip install einops
Finally, test whether it works normally.
pytest tests/test_flash_attn.pyOr easily by
python -c "
import torch
from flash_attn import flash_attn_func
q = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
k = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
v = torch.randn(2, 128, 8, 64, dtype=torch.float16, device='cuda')
out = flash_attn_func(q, k, v)
print(f'✅ Flash Attention on {torch.cuda.get_device_name(0)}: {out.shape}')
"Alternatively, if you are using NVIDIA(e.g., A100), please easily run the following code to install FlashAttention(2.X)
pip install flash-attn --no-build-isolationif failed try the pre-built wheel
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
pip install ./flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whlWe implement the FlashAttention(2.x) in our code, which is completely reweited and 2x faster than FlashAttention(1.x).
The model is capable of handling input from individual cells and doublets. It was originally pretrained on a large-scale dataset of pairwise doublets with both positive and negative characteristics. Specifically, the positive pairs consist of all cells located within the niches of a certain query cell. In contrast, the negative pairs can include any distant cells that are either far away from the query cell.
The processed individual cell dataset can be retrieved from the Hugging Face dataset repository at SpatialCC-17M. The pairwise data can be generated by following the instructions provided in /data_preprocess/.
You can easily download the dataset in python as below
from datasets import load_dataset
spatialcc = load_dataset("TerminatorJ/xenium_5k_pandavid_dataset_v2", cache_dir = "your_cache_dir")SpatialFormer provides a simple function to extract embeddings. By using the sp.tl.embed() function, we can seamlessly integrate with the AnnData object, meaning the generated embeddings will be stored in obsm under the key "X_SpaF".
SpatialFormer supports two methods for generating embeddings: 1) single input mode and 2) pairwise input mode. Below is an example of generating the AnnData embeddings:
The checkpoints can be downloaded according to different use cases as below:
| Input type | Tissue types | Size (number of slides) | Links |
|---|---|---|---|
| Paired | lung | 1 | ckp_pair_lung_1 |
| Paired | 13 types | 61 | ckp_pair_13tissues_61 |
| Paired | 13 types | 71 | ckp_pair_13tissues_71 |
| Paired | lung | 25 | ckp_pair_lung_25 |
| Single | 13 types | 62 | ckp_single_13tissues_62 |
| Single | 13 types | 71 | ckp_single_13tissues_71 |
The LoRA fine-tuned checkpoints can be downloaded as below:
| Input type | Tissue types | Size (number of slides) | Cell Number | Links |
|---|---|---|---|---|
| Paired | lung | 1 | 10k | ckp_pair_lung_LoRA_10K |
| Paired | breast | 1 | 10k | ckp_pair_breast_LoRA_10K |
| Paired | colon | 1 | 10k | ckp_pair_colon_LoRA_10K |
| Paired | lung | 1 | 100k | ckp_pair_lung_LoRA_100K |
| Paired | breast | 1 | 100k | ckp_pair_breast_LoRA_100K |
| Paired | colon | 1 | 100k | ckp_pair_colon_LoRA_100K |
SpatialFormer is mainly focus on the zero-shot learning for the single-cell spatial omics data. Therefore, extracting the embeddings should be the most frequently used in the downstream tasks. We support diversed input format for extracting the cell embeddings. The input can be ".h5ad", or "huggingface dataset".
For the easiest implementation, ".h5ad" file can easily input and get the embedding out following the codes as below:
We also provide Google Colab for practical purposes (stable and strongly recommended).
A simple example anndata can be downloaded here
import scanpy as sc
adata = sc.read_h5ad("./downstream/cell_cell_communication/data/covid_subsampled.h5ad")make sure the "gene_name" column is in the adata.var column names
import spatialformer as sp
method = "cls"
tissue = "Lung"
condition = "Disease"
assay = "Xenium"
model_ckp_path = "./ckp_single_13tissues_71.ckpt" # "ckp_single_13tissues_71" is recommended
use_flash_attn = True # Depends on whether you install the FlashAttention, if installed -> "True", "False" instead.
batch_size = 16
embed_adata = sp.tl.embed_data(
adata = adata,
tissue = tissue,
condition = condition,
assay = assay,
method = method,
model_ckp_path = model_ckp_path,
batch_size = batch_size,
mode = "single",
use_flash_attn = use_flash_attn,
num_workers = 32
)import spatialformer as sp
method = "cls"
tissue = "Lung"
condition = "Disease"
assay = "Xenium"
model_ckp_path = "./ckp_pair_13tissues_71.ckpt" #"ckp_pair_13tissues_71" is recommended
batch_size = 16
embed_adata = sp.tl.embed_data(
adata = adata,
tissue = tissue,
condition = condition,
assay = assay,
method = method,
model_ckp_path = model_ckp_path,
batch_size = batch_size,
mode = "pair",
left_cell = ["20532-0-1-0-1", "222101-0-0-1"],
right_cell = ["483188-0-0-1", "513429-0-0-1"],
num_workers = 16
)| Arguments | dtype | Description |
|---|---|---|
| adata | object | An AnnData object that stores expression information by CellXGene. |
| tissue | string | The type of tissue (e.g., Breast/Lung). |
| condition | string | Metadata for the sample condition (e.g., Disease/Healthy). |
| assay | string | The method of getting the data (e.g. Merfish, Xenium). |
| method | string | Embedding extraction method. "cls": Use CLS token embedding as cell representation; "gene": Use the mean of gene token embeddings. |
| mode | string | The method of the embed function, which can be either "single" or "pair." The single mode collates only individual cells as input for the model. In "pair" mode, data is prepared for pairwise input. If using "pair," both left_cell and right_cell must be provided. Each cell ID in left_cell corresponds to the cell ID at the same index in right_cell. |
| model_ckp_path | string | The path to the SpatialFormer model checkpoint. |
| batch_size | integer | The batch size for the data loader. |
| threshold | float | The threshold for filtering whether two genes are paired, which helps in identifying confidently paired genes at subcellular resolution. This option is applicable only in "single" input mode and is not functional in "pair" mode. |
| left_cell | array_like | A list of cell IDs representing the query cells. |
| right_cell | array_like | A list of cell IDs representing the key cells. |
| num_workers | integer | The number of CPU cores to load the data. This value should match the number of workers specified in the data loader. |
| resume_before_5k | bool | Indicates whether to resume from a checkpoint trained on the small panel. Set to True to use the small-panel checkpoint; set to False to use the checkpoint trained with the 5k Xenium panel. |
| max_len | integer | Maximum length of each sequence considered. Default is None, meaning all genes are used. For large numbers of pairwise sequences, we strongly recommend setting this to 500 per sequence to significantly improve runtime performance if FlashAttention is not installed. |
If the input data is a huggingface dataset, we have built a huggingface specified dataloader only for inference step:
from datasets import load_from_disk,concatenate_datasets,load_dataset
def load_model(model_ckp_path, device):
get_file_path = lambda path, filename: os.path.join("/scratch/project_465001820/Spatialformer", path, filename)
config_path = get_file_path("config", "_config_train_large_pair.json")
with open(config_path, 'r') as json_file:
config = json.load(json_file)
model = manual_train_fm(config = config)
ckp = torch.load(model_ckp_path, map_location=torch.device(device))
params = ckp["state_dict"]
model.load_state_dict(params)
model.eval()
model.to(device)
return model
model_ckp_path = "/scratch/project_465001027/Spatialformer/output/checkpoints/step=0104000-train_total_loss=-2.3064-val_total_loss=0.0000.ckpt"
model = load_model(model_ckp_path, "cuda")
dataloader = create_single_data_loaders(lung_dataset, #define your own dataset here
cls_token = 1,
padding_idx = 0,
sep_token = 1949,
batch_size=batch_size,
context_length=500,
special_token_num = 4,
split_num = 1,
num_workers = 64,
mode="eval")
all_embeds = []
with torch.no_grad():
for i, batch in tqdm(enumerate(dataloader)):
counter += batch_size
tissues = batch["Tissues"]
conditions = batch["Conditions"]
anns = batch["Annotations"]
attn_mask = batch["attention_mask"]
embeddings, _ = model.get_embeddings(batch, [-1], True, False) #normal prob
embeddings = embeddings[0][:,0,:].detach().cpu().numpy()
all_embeds.append(embeddings)The model can be further pretrained with the following codes. Get the script/train.py for pretraining as below:
The parameters of the configuration can refer to the table
Pretrain the singular input model
python ./script/train.py --config /scratch/project_465001820/Spatialformer/config/_config_train_large_single.jsonPretrain the doublet input model
python ./script/train.py --config /scratch/project_465001820/Spatialformer/config/_config_train_large_pair.jsonFor each slide, the accurate prediction of the molecular features largely rely on the cell-cell colocalization. We use LoRA to fine-tune the SpatialFormer model with one slide.
We also provide Google Colab for practice purposes (stable and strongly recommended)
python cell_cell_communication_zero_shot_multi_platform.py --radius 30 --fine_tune_mode lora --rank 64 --lora_alpha 128 --cell_by_gene_path /scratch/project_465001820/Spatialformer_main_practice/data/MERFISH_Lung/HumanLungCancerPatient1_cell_by_gene.csv --cell_meta_path /scratch/project_465001820/Spatialformer_main_practice/data/MERFISH_Lung/HumanLungCancerPatient1_cell_metadata.csv --sample_name MERFISH_Lung --zero_shot_cell_size 500 --tissue Lung --condition Disease --config_path /scratch/project_465001820/Spatialformer/spatialformer/config/_config_fine_tune_probe.json --batch_size 32 --max_cells 10000All the codes for reproducing the results of the manuscript were presented in the ./downstream directory.
Wang J, Huang Y, Winther O. SpatialFormer: Universal Spatial Representation Learning from Subcellular Molecular to Multicellular Landscapes[J]. bioRxiv, 2025: 2025.01. 18.633701.

