01 - Understanding Attention: The Core of Transformers
In 2017, a paper from Google Brain titled "Attention Is All You Need" permanently reshaped the deep learning landscape. Vaswani and colleagues proposed an architecture built entirely on a mechanism called attention, discarding the recurrent and convolutional networks that had dominated until then. The result was the Transformer architecture, now powering GPT-4, Claude, Llama 3, BERT, T5, Vision Transformers and virtually every frontier model in production today.
Understanding attention is not an academic exercise: it is the foundation on which techniques like LoRA fine-tuning, quantization, pruning and edge deployment are built, all topics we will cover in this series. Without a solid grasp of how attention works, every downstream optimization remains a black box.
In this first article of the Advanced Deep Learning and Edge Deployment series, we will explore attention in depth: from the initial intuition to the mathematical formula, from a PyTorch implementation to modern variants like Flash Attention 3 and Grouped-Query Attention.
Series Overview
| # | Article | Focus |
|---|---|---|
| 1 | You are here - Attention Mechanism in Transformers | Self-attention, multi-head, full architecture |
| 2 | Fine-tuning with LoRA, QLoRA and Adapters | Parameter-efficient fine-tuning |
| 3 | Model Quantization | INT8, INT4, GPTQ, AWQ |
| 4 | Pruning and Compression | Parameter reduction, distillation |
| 5 | Knowledge Distillation | Teacher-student, knowledge transfer |
| 6 | Ollama and Local LLMs | Local inference, optimization |
| 7 | Vision Transformers | ViT, DINO, image classification |
| 8 | Edge Deployment | ONNX, TensorRT, mobile devices |
| 9 | NAS and AutoML | Neural Architecture Search |
| 10 | Benchmarks and Optimization | Profiling, metrics, tuning |
What You Will Learn
- Why RNNs and LSTMs were not enough for long sequences
- The intuition behind the attention mechanism: Query, Key and Value
- The full Scaled Dot-Product Attention formula
- How Multi-Head Attention works and why multiple heads matter
- The difference between Self-Attention and Cross-Attention
- How Positional Encoding solves the ordering problem
- The complete Transformer architecture: encoder and decoder
- Hands-on PyTorch implementation, line by line
- Modern variants: Flash Attention 3, GQA, Sliding Window Attention
- Real architectures: GPT (decoder-only), BERT (encoder-only), T5 (encoder-decoder)
1. The Sequence Problem: Before Attention
To appreciate why attention was revolutionary, we need to understand the models that came before it. Deep learning for sequential data (text, audio, time series) was dominated by two architectures: RNNs (Recurrent Neural Networks) and LSTMs (Long Short-Term Memory).
1.1 RNNs and the Sequential Bottleneck
RNNs process sequences one token at a time, passing a hidden state from one timestep to the next. Each token updates the hidden state, which serves as the network's "memory" of the sequence seen so far.
Input: x1 -----> x2 -----> x3 -----> x4 -----> x5
| | | | |
v v v v v
Hidden: h1 -----> h2 -----> h3 -----> h4 -----> h5
| |
v v
Output: y1 y5
Problem: h5 must "remember" x1 through 4 intermediate steps.
With sequences of 1000+ tokens, information from x1 fades away.
This is the long-range dependency problem. In a sentence like "The cat, which had been adopted from the shelter three years ago and lived happily with the family, slept on the couch", the RNN must connect "cat" to "slept" across dozens of intervening tokens. The hidden state, compressed into a fixed-size vector, inevitably loses older information.
1.2 LSTM: An Improvement, Not a Solution
LSTMs introduced a gating mechanism (input gate, forget gate, output gate) to control which information to retain and which to discard. This improved the situation but did not fully solve it. LSTMs still suffer from two fundamental limitations:
RNN/LSTM Limitations
| Problem | Description | Impact |
|---|---|---|
| Sequential processing | Each token depends on the previous one: no parallelization | Extremely slow training on long sequences |
| Information bottleneck | All information flows through a single fixed-size vector | Information loss with sequences > 100-200 tokens |
| Vanishing gradients | Gradients shrink exponentially during backpropagation | Model cannot learn distant relationships |
What was needed was a mechanism allowing every token to access any other token in the sequence directly, without passing through intermediate states. That mechanism is attention.
2. What Is Attention: The Intuition
Attention is a mechanism that allows a model to focus its attention on the most relevant parts of the input when producing output. Instead of compressing the entire sequence into a single vector, attention creates a direct connection between every output position and all input positions.
Analogy: Searching a Library
Imagine you are in a library looking for information on "the history of Transformers". You have a question in mind (Query). Every book has a title (Key) describing its content. When a title matches your question, you extract the content (Value) from that book. Attention works exactly this way:
- Query (Q): "What am I looking for?" - the question the current token asks
- Key (K): "What does this element contain?" - the label for each token in the sequence
- Value (V): "Here is the information" - the actual content of each token
The mechanism computes a compatibility score between the Query and every Key. This score determines how much attention to pay to the corresponding Value. Scores are normalized via softmax to produce weights that sum to 1, and the final result is a weighted average of the Values.
Current token: "slept"
Query for "slept": "Who is performing this action?"
Key Score Weight (softmax)
"The" -----> 0.1 0.02
"cat" -----> 4.8 0.65 <-- High attention!
"which" -----> 0.3 0.03
"had" -----> 0.2 0.02
"been" -----> 0.1 0.02
"adopted" -----> 1.2 0.08
"..." -----> ... ...
"on" -----> 2.1 0.12
"couch" -----> 0.8 0.06
Output = 0.02 * V("The") + 0.65 * V("cat") + 0.03 * V("which") + ...
The model learned that "cat" is the subject of "slept",
even though they are separated by many tokens.
3. Scaled Dot-Product Attention: The Formula
The mathematical formulation of attention used in Transformers is Scaled Dot-Product Attention. It is elegant in its simplicity and computationally efficient thanks to matrix operations.
The Attention Formula
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Where:
- Q (Query): matrix of shape (n x d_k), where n is the number of tokens and d_k is the query/key dimension
- K (Key): matrix of shape (n x d_k)
- V (Value): matrix of shape (n x d_v), where d_v is the value dimension
- d_k: the key dimension, used as a scaling factor
- Q * K^T: dot product between queries and keys (n x n score matrix)
- / sqrt(d_k): scaling factor to stabilize gradients
- softmax: normalizes scores into weights that sum to 1
3.1 Why Scaling Is Necessary
Without the sqrt(d_k) factor, the dot product between Q and K produces values
that grow proportionally with d_k. With d_k = 512, dot products can reach very large
values. When fed into softmax, these produce near one-hot distributions (one weight
close to 1, all others near 0), resulting in extremely small gradients. Scaling
prevents this problem.
Without scaling (d_k = 512):
Raw scores: [120.3, 115.8, 2.1, -5.4]
Softmax: [0.989, 0.011, 0.000, 0.000] <-- Near one-hot, gradients ~0
With scaling (/ sqrt(512) = / 22.6):
Scaled scores: [5.32, 5.12, 0.09, -0.24]
Softmax: [0.44, 0.36, 0.10, 0.10] <-- Smooth distribution, healthy gradients
3.2 Step by Step: Computing Attention
Let us walk through a concrete numerical example with a 3-token sequence and d_k = 4:
Sequence: ["The", "cat", "sat"]
Step 1: Generate Q, K, V via linear projections
Q = X * W_Q K = X * W_K V = X * W_V
Q = [[1.0, 0.5, 0.3, 0.2], (The)
[0.8, 1.2, 0.1, 0.9], (cat)
[0.3, 0.4, 1.1, 0.6]] (sat)
K = [[0.9, 0.6, 0.4, 0.1],
[0.7, 1.1, 0.2, 0.8],
[0.4, 0.3, 1.0, 0.5]]
V = [[0.2, 0.8, 0.1, 0.5],
[0.9, 0.3, 0.7, 0.2],
[0.4, 0.6, 0.5, 0.8]]
Step 2: Compute Q * K^T (3x3 score matrix)
Score[i][j] = dot(Q[i], K[j])
Scores = [[1.19, 1.37, 0.89],
[1.35, 1.77, 1.10],
[0.98, 1.15, 1.42]]
Step 3: Scale by sqrt(d_k) = sqrt(4) = 2
Scaled = [[0.60, 0.69, 0.45],
[0.68, 0.89, 0.55],
[0.49, 0.58, 0.71]]
Step 4: Apply softmax row-wise
Weights = [[0.33, 0.36, 0.31], (The attending to The, cat, sat)
[0.32, 0.40, 0.28], (cat attending to The, cat, sat)
[0.29, 0.32, 0.39]] (sat attending to The, cat, sat)
Step 5: Multiply weights by V
Output[0] = 0.33*V[0] + 0.36*V[1] + 0.31*V[2]
= [0.51, 0.56, 0.39, 0.48]
Complexity Warning
The Q * K^T matrix has shape n x n, where n is the sequence length. With n = 1,000, the matrix has 1,000,000 elements. With n = 100,000, it has 10 billion elements. This quadratic O(n^2) complexity is the primary bottleneck of Transformers and the reason variants like Flash Attention and Sliding Window Attention were developed.
4. Multi-Head Attention: Looking from Multiple Angles
A single attention operation captures one type of relationship between tokens. But relationships in a sequence are multifaceted: syntactic (subject-verb), semantic (synonyms, context), positional (adjacent tokens) and many others. Multi-Head Attention addresses this by running attention in parallel with different projections.
Input X (shape: n x d_model, e.g. n x 512)
|
+---> Head 1: Q1=X*Wq1, K1=X*Wk1, V1=X*Wv1 --> Attention(Q1,K1,V1) --> Z1
| (d_k = d_model/h = 64)
+---> Head 2: Q2=X*Wq2, K2=X*Wk2, V2=X*Wv2 --> Attention(Q2,K2,V2) --> Z2
|
+---> Head 3: Q3=X*Wq3, K3=X*Wk3, V3=X*Wv3 --> Attention(Q3,K3,V3) --> Z3
|
+---> ...
|
+---> Head 8: Q8=X*Wq8, K8=X*Wk8, V8=X*Wv8 --> Attention(Q8,K8,V8) --> Z8
|
v
Concatenate: [Z1; Z2; Z3; ... Z8] (shape: n x d_model)
|
v
Final projection: Concat * W_O (shape: n x d_model)
With h = 8 heads and d_model = 512, each head operates on a
subspace of dimension d_k = d_v = 512 / 8 = 64. The total computational
cost is similar to a single full-dimension attention, because the heads operate in
parallel on smaller subspaces.
What Each Head Learns
Empirical research has shown that different heads specialize in different patterns:
- Head 1: May learn subject-verb relationships
- Head 2: May learn coreference patterns (pronouns and their antecedents)
- Head 3: May focus on adjacent tokens (local n-grams)
- Head 4: May capture long-range dependencies across clauses
- Other heads: Syntactic patterns, entities, discourse structure
Multi-Head Attention Formula
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
where head_i = Attention(Q * W_Qi, K * W_Ki, V * W_Vi)
Typical parameters in the original paper: d_model = 512, h = 8, d_k = d_v = 64. In modern models: d_model = 4096-8192, h = 32-128.
5. Self-Attention: A Token Looking at All Others
Self-attention is the specific case where Query, Key and Value all come from the same sequence. Every token generates its own Query, Key and Value, then uses its Query to "interrogate" the Keys of all other tokens (including itself).
Sentence: "The cat sat on the mat"
Attention Matrix (each row sums to 1.0):
The cat sat on the mat
The [0.15 0.25 0.10 0.05 0.15 0.30]
cat [0.10 0.20 0.35 0.05 0.05 0.25]
sat [0.05 0.40 0.15 0.20 0.05 0.15]
on [0.05 0.10 0.30 0.10 0.15 0.30]
the [0.20 0.15 0.05 0.10 0.10 0.40]
mat [0.10 0.15 0.15 0.25 0.15 0.20]
Observations:
- "sat" pays high attention to "cat" (0.40) --> subject-verb
- "on" attends to "sat" (0.30) and "mat" (0.30) --> spatial relation
- "the" (second) pays high attention to "mat" (0.40) --> determiner-noun
Self-attention is the heart of Transformers. It is what allows the model to build contextual representations: the representation of every token incorporates information from the entire sequence, weighted by relevance. The word "bank" will have a different representation in "river bank" and "bank account" because surrounding tokens influence its representation through attention.
Masked Self-Attention in Decoders
In generative models (decoders), self-attention is masked: each token can only see previous tokens, not future ones. This is implemented by setting future token scores to negative infinity before softmax, producing zero weights. This is the causal attention used in GPT, Llama and all autoregressive models.
Mask for a 5-token sequence (0 = visible, -inf = masked):
t1 t2 t3 t4 t5
t1 [ 0 -inf -inf -inf -inf ]
t2 [ 0 0 -inf -inf -inf ]
t3 [ 0 0 0 -inf -inf ]
t4 [ 0 0 0 0 -inf ]
t5 [ 0 0 0 0 0 ]
After softmax:
t1 sees only [t1]
t2 sees only [t1, t2]
t3 sees only [t1, t2, t3]
...and so on
6. Cross-Attention: When Encoder and Decoder Communicate
Cross-attention (or encoder-decoder attention) is the mechanism that allows the decoder to "look at" the encoder output. Unlike self-attention where Q, K and V come from the same sequence, in cross-attention the Queries come from the decoder while Keys and Values come from the encoder.
ENCODER (processes input, e.g. Italian sentence):
"Il gatto dorme" --> Encoder --> Encoder representations (K_enc, V_enc)
DECODER (generates output, e.g. English translation):
"The cat" --> Masked Self-Attention --> Q_dec
CROSS-ATTENTION:
Q = Q_dec (from decoder: "what do I need to generate the next token?")
K = K_enc (from encoder: "what does each input token contain?")
V = V_enc (from encoder: "here is the input information")
The decoder can "look at" the entire encoder sequence
to decide which token to generate next.
Cross-attention is fundamental in encoder-decoder architectures used for machine translation (T5, mBART), text summarization and conditional generation. In T5, for instance, the encoder processes the input text and the decoder generates output text, using cross-attention to consult the encoder at each generation step.
The Three Types of Attention in Transformers
| Type | Q Source | K, V Source | Where Used |
|---|---|---|---|
| Self-Attention (encoder) | Encoder input | Encoder input | BERT encoder, T5 encoder |
| Masked Self-Attention | Decoder input | Decoder input | GPT, Llama, T5 decoder |
| Cross-Attention | Decoder | Encoder output | T5 decoder, mBART |
7. Positional Encoding: How Transformers Know Token Order
Unlike RNNs, which process tokens sequentially, self-attention is order-invariant: the result does not change if you permute the input tokens. "The cat eats the fish" and "fish the cat the eats" would produce the same output without an additional mechanism. Positional encoding solves this by injecting position information into each token.
7.1 Sinusoidal Positional Encoding (Original Paper)
The original paper uses sinusoidal functions to generate positional encodings:
Sinusoidal Positional Encoding Formulas
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Where pos is the token position in the sequence and i is the dimension index. Even positions use sine, odd positions use cosine. The varying frequency for each dimension allows the model to learn relative positional relationships.
Position 0: [sin(0), cos(0), sin(0), cos(0), ...] = [0.00, 1.00, 0.00, 1.00, ...]
Position 1: [sin(1), cos(1), sin(0.01), cos(0.01)] = [0.84, 0.54, 0.01, 1.00, ...]
Position 2: [sin(2), cos(2), sin(0.02), cos(0.02)] = [0.91, -0.42, 0.02, 1.00, ...]
The final embedding for each token is:
token_embedding = word_embedding + positional_encoding
Lower frequencies (higher dimensions) capture global positions.
Higher frequencies (lower dimensions) capture local positions.
7.2 Learned Positional Encoding
An alternative to sinusoidal encoding is learned positional embeddings: a trainable parameter matrix with one row per position. This approach is used in BERT and GPT-2. The advantage is that the model can learn optimal positional patterns for the specific task. The drawback is that maximum sequence length is fixed at training time.
Positional Encoding Comparison
| Type | Advantages | Disadvantages | Used In |
|---|---|---|---|
| Sinusoidal | No extra parameters, generalizes to longer sequences | Fixed patterns, not optimized for the task | Original Transformer |
| Learned | Optimized for the specific task | Fixed maximum length, more parameters | BERT, GPT-2 |
| RoPE (Rotary) | Captures relative positions, extensible | More complex implementation | Llama, Mistral, GPT-NeoX |
| ALiBi | No parameters, good extrapolation | Linear bias can be limiting | BLOOM, MPT |
8. The Complete Transformer Architecture
With all the puzzle pieces in hand, we can now assemble the full Transformer architecture. The original Transformer consists of an encoder stack and a decoder stack, each made of N identical layers (N = 6 in the original paper).
INPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| ENCODER STACK | x N (6 in the original paper)
| |
| +--Multi-Head-------+
| | Self-Attention |
| +------|------------+
| v
| +--Add & Norm-------+ (residual connection + layer norm)
| +------|------------+
| v
| +--Feed-Forward-----+ (2 linear layers with ReLU/GELU)
| | Network | (d_model -> d_ff -> d_model)
| +------|------------+ (d_ff = 4 * d_model = 2048)
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
| (K, V for cross-attention)
|
OUTPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| DECODER STACK | x N
| |
| +--Masked Multi-----+
| | Head Self-Attn | (causal mask: sees only the past)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Cross-Attention--+ (Q from decoder, K/V from encoder)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Feed-Forward-----+
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
v
Linear + Softmax
|
v
Output Probabilities (vocabulary)
8.1 Residual Connections
Every sub-layer (attention or feed-forward) has a residual connection:
the sub-layer output is added to the input. The formula is
output = LayerNorm(x + SubLayer(x)). Residual connections solve the
vanishing gradient problem in deep networks by allowing gradients to flow directly
through shortcut connections.
8.2 Feed-Forward Network
After attention, each token passes through a feed-forward network applied independently at every position. It consists of two linear transformations with a nonlinear activation (ReLU in the original paper, GELU or SwiGLU in modern models):
FFN(x) = W2 * activation(W1 * x + b1) + b2
The inner dimension (d_ff) is typically 4 times d_model. With d_model = 512, d_ff = 2048. In modern models like Llama 3, d_ff reaches 14,336 with d_model = 4096.
8.3 Layer Normalization
Layer Normalization normalizes activations along the feature dimension (not the batch dimension). It stabilizes training and accelerates convergence. The original Transformer uses Post-LN (normalization after the residual connection), but most modern models use Pre-LN (normalization before the sub-layer), which is more stable during training.
9. PyTorch Implementation: Self-Attention from Scratch
Let us move from theory to code. We will implement Scaled Dot-Product Attention and Multi-Head Attention from scratch in PyTorch, without using pre-built modules.
9.1 Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout: nn.Dropout = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Scaled Dot-Product Attention.
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) or (batch, 1, seq_len, seq_len)
dropout: optional dropout module
Returns:
output: (batch, heads, seq_len, d_v)
attention_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Step 1: Compute scores Q * K^T / sqrt(d_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: Apply mask (optional)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax to obtain attention weights
attention_weights = F.softmax(scores, dim=-1)
# Step 4: Optional dropout on weights
if dropout is not None:
attention_weights = dropout(attention_weights)
# Step 5: Multiply weights by Values
output = torch.matmul(attention_weights, value)
return output, attention_weights
9.2 Multi-Head Attention
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention implemented from scratch.
Parameters:
d_model: model dimension (e.g. 512)
num_heads: number of attention heads (e.g. 8)
dropout: dropout rate (e.g. 0.1)
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, \
f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # per-head dimension
# Linear projections for Q, K, V and output
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Reshape tensor from (batch, seq_len, d_model)
to (batch, num_heads, seq_len, d_k).
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, heads, seq_len, d_k)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Forward pass.
For Self-Attention: query = key = value = X
For Cross-Attention: query = decoder, key = value = encoder
"""
batch_size = query.size(0)
# 1. Linear projections
q = self.w_q(query) # (batch, seq_len, d_model)
k = self.w_k(key)
v = self.w_v(value)
# 2. Split into heads
q = self.split_heads(q) # (batch, heads, seq_len, d_k)
k = self.split_heads(k)
v = self.split_heads(v)
# 3. Scaled Dot-Product Attention
attn_output, attn_weights = scaled_dot_product_attention(
q, k, v, mask=mask, dropout=self.dropout
)
# 4. Concatenate heads
# (batch, heads, seq_len, d_k) -> (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 5. Final projection
output = self.w_o(attn_output)
return output
9.3 Usage Example
# Configuration
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
# Create the module
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
# Random input (simulates a sequence of token embeddings)
x = torch.randn(batch_size, seq_len, d_model)
# Self-Attention (query = key = value)
output = mha(query=x, key=x, value=x)
print(f"Input shape: {x.shape}") # torch.Size([2, 10, 512])
print(f"Output shape: {output.shape}") # torch.Size([2, 10, 512])
# Causal mask for decoder (lower triangular)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Masked Self-Attention
output_masked = mha(query=x, key=x, value=x, mask=causal_mask)
print(f"Masked output shape: {output_masked.shape}")
# Cross-Attention (query from decoder, key/value from encoder)
encoder_output = torch.randn(batch_size, 20, d_model) # longer encoder sequence
decoder_input = torch.randn(batch_size, seq_len, d_model)
cross_attn_output = mha(
query=decoder_input,
key=encoder_output,
value=encoder_output
)
print(f"Cross-attention shape: {cross_attn_output.shape}") # [2, 10, 512]
10. Modern Attention Variants
The quadratic O(n^2) complexity of standard attention has motivated the development of numerous optimized variants. These are essential for modern models handling contexts of 100K to over 1 million tokens.
10.1 Flash Attention (v1, v2, v3)
Flash Attention, developed by Tri Dao and colleagues, does not change the mathematics of attention but radically optimizes its hardware-level implementation. The key idea is to avoid materializing the full n x n attention score matrix in GPU memory (HBM), instead using a tiled approach that works entirely in SRAM (fast on-chip memory).
Flash Attention Evolution
| Version | Year | Key Innovation | Performance |
|---|---|---|---|
| Flash Attention 1 | 2022 | Tiling + fused kernel, IO-awareness | 2-4x speedup vs standard |
| Flash Attention 2 | 2023 | Improved parallelism, less communication | 2x additional over v1 |
| Flash Attention 3 | 2024 | Asynchrony on Hopper GPUs, FP8, warp specialization | Up to 740 TFLOPS (FP16) on H100, 1.2 PFLOPS with FP8 |
Flash Attention 3 leverages specific capabilities of NVIDIA Hopper GPUs (H100/H200): asynchrony between Tensor Cores and TMA (Tensor Memory Accelerator) to overlap compute and data movement, warp specialization for optimal interleaving of matmul and softmax operations, and block FP8 quantization with 2.6x lower numerical error than naive FP8 implementation. Flash Attention is now integrated into PyTorch, Hugging Face Transformers, vLLM and TensorRT-LLM.
10.2 Multi-Query Attention (MQA)
Proposed by Shazeer in 2019, Multi-Query Attention dramatically reduces the memory required for the KV cache during inference. Instead of having a separate set of Keys and Values for each head, MQA shares a single set of K and V across all heads, while maintaining separate Queries.
Multi-Head Attention (MHA) - Standard:
Head 1: Q1, K1, V1 | KV Cache per head: d_k * seq_len * 2
Head 2: Q2, K2, V2 | Total KV Cache: h * d_k * seq_len * 2
... | With h=32, d_k=128, seq=4096:
Head h: Qh, Kh, Vh | = 32 * 128 * 4096 * 2 = 33.5 MB per layer
Multi-Query Attention (MQA):
Head 1: Q1 \
Head 2: Q2 |--- K_shared, V_shared
... | Total KV Cache: d_k * seq_len * 2
Head h: Qh / = 128 * 4096 * 2 = 1.05 MB per layer (32x less!)
10.3 Grouped-Query Attention (GQA)
GQA, introduced by Ainslie et al. in 2023, is a compromise between MHA and MQA. Instead of sharing one K/V set across all heads (MQA) or having one per head (MHA), GQA groups heads into g groups, with each group sharing a K/V set. With g = 1 you get MQA, with g = h you get MHA.
Example: 8 query heads, 2 KV groups (g=2)
Group 1: Q1, Q2, Q3, Q4 share K1, V1
Group 2: Q5, Q6, Q7, Q8 share K2, V2
KV Cache: g * d_k * seq_len * 2 = 2 * 128 * 4096 * 2 = 2.1 MB
(16x less than MHA, but only 2x more than MQA)
Models using GQA:
- Llama 2 (70B): 8 KV heads, 64 query heads
- Llama 3: GQA with 8:1 ratio
- Mistral 7B: 8 KV heads, 32 query heads
Attention Variants Comparison
| Variant | KV Heads | KV Cache Memory | Quality | Models |
|---|---|---|---|---|
| MHA | h (all) | Maximum | Best | BERT, GPT-2, GPT-3 |
| GQA | g (groups) | h/g reduction | Near MHA | Llama 2/3, Mistral |
| MQA | 1 | Minimum | Slight drop | PaLM, Falcon |
10.4 Sliding Window Attention
Sliding Window Attention, used in Mistral and Longformer, restricts attention to a local window of w tokens per position. Instead of computing attention over the full sequence (O(n^2)), each token sees only the w preceding tokens, reducing complexity to O(n * w).
Sequence: t1 t2 t3 t4 t5 t6 t7 t8
Attention for t5 (window=3): sees only [t3, t4, t5]
Attention for t8 (window=3): sees only [t6, t7, t8]
Attention Matrix (1 = visible, 0 = masked):
t1 t2 t3 t4 t5 t6 t7 t8
t1 [ 1 0 0 0 0 0 0 0 ]
t2 [ 1 1 0 0 0 0 0 0 ]
t3 [ 1 1 1 0 0 0 0 0 ]
t4 [ 0 1 1 1 0 0 0 0 ]
t5 [ 0 0 1 1 1 0 0 0 ]
t6 [ 0 0 0 1 1 1 0 0 ]
t7 [ 0 0 0 0 1 1 1 0 ]
t8 [ 0 0 0 0 0 1 1 1 ]
Information is NOT lost: across multiple stacked layers,
information from t1 can reach t8 through propagation.
With L layers and window w, the effective receptive field is L * w.
10.5 Ring Attention and PagedAttention
For extremely long contexts (over 1 million tokens), additional innovations have emerged:
- Ring Attention: distributes attention computation across multiple GPUs arranged in a ring. Each GPU computes attention on a segment of the sequence and passes results to the next GPU. RingX (2025) achieves 94% scaling efficiency up to 4,096 GPUs with 1-million-token sequences.
- PagedAttention: inspired by virtual memory management in operating systems, it allocates the KV cache in non-contiguous blocks (pages), eliminating memory fragmentation. It powers vLLM and enables batch sizes up to 76 times larger.
- FlexAttention (PyTorch): a unified API supporting diverse attention variants (GQA, causal, sliding window, PagedAttention) with less than 5% overhead compared to dedicated implementations.
11. Applications: Transformer Architectures in Practice
The Transformer architecture has spawned three main model families, each using attention differently.
11.1 Encoder-Only: BERT and Derivatives
Encoder-only models use bidirectional self-attention: every token can see all other tokens in the sequence, both preceding and following. This makes them ideal for language understanding tasks.
BERT (Bidirectional Encoder Representations from Transformers)
- Pre-training: Masked Language Model (MLM) + Next Sentence Prediction
- Attention: Bidirectional self-attention (sees the entire sequence)
- Tasks: Classification, Named Entity Recognition, Question Answering
- Variants: RoBERTa, ALBERT, DeBERTa, DistilBERT
11.2 Decoder-Only: GPT and the LLM Family
Decoder-only models use masked self-attention (causal): each token sees only preceding tokens. They are optimized for autoregressive text generation.
Decoder-Only Models
| Model | Parameters | Attention Variant | Context Window |
|---|---|---|---|
| GPT-3 | 175B | Standard MHA | 2K-4K tokens |
| GPT-4 | ~1.8T (MoE) | GQA (estimated) | 128K tokens |
| Llama 3 405B | 405B | GQA + RoPE | 128K tokens |
| Mistral 7B | 7.3B | GQA + Sliding Window | 32K tokens |
| Claude (Anthropic) | Not published | Not published | 200K tokens |
11.3 Encoder-Decoder: T5 and Seq2Seq Models
Encoder-decoder models use all three types of attention: bidirectional self-attention in the encoder, masked self-attention in the decoder and cross-attention between decoder and encoder. They are ideal for tasks that transform an input into an output (translation, summarization, question answering).
Encoder-Decoder Models
- T5: "Text-to-Text Transfer Transformer" - every task is framed as text-in-text-out
- BART: Denoising autoencoder for generation and understanding
- mBART: Multilingual BART for translation
- Flan-T5: T5 trained with instruction tuning
11.4 Vision Transformer (ViT)
Attention is not limited to text. Vision Transformers apply self-attention to images by dividing the image into patches (e.g. 16x16 pixels) and treating each patch as a "token". This demonstrated that attention is a general mechanism applicable to any type of sequential data.
Image 224x224 pixels
|
v
Split into 16x16 patches: (224/16)^2 = 196 patches
|
v
Each patch -> flatten -> linear projection -> patch embedding
|
v
[CLS] + 196 patch embeddings + positional encoding
|
v
Transformer Encoder (self-attention over 197 tokens)
|
v
[CLS] token -> image classification
Conclusions and Next Steps
In this article we have traced the entire arc of the attention mechanism: from the long-range dependency problem in RNNs, to the Query-Key-Value intuition, to the Scaled Dot-Product Attention formula, to Multi-Head Attention, all the way to the complete Transformer architecture. We implemented self-attention from scratch in PyTorch and explored modern variants that make million-token context windows possible.
Attention is the fundamental building block of all modern deep learning. Understanding how it works empowers you to reason about why certain optimizations succeed, why some models are faster than others and how to choose the right architecture for your use case.
Key Takeaways
- Attention enables direct connections between any pair of tokens, without bottlenecks
- Scaling (sqrt(d_k)) prevents unstable gradients in softmax
- Multi-Head captures diverse relationships in parallel at no extra cost
- Self-Attention creates contextual representations; Cross-Attention links encoder and decoder
- Positional Encoding provides order information (sinusoidal, learned, RoPE)
- Flash Attention optimizes the hardware implementation without changing the math
- GQA is the optimal tradeoff between quality (MHA) and efficiency (MQA)
In the next article in this series, we will explore Transformer fine-tuning with LoRA, QLoRA and Adapters: how to adapt pre-trained models to specific tasks by modifying only a small fraction of parameters, drastically reducing GPU and memory costs.
Additional Resources
- Original paper: "Attention Is All You Need" (Vaswani et al., 2017)
- Flash Attention 3: "Fast and Accurate Attention with Asynchrony and Low-precision" (Dao et al., 2024)
- GQA Paper: "GQA: Training Generalized Multi-Query Transformer Models" (Ainslie et al., 2023)
- The Illustrated Transformer: Visual guide by Jay Alammar
- PyTorch Documentation: torch.nn.MultiheadAttention for optimized implementations
- Hugging Face: Transformer documentation with practical examples







