Introduction

I have always found a disconnect between the theoretical understanding of attention mechanisms, the illustrations and their practical code implementation in models like GPT and wasn’t able to merge all the knowledge while I was trying to understand multi-head attention mechanism in-depth. Since transformers work on batched input, they things change up a bit and it is always confusion to relate the illustration and theory to the code implementation. In this post, I aim to bridge this gap by providing a simple and intuitive explanation of the vanilla as well as causal attention mechanism in GPT models.

Let’s break it down step-by-step, starting from the input sequence and moving through the entire process.

Mathematical Representation of Attention Mechanism

We all know the attention mechanism in transformers can be mathematically represented as follows:

\[Y = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V\]

here \(Q\), \(K\), and \(V\) are the query, key, and value matrices, respectively. The softmax function is applied to the scaled dot-product of \(Q\) and \(K\), divided by the square root of the dimension of the key matrix. The output is then multiplied by the value matrix to obtain the final output \(Y\). Here i am interested in the dimensions of the input data as well as that of \(Q, K\) and \(V\) and how they are used in the attention mechanism. and what dimension the operations in the attention mechanism are performed on.

Input Data Representation

as we know the input data to the transformer model is represented as a sequence of tokens. Each token is embedded into a vector of dimension n_embd. The input is represented as a matrix X of shape (batch_size, seq_length, n_embd). This matrix is then passed through linear transformations to obtain the query, key, and value matrices \(Q\), \(K\), and \(V\), respectively.

Given that we have a trained tokenizer, tokenized the input sequence and passed it through an input embedding layer (nn.embedding()), we can represent the input data as a matrix \(X\) of shape (batch_size, seq_length, n_embd). This matrix contains the embeddings of the input tokens, where each row corresponds to a token and each column represents the embedding dimension.

suppose we have an input X with batch size is 4, with sequence tokens (eq_length) 12, after passing it through our tokenizer and embedding layer, we end up with final input of [4,12,6] (batch_size, seq_length, n_embd) The input data matrix \(X\) is represented as shown in the figure 01 below.

Input Data Matrix X
Figure 01: Input Data Matrix X: where each row corresponds to a token and each column represents the embedding dimension

In the figure 01 above, the \(12 \times 6\) matrix represents a sinlge batch of the input data matrix \(X\), where each row corresponds to a token and each column represents the embedding dimension. we have a total of 4 batches of the input data matrix \(X\).

Multi-Head Self-Attention

Many blogs and articles mention that the input X is projected into queries, keys, and values of equal dimension, which are then used to compute attention based on a given equation. They often state that this process is repeated in parallel through multi-head attention.

However, in practice, multi-head attention is implemented more cleverly. The primary goal is to allow the model to attend to information from different representation subspaces simultaneously at different positions. This approach enables the model to learn richer and more diverse representations of the input data.

This concept will become clearer once we delve into the implementation details of multi-head attention.

Since each token as a an embedding of dimension n_embd which is 6 in our case, what is done is we further divide the each batch of \(12 \times 6\) (seq_length, n_embd) matrix into a ([seq_length, num_h ,h_size]) so embedding dimension is divided into num_h heads, where each head has a dimension of h_size. This is done by reshaping the input data matrix \(X\) into a tensor of shape (batch_size, seq_length, num_h, h_size). and for each batch we reshape it into [num_h, seq_length, h_size]

so in order for multi-head attention to work, the embedding dimension n_embd should be divisible by the number of heads num_h. in our case the embedding dimension is 6 and we have decide number of heads to be 3, so we end up with a tensor of shape [3,12,2] (num_h, seq_length, h_size) for each batch. as shown in figure 02 below.

Figure 02:the embedding dimension for each batch is divided into multiple heads, allowing the model to attend to different aspects of the tokens simultaneously

For each batch of the input data matrix X of shape [12,6] passed through linear transformations to obtain the query, key, and value matrices Q, K, and V, each of shape [12,6] respectively. \(QKV\) are then reshaped it into a tensor of shape [3,12,2] (num_h, seq_length, h_size).

Did you notice something ? Instead of projecting the input data X into multiple \(QKV\) matrices for each head , we do the projection once and divide the embedding dimension into multiple heads.

Info: Did you notice something ? Instead of projecting the input data \(X\) into multiple \(QKV\) matrices for each head , we do the projection once and divide the embedding dimension into multiple heads. The reason for this is that it allows the model to attend to information from different representation subspaces simultaneously at different positions. This approach enables the model to learn richer and more diverse representations of the input data.

Since embeddings of each token represent a different aspect of the token, dividing the embedding dimension into multiple heads allows the model to attend to different aspects (sub embedding space) of the token simultaneously.

This way, One aspect of a token might be more attentive to a certain aspect of an other token while a different aspect of the same token might learn to pay attention to certain aspect of a different token in a different head. This is the essence of multi-head attention.

We project each batch into \(QKV\) matrices, each of shape [12,6] (seq_length, n_embd). The \(QKV\) matrices are then reshaped into a tensor of shape [3,12,2] (num_h, seq_length, h_size) for each batch.

and inside each head, we first compute the scaled dot-product attention between \(Q\) [12,2] and \(K\) [12,2] to get the attention scores of shape [12,12] as shown in figure 03 below.

Input Data Matrix X
Figure 03: Scaled Dot-Product Attention Q@K for Each Head

basically – because we chopped up the original token embeddings of each token into multiple heads of size h_size– for each head we get different attention scores since we are attending to different aspects of each token in each head. as highlighted in the figure 04 below.

Input Data Matrix X
Figure 04: Different aspect of the tokens give rise to unique attention score matrices

The attention scores are then passed through the softmax function to obtain the attention weights. that are then multiplied by the value matrix \(V\) [12, 2] to obtain the output of the head \(y\) of shape [12,2] (figure 04).

Input Data Matrix X
Figure 05: Attention score is multiplied with V to aggregate context between concepts (sub-embedding of each token in that head)

This process is done for each of 3 heads, and the outputs are concatenated and reshaped to ge the final output of the multi-head attention mechanism as shown in figure 06 below.

Input Data Matrix X
Figure 06: partial Context aggregation at sub-embedding level of each head are combined to get the final context aware embeddings of tokens

Causal Attention Mechanism

The causal attention mechanism is a variant of the multi-head attention mechanism that restricts the model from attending to tokens that come after the current token. This is achieved by masking the attention scores of the tokens that come after the current token (Figure 07). The masking is done by setting the attention scores to negative infinity -inf before passing them through the softmax function and multiplying it with the \(V\) matrix.

Input Data Matrix X
Figure 07: Each token (in rows) only have the attention score for tokens before it to prevent looking into the future

Code Implementation

The code implementation of the multi-head attention mechanism in PyTorch is shown below:

class CausalSelfAttention(nn.Module):
    def __init__(self, config:GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, config.n_embd*3)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer('bias',
                             tril(ones(config.block_size, config.block_size)).view(
                                    1, 1, 
                                    config.block_size,
                                    config.block_size)) # create a lower triangular matrix of ones
        
    def forward(self, x):
        B, T, C = x.size()   # B: batch size, T: sequence length, C: n_embd [4, 12,6]
        
        qkv = self.c_attn(x)  # [4, 12, 6] -> [4, 12, 6*3] -> [4, 12, 18]
        q, k, v = qkv.split(C, dim=2) # [4, 12, 6], [4, 12, 6], [4, 12, 6]
        k = k.view(B, T, self.n_head, C//self.n_head).transpose(1, 2) # [4, 12, 3, 2] transpose-> [4, 3, 12, 2] 
        q = q.view(B, T, self.n_head, C//self.n_head).transpose(1, 2) # [4, 12, 3, 2] transpose-> [4, 3, 12, 2] 
        v = v.view(B, T, self.n_head, C//self.n_head).transpose(1, 2)# [4, 12, 3, 2] transpose-> [4, 3, 12, 2] 
        att = (q @ k.transpose(-2, -1)) * (k.size(-1) ** -0.5) # [4, 3, 12, 12]
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) # [4, 3, 12, 12]replace zero with -inf
        att = nn.functional.softmax(att, dim=-1) # [4, 3, 12, 12] softmax
        y = att @ v # [4, 3, 12, 12] @ [4, 3, 12, 2] -> [4, 3, 12, 2]
        y = y.transpose(1, 2).contiguous().view(B, T, C) # [4, 3, 12, 2] transpose-> [4, 12, 3, 2] view-> [4, 12, 6]
        y = self.c_proj(y) # [4, 12, 6] -> [4, 12, 6] learnable linear layer
        return y   # [4, 12, 6]

initialization

class CausalSelfAttention(nn.Module):
    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, config.n_embd * 3)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer('bias',
                             tril(ones(config.block_size, config.block_size)).view(
                                    1, 1, config.block_size, config.block_size))

Asserting Divisibility:

Ensure that the embedding dimension (n_embd) is divisible by the number of heads (n_head). This is important for splitting the embeddings into multiple heads. Linear Layers for Projections:

self.c_attn:

A linear layer to project the input into queries, keys, and values. The output dimension is three times the embedding dimension to accommodate Q, K, and V.

self.c_proj:

A linear layer to project the concatenated outputs of the multi-head attention back to the original embedding dimension.

Registering the Causal Mask:

self.register_buffer('bias', tril(ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))

Creates a lower triangular matrix of ones using tril(ones(config.block_size, config.block_size)) to serve as a causal mask. The name is said to ‘bias’ to match the naming scheme of GPT to lead pre-trained weight ^_~

The mask is reshaped and registered as a buffer, which means it won’t be updated during training but is persistent in the model’s state. ( some mumbo jumbo for leading GPT weights more on this in the next post or )

Forward Pass

def forward(self, x):
    B, T, C = x.size()
    
    qkv = self.c_attn(x)
    q, k, v = qkv.split(C, dim=2)
    k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
    att = (q @ k.transpose(-2, -1)) * (k.size(-1) ** -0.5)
    att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
    att = nn.functional.softmax(att, dim=-1)
    y = att @ v
    y = y.transpose(1, 2).contiguous().view(B, T, C)
    y = self.c_proj(y)
    return y

Input Dimensions:

B, T, C = x.size(): Extract the batch size (B), sequence length (T), and embedding dimension (C) from the input tensor x.

Linear Projection to Q, K, V:

qkv = self.c_attn(x): Apply the linear layer to project the input into queries (q), keys (k), and values (v). q, k, v = qkv.split(C, dim=2): Split the concatenated qkv tensor into separate q, k, and v tensors. Reshaping and Transposing for Multi-Head Attention:

k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2): Reshape k to [B, T, nh, hs] and then transpose to [B, nh, T, hs].

Similar operations are performed for q and v.

Scaled Dot-Product Attention:

att = (q @ k.transpose(-2, -1)) * (k.size(-1) ** -0.5): Compute the attention scores using the dot product of q and k, scaled by the square root of the head size. att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float(‘-inf’)): Apply the causal mask to ensure that each position can only attend to previous positions. att = nn.functional.softmax(att, dim=-1): Apply softmax to obtain the attention weights.

Apply Attention Weights to Values:

y = att @ v: Compute the weighted sum of the values using the attention weights.

Combining Heads:

y = y.transpose(1, 2).contiguous().view(B, T, C): Transpose and reshape the output to combine the heads back into the original embedding dimension.

Final Linear Projection:

y = self.c_proj(y): Apply the final linear projection to produce the output of the attention mechanism.