Attention is O(n²) in memory. Explain concretely what that means at 100k tokens, why it makes naïve attention infeasible, and how Flash Attention solves it without changing the math. Then explain what other approaches exist for truly long contexts.
formulate your answer, then —
tldr
Naive O(n²) attention materializes an n×n matrix in GPU HBM — 20 GB per head at 100k tokens. Flash Attention avoids this by tiling computation in SRAM and using online softmax: same math, no n×n materialization, ~10× memory savings, 2-4× speedup. But it's O(n²) FLOPs — at 1M tokens, compute is still the bottleneck. For truly long context: sparse/sliding-window attention, Ring Attention across GPUs, or SSM architectures like Mamba.
follow-up
- How does online softmax work — why can you compute softmax over tiles without seeing all scores at once?
- What is Ring Attention and how does it distribute long-sequence processing across GPUs?
- When would you choose a hybrid SSM-attention architecture over a pure transformer for a long-context application?