ml coding
Apple
Google
Meta

From-Scratch PyTorch Transformer — Apple Interview

Topics:
Multi-Head Attention
Residual Connections
Tensor Reshaping
Roles:
Software Engineer
ML Engineer
Research Engineer
Experience:
Mid Level
Senior
Staff

Question Description

Overview

You are asked to implement a runnable, from-scratch Transformer in PyTorch suitable for sequence-to-sequence tasks. The deliverable must include a fully implemented Multi-Head Attention module and a complete Transformer (encoder–decoder) assembly. The Multi-Head Attention should accept query, key, value tensors of shapes (B, S_q, D), (B, S_k, D), (B, S_k, D) respectively, support H heads with d_head = D / H, and return (output, attn_weights) where output has shape (B, S_q, D) and attn_weights has shape (B, H, S_q, S_k).

High-level flow

You will typically build in stages: implement and validate the Multi-Head Attention (including query/key/value projections, head splitting/reshaping, scaled dot-product, masking, and final projection). Next, implement encoder and decoder layers that compose attention, residual connections, LayerNorm, and position-wise feed-forward networks. Finally, assemble positional encoding, stacked encoder/decoder layers, and a forward pass that accepts src, tgt, src_mask, tgt_mask and returns decoder representations of shape (B, S_tgt, D).

What you must show / Skill signals

You must demonstrate correct tensor reshaping (split/merge heads), mask alignment and broadcasting, numerical stability in softmax (scale by sqrt(d_head)), residual connections with normalization, and clean PyTorch module structure (init and forward). Knowledge of positional encoding (sinusoidal or learned), causal masking for decoder self-attention, and debugging shape/attention-weight outputs will be evaluated. Include unit-like shape checks and ensure attention weights are exposed per head for inspection.

Common Follow-up Questions

  • How would you modify the model to use relative positional encodings instead of sinusoidal or absolute learned embeddings?
  • Explain how you would implement efficient masking for variable-length inputs and how src_mask / tgt_mask should be shaped and broadcast to (B, H, S_q, S_k).
  • How can you optimize memory and compute for large D and many heads (e.g., gradient checkpointing, fused projections, or mixed precision)?
  • If attention weights are numerically unstable, what changes would you make to the scaled dot-product or masking strategy to improve stability?
  • How would you extend the decoder to perform incremental (autoregressive) decoding for inference while reusing cached key/value tensors?

Related Questions

1Implement Scaled Dot-Product Attention in PyTorch and return attention maps
2Build a Transformer encoder-only model (BERT-style) with positional encodings and masking
3Implement position-wise feed-forward layers, residuals, and LayerNorm for Transformer blocks

Explore More Questions

Practice This Question with AI

Get real-time hints, detailed requirements, and insightful analysis of the question.

PyTorch Transformer Interview Question — Apple (Seq2Seq) | Voker