PyTorch implementation of scCello, a cell-ontology guided transcriptome foundation model (TFM) for single cell RNA-seq data. Authored by Xinyu Yuan, and Zhihao Zhan.
scCello enhances transcriptome foundation models (TFMs) by integrating cell ontology graphs into pre-training, addressing the limitation of treating cells as independent entities. By incorporating cell-level objectives: cell-type coherence loss and ontology alignment loss, scCello demonstrate superior or competitive generalization and transferability capability over the existing TFMs on biologically important tasks including identifying novel cell types of unseen cells, prediction of cell-type-specific marker genes, and cancer drug responses.
This repository is based on PyTorch 2.0 and Python 3.9.
Table of contents:
- Cell-type Specific Learning: Utilizes cell-type coherence loss to learn specific gene expression patterns relevant to each cell type.
- Ontology-aware Modeling: Employs ontology alignment loss to understand and preserve the hierarchical relationships among different cell types.
- Large-scale Pre-training: Trained on over 22 million cells from the CellxGene database, ensuring robust and generalizable models.
- Advanced Generalization and Transferability: Demonstrates superior performance on various biologically significant tasks such as identifying novel cell types and predicting cell-type-specific marker genes.
- Feb 5th, 2025: scCello code released!
- Oct 1st, 2024: scCello got accepted at NeurIPS 2024!
- Aug 22nd, 2024: scCello preprint release on arxiv!
You may install the dependencies via the following bash command.
conda install pytorch==2.0.1 pytorch-cuda=11.7 -c pytorch -c nvidia pip install transformers[torch] pip install easydict pip install psutil pip install wandb pip install pytz pip install ipdb pip install pandas pip install datasets pip install torchmetrics pip install rdflib pip install hickle pip install anndata pip install scikit-learn pip install scanpy pip install scib conda install -c conda-forge cupy conda install rapidsai::cuml conda install -c rapidsai -c conda-forge -c nvidia cugraph
Quick start guide to load scCello checkpoint:
- for zero-shot inference tasks
from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM model = PrototypeContrastiveForMaskedLM.from_pretrained("katarinayuan/scCello-zeroshot", output_hidden_states=True)
- for linear probing tasks (see details in sccello/script/run_cell_type_classification.py)
from sccello.src.model_prototype_contrastive import PrototypeContrastiveForSequenceClassification model_kwargs = { "num_labels": NUM_LABELS, # number of labels for classification "total_logging_steps": training_cfg["logging_steps"] * training_args.gradient_accumulation_steps, } model = PrototypeContrastiveForSequenceClassification.from_pretrained("katarinayuan/scCello-zeroshot", **model_kwargs)
For downstreams, in-distribution (ID) data
# Note that some datasets are extremely large, use the following command to change data caching directory. The default is "~/.cache/huggingface/datasets/". export HF_HOME="/path/to/another/directory/datasets" from sccello.src.utils import data_loading # pre-training data & D^{id} train_dataset = load_dataset("katarinayuan/scCello_pretrain_unsplitted")["train"] train_dataset, indist_test_data = train_dataset.train_test_split(test_size=0.001, seed=237) # seed used in scCello # D_1^{ct} & D_2^{ct} d1_ct, d2_ct = data_loading.get_fracdata("celltype", "frac100", False, False) # D_1^{ts} & D_2^{ts} d1_ts, d2_ts = data_loading.get_fracdata("tissue", "frac100", False, False) # D_1^{dn} & D_2^{dn} d1_dn, d2_dn = data_loading.get_fracdata("donor", "frac100", False, False)
Example data for transforming h5ad format to huggingface format. For building pre-training datasets and downstream datasets, we downloaded a series of human h5ad data from CellxGene
pip install gdown cd ./data/example_h5ad/ gdown https://drive.google.com/uc?id=1UsbkhmZwSDWTgY4die60fHvzL_FnXtWE
The sccello/script
folder contains all executable files.
General configurations:
pretrained_ckpt=katarinayuan/scCello-zeroshot output_dir=/home/xinyu402/single_cell_output/ wandb_run_name=test
python ./sccello/script/run_data_transformation.py
python ./sccello/script/run_cell_type_clustering.py --pretrained_ckpt $pretrained_ckpt --wandb_run_name $wandb_run_name --output_dir $output_dir
# Linear Probing training_type=linear_probing # or Train from Scratch without Loading the Pre-trained Model # training_type=from_scratch_linear torchrun ./sccello/script/run_cell_type_classification.py --pretrained_ckpt $pretrained_ckpt --training_type $training_type --wandb_run_name $wandb_run_name --further_downsample 0.01 --output_dir $output_dir
python ./sccello/script/run_novel_cell_type_classification.py --pretrained_ckpt $pretrained_ckpt --wandb_run_name $wandb_run_name --indist_repr_path ./embedding_storage/cellreprs_indist_frac_celltype_data1.pkl --output_dir $output_dir
python ./sccello/script/run_marker_gene_prediction.py --pretrained_ckpt $pretrained_ckpt --wandb_run_name $wandb_run_name --output_dir $output_dir
python ./sccello/script/run_cancer_drug_response.py --pretrained_ckpt $pretrained_ckpt --wandb_run_name $wandb_run_name
python -m torch.distributed.run --nproc_per_node=1 ./sccello/script/run_pretrain_prototype_contrastive.py --wandb_run_name pretrain_test
If you find this codebase useful in your research, please cite the original papers.
The main scCello paper:
@inproceedings{yuancell, title={Cell ontology guided transcriptome foundation model}, author={Yuan, Xinyu and Zhan, Zhihao and Zhang, Zuobai and Zhou, Manqi and Zhao, Jianan and Han, Boyu and Li, Yue and Tang, Jian}, booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems} }