ml coding
Meta
Facebook

Scaled Self-Attention Implementation — Meta

Topics:
Self-Attention
Matrix Multiplication
Softmax Regression
Roles:
Machine Learning Engineer
ML Engineer
Data Scientist
Experience:
Entry Level
Mid Level
Senior

Question Description

Implementing scaled dot-product self-attention is a common ML coding interview task that tests your understanding of linear algebra, tensor shapes, masking, and numerical stability in Transformer-style models.

You are asked to compute attention outputs and attention weights from Queries (Q), Keys (K), and Values (V). The core operation is to form the batched dot-product QK^T, scale by sqrt(d_k), apply an optional mask that blocks specific key positions, run a numerically-stable softmax over the key dimension, and multiply the resulting attention distribution by V. The function must accept batched inputs (batch, seq_len_q, d_k), (batch, seq_len_k, d_k), and (batch, seq_len_k, d_v) and return (attention_output: batch, seq_len_q, d_v) and (attention_weights: batch, seq_len_q, seq_len_k).

Interview flow / stages:

  • Clarify shapes and mask semantics (True/1 = masked out).
  • Describe scaling and why you divide by sqrt(d_k).
  • Explain numerical-stability steps (subtract max per row before softmax) and dtype considerations.
  • Implement batched matmul, mask application (set masked logits to -inf or a large negative), softmax along keys, then multiply by V.

Skill signals the interviewer expects:

  • Correct tensor broadcasting and axis choices
  • Preservation of input dtypes where practical
  • Proper mask handling and numerical stability for softmax
  • Awareness of runtime and memory trade-offs for batched attention

Common Follow-up Questions

  • How would you extend this to implement multi-head attention and combine heads efficiently?
  • How do you implement a causal (autoregressive) mask so queries cannot attend to future keys?
  • What changes would you make to preserve numerical stability and dtype when inputs are float16?
  • How does the attention computational and memory complexity scale with sequence length, and how would you optimize for long sequences?

Related Questions

1Implement multi-head attention given a scaled dot-product attention primitive
2Write a masked attention function for decoder-only Transformers (causal masking)
3Explain and implement an efficient attention variant (sparse or chunked attention) for long sequences
4Derive gradients for the scaled dot-product attention (backprop through softmax and matmul)

Explore More Questions

Practice This Question with AI

Get real-time hints, detailed requirements, and insightful analysis of the question.

Scaled Self-Attention ML Coding Implementation — Meta | Voker