In plain English
FlashAttention is a faster way to run the attention step inside a transformer. It produces the exact same numbers as the textbook version — not an approximation — but it gets there in a fraction of the time by being clever about how it moves data around the GPU.
Here's the analogy. Imagine a chef working with a huge walk-in fridge (lots of room, but a long walk away) and a tiny countertop right next to the stove (small, but instant to reach). The slow part of cooking isn't the chopping — it's all the trips back and forth to the fridge. A clumsy chef hauls out every ingredient, lays them all on a giant table, then walks each one back. A smart chef grabs one tray of ingredients, does everything possible with it on the countertop, stores the result, and only then fetches the next tray. Same dish, far fewer trips.
On a GPU, the walk-in fridge is HBM (high-bandwidth memory — big but relatively slow) and the countertop is SRAM (tiny on-chip cache — small but blazing fast). Standard attention is the clumsy chef: it writes a giant table of intermediate scores out to HBM and reads it back. FlashAttention is the smart chef: it streams the data through the fast on-chip cache in small tiles and never writes the giant table at all. If you're new to attention itself, start with how attention works and come back.
Why it matters
Attention is the heart of every modern transformer, and it's expensive in a very specific way. For a sequence of N tokens, attention compares every token with every other token — that's an N × N grid of scores. Double the context length and that grid quadruples. This quadratic blow-up is the reason long context is hard and why models need serious GPUs.
The surprising part: most of that cost was never the math. Modern GPUs can multiply matrices astonishingly fast — far faster than they can shuttle data in and out of memory. So the real bottleneck for attention was memory bandwidth, the time spent reading and writing that giant N × N score grid. The chips were sitting idle waiting on memory. FlashAttention's whole insight is to attack the memory traffic, not the arithmetic.
Who cares? Just about everyone who touches an LLM:
- Anyone using long context. FlashAttention turns the memory needed for attention from quadratic to roughly linear in sequence length. That's a big reason today's million-token context windows are even feasible.
- Anyone training models. Faster, more memory-efficient attention means bigger batches, longer sequences, and lower training bills — a direct lever on the scaling laws that drive model quality.
- Anyone serving models. Faster attention means lower latency and higher throughput per GPU, which is most of what inference cost comes down to.
- You, indirectly. You almost certainly used FlashAttention today without knowing it — it's the default attention kernel behind PyTorch, vLLM, and the major inference stacks.
How it works
Standard attention runs in distinct, memory-heavy passes. It computes the full score grid S = Q·Kᵀ, writes it to HBM, reads it back to apply softmax, writes the softmax result back, then reads it again to multiply by the values V. Every one of those round-trips drags a quadratic-sized matrix across the slow memory bus.
- Build full N×N score grid
- Write grid to slow HBM
- Read it back for softmax
- Write + read again for ×V
- Memory grows with N²
- Split Q, K, V into tiles
- Stream tiles through fast SRAM
- Fuse score + softmax + ×V per tile
- Never write the full grid
- Memory grows with N
FlashAttention does it in one fused pass using three ideas working together: tiling, online softmax, and recomputation.
1. Tiling — work in small blocks
Instead of building the whole grid, FlashAttention chops the queries, keys, and values into small blocks that fit in fast on-chip SRAM. It loads a block of queries, then walks through the key/value blocks one at a time, computing scores and accumulating a partial output — all without ever writing the full grid to HBM. This is a fused kernel: scoring, softmax, and the value multiply happen back-to-back on-chip rather than as three separate trips to memory.
2. Online softmax — the hard part
Softmax is the catch. Normally you need all the scores in a row before you can normalize them, because softmax divides by a sum over the whole row and subtracts the row's maximum for numerical stability. But if you're processing one block at a time, you don't have the whole row yet. Online softmax solves this by keeping two running statistics per query — the max score seen so far and the running sum — and rescaling the partial result each time a new block arrives. When the last block is done, the answer is mathematically identical to softmaxing the full row at once.
3. Recomputation — cheaper than storing
Training needs a backward pass, which traditionally requires the stored attention grid to compute gradients. FlashAttention refuses to store it. Instead it keeps only the tiny running statistics from the forward pass and recomputes the needed blocks on the fly during the backward pass. Counterintuitively, this is faster overall: redoing a bit of cheap matrix math beats hauling a quadratic grid back across slow memory. It's the same memory-versus-compute trade the whole algorithm is built on.
The GPU memory hierarchy that makes it work
To see why FlashAttention is a memory trick and not a math trick, you have to picture the GPU's memory as a pyramid. The chips have a tiny pool of extremely fast on-chip SRAM and a large pool of slower off-chip HBM. The bandwidth gap between them is roughly an order of magnitude — and crucially, the compute units are fast enough that they're usually starved, waiting on HBM.
Standard attention generates a quadratic amount of HBM traffic. FlashAttention keeps the working set on the top layer, so it touches HBM roughly linearly in sequence length. That's the difference between a kernel that's memory-bound (idle cores) and one that's compute-bound (busy cores). For attention specifically, FlashAttention's authors reported up to a 20x reduction in memory reads and writes versus a standard implementation — which is where the dramatic speedups come from. Note this is about the attention scratch memory and bandwidth; the context window you can actually fit also depends on the KV cache, which is a separate budget.
How you actually use it
The good news: you almost never call FlashAttention by hand. It's baked into the frameworks. In PyTorch, scaled_dot_product_attention (SDPA) automatically dispatches to a FlashAttention backend when your tensors qualify (right dtype, supported GPU, no exotic mask). Inference servers like vLLM and SGLang use it by default. If you write a transformer the normal way, you get it for free.
import torch
import torch.nn.functional as F
# q, k, v: (batch, heads, seq_len, head_dim) on a CUDA GPU
q = torch.randn(2, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
k = torch.randn_like(q)
v = torch.randn_like(q)
# PyTorch picks a fused FlashAttention kernel automatically
# when the inputs qualify (GPU + dtype + simple mask).
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.FLASH_ATTENTION
):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
print(out.shape) # (2, 16, 4096, 64) — exact attention, fused & fastIf you need the bleeding edge — the newest hardware, FP8, or custom masks — you can install the standalone package (flash-attn) from the Dao-AILab repo and call its kernels directly. Most people don't need to. The headline is simply: prefer bfloat16 or float16, keep your masks simple (causal masks are the well-trodden fast path), and let the framework route you to the fused kernel.
The current landscape (mid-2026)
FlashAttention isn't one frozen thing — it's a line of kernels, each rewritten to exploit a specific NVIDIA GPU generation. As of mid-2026, here's the family. (Numbers are vendor-reported peak figures, not promises about your workload.)
| Version | Year | Target GPU | Headline result |
|---|---|---|---|
| FlashAttention (v1) | 2022 | Ampere (A100) | First IO-aware exact attention; up to ~20x fewer memory ops vs standard |
| FlashAttention-2 | 2023 | Ampere/Ada | Better work partitioning; ~2x faster than v1, ~70%+ FP16 utilization |
| FlashAttention-3 | 2024 | Hopper (H100) | Async + FP8; up to ~740 TFLOPs/s FP16 (75% util), ~1.2 PFLOPs/s FP8, 1.5–2.0x over v2 |
| FlashAttention-4 | 2026 | Blackwell (B200) | Up to ~1605 TFLOPs/s BF16 (71% util); up to 1.3x over cuDNN, 2.7x over Triton |
The pattern is telling: each new GPU adds hardware tricks — asynchronous tensor cores, the TMA data-movement engine on Hopper, FP8, then Blackwell's fully asynchronous matrix units and larger tiles — and FlashAttention is rewritten to ride them. FlashAttention-3 (2024) leaned on Hopper's asynchrony, overlapping the matrix multiply with the softmax and adding FP8 with a neat outlier-taming step (Hadamard 'incoherent processing') that cut FP8 error roughly 2.6x versus a naive FP8 baseline. FlashAttention-4, published in March 2026, is a ground-up redesign for Blackwell B200 with new forward/backward pipelines and a software-emulated exponential.
Alongside the core line sits a small ecosystem: PyTorch's FlexAttention lets you express custom mask/score patterns and still get a fused kernel; vendor libraries like cuDNN ship their own competitive attention kernels. But FlashAttention remains the reference implementation everyone benchmarks against — the work largely came out of Tri Dao's research (now at Together AI and Princeton), with the code openly maintained at the Dao-AILab repository.
Going deeper
A few subtleties worth internalizing once the basics click.
It changes memory complexity, not compute complexity
This trips people up. FlashAttention still does O(N²) arithmetic — every token still attends to every token, so the FLOP count is unchanged. What drops from quadratic to linear is the memory footprint (and the HBM traffic), because the full N × N grid is never stored. So FlashAttention does not break the quadratic wall of attention; it makes the constant factor and the memory cost dramatically better. Truly sub-quadratic compute needs a different attack — sparse, linear, or mixture-of-experts style architectures — which do approximate or restructure the math.
Training vs inference are different problems
FlashAttention's original win was the prefill / training phase, where you process many tokens at once and the score grid is genuinely large. Single-token decoding (generating one token at a time) is a different regime — it's dominated by reading the KV cache, not by the attention grid — which is why decode-time variants like FlashDecoding and paged-KV schemes (the trick behind vLLM) exist. When you read that a serving stack is 'fast', it's usually FlashAttention for prefill plus a paged KV cache for decode, working together.
Why 'exact' is a selling point
Plenty of pre-2022 research chased cheaper attention by approximating it — sparse patterns, low-rank projections, kernelized 'linear' attention. The problem was that approximations often dented quality in ways that were hard to predict. FlashAttention's quiet genius is that it asked a different question: what if we don't change the math at all, and just stop wasting memory bandwidth? Because the output is exact, it's a free drop-in — no accuracy trade-off, no retraining, no eval surprises. That's exactly why it swept the field and became the silent default behind nearly every model you use, from frontier APIs to a local model on your laptop.
FAQ
Is FlashAttention an approximation of attention?
No. FlashAttention computes exact attention — the output matches a careful standard implementation to within normal floating-point rounding. It only changes how data moves through GPU memory (tiling, online softmax, recomputation), not the underlying math. Approximate methods like sparse or linear attention do change the math; FlashAttention does not.
Why is FlashAttention faster if it does the same amount of math?
Because attention was never bottlenecked by math — it was bottlenecked by memory bandwidth. Standard attention writes and reads a giant N×N score grid to slow HBM repeatedly. FlashAttention keeps the work in fast on-chip SRAM and never materializes that grid, cutting memory traffic by up to ~20x. The GPU's compute units stop sitting idle waiting on memory, so wall-clock time drops sharply.
Does FlashAttention let me use longer context windows?
Partly. It changes attention's scratch-memory footprint from quadratic to roughly linear in sequence length, which removes one major obstacle to long context. But the context length you can actually run also depends on the KV cache, model size, and total GPU memory — those are separate budgets. FlashAttention is necessary but not sufficient for million-token context.
Do I need to install anything to use FlashAttention?
Usually not. PyTorch's scaled_dot_product_attention automatically dispatches to a FlashAttention backend when your inputs qualify (supported GPU, float16/bfloat16, simple mask), and servers like vLLM and SGLang use it by default. You only need the standalone flash-attn package for the newest hardware, FP8, or custom kernels.
What is the latest version of FlashAttention as of 2026?
FlashAttention-4, published in March 2026, is the newest in the line. It's a ground-up redesign for NVIDIA Blackwell (B200) GPUs reaching up to ~1605 TFLOPs/s in BF16. Earlier generations are still widely used: FlashAttention-2 for Ampere/Ada and FlashAttention-3 for Hopper (H100), which added FP8 support.
What is the difference between FlashAttention and standard attention?
Standard attention builds the full N×N attention grid, writes it to slow memory, and reads it back multiple times. FlashAttention fuses scoring, softmax, and the value multiply into one pass over small tiles in fast on-chip memory, never storing the full grid. Same exact output, far less memory movement, much faster — especially as sequences get longer.