The single most important idea behind modern LLMs. Learn how models decide what to focus on — with analogies, math, and code.
Before Attention was invented, sequence models (RNNs) had a serious flaw: they forgot things. The longer the sentence, the worse they performed.
Imagine reading a whole book, but you can only remember the last sentence you read. That's an RNN. Now imagine you can flip back to any page whenever you need to — that's Attention!
An RNN looks at a whole room through a tiny keyhole — it only sees what's directly in front. Attention gives you a wide-open door: you can see the entire room at once and choose where to look.
Attention works through three vectors for every word: a Query, a Key, and a Value. Together they answer: "How much should I pay attention to each other word?"
You walk into a library with a question (Query). Every book has a title (Key). You compare your question to each title — the better the match, the more you read that book's contents (Value).
Query = "I want to learn about cats" (your question)
Keys = ["Animal Behavior", "Quantum Physics", "Cat Care Guide"] (book titles)
Values = [actual content of each book]
Your query matches "Cat Care Guide" best → you read mostly that book's content.
Click any word in the sentence to see which other words it attends to most.
Now let's see the actual math. The formula is surprisingly elegant:
Multiply Q and K to get "how similar are these?" scores. Divide by √dk so numbers don't get too big. Run softmax to turn scores into percentages. Multiply by V to get the final answer.
The √dk division is like adjusting a thermostat. Without it, dot products get very large in high dimensions, making softmax "too confident" — it would pick one word and ignore everything else. Scaling keeps the temperature comfortable.
See how scaling affects the softmax distribution. Raw scores: [2, 1, 2]
One attention head finds one type of pattern. But language has many patterns simultaneously — syntax, semantics, coreference. The solution? Run multiple attention heads in parallel.
Instead of sending one detective to investigate a crime scene, you send 8 detectives — each looking for something different. One checks fingerprints, another interviews witnesses, another studies the floor. Then they all share their findings.
Head 1 might learn "which words are grammatically related." Head 2 might learn "which words refer to the same entity." Head 3 might learn "which words are nearby." Each head has its own Q, K, V projections — its own way of asking questions.
import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model=512, n_heads=8): super().__init__() self.n_heads, self.d_k = n_heads, d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, Q, K, V): B, L, _ = Q.shape # Project & reshape to (B, heads, L, d_k) q = self.W_q(Q).view(B, L, self.n_heads, self.d_k).transpose(1,2) k = self.W_k(K).view(B, L, self.n_heads, self.d_k).transpose(1,2) v = self.W_v(V).view(B, L, self.n_heads, self.d_k).transpose(1,2) # Scaled dot-product attention scores = (q @ k.transpose(-2,-1)) / self.d_k**0.5 attn = torch.softmax(scores, dim=-1) out = (attn @ v).transpose(1,2).contiguous().view(B, L, -1) return self.W_o(out)
Self-attention = talking to yourself, "which of my own words relate to each other?"
Cross-attention = asking someone else, "which of YOUR words help ME understand?"
Self-attention is like re-reading your own essay to find connections between paragraphs. Cross-attention is like reading someone else's notes while writing your essay — Q comes from you, but K and V come from them.
Q, K, V all come from the same sequence
Used in: encoder, decoder (masked)
Q from one sequence, K & V from another
Used in: decoder attending to encoder
Here's a full working implementation of scaled dot-product attention with a test you can run.
We're building the attention mechanism from scratch — first the basic function, then wrapping it in a class, then running it on fake data to see it actually work.
Think of it like building a car engine: first we build one piston (single-head attention), then put 8 pistons together (multi-head), then start the engine (test with sample data).
import torch, torch.nn as nn, math def scaled_dot_product_attention(Q, K, V, mask=None): d_k = Q.size(-1) scores = (Q @ K.transpose(-2, -1)) / math.sqrt(d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = torch.softmax(scores, dim=-1) return weights @ V, weights class MultiHeadAttention(nn.Module): def __init__(self, d_model=64, n_heads=4): super().__init__() self.n_heads, self.d_k = n_heads, d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model) self.out = nn.Linear(d_model, d_model) def forward(self, x, mask=None): B, L, D = x.shape qkv = self.qkv(x).view(B, L, 3, self.n_heads, self.d_k) qkv = qkv.permute(2,0,3,1,4) Q, K, V = qkv[0], qkv[1], qkv[2] attn_out, weights = scaled_dot_product_attention(Q, K, V, mask) out = attn_out.transpose(1,2).contiguous().view(B, L, D) return self.out(out), weights # --- Test it! --- x = torch.randn(1, 6, 64) # batch=1, seq_len=6, d_model=64 mha = MultiHeadAttention(d_model=64, n_heads=4) output, attn_weights = mha(x) print(f"Input: {x.shape}") print(f"Output: {output.shape}") print(f"Weights: {attn_weights.shape}") print(f"Attn weights (head 0, first token):") print(attn_weights[0,0,0].data.numpy().round(3))
Select a sentence pattern to see how attention weights differ.