Paper: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (ECCV 2024) Official Repo: https://github.com/willisma/SiT
SiT is a family of generative models built on Diffusion Transformers (DiT) that uses an interpolant framework to flexibly connect two distributions.
Key features:
- Transformer-based architecture with VAE latent space encoding
- Flexible interpolant paths (Linear, GVP, VP)
- Multiple prediction targets (velocity, noise, score)
- Classifier-Free Guidance (CFG) support
- Exponential Moving Average (EMA) for stable generation
# Install SiT dependencies
uv pip install -e ".[sit]"
# Optional performance optimizations
uv pip install flash-attn --no-build-isolation
uv pip install liger-kernelYour dataset should contain:
image: Image data (PIL Image or tensor)label: Class label (integer, e.g., 0-999 for ImageNet)
Supported formats: Arrow, Parquet, HuggingFace Dataset
Edit sit_xl_2.yaml and set your dataset path:
dataset_config:
dataset_type: "sit"
dataset_format: "hf_dataset"
dataset_path: ILSVRC/imagenet-1kexport HF_TOKEN="your_huggingface_token" # Required for VAE download
export HF_HOME="$HOME/.cache/huggingface" # Optional: cache directorySingle Node (8 GPUs):
bash examples/scalable_interpolant_transformer/run.shCustom GPU Count:
NPROC_PER_NODE=4 bash examples/scalable_interpolant_transformer/run.shMulti-Node Training (Master Node):
NNODES=4 NODE_RANK=0 MASTER_ADDR=192.168.1.100 bash examples/scalable_interpolant_transformer/run.shMulti-Node Training (Worker Nodes):
NNODES=4 NODE_RANK=1 MASTER_ADDR=192.168.1.100 bash examples/scalable_interpolant_transformer/run.shThe default configuration uses SiT-XL/2 (~675M parameters). To use other variants, modify the model configuration in sit_xl_2.yaml:
| Model | hidden_size | depth | num_heads | patch_size | Params | GFLOPs |
|---|---|---|---|---|---|---|
| SiT-S/2 | 384 | 12 | 6 | 2 | ~33M | ~4 |
| SiT-B/2 | 768 | 12 | 12 | 2 | ~130M | ~16 |
| SiT-L/2 | 1024 | 24 | 16 | 2 | ~458M | ~80 |
| SiT-XL/2 | 1152 | 28 | 16 | 2 | ~675M | ~119 |
SiT-S/2 Configuration
model_config:
load_from_config:
model_type: "sit"
input_size: 32
patch_size: 2
in_channels: 4
hidden_size: 384 # S model
depth: 12 # S model
num_heads: 6 # S model
mlp_ratio: 4.0
class_dropout_prob: 0.1
num_classes: 1000
learn_sigma: true
vae_path: "stabilityai/sd-vae-ft-ema"
path_type: "Linear"
prediction: "velocity"
cfg_scale: 1.0SiT-B/2 Configuration
model_config:
load_from_config:
model_type: "sit"
input_size: 32
patch_size: 2
in_channels: 4
hidden_size: 768 # B model
depth: 12 # B model
num_heads: 12 # B model
mlp_ratio: 4.0
class_dropout_prob: 0.1
num_classes: 1000
learn_sigma: true
vae_path: "stabilityai/sd-vae-ft-ema"
path_type: "Linear"
prediction: "velocity"
cfg_scale: 1.0SiT-L/2 Configuration
model_config:
load_from_config:
model_type: "sit"
input_size: 32
patch_size: 2
in_channels: 4
hidden_size: 1024 # L model
depth: 24 # L model
num_heads: 16 # L model
mlp_ratio: 4.0
class_dropout_prob: 0.1
num_classes: 1000
learn_sigma: true
vae_path: "stabilityai/sd-vae-ft-ema"
path_type: "Linear"
prediction: "velocity"
cfg_scale: 1.0Adjust per_device_train_batch_size based on your GPU memory:
| GPU | SiT-S/2 | SiT-B/2 | SiT-L/2 | SiT-XL/2 |
|---|---|---|---|---|
| A100 40GB | 32 | 16 | 8 | 4 |
| A100 80GB | 64 | 32 | 16 | 8 |
| H100 80GB | 128 | 64 | 32 | 16 |
If you encounter OOM errors, reduce batch size and increase gradient_accumulation_steps proportionally to maintain effective batch size.
input_size: Latent space size (image_size = input_size × 8, default: 32 → 256px)patch_size: Patch size for tokenization (2, 4, or 8). Smaller = finer detail but more computehidden_size: Transformer hidden dimension (384/768/1024/1152 for S/B/L/XL)depth: Number of transformer blocks (12/12/24/28 for S/B/L/XL)num_heads: Number of attention headsmlp_ratio: MLP expansion ratio (default: 4.0)
path_type: Interpolant path typeLinear: Linear interpolation (simplest, recommended, used in paper)GVP: Geodesic Variance PreservingVP: Variance Preserving
prediction: Model prediction targetvelocity: Velocity field (recommended, default)noise: Noise predictionscore: Score function
loss_weight: Optional loss weighting schemetrain_eps: Training epsilon for numerical stability (optional)sample_eps: Sampling epsilon for numerical stability (optional)
num_classes: Number of classes (1000 for ImageNet-1K)class_dropout_prob: Class dropout probability for CFG training (default: 0.1)cfg_scale: Classifier-Free Guidance scale during inference (>1.0 enables CFG, default: 1.0)
vae_path: Pre-trained VAE model (default: "stabilityai/sd-vae-ft-ema")- Images are encoded to 4-channel latent space (32×32 for 256×256 images)
- EMA Model: Automatically maintains exponential moving average model (decay=0.9999) for improved sample quality
- VAE Encoding: Images are encoded to 4-channel latent space using Stable Diffusion VAE
- CFG Support: Classifier-Free Guidance with configurable dropout and scale
- FSDP2: Fully Sharded Data Parallel for large-scale distributed training
- Transport Framework: Flexible interpolant paths and prediction targets
Checkpoints are saved to {output_dir}/checkpoint-{step}/:
output/sit_xl_2_training/
├── checkpoint-1000/
│ ├── model.safetensors # Main model weights
│ ├── ema.pt # EMA model weights
│ └── trainer_state.json # Training state
├── checkpoint-2000/
└── ...
Both main and EMA models are saved for proper resumption.
- Reduce
per_device_train_batch_size - Increase
gradient_accumulation_stepsto maintain effective batch size - Use a smaller model variant (e.g., SiT-L/2, SiT-B/2, or SiT-S/2)
- Use larger
patch_size(4 or 8 instead of 2)
- Increase
dataloader_num_workers(default: 4, try 8 or 16) - Enable
tf32=truefor A100+ GPUs (faster compute) - Use
bf16=truefor mixed precision training - Consider using larger
patch_sizefor faster training (tradeoff: lower quality)
# Install all SiT dependencies
pip install lmms_engine[sit]
# Or install individually
pip install timm diffusers torchdiffeqIf you see "Install with: pip install lmms_engine[sit]" error, the SiT optional dependencies are missing. Run:
uv pip install -e ".[sit]"Results from the paper on ImageNet 256×256:
| Model | FID-50K ↓ | Inception Score ↑ | Precision ↑ | Recall ↑ |
|---|---|---|---|---|
| DiT-XL(cfg = 1.5) | 2.27 | 4.60 | 278.24 | 0.83 |
| SiT-XL(cfg = 1.5, ODE) | 2.15 | 4.60 | 258.09 | 0.81 |
| SiT-XL(cfg = 1.5, SDE) | 2.06 | 4.49 | 277.50 | 0.83 |
- Paper: SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers
- Official GitHub: https://github.com/willisma/SiT
- Project Website: https://scalable-interpolant.github.io/
- VAE: Stable Diffusion VAE
- Base Architecture: Built on Diffusion Transformers (DiT)
@inproceedings{ma2024sit,
title={Scalable Interpolant Transformers},
author={Ma, Nanye and Goldstein, Mark and Albergo, Michael and Boffi, Nicholas and Vanden-Eijnden, Eric and Xie, Saining},
booktitle={European Conference on Computer Vision (ECCV)},
year={2024}
}