ml coding
Google
DeepMind
OpenAI

Google ML Coding: Hand-code Multi-Head Attention in NumPy

Topics:
Matrix Multiplication
Tensor Reshaping
Attention Mechanisms
Roles:
Machine Learning Engineer
ML Engineer
Research Engineer
Experience:
Entry Level
Mid Level
Senior

Question Description

Implement a from-scratch multi-head attention forward pass used in Transformer models. You will take batched queries, keys, and values (shapes: (batch_size, seq_len, embed_dim)), apply linear projections into Q, K, V, split into num_heads heads (head dim d_k = embed_dim / num_heads), compute scaled dot-product attention per head, concatenate heads, and apply an output projection. The attention formula you should follow is:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V

You should also accept an optional mask broadcastable to (batch_size, num_heads, seq_len, seq_len) so that masked positions are excluded from the softmax (use large negative values before softmax). The function must return the final output (batch_size, seq_len, embed_dim) and the per-head attention weights (batch_size, num_heads, seq_len, seq_len).

Flow you can expect in an interview:

  • Project inputs with W_q, W_k, W_v (embed_dim -> embed_dim), reshape and transpose to (batch, num_heads, seq_len, d_k).
  • Compute scaled dot-products, apply mask, softmax, and multiply by V per head.
  • Concatenate heads, project with W_o, and return outputs and attention maps.

Skill signals the interviewer will look for: correct tensor reshaping and transposes, efficient batched matrix multiplies, numeric stability in softmax, correct masking and broadcasting, and clear handling of dimensions and edge cases (e.g., num_heads divides embed_dim).

Common Follow-up Questions

  • How would you add causal (autoregressive) masking so each position only attends to previous tokens, and what changes in indices/shape handling are required?
  • Modify the implementation to support attention dropout and layer normalization: where would you apply dropout and why, and how would you maintain numerical stability?
  • How would you implement efficient key-value caching for incremental decoding (serve one token at a time) while preserving batched dimensions and projection matrices?
  • Describe how you would extend this to support multi-query attention (shared K/V across heads) and what tensor reshapes or broadcasts would change.
  • How does backpropagation flow through your implementation? Point out where gradient shapes matter and common pitfalls (e.g., broadcasting into softmax or large negative mask values).

Related Questions

1Implement scaled dot-product attention (single-head) with masking and numerical stability
2Write a batched self-attention forward and backward pass in NumPy for small transformers
3Optimize multi-head attention for memory and runtime: fused projections and attention kernels
4Implement causal multi-head attention for autoregressive generation with KV caching
5Explain and implement positional encodings and integrate them into attention inputs

Explore More Questions

Practice This Question with AI

Get real-time hints, detailed requirements, and insightful analysis of the question.

Multi-Head Attention Implementation — Google (NumPy) | Voker