Google ML Coding: Hand-code Multi-Head Attention in NumPy
Question Description
Implement a from-scratch multi-head attention forward pass used in Transformer models. You will take batched queries, keys, and values (shapes: (batch_size, seq_len, embed_dim)), apply linear projections into Q, K, V, split into num_heads heads (head dim d_k = embed_dim / num_heads), compute scaled dot-product attention per head, concatenate heads, and apply an output projection. The attention formula you should follow is:
You should also accept an optional mask broadcastable to (batch_size, num_heads, seq_len, seq_len) so that masked positions are excluded from the softmax (use large negative values before softmax). The function must return the final output (batch_size, seq_len, embed_dim) and the per-head attention weights (batch_size, num_heads, seq_len, seq_len).
Flow you can expect in an interview:
- Project inputs with W_q, W_k, W_v (embed_dim -> embed_dim), reshape and transpose to (batch, num_heads, seq_len, d_k).
- Compute scaled dot-products, apply mask, softmax, and multiply by V per head.
- Concatenate heads, project with W_o, and return outputs and attention maps.
Skill signals the interviewer will look for: correct tensor reshaping and transposes, efficient batched matrix multiplies, numeric stability in softmax, correct masking and broadcasting, and clear handling of dimensions and edge cases (e.g., num_heads divides embed_dim).
Common Follow-up Questions
- •How would you add causal (autoregressive) masking so each position only attends to previous tokens, and what changes in indices/shape handling are required?
- •Modify the implementation to support attention dropout and layer normalization: where would you apply dropout and why, and how would you maintain numerical stability?
- •How would you implement efficient key-value caching for incremental decoding (serve one token at a time) while preserving batched dimensions and projection matrices?
- •Describe how you would extend this to support multi-query attention (shared K/V across heads) and what tensor reshapes or broadcasts would change.
- •How does backpropagation flow through your implementation? Point out where gradient shapes matter and common pitfalls (e.g., broadcasting into softmax or large negative mask values).
Related Questions
Explore More Questions
Practice This Question with AI
Get real-time hints, detailed requirements, and insightful analysis of the question.