AI Math
View as Markdown Suggest changes

Masks and Cross-Attention: Extending Self-Attention

· Reading time: 7 min
Masks and Cross-Attention: Extending Self-Attention

This post is a follow-up to From Kernel Regression to Self-Attention — same notation, same setup. There we ended with full self-attention, where every position in a sequence attends to every other position using learned projections of itself. That is the mechanism. In practice, two extensions are layered on top before you get to a real Transformer:

  • Masks, which restrict which positions a query is allowed to read from. The most important is the causal mask used in autoregressive language modeling, but padding masks come up routinely in batched training.
  • Cross-attention, which relaxes the self part: queries come from one sequence, keys and values from another. This is what lets one sequence be conditioned on another.

Both are small modifications to equation (13) of the previous post. The mechanism — scaled dot product, softmax, weighted average — is unchanged in both cases. What changes is which positions are allowed to attend to which, and which sequence each input comes from. This post covers both.


1. The leakage problem

A decoder-only language model is trained on next-token prediction: given tokens x1,,xn\boldsymbol{x}_1, \ldots, \boldsymbol{x}_n, it should predict xi+1\boldsymbol{x}_{i+1} from the output yi\boldsymbol{y}_i at position ii, for every ii.

At inference time, this constraint is automatic — when generating token i+1i+1, only tokens 1,,i1, \ldots, i exist. There is nothing to leak. At training time, the situation is different. The point of a Transformer is parallelism: we want to compute y1,,yn\boldsymbol{y}_1, \ldots, \boldsymbol{y}_n in a single forward pass, with all targets visible. But equation (16) from the previous post says

yi=Attn ⁣(W(q)xi,  {(W(k)xj,  W(v)xj)}j=1n)\boldsymbol{y}_i = \operatorname{Attn}\!\left(W^{(q)} \boldsymbol{x}_i,\; \left\{\left(W^{(k)} \boldsymbol{x}_j,\; W^{(v)} \boldsymbol{x}_j\right)\right\}_{j=1}^{n}\right)

Position ii attends to all positions jj, including those with j>ij > i. The model can read the future and copy it into yi\boldsymbol{y}_i. Loss goes to zero; nothing has been learned.

We need a way to keep the parallelism but block any flow from positions j>ij > i to position ii.


2. The fix: -\infty before softmax

The mechanism is small. In the score for query ii against key jj, replace a~(qi,kj)\tilde{a}(\boldsymbol{q}_i, \boldsymbol{k}_j) with -\infty whenever j>ij > i. After softmax, e=0e^{-\infty} = 0, so those positions receive exactly zero attention weight. The remaining weights — over keys 1,,i1, \ldots, i — renormalize to sum to one.

Concretely, this is achieved by adding a mask matrix MRn×nM \in \mathbb{R}^{n \times n} to the score matrix before softmax:

Mij={0if jiif j>iM_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

Equation (13) from the previous post becomes

Attnmasked(Q,K,V)=softmax ⁣(QKd+M)V(17)\operatorname{Attn}_{\text{masked}}(Q, K, V) = \operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + M\right) V \tag{17}

The softmax is applied row-wise. After masking, the attention matrix is lower-triangular: row ii has nonzero entries only in columns 1,,i1, \ldots, i.

In code, -\infty is typically a large negative constant like 109-10^9, which is numerically equivalent for floating-point softmax.


3. What changes, what doesn’t

The mechanism is unchanged. Queries still come from W(q)xW^{(q)} \boldsymbol{x}, keys from W(k)xW^{(k)} \boldsymbol{x}, values from W(v)xW^{(v)} \boldsymbol{x}; the dot-product score and the d\sqrt{d} rescaling are exactly as in §5 of the previous post. All that has changed is which scores survive the softmax.

The asymmetry is worth dwelling on. Position ii no longer sees positions j>ij > i, but position jj (with j>ij > i) still sees position ii. The mask makes information flow strictly forward through the sequence. Each output yi\boldsymbol{y}_i is now a weighted average over only its own past — exactly the constraint autoregressive training requires.

Causal-mask asymmetry: position 3 reads only 3 past positions, but is read by 6 current and future positions. Same matrix, same position, asymmetric reach.

Parallelism is preserved. All nn outputs are still computed in a single matrix multiplication, just with a triangular attention pattern instead of a full one. This is what lets decoder-only Transformers train at the same speed as bidirectional ones, despite the stricter information constraint.


4. Padding masks

A second, more mundane use of masking comes from batching. Real training batches contain sequences of different lengths. To pack them into a single tensor, shorter sequences are padded with a special token up to the length of the longest. These padding tokens carry no meaning and should not contribute to any output.

The same trick handles them. Build a padding mask MpadM^{\text{pad}} where Mijpad=M^{\text{pad}}_{ij} = -\infty whenever key jj is a padding token, and 00 otherwise. Add it to the scores before softmax. In a decoder, both masks combine: Mtotal=Mcausal+MpadM^{\text{total}} = M^{\text{causal}} + M^{\text{pad}}, which is still -\infty wherever either is.


5. Cross-attention

The second extension is cross-attention. In self-attention, the query, key, and value all come from the same sequence — that is what makes it self. Cross-attention relaxes this: queries come from one sequence, keys and values from another.

Concretely, given two sequences — call them XARn×dX^A \in \mathbb{R}^{n \times d} with rows x1A,,xnA\boldsymbol{x}^A_1, \ldots, \boldsymbol{x}^A_n and XBRm×dX^B \in \mathbb{R}^{m \times d} with rows x1B,,xmB\boldsymbol{x}^B_1, \ldots, \boldsymbol{x}^B_m — cross-attention lets every position in AA attend to every position in BB. The learned projections from §8 of the previous post split as follows:

qi=W(q)xiA,kj=W(k)xjB,vj=W(v)xjB\boldsymbol{q}_i = W^{(q)} \boldsymbol{x}^A_i, \quad \boldsymbol{k}_j = W^{(k)} \boldsymbol{x}^B_j, \quad \boldsymbol{v}_j = W^{(v)} \boldsymbol{x}^B_j

Three things are worth noting:

  • The mechanism is unchanged. Scaled dot product, softmax, weighted average — exactly equation (13). Only the inputs differ.
  • The score matrix changes shape. Stacking Q=XAW(q)Q = X^A W^{(q)\top} (nn rows) and K=XBW(k)K = X^B W^{(k)\top} (mm rows), the score matrix QKQK^\top is n×mn \times m instead of the n×nn \times n we have seen so far. Each row ii holds the scores from query position ii in AA against every key position in BB.
  • Self-attention is the special case where A=BA = B. When the two sequences are identical, queries, keys, and values all come from the same source, and we are back to (13) plus the learned projections.

In matrix form:

Attncross(XA,XB)=softmax ⁣(QKd)VRn×d(18)\operatorname{Attn}_{\text{cross}}(X^A, X^B) = \operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right) V \in \mathbb{R}^{n \times d} \tag{18}

— one output vector per position in AA, where each output is a weighted blend of the value rows derived from BB. Most of the previous post carries through unchanged; the only difference is the interpretation of where each input comes from.

The most prominent use is in encoder–decoder Transformers, where decoder positions attend to the encoder’s output — that is what lets a translation model align “die Katze saß” with “the cat sat.” But the mechanism is general: it shows up wherever a model needs to condition one sequence on another.


Summary

Self-attention as derived in the previous post lets every position attend to every other position in the same sequence. Two extensions cover most of what real Transformers do:

  • Masking sets entries of the score matrix to -\infty before softmax, blocking flow from masked positions. Causal masks block the future; padding masks block padding tokens. Both rely on e=0e^{-\infty} = 0.
  • Cross-attention lets queries come from one sequence and keys/values from another. The mechanism is otherwise identical; only the score matrix changes shape from n×nn \times n to n×mn \times m.

A single formula subsumes all three cases — vanilla self-attention, masked self-attention, and cross-attention:

Attn(Q,K,V;M)=softmax ⁣(QKd+M)V\operatorname{Attn}(Q, K, V; M) = \operatorname{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + M\right) V

When QQ, KK, VV come from the same sequence and M=0M = 0, this is the original (13). When MM encodes a causal or padding mask, this is masked self-attention. When QQ comes from one sequence and KK, VV from another, this is cross-attention. The next post shows how Transformers stack these into encoder-only, decoder-only, and encoder–decoder architectures.


Built on the notation and derivation in From Kernel Regression to Self-Attention. Causal masking and cross-attention both originate with Vaswani et al., “Attention Is All You Need”, NeurIPS 2017, §3.2.

AI Chat

Ask me anything about Daniel's experience, skills, or background!