From-Scratch PyTorch Transformer — Apple Interview
Question Description
Overview
You are asked to implement a runnable, from-scratch Transformer in PyTorch suitable for sequence-to-sequence tasks. The deliverable must include a fully implemented Multi-Head Attention module and a complete Transformer (encoder–decoder) assembly. The Multi-Head Attention should accept query, key, value tensors of shapes (B, S_q, D), (B, S_k, D), (B, S_k, D) respectively, support H heads with d_head = D / H, and return (output, attn_weights) where output has shape (B, S_q, D) and attn_weights has shape (B, H, S_q, S_k).
High-level flow
You will typically build in stages: implement and validate the Multi-Head Attention (including query/key/value projections, head splitting/reshaping, scaled dot-product, masking, and final projection). Next, implement encoder and decoder layers that compose attention, residual connections, LayerNorm, and position-wise feed-forward networks. Finally, assemble positional encoding, stacked encoder/decoder layers, and a forward pass that accepts src, tgt, src_mask, tgt_mask and returns decoder representations of shape (B, S_tgt, D).
What you must show / Skill signals
You must demonstrate correct tensor reshaping (split/merge heads), mask alignment and broadcasting, numerical stability in softmax (scale by sqrt(d_head)), residual connections with normalization, and clean PyTorch module structure (init and forward). Knowledge of positional encoding (sinusoidal or learned), causal masking for decoder self-attention, and debugging shape/attention-weight outputs will be evaluated. Include unit-like shape checks and ensure attention weights are exposed per head for inspection.
Common Follow-up Questions
- •How would you modify the model to use relative positional encodings instead of sinusoidal or absolute learned embeddings?
- •Explain how you would implement efficient masking for variable-length inputs and how src_mask / tgt_mask should be shaped and broadcast to (B, H, S_q, S_k).
- •How can you optimize memory and compute for large D and many heads (e.g., gradient checkpointing, fused projections, or mixed precision)?
- •If attention weights are numerically unstable, what changes would you make to the scaled dot-product or masking strategy to improve stability?
- •How would you extend the decoder to perform incremental (autoregressive) decoding for inference while reusing cached key/value tensors?
Related Questions
Explore More Questions
Practice This Question with AI
Get real-time hints, detailed requirements, and insightful analysis of the question.