Multi-headed Attention


  • Dot Product Attention
  • Self-Attention
  • Multi-Headed Scaled Dot-Product Attention

Implement multi-headed scaled dot-product attention according to the following specs

class MultiHeadAttention(nn.Module):
      attn = MultiHeadAttention(embed_dim, num_heads=2)
      # self-attention
      data = torch.randn(batch_size, sequence_length, embed_dim)
      self_attn_output = attn(query=data, key=data, value=data)
      # attention using two inputs
      other_data = torch.randn(batch_size, sequence_length, embed_dim)
      attn_output = attn(query=data, key=other_data, value=other_data)
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        Construct a new MultiHeadAttention layer.
         - embed_dim: Dimension of the token embedding
         - num_heads: Number of attention heads
         - dropout: Dropout probability
        assert embed_dim % num_heads == 0
        self.key = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.num_heads = num_heads
        self.dropout = nn.Dropout(p=dropout)
        self.scale = math.sqrt(embed_dim / num_heads)
    def forward(self, query, key, value, attn_mask=None):
        N, S, D = query.shape
        N, T, D = value.shape
        H = self.num_heads
        # Compute key, query and value matrices from sequences
        K = self.key(key).view(N, T, H, D//H).moveaxis(1, 2)
        Q = self.query(query).view(N, S, H, D//H).moveaxis(1, 2)
        V = self.value(value).view(N, T, H, D//H).moveaxis(1, 2)
        # (N,H,S,D/H) @ (N,H,D/H,T) -> (N,H,S,T)
        Y = Q @ K.transpose(2, 3) / self.scale
        if attn_mask is not None:
            # Ensure small probabilities in softmax
            Y = Y.masked_fill(attn_mask==0, float("-inf"))
        # NOTE: Assignment says apply dropout after attention output. That does
        # not work so dropout is applied right after softmax.
        # (N,H,S,T) @ (N,H,T,D/H) -> (N,H,S,D/H)
        Y = self.dropout(F.softmax(Y, dim=-1)) @ V
        output = self.proj(Y.moveaxis(1, 2).reshape(N, S, D))
        return output

Implement the Positional Encoding layer according to the following specs

class PositionalEncoding(nn.Module):
    Encodes information about the positions of the tokens in the sequence. In
    this case, the layer has no learnable parameters, since it is a simple
    function of sines and cosines.
    def __init__(self, embed_dim, dropout=0.1, max_len=5000):
        Construct the PositionalEncoding layer.
         - embed_dim: the size of the embed dimension
         - dropout: the dropout value
         - max_len: the maximum possible length of the incoming sequence
        self.dropout = nn.Dropout(p=dropout)
        assert embed_dim % 2 == 0
        # Create an array with a "batch dimension" of 1 (which will broadcast
        # across all examples in the batch).
        pe = torch.zeros(1, max_len, embed_dim)
        # Get col idx range (i) and powers
        i = torch.arange(max_len)[:, None]
        pows = torch.pow(10000, -torch.arange(0, embed_dim, 2) / embed_dim)
        # Compute positional values sin/cos
        pe[0, :, 0::2] = torch.sin(i * pows)
        pe[0, :, 1::2] = torch.cos(i * pows)
        # Make sure the positional encodings will be saved with the model
        # parameters (mostly for completeness).
        self.register_buffer('pe', pe)
    def forward(self, x):
        Element-wise add positional embeddings to the input sequence.
         - x: the sequence fed to the positional encoder model, of shape
              (N, S, D), where N is the batch size, S is the sequence length and
              D is embed dim
         - output: the input sequence + positional encodings, of shape (N, S, D)
        N, S, D = x.shape
        output = torch.empty((N, S, D))
        output = x +[:, :S]
        output = self.dropout(output)
        return output

Captioning Transformer

import numpy as np
import copy
import torch
import torch.nn as nn
from ..transformer_layers import *
class CaptioningTransformer(nn.Module):
    A CaptioningTransformer produces captions from image features using a
    Transformer decoder.
    The Transformer receives input vectors of size D, has a vocab size of V,
    works on sequences of length T, uses word vectors of dimension W, and
    operates on minibatches of size N.
    def __init__(self, word_to_idx, input_dim, wordvec_dim, num_heads=4,
                 num_layers=2, max_length=50):
        Construct a new CaptioningTransformer instance.
        - word_to_idx: A dictionary giving the vocabulary. It contains V entries.
          and maps each string to a unique integer in the range [0, V).
        - input_dim: Dimension D of input image feature vectors.
        - wordvec_dim: Dimension W of word vectors.
        - num_heads: Number of attention heads.
        - num_layers: Number of transformer layers.
        - max_length: Max possible sequence length.
        vocab_size = len(word_to_idx)
        self._null = word_to_idx["<NULL>"]
        self._start = word_to_idx.get("<START>", None)
        self._end = word_to_idx.get("<END>", None)
        self.visual_projection = nn.Linear(input_dim, wordvec_dim)
        self.embedding = nn.Embedding(vocab_size, wordvec_dim, padding_idx=self._null)
        self.positional_encoding = PositionalEncoding(wordvec_dim, max_len=max_length)
        decoder_layer = TransformerDecoderLayer(input_dim=wordvec_dim, num_heads=num_heads)
        self.transformer = TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output = nn.Linear(wordvec_dim, vocab_size)
    def _init_weights(self, module):
        Initialize the weights of the network.
        if isinstance(module, (nn.Linear, nn.Embedding)):
  , std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
        elif isinstance(module, nn.LayerNorm):
    def forward(self, features, captions):
        Given image features and caption tokens, return a distribution over the
        possible tokens for each timestep. Note that since the entire sequence
        of captions is provided all at once, we mask out future timesteps.
         - features: image features, of shape (N, D)
         - captions: ground truth captions, of shape (N, T)
         - scores: score for each token at each timestep, of shape (N, T, V)
        N, T = captions.shape
        # Embed the captions.
        # shape: [N, T] -> [N, T, W]
        caption_embeddings = self.embedding(captions)
        caption_embeddings = self.positional_encoding(caption_embeddings)
        # Project image features into the same dimension as the text embeddings.
        # shape: [N, D] -> [N, W] -> [N, 1, W]
        projected_features = self.visual_projection(features).unsqueeze(1)
        # An additive mask for masking the future (one direction).
        # shape: [T, T]
        tgt_mask = torch.tril(torch.ones(T, T,
        # Apply the Transformer decoder to the caption, allowing it to also
        # attend to image features.
        features = self.transformer(tgt=caption_embeddings,
        # Project to scores per token.
        # shape: [N, T, W] -> [N, T, V]
        scores = self.output(features)
        return scores
    def sample(self, features, max_length=30):
        Given image features, use greedy decoding to predict the image caption.
         - features: image features, of shape (N, D)
         - max_length: maximum possible caption length
         - captions: captions for each example, of shape (N, max_length)
        with torch.no_grad():
            features = torch.Tensor(features)
            N = features.shape[0]
            # Create an empty captions tensor (where all tokens are NULL).
            captions = self._null * np.ones((N, max_length), dtype=np.int32)
            # Create a partial caption, with only the start token.
            partial_caption = self._start * np.ones(N, dtype=np.int32)
            partial_caption = torch.LongTensor(partial_caption)
            # [N] -> [N, 1]
            partial_caption = partial_caption.unsqueeze(1)
            for t in range(max_length):
                # Predict the next token (ignoring all other time steps).
                output_logits = self.forward(features, partial_caption)
                output_logits = output_logits[:, -1, :]
                # Choose the most likely word ID from the vocabulary.
                # [N, V] -> [N]
                word = torch.argmax(output_logits, axis=1)
                # Update our overall caption and our current partial caption.
                captions[:, t] = word.numpy()
                word = word.unsqueeze(1)
                partial_caption =[partial_caption, word], dim=1)
            return captions
class TransformerDecoderLayer(nn.Module):
    A single layer of a Transformer decoder, to be used with TransformerDecoder.
    def __init__(self, input_dim, num_heads, dim_feedforward=2048, dropout=0.1):
        Construct a TransformerDecoderLayer instance.
         - input_dim: Number of expected features in the input.
         - num_heads: Number of attention heads
         - dim_feedforward: Dimension of the feedforward network model.
         - dropout: The dropout value.
        self.self_attn = MultiHeadAttention(input_dim, num_heads, dropout)
        self.multihead_attn = MultiHeadAttention(input_dim, num_heads, dropout)
        self.linear1 = nn.Linear(input_dim, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, input_dim)
        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.norm3 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.activation = nn.ReLU()
    def forward(self, tgt, memory, tgt_mask=None):
        Pass the inputs (and mask) through the decoder layer.
        - tgt: the sequence to the decoder layer, of shape (N, T, W)
        - memory: the sequence from the last layer of the encoder, of shape (N, S, D)
        - tgt_mask: the parts of the target sequence to mask, of shape (T, T)
        - out: the Transformer features, of shape (N, T, W)
        # Perform self-attention on the target sequence (along with dropout and
        # layer norm).
        tgt2 = self.self_attn(query=tgt, key=tgt, value=tgt, attn_mask=tgt_mask)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        # Attend to both the target sequence and the sequence from the last
        # encoder layer.
        tgt2 = self.multihead_attn(query=tgt, key=memory, value=memory)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        # Pass
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers):
        self.layers = clones(decoder_layer, num_layers)
        self.num_layers = num_layers
    def forward(self, tgt, memory, tgt_mask=None):
        output = tgt
        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask)
        return output