Overview
DALLE2-pytorch is a from-scratch PyTorch implementation of DALL-E 2, OpenAI's text-to-image model. It lets researchers and engineers build and train the full pipeline themselves instead of relying on a closed API.
The key idea it reproduces is the diffusion prior: a network that takes a CLIP text embedding and predicts the matching CLIP image embedding. A separate diffusion decoder then turns that image embedding into an actual picture, with extra upsampling stages for higher resolution.
The project was central to community efforts (with the LAION group) to openly replicate DALL-E 2. It is built to scale, with distributed training that contributors have run across hundreds of GPUs.
What it does
- Full DALL-E 2 pipeline in PyTorch: CLIP encoder, diffusion prior, and image decoder, wrapped in a single DALLE2 class for text-to-image generation
- Diffusion prior network that maps CLIP text embeddings to CLIP image embeddings, the main contribution of the DALL-E 2 paper
- Cascading diffusion decoder: stack multiple U-Nets at increasing resolutions (for example 128 then 256 or 512) and train each stage separately
- Classifier-free guidance via a cond_scale argument to strengthen how closely an image follows its text prompt
- Train directly on preprocessed CLIP embeddings (image_embed, text_embed, text_encodings) to save compute when scaling up
- Integrates with the x-clip package so you can train or bring your own CLIP model
Getting started
Install the package from PyPI, then train the three pieces in order: CLIP, the decoder, and the diffusion prior. Once trained, wrap the prior and decoder in a DALLE2 object to generate images from text. A GPU is required (the examples call .cuda()).
Install
Install the library with pip.
$ pip install dalle2-pytorchTrain CLIP
CLIP is trained first and is the most important step. Feed it batches of text token ids and images, and set return_loss = True to get the contrastive loss to backpropagate.
import torch
from dalle2_pytorch import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
loss = clip(text, images, return_loss = True)
loss.backward()Train the diffusion prior and decoder
The diffusion prior learns to generate CLIP image embeddings from text embeddings. The decoder, which wraps one or more U-Nets, learns to turn images into pictures conditioned on those embeddings. Train both over many steps.
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
prior_network = DiffusionPriorNetwork(dim = 512, depth = 6, dim_head = 64, heads = 8).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
unet = Unet(dim = 128, image_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults = (1, 2, 4, 8)).cuda()
decoder = Decoder(unet = unet, clip = clip, timesteps = 100, image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5).cuda()
loss = decoder(images)
loss.backward()Generate images from text
Combine the trained prior and decoder in a DALLE2 object, then pass a list of text prompts to get back images.
from dalle2_pytorch import DALLE2
dalle2 = DALLE2(prior = diffusion_prior, decoder = decoder)
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength
)Commands and code are distilled from the project's own documentation — always check the official repo for the latest.
When to use it
- Reproducing and studying the DALL-E 2 architecture, especially how the diffusion prior bridges CLIP text and image embeddings
- Training a custom open text-to-image model on your own dataset instead of using a closed, paid API
- Experimenting with cascading diffusion decoders to generate higher-resolution images from a learned image embedding
- Plugging a trained diffusion prior into other CLIP-conditioned pipelines, such as CLIP-to-StyleGAN text-to-image applications
How DALLE2-pytorch compares
DALLE2-pytorch alongside other open-source image generation tools AI/TLDR tracks, ranked by GitHub stars.
| Tool | Stars | What it does |
|---|---|---|
| Stable Diffusion web UI (AUTOMATIC1111) | ★ 164k | A browser interface for running Stable Diffusion image generation locally with extensions and fine-grained controls. |
| ComfyUI | ★ 118k | A node-based visual editor for building and running image and video generation pipelines like Stable Diffusion and FLUX locally. |
| Fooocus | ★ 50.4k | A simplified image generation app built on Stable Diffusion that hides technical settings for easy prompting. |
| InvokeAI | ★ 27.5k | A self-hosted creative tool and canvas for generating and editing images with open diffusion models. |
| Stability-AI generative-models | ★ 27.2k | Stability AI's official code for its Stable Diffusion family of image and video generation models. |
| FLUX | ★ 25.6k | Black Forest Labs' open-weight diffusion models and inference code for generating and editing images from text prompts. |
| Z-Image | ★ 11.6k | Alibaba Tongyi's 6B-parameter open image model that produces photorealistic images quickly on a single GPU. |
| DALLE2-pytorch | ★ 11.3k | PyTorch implementation of OpenAI's DALL-E 2 text-to-image model |