Understanding Self-Attention in Transformers

Self-Attention (also called intra-attention) is a core mechanism in the Transformer architecture that enables a model to compute the relevance of every token in a sequence to every other token—in parallel. Unlike recurrent neural networks (RNNs) or LSTMs, which process sequences sequentially and struggle with long-term dependencies, self-attention directly models relationships between all pairs of tokens, regardless of their position in the sequence.

Self-attention is the foundation of modern natural language processing (NLP) models like BERT, GPT, and T5, and has been adapted for computer vision (Vision Transformer, ViT) and speech processing tasks.


I. Core Idea of Self-Attention

For a given sequence (e.g., a sentence of words), self-attention answers the question:

For each token in the sequence, how much should it “pay attention” to every other token (including itself)?

The output for each token is a weighted sum of all tokens in the sequence, where the weights reflect the importance of other tokens to the current token. This creates contextualized representations—tokens that have the same literal value but different contexts get different embeddings (e.g., “bank” in “river bank” vs. “bank account”).

Example: Sentence Context

Consider the sentence:

The animal didn’t cross the street because it was tired.

Self-attention will assign a high weight to the token it when processing tired, because it refers to the animal. Conversely, if the sentence were …because it was wide, self-attention would link it to the street instead.


II. Mathematical Formulation of Scaled Dot-Product Self-Attention

The original Transformer uses scaled dot-product attention—the most common form of self-attention. It operates on three learnable vector representations of the input sequence: Query (Q)Key (K), and Value (V).

Step 1: Define Query, Key, Value Vectors

Given an input sequence embedding matrix X of shape \((\text{seq_len}, d_{\text{model}})\) (where:

  • \(\text{seq_len}\) = number of tokens in the sequence,
  • \(d_{\text{model}}\) = dimension of each token embedding),

we generate three matrices by multiplying X with three learnable weight matrices (\(W_Q, W_K, W_V\)):

\(Q = X W_Q, \quad K = X W_K, \quad V = X W_V\)

  • Q (Query): Shape \((\text{seq_len}, d_k)\) → Represents what the current token is “looking for” (its interest in other tokens).
  • K (Key): Shape \((\text{seq_len}, d_k)\) → Represents what other tokens can “offer” (their relevance to the query).
  • V (Value): Shape \((\text{seq_len}, d_v)\) → Represents the actual content of each token (used to compute the weighted output).

In the original Transformer, \(d_k = d_v = d_{\text{model}} / h\) (where h = number of attention heads in multi-head attention).

Step 2: Compute Attention Scores

The attention score between token i (query) and token j (key) is the dot product of \(Q_i\) and \(K_j\). This measures how well the key of token j matches the query of token i.

The score matrix S is computed as:

\(S = Q K^T\)

  • Shape of S: \((\text{seq_len}, \text{seq_len})\) → \(S_{i,j}\) = score between token i and token j.

Step 3: Scale the Scores

Dot products of high-dimensional vectors can become very large, which flattens the softmax function (making all weights close to equal). To fix this, we scale the scores by \(\sqrt{d_k}\) (the square root of the dimension of Q and K):

\(S_{\text{scaled}} = \frac{Q K^T}{\sqrt{d_k}}\)

Step 4: Apply Softmax to Get Attention Weights

The scaled scores are passed through a softmax function to convert them into attention weights (values between 0 and 1 that sum to 1 for each row):

\(\text{Attention Weights} = \text{softmax}(S_{\text{scaled}}) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)\)

  • A weight of 1 means token i fully attends to token j; a weight of 0 means no attention.

Step 5: Compute Weighted Sum of Values

The final self-attention output is the weighted sum of the Value matrix using the attention weights:

\(\text{Self-Attention Output} = \text{Attention Weights} \times V\)

  • Shape of output: \((\text{seq_len}, d_v)\) → Each token’s output is a combination of all tokens, weighted by their relevance.

Full Formula for Scaled Dot-Product Self-Attention

\(\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V\)


III. Key Variations of Self-Attention

1. Masked Self-Attention

Used in the decoder of the Transformer for autoregressive tasks (e.g., text generation, machine translation). It prevents the model from “cheating” by attending to future tokens (tokens that come after the current position in the sequence).

How It Works

mask matrix (of shape \((\text{seq_len}, \text{seq_len})\)) is added to the scaled score matrix \(S_{\text{scaled}}\) before applying softmax. The mask sets the scores of future tokens to \(-\infty\), which becomes 0 after softmax:

\(S_{\text{masked}} = S_{\text{scaled}} + \text{mask}\)

  • For token i, the mask sets \(S_{\text{masked}}[i, j] = -\infty\) for all \(j > i\).

Use Cases

  • Autoregressive text generation (GPT, where each token is generated one at a time).
  • Decoder in machine translation (ensures the model only uses tokens generated so far to predict the next token).

2. Multi-Head Self-Attention

The original Transformer uses multi-head self-attention to capture multiple types of relationships between tokens (e.g., syntactic structure, semantic meaning, coreference) simultaneously. It splits \(Q, K, V\) into h smaller sub-vectors (heads), computes self-attention for each head independently, and concatenates the results.

Step-by-Step Process

  1. Split \(Q, K, V\) into heads: For each head i (from 1 to h):\(Q_i = Q[:, i \cdot d_k : (i+1) \cdot d_k], \quad K_i = K[:, i \cdot d_k : (i+1) \cdot d_k], \quad V_i = V[:, i \cdot d_v : (i+1) \cdot d_v]\)
    • \(d_k = d_{\text{model}} / h\), \(d_v = d_{\text{model}} / h\).
  2. Compute attention for each head:\(\text{Head}_i = \text{Attention}(Q_i, K_i, V_i)\)
  3. Concatenate heads: Combine the outputs of all heads into a single matrix:\(\text{Concat} = [\text{Head}_1; \text{Head}_2; \dots; \text{Head}_h]\)
    • Shape of \(\text{Concat}\): \((\text{seq_len}, d_{\text{model}})\).
  4. Apply linear projection: Use a learnable weight matrix \(W_O\) to project the concatenated output to the final dimension:\(\text{MultiHead}(Q, K, V) = \text{Concat} \times W_O\)

Why Multi-Head Attention Works

Each head learns a different attention pattern. For example:

  • One head might focus on subject-verb agreement (e.g., “The cat is black”).
  • Another head might focus on coreference (e.g., “The cat… it is black”).

3. Sparse Self-Attention

Standard self-attention has a time and memory complexity of \(O(n^2)\) (where \(n = \text{seq_len}\)), which is infeasible for very long sequences (e.g., 10,000+ tokens). Sparse self-attention reduces complexity by limiting attention to a subset of tokens instead of all pairs.

Common Sparse Variants

VariantMethodComplexityUse Case
Local AttentionEach token only attends to tokens in a fixed window around it (e.g., ±2 tokens).\(O(n \cdot w)\) (w = window size)Long sequences like documents or code.
Strided AttentionEach token attends to tokens at fixed intervals (e.g., every 10 tokens).\(O(n \cdot n/s)\) (s = stride)Balances local and global context.
Longformer AttentionCombines local window attention with global attention (a few tokens attend to all others).\(O(n \cdot w)\)Long documents, legal texts, or books.

IV. Self-Attention vs. Other Attention Mechanisms

Self-attention is a type of intra-attention (attention within a single sequence). It is distinct from other attention mechanisms used in deep learning:

Attention TypeDefinitionUse Case
Self-AttentionAttention between tokens in the same sequence.Contextual embedding (BERT, GPT), sequence classification.
Encoder-Decoder AttentionAttention between tokens in the encoder sequence (source) and decoder sequence (target).Machine translation (linking English words to French words), summarization.
Cross-AttentionAttention between two different sequences (e.g., image features and text tokens).Multimodal tasks (image captioning, visual question answering).

V. Self-Attention Implementation (Python with PyTorch)

Below is a minimal implementation of scaled dot-product self-attention and multi-head self-attention.

1. Scaled Dot-Product Self-Attention

python

运行

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, q, k, v, mask=None):
        # q, k, v shapes: (batch_size, seq_len, d_k)
        d_k = q.size(-1)
        
        # Step 1: Compute scaled dot-product scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (batch_size, seq_len, seq_len)
        
        # Step 2: Apply mask (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)  # Masked positions → -infinity
        
        # Step 3: Compute attention weights
        attn_weights = F.softmax(scores, dim=-1)  # (batch_size, seq_len, seq_len)
        
        # Step 4: Compute weighted sum of values
        output = torch.matmul(attn_weights, v)  # (batch_size, seq_len, d_v)
        
        return output, attn_weights

2. Multi-Head Self-Attention

python

运行

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        
        # Learnable weight matrices for Q, K, V projections
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        
        # Output projection matrix
        self.w_o = nn.Linear(d_model, d_model)
        
        # Scaled dot-product attention module
        self.attention = ScaledDotProductAttention()

    def split_heads(self, x):
        # Split x into n_heads: (batch_size, seq_len, d_model) → (batch_size, n_heads, seq_len, d_k)
        batch_size = x.size(0)
        return x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine heads back to d_model: (batch_size, n_heads, seq_len, d_k) → (batch_size, seq_len, d_model)
        batch_size = x.size(0)
        return x.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)

    def forward(self, x, mask=None):
        # Step 1: Project input to Q, K, V
        q = self.w_q(x)  # (batch_size, seq_len, d_model)
        k = self.w_k(x)
        v = self.w_v(x)
        
        # Step 2: Split into heads
        q = self.split_heads(q)  # (batch_size, n_heads, seq_len, d_k)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        # Step 3: Compute scaled dot-product attention
        attn_output, attn_weights = self.attention(q, k, v, mask)  # (batch_size, n_heads, seq_len, d_k)
        
        # Step 4: Combine heads and project output
        output = self.combine_heads(attn_output)  # (batch_size, seq_len, d_model)
        output = self.w_o(output)  # (batch_size, seq_len, d_model)
        
        return output, attn_weights

3. Test the Implementation

python

运行

# Hyperparameters
batch_size = 2
seq_len = 5
d_model = 128
n_heads = 4

# Dummy input sequence (batch_size, seq_len, d_model)
x = torch.randn(batch_size, seq_len, d_model)

# Initialize multi-head self-attention
mha = MultiHeadSelfAttention(d_model, n_heads)

# Forward pass (no mask for self-attention)
output, attn_weights = mha(x)

print(f"Input Shape: {x.shape}")  # (2, 5, 128)
print(f"Output Shape: {output.shape}")  # (2, 5, 128)
print(f"Attention Weights Shape: {attn_weights.shape}")  # (2, 4, 5, 5) → (batch, heads, seq_len, seq_len)


VI. Key Properties of Self-Attention

  1. Parallelization: Unlike RNNs/LSTMs, self-attention computes all token relationships in parallel, which speeds up training on GPUs.
  2. Long-Term Dependency Modeling: Self-attention has a constant path length between any two tokens (regardless of their distance in the sequence), whereas RNNs have a path length equal to the distance between tokens. This makes self-attention far better at capturing long-range relationships.
  3. Contextualization: Tokens get embeddings that depend on their context (e.g., “bank” has different embeddings in different sentences).
  4. Interpretability: Attention weights can be visualized to see which tokens a model focuses on (e.g., in translation, you can see which source words map to target words).

VII. Practical Applications of Self-Attention

  1. Natural Language Processing (NLP)
    • Text Classification: BERT uses bidirectional self-attention to create contextual embeddings for sentiment analysis, spam detection, etc.
    • Text Generation: GPT uses masked self-attention to generate coherent text token by token.
    • Machine Translation: The Transformer’s encoder uses self-attention to model source sentences, and the decoder uses encoder-decoder attention to link source and target tokens.
  2. Computer Vision
    • Image Classification: Vision Transformers (ViT) split images into patches, treat them as tokens, and use self-attention to model patch relationships.
    • Object Detection: DETR uses self-attention to detect objects without hand-designed anchors.
  3. Speech Processing
    • Speech Recognition: Self-attention models (e.g., Conformer) combine convolution and self-attention to model speech sequences.
  4. Time-Series Forecasting
    • Self-attention models capture long-term dependencies in time-series data (e.g., stock prices, weather data).

VIII. Summary

Limitations: \(O(n^2)\) complexity for long sequences (mitigated by sparse attention variants).

Self-Attention is a mechanism that models relationships between all pairs of tokens in a sequence, generating contextualized embeddings.

Scaled Dot-Product Attention is the core implementation—uses query, key, value vectors, scaled scores, and softmax weights.

Multi-Head Attention extends self-attention to capture multiple types of token relationships.

Masked Self-Attention is used in decoders for autoregressive tasks (prevents future token access).

Key Advantages: Parallelization, long-term dependency modeling, contextualization, and interpretability.



了解 Ruigu Electronic 的更多信息

订阅后即可通过电子邮件收到最新文章。

Posted in

Leave a comment