DEV Community

Vuk Rosić
Vuk Rosić

Posted on

Attention Mechanism Tutorial: From Simple to Advanced

Part 1: The Core Idea

Attention is like a spotlight - it helps models focus on what's important.

import torch import torch.nn.functional as F # Simple example: Which word is most important? sentence = ["I", "love", "pizza"] importance = torch.tensor([0.1, 0.3, 0.6]) # pizza is most important 
Enter fullscreen mode Exit fullscreen mode

Intuition: Instead of treating all words equally, attention assigns different weights to focus on what matters most.

Part 2: Basic Attention Weights

# Raw attention scores (how much to focus on each word) scores = torch.tensor([2.0, 1.0, 3.0]) # [I, love, pizza]  # Convert to probabilities (softmax) weights = F.softmax(scores, dim=0) print(weights) # [0.24, 0.09, 0.67] - pizza gets most attention 
Enter fullscreen mode Exit fullscreen mode

What happened: Softmax converts raw scores to probabilities that sum to 1.

Part 3: Weighted Combination

# Word representations (simplified vectors) words = torch.tensor([[1.0, 0.0], # "I"  [0.0, 1.0], # "love"  [1.0, 1.0]]) # "pizza"  # Apply attention weights attended = torch.sum(weights.unsqueeze(1) * words, dim=0) print(attended) # Mostly "pizza" representation 
Enter fullscreen mode Exit fullscreen mode

Intuition: We combine all word vectors, but "pizza" contributes most because it has the highest attention weight.

Part 4: Computing Attention Scores

# How similar are words? (dot product) query = torch.tensor([1.0, 1.0]) # What we're looking for key1 = torch.tensor([1.0, 0.0]) # "I" key2 = torch.tensor([0.0, 1.0]) # "love"  score1 = torch.dot(query, key1) # 1.0 score2 = torch.dot(query, key2) # 1.0 
Enter fullscreen mode Exit fullscreen mode

Theory: Attention scores measure how well a query matches each key.

Part 5: The Q, K, V Concept

# Three roles for each word: # Q (Query): "What am I looking for?" # K (Key): "What do I represent?" # V (Value): "What information do I carry?"  query = torch.tensor([1.0, 0.0]) # Looking for subject keys = torch.tensor([[1.0, 0.0], # "I" - matches query well  [0.0, 1.0]]) # "love" - doesn't match values = torch.tensor([[2.0, 3.0], # "I" carries this info  [1.0, 4.0]]) # "love" carries this info 
Enter fullscreen mode Exit fullscreen mode

Intuition: Query asks "what do I need?", Keys answer "what do I offer?", Values provide the actual information.

Part 6: One-Line Attention

# Complete attention in one line attention_output = torch.sum(F.softmax(torch.mv(keys, query), dim=0).unsqueeze(1) * values, dim=0) 
Enter fullscreen mode Exit fullscreen mode

What it does: Computes scores (query·keys), applies softmax, weights the values.

Part 7: Self-Attention Intuition

# In self-attention, each word can attend to every other word sentence = ["The", "cat", "sat"] # "cat" might attend to "sat" (what did the cat do?) # "sat" might attend to "cat" (who sat?) 
Enter fullscreen mode Exit fullscreen mode

Key insight: Words can look at each other to understand relationships and context.

Part 8: Multi-Head Attention (Simple)

# Multiple "attention heads" look for different things head1_query = torch.tensor([1.0, 0.0]) # Looking for subjects head2_query = torch.tensor([0.0, 1.0]) # Looking for actions  # Each head focuses on different aspects 
Enter fullscreen mode Exit fullscreen mode

Why multiple heads: Different heads can specialize in different types of relationships (subject-verb, adjective-noun, etc.).

Part 9: Scaling Up

# Real sentences have many words seq_len = 10 # 10 words in sentence d_model = 64 # Each word is 64-dimensional vector  # Q, K, V matrices transform word vectors Q = torch.randn(seq_len, d_model) # Queries for each word K = torch.randn(seq_len, d_model) # Keys for each word V = torch.randn(seq_len, d_model) # Values for each word 
Enter fullscreen mode Exit fullscreen mode

Scale: Real models use hundreds of dimensions and thousands of words.

Part 10: Attention Matrix

# Attention scores between all word pairs attention_scores = torch.mm(Q, K.transpose(0, 1)) # [10, 10] matrix attention_weights = F.softmax(attention_scores, dim=1) # Each row sums to 1  # Row i, column j = how much word i attends to word j 
Enter fullscreen mode Exit fullscreen mode

Visualization: Each row shows where one word "looks" in the sentence.

Part 11: Why Attention Works

# Traditional RNN: Information flows sequentially # Word 1 → Word 2 → Word 3 → Word 4  # Attention: All words can interact directly # Word 1 ↔ Word 2 ↔ Word 3 ↔ Word 4 
Enter fullscreen mode Exit fullscreen mode

Advantage: No information loss over long distances, parallel processing.

Part 12: Putting It All Together

# Complete self-attention step by step def simple_attention(X): Q = X # Queries (simplified)  K = X # Keys  V = X # Values  scores = torch.mm(Q, K.transpose(0, 1)) # Compute similarities  weights = F.softmax(scores, dim=1) # Convert to probabilities  output = torch.mm(weights, V) # Weighted combination  return output # Usage word_vectors = torch.randn(5, 8) # 5 words, 8 dimensions each attended_vectors = simple_attention(word_vectors) 
Enter fullscreen mode Exit fullscreen mode

Result: Each word vector is now updated with information from all other words, weighted by attention.

Key Takeaways

  1. Attention = Weighted Average: Focus more on important parts
  2. Q·K = Similarity: How well query matches key
  3. Softmax = Probability: Convert scores to weights that sum to 1
  4. Weighted V = Output: Combine values using attention weights
  5. Self-Attention = Words talking to each other: Every word can attend to every other word

This foundation prepares you for transformer models, which are built entirely on attention mechanisms!

Understanding Attention: From Words to Vectors

1. Word Embeddings - The Foundation

import torch import torch.nn as nn import torch.nn.functional as F # Sample sentence: "The cat sat on the mat" vocab = {"<pad>": 0, "the": 1, "cat": 2, "sat": 3, "on": 4, "mat": 5} sentence = [1, 2, 3, 4, 1, 5] # token IDs  # Create embeddings vocab_size = len(vocab) embed_dim = 64 embedding = nn.Embedding(vocab_size, embed_dim) # Convert tokens to vectors tokens = torch.tensor(sentence) embeddings = embedding(tokens) print(f"Shape: {embeddings.shape}") # [6, 64] print(f"'cat' vector: {embeddings[1][:8]}...") # First 8 dimensions 
Enter fullscreen mode Exit fullscreen mode

Each word becomes a 64-dimensional vector that captures semantic meaning.

2. The Q, K, V Matrices - Core of Attention

# Attention dimensions d_model = 64 num_heads = 8 d_k = d_model // num_heads # 8  # Linear transformations to create Q, K, V W_q = nn.Linear(d_model, d_model, bias=False) W_k = nn.Linear(d_model, d_model, bias=False) W_v = nn.Linear(d_model, d_model, bias=False) # Transform embeddings Q = W_q(embeddings) # Queries: "What am I looking for?" K = W_k(embeddings) # Keys: "What do I represent?" V = W_v(embeddings) # Values: "What information do I carry?"  print(f"Q shape: {Q.shape}") # [6, 64] print(f"K shape: {K.shape}") # [6, 64] print(f"V shape: {V.shape}") # [6, 64] 
Enter fullscreen mode Exit fullscreen mode

Intuition:

  • Q (Query): "What information does this word need?"
  • K (Key): "What kind of information does this word offer?"
  • V (Value): "What actual information does this word contain?"

3. Computing Attention Scores

# Reshape for multi-head attention batch_size, seq_len = 1, 6 Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8] K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8] V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8]  # Attention scores: How much should each word pay attention to others? scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) print(f"Attention scores shape: {scores.shape}") # [1, 8, 6, 6]  # Example: How much does "cat" attend to each word? cat_attention = scores[0, 0, 1, :] # First head, "cat" position words = ["the", "cat", "sat", "on", "the", "mat"] for i, word in enumerate(words): print(f"cat -> {word}: {cat_attention[i]:.3f}") 
Enter fullscreen mode Exit fullscreen mode

4. Softmax and Weighted Values

# Convert scores to probabilities attention_weights = F.softmax(scores, dim=-1) print(f"Attention weights shape: {attention_weights.shape}") # [1, 8, 6, 6]  # Apply attention to values attended_values = torch.matmul(attention_weights, V) # [1, 8, 6, 8]  # Concatenate heads and project back attended_values = attended_values.transpose(1, 2).contiguous().view( batch_size, seq_len, d_model) # [1, 6, 64]  print(f"Final attended values shape: {attended_values.shape}") # Show attention pattern for "cat" print("\nAttention pattern for 'cat':") cat_weights = attention_weights[0, 0, 1, :] # First head for i, word in enumerate(words): print(f" {word}: {cat_weights[i]:.3f}") 
Enter fullscreen mode Exit fullscreen mode

5. Complete Self-Attention Implementation

class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x): batch_size, seq_len, d_model = x.size() # Linear transformations  Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # Scaled dot-product attention  scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5) attention_weights = F.softmax(scores, dim=-1) attended_values = torch.matmul(attention_weights, V) # Concatenate heads  attended_values = attended_values.transpose(1, 2).contiguous().view( batch_size, seq_len, d_model) # Final projection  output = self.W_o(attended_values) return output, attention_weights # Usage attention = MultiHeadAttention(d_model=64, num_heads=8) output, weights = attention(embeddings.unsqueeze(0)) print(f"Output shape: {output.shape}") # [1, 6, 64] 
Enter fullscreen mode Exit fullscreen mode

6. Visualizing Attention Patterns

# Extract attention weights for visualization attention_matrix = weights[0, 0].detach().numpy() # First head words = ["the", "cat", "sat", "on", "the", "mat"] print("Attention Matrix (first head):") print("From -> To:") for i, from_word in enumerate(words): print(f"{from_word:>4}: ", end="") for j, to_word in enumerate(words): print(f"{attention_matrix[i,j]:.2f} ", end="") print() 
Enter fullscreen mode Exit fullscreen mode

7. Key Insights

What happens in attention?

  1. Each word creates a query (what it's looking for)
  2. Each word creates a key (what it represents)
  3. We compute similarity between queries and keys
  4. Higher similarity = more attention
  5. We use attention weights to combine values (actual information)

Example: When processing "cat", the model might:

  • Query: "I need information about animals"
  • Look at all keys: "the" (determiner), "sat" (action), "mat" (object)
  • Pay most attention to "sat" because it's the relevant action
  • Combine information weighted by attention scores

8. Practical Example with Real Meaning

# Sentence: "The cat chased the mouse" sentence = "The cat chased the mouse" words = sentence.lower().split() # Simulate what attention might learn print("Attention patterns the model might learn:") print("- 'cat' attends to 'chased' (subject-verb relationship)") print("- 'chased' attends to 'cat' and 'mouse' (verb-subject-object)") print("- 'mouse' attends to 'chased' (object-verb relationship)") print("- 'the' attends to following nouns ('cat', 'mouse')") # This allows the model to understand: # - Who did what to whom # - Grammatical relationships # - Semantic dependencies 
Enter fullscreen mode Exit fullscreen mode

Summary

Attention mechanism allows models to:

  • Focus on relevant parts of the input
  • Relate different words to each other
  • Combine information based on relevance
  • Understand long-range dependencies

The magic is in the learned Q, K, V matrices that transform word embeddings into queries, keys, and values that can interact meaningfully.

Top comments (0)