ml coding
NVIDIA
NVIDIA Research

NVIDIA ML Coding: Decaying Attention Implementation

Topics:
Attention Mechanisms
Linear Algebra / Matrix Multiplication
NumPy Array Indexing & Broadcasting
Roles:
Machine Learning Engineer
ML Engineer
Experience:
Entry Level
Mid Level
Senior

Question Description

You are asked to implement a variant of dot‑product attention where a pairwise positional bias grows with the absolute difference in token indices. Concretely, you must compute

A = softmax(Q K^T + B) V

with B_{ij} = |i - j| applied across the query/key sequence axes. The implementation must accept both unbatched (Q: (L_q,d), K: (L_k,d), V: (L_k,d_v) -> output (L_q,d_v)) and batched inputs (Q: (B,L_q,d), K: (B,L_k,d), V: (B,L_k,d_v) -> output (B,L_q,d_v)). You should preserve batch semantics and ensure numeric dtype consistency across matmuls, bias addition, and softmax.

Focus areas and flow

  • Compute similarity S = Q @ K^T (vectorized via matmul/einsum). Build B once from arange indices (abs(i-j)) and broadcast to (L_q,L_k) or (B,L_q,L_k).
  • Add B to S, apply row-wise softmax over keys (subtract the row max for numeric stability), then multiply by V to get output of shape (L_q,d_v) or (B,L_q,d_v).
  • Validate shapes: allow L_q != L_k, require matching feature dims d and batch dims when present.

Skill signals

You should demonstrate: efficient NumPy vectorization, correct broadcasting of positional bias over batch, numeric/stability best practices for softmax, careful shape and dtype validation, and awareness of time/memory complexity (O(B * L_q * L_k * d)). Include tests for edge cases (single token, unequal lengths, float32 vs float64).

Common Follow-up Questions

  • How would you modify the bias to use an exponential decay B_{ij} = -alpha * |i-j| (learnable alpha) and what numerical/stability concerns arise?
  • Add scaling by 1/sqrt(d) (scaled dot-product). How does that change outputs and empirical gradient magnitudes, and when is it necessary?
  • How would you combine this decaying bias with a causal mask (prevent attention to future positions) while preserving performance and dtype correctness?
  • Design an efficient implementation for very long sequences (L in tens of thousands). What approximations or algorithms reduce memory/time while keeping the decaying behavior?

Related Questions

1Implement scaled dot‑product attention (softmax(QK^T / sqrt(d)) V) with batched inputs
2Implement attention with relative positional encodings (learnable or sinusoidal) and explain differences vs absolute-index bias
3Implement causal (autoregressive) masked attention and explain broadcasting of masks over batch heads
4Approaches to efficient attention for long sequences: sparse attention, kernelized attention, or locality-sensitive hashing

Explore More Questions

Practice This Question with AI

Get real-time hints, detailed requirements, and insightful analysis of the question.

Decaying Attention Implementation - NVIDIA ML Coding | Voker