Part 1: The Core Idea
Attention is like a spotlight - it helps models focus on what's important.
import torch import torch.nn.functional as F # Simple example: Which word is most important? sentence = ["I", "love", "pizza"] importance = torch.tensor([0.1, 0.3, 0.6]) # pizza is most important
Intuition: Instead of treating all words equally, attention assigns different weights to focus on what matters most.
Part 2: Basic Attention Weights
# Raw attention scores (how much to focus on each word) scores = torch.tensor([2.0, 1.0, 3.0]) # [I, love, pizza] # Convert to probabilities (softmax) weights = F.softmax(scores, dim=0) print(weights) # [0.24, 0.09, 0.67] - pizza gets most attention
What happened: Softmax converts raw scores to probabilities that sum to 1.
Part 3: Weighted Combination
# Word representations (simplified vectors) words = torch.tensor([[1.0, 0.0], # "I" [0.0, 1.0], # "love" [1.0, 1.0]]) # "pizza" # Apply attention weights attended = torch.sum(weights.unsqueeze(1) * words, dim=0) print(attended) # Mostly "pizza" representation
Intuition: We combine all word vectors, but "pizza" contributes most because it has the highest attention weight.
Part 4: Computing Attention Scores
# How similar are words? (dot product) query = torch.tensor([1.0, 1.0]) # What we're looking for key1 = torch.tensor([1.0, 0.0]) # "I" key2 = torch.tensor([0.0, 1.0]) # "love" score1 = torch.dot(query, key1) # 1.0 score2 = torch.dot(query, key2) # 1.0
Theory: Attention scores measure how well a query matches each key.
Part 5: The Q, K, V Concept
# Three roles for each word: # Q (Query): "What am I looking for?" # K (Key): "What do I represent?" # V (Value): "What information do I carry?" query = torch.tensor([1.0, 0.0]) # Looking for subject keys = torch.tensor([[1.0, 0.0], # "I" - matches query well [0.0, 1.0]]) # "love" - doesn't match values = torch.tensor([[2.0, 3.0], # "I" carries this info [1.0, 4.0]]) # "love" carries this info
Intuition: Query asks "what do I need?", Keys answer "what do I offer?", Values provide the actual information.
Part 6: One-Line Attention
# Complete attention in one line attention_output = torch.sum(F.softmax(torch.mv(keys, query), dim=0).unsqueeze(1) * values, dim=0)
What it does: Computes scores (query·keys), applies softmax, weights the values.
Part 7: Self-Attention Intuition
# In self-attention, each word can attend to every other word sentence = ["The", "cat", "sat"] # "cat" might attend to "sat" (what did the cat do?) # "sat" might attend to "cat" (who sat?)
Key insight: Words can look at each other to understand relationships and context.
Part 8: Multi-Head Attention (Simple)
# Multiple "attention heads" look for different things head1_query = torch.tensor([1.0, 0.0]) # Looking for subjects head2_query = torch.tensor([0.0, 1.0]) # Looking for actions # Each head focuses on different aspects
Why multiple heads: Different heads can specialize in different types of relationships (subject-verb, adjective-noun, etc.).
Part 9: Scaling Up
# Real sentences have many words seq_len = 10 # 10 words in sentence d_model = 64 # Each word is 64-dimensional vector # Q, K, V matrices transform word vectors Q = torch.randn(seq_len, d_model) # Queries for each word K = torch.randn(seq_len, d_model) # Keys for each word V = torch.randn(seq_len, d_model) # Values for each word
Scale: Real models use hundreds of dimensions and thousands of words.
Part 10: Attention Matrix
# Attention scores between all word pairs attention_scores = torch.mm(Q, K.transpose(0, 1)) # [10, 10] matrix attention_weights = F.softmax(attention_scores, dim=1) # Each row sums to 1 # Row i, column j = how much word i attends to word j
Visualization: Each row shows where one word "looks" in the sentence.
Part 11: Why Attention Works
# Traditional RNN: Information flows sequentially # Word 1 → Word 2 → Word 3 → Word 4 # Attention: All words can interact directly # Word 1 ↔ Word 2 ↔ Word 3 ↔ Word 4
Advantage: No information loss over long distances, parallel processing.
Part 12: Putting It All Together
# Complete self-attention step by step def simple_attention(X): Q = X # Queries (simplified) K = X # Keys V = X # Values scores = torch.mm(Q, K.transpose(0, 1)) # Compute similarities weights = F.softmax(scores, dim=1) # Convert to probabilities output = torch.mm(weights, V) # Weighted combination return output # Usage word_vectors = torch.randn(5, 8) # 5 words, 8 dimensions each attended_vectors = simple_attention(word_vectors)
Result: Each word vector is now updated with information from all other words, weighted by attention.
Key Takeaways
- Attention = Weighted Average: Focus more on important parts
- Q·K = Similarity: How well query matches key
- Softmax = Probability: Convert scores to weights that sum to 1
- Weighted V = Output: Combine values using attention weights
- Self-Attention = Words talking to each other: Every word can attend to every other word
This foundation prepares you for transformer models, which are built entirely on attention mechanisms!
Understanding Attention: From Words to Vectors
1. Word Embeddings - The Foundation
import torch import torch.nn as nn import torch.nn.functional as F # Sample sentence: "The cat sat on the mat" vocab = {"<pad>": 0, "the": 1, "cat": 2, "sat": 3, "on": 4, "mat": 5} sentence = [1, 2, 3, 4, 1, 5] # token IDs # Create embeddings vocab_size = len(vocab) embed_dim = 64 embedding = nn.Embedding(vocab_size, embed_dim) # Convert tokens to vectors tokens = torch.tensor(sentence) embeddings = embedding(tokens) print(f"Shape: {embeddings.shape}") # [6, 64] print(f"'cat' vector: {embeddings[1][:8]}...") # First 8 dimensions
Each word becomes a 64-dimensional vector that captures semantic meaning.
2. The Q, K, V Matrices - Core of Attention
# Attention dimensions d_model = 64 num_heads = 8 d_k = d_model // num_heads # 8 # Linear transformations to create Q, K, V W_q = nn.Linear(d_model, d_model, bias=False) W_k = nn.Linear(d_model, d_model, bias=False) W_v = nn.Linear(d_model, d_model, bias=False) # Transform embeddings Q = W_q(embeddings) # Queries: "What am I looking for?" K = W_k(embeddings) # Keys: "What do I represent?" V = W_v(embeddings) # Values: "What information do I carry?" print(f"Q shape: {Q.shape}") # [6, 64] print(f"K shape: {K.shape}") # [6, 64] print(f"V shape: {V.shape}") # [6, 64]
Intuition:
- Q (Query): "What information does this word need?"
- K (Key): "What kind of information does this word offer?"
- V (Value): "What actual information does this word contain?"
3. Computing Attention Scores
# Reshape for multi-head attention batch_size, seq_len = 1, 6 Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8] K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8] V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8] # Attention scores: How much should each word pay attention to others? scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) print(f"Attention scores shape: {scores.shape}") # [1, 8, 6, 6] # Example: How much does "cat" attend to each word? cat_attention = scores[0, 0, 1, :] # First head, "cat" position words = ["the", "cat", "sat", "on", "the", "mat"] for i, word in enumerate(words): print(f"cat -> {word}: {cat_attention[i]:.3f}")
4. Softmax and Weighted Values
# Convert scores to probabilities attention_weights = F.softmax(scores, dim=-1) print(f"Attention weights shape: {attention_weights.shape}") # [1, 8, 6, 6] # Apply attention to values attended_values = torch.matmul(attention_weights, V) # [1, 8, 6, 8] # Concatenate heads and project back attended_values = attended_values.transpose(1, 2).contiguous().view( batch_size, seq_len, d_model) # [1, 6, 64] print(f"Final attended values shape: {attended_values.shape}") # Show attention pattern for "cat" print("\nAttention pattern for 'cat':") cat_weights = attention_weights[0, 0, 1, :] # First head for i, word in enumerate(words): print(f" {word}: {cat_weights[i]:.3f}")
5. Complete Self-Attention Implementation
class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): batch_size, seq_len, d_model = x.size() # Linear transformations Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # Scaled dot-product attention scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) attention_weights = F.softmax(scores, dim=-1) attended_values = torch.matmul(attention_weights, V) # Concatenate heads attended_values = attended_values.transpose(1, 2).contiguous().view( batch_size, seq_len, d_model) # Final projection output = self.W_o(attended_values) return output, attention_weights # Usage attention = MultiHeadAttention(d_model=64, num_heads=8) output, weights = attention(embeddings.unsqueeze(0)) print(f"Output shape: {output.shape}") # [1, 6, 64]
6. Visualizing Attention Patterns
# Extract attention weights for visualization attention_matrix = weights[0, 0].detach().numpy() # First head words = ["the", "cat", "sat", "on", "the", "mat"] print("Attention Matrix (first head):") print("From -> To:") for i, from_word in enumerate(words): print(f"{from_word:>4}: ", end="") for j, to_word in enumerate(words): print(f"{attention_matrix[i,j]:.2f} ", end="") print()
7. Key Insights
What happens in attention?
- Each word creates a query (what it's looking for)
- Each word creates a key (what it represents)
- We compute similarity between queries and keys
- Higher similarity = more attention
- We use attention weights to combine values (actual information)
Example: When processing "cat", the model might:
- Query: "I need information about animals"
- Look at all keys: "the" (determiner), "sat" (action), "mat" (object)
- Pay most attention to "sat" because it's the relevant action
- Combine information weighted by attention scores
8. Practical Example with Real Meaning
# Sentence: "The cat chased the mouse" sentence = "The cat chased the mouse" words = sentence.lower().split() # Simulate what attention might learn print("Attention patterns the model might learn:") print("- 'cat' attends to 'chased' (subject-verb relationship)") print("- 'chased' attends to 'cat' and 'mouse' (verb-subject-object)") print("- 'mouse' attends to 'chased' (object-verb relationship)") print("- 'the' attends to following nouns ('cat', 'mouse')") # This allows the model to understand: # - Who did what to whom # - Grammatical relationships # - Semantic dependencies
Summary
Attention mechanism allows models to:
- Focus on relevant parts of the input
- Relate different words to each other
- Combine information based on relevance
- Understand long-range dependencies
The magic is in the learned Q, K, V matrices that transform word embeddings into queries, keys, and values that can interact meaningfully.
Top comments (0)