AI/TLDR

DALLE2-pytorch

PyTorch implementation of OpenAI's DALL-E 2 text-to-image model

Overview

DALLE2-pytorch is a from-scratch PyTorch implementation of DALL-E 2, OpenAI's text-to-image model. It lets researchers and engineers build and train the full pipeline themselves instead of relying on a closed API.

The key idea it reproduces is the diffusion prior: a network that takes a CLIP text embedding and predicts the matching CLIP image embedding. A separate diffusion decoder then turns that image embedding into an actual picture, with extra upsampling stages for higher resolution.

The project was central to community efforts (with the LAION group) to openly replicate DALL-E 2. It is built to scale, with distributed training that contributors have run across hundreds of GPUs.

What it does

  • Full DALL-E 2 pipeline in PyTorch: CLIP encoder, diffusion prior, and image decoder, wrapped in a single DALLE2 class for text-to-image generation
  • Diffusion prior network that maps CLIP text embeddings to CLIP image embeddings, the main contribution of the DALL-E 2 paper
  • Cascading diffusion decoder: stack multiple U-Nets at increasing resolutions (for example 128 then 256 or 512) and train each stage separately
  • Classifier-free guidance via a cond_scale argument to strengthen how closely an image follows its text prompt
  • Train directly on preprocessed CLIP embeddings (image_embed, text_embed, text_encodings) to save compute when scaling up
  • Integrates with the x-clip package so you can train or bring your own CLIP model

Getting started

Install the package from PyPI, then train the three pieces in order: CLIP, the decoder, and the diffusion prior. Once trained, wrap the prior and decoder in a DALLE2 object to generate images from text. A GPU is required (the examples call .cuda()).

Install

Install the library with pip.

bashbash
$ pip install dalle2-pytorch

Train CLIP

CLIP is trained first and is the most important step. Feed it batches of text token ids and images, and set return_loss = True to get the contrastive loss to backpropagate.

pythonpython
import torch
from dalle2_pytorch import CLIP

clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 6,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 6,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

loss = clip(text, images, return_loss = True)
loss.backward()

Train the diffusion prior and decoder

The diffusion prior learns to generate CLIP image embeddings from text embeddings. The decoder, which wraps one or more U-Nets, learns to turn images into pictures conditioned on those embeddings. Train both over many steps.

pythonpython
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder

prior_network = DiffusionPriorNetwork(dim = 512, depth = 6, dim_head = 64, heads = 8).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

loss = diffusion_prior(text, images)
loss.backward()

unet = Unet(dim = 128, image_embed_dim = 512, cond_dim = 128, channels = 3, dim_mults = (1, 2, 4, 8)).cuda()

decoder = Decoder(unet = unet, clip = clip, timesteps = 100, image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5).cuda()

loss = decoder(images)
loss.backward()

Generate images from text

Combine the trained prior and decoder in a DALLE2 object, then pass a list of text prompts to get back images.

pythonpython
from dalle2_pytorch import DALLE2

dalle2 = DALLE2(prior = diffusion_prior, decoder = decoder)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2.  # classifier free guidance strength
)

Commands and code are distilled from the project's own documentation — always check the official repo for the latest.

When to use it

  • Reproducing and studying the DALL-E 2 architecture, especially how the diffusion prior bridges CLIP text and image embeddings
  • Training a custom open text-to-image model on your own dataset instead of using a closed, paid API
  • Experimenting with cascading diffusion decoders to generate higher-resolution images from a learned image embedding
  • Plugging a trained diffusion prior into other CLIP-conditioned pipelines, such as CLIP-to-StyleGAN text-to-image applications

How DALLE2-pytorch compares

DALLE2-pytorch alongside other open-source image generation tools AI/TLDR tracks, ranked by GitHub stars.

ToolStarsWhat it does
Stable Diffusion web UI (AUTOMATIC1111)★ 164kA browser interface for running Stable Diffusion image generation locally with extensions and fine-grained controls.
ComfyUI★ 118kA node-based visual editor for building and running image and video generation pipelines like Stable Diffusion and FLUX locally.
Fooocus★ 50.4kA simplified image generation app built on Stable Diffusion that hides technical settings for easy prompting.
InvokeAI★ 27.5kA self-hosted creative tool and canvas for generating and editing images with open diffusion models.
Stability-AI generative-models★ 27.2kStability AI's official code for its Stable Diffusion family of image and video generation models.
FLUX★ 25.6kBlack Forest Labs' open-weight diffusion models and inference code for generating and editing images from text prompts.
Z-Image★ 11.6kAlibaba Tongyi's 6B-parameter open image model that produces photorealistic images quickly on a single GPU.
DALLE2-pytorch★ 11.3kPyTorch implementation of OpenAI's DALL-E 2 text-to-image model