where WQ,WK,WV are learnable parameters, called projection matrices
WQ∈Rdmodel×dk: what to look for
WK∈Rdmodel×dk: what to compare with
WV∈Rdmodel×dv: what information to extract
Attention:
score: S=QKT
scaled: Sscaled=dkS
softmax: A=softmax(Sscaled)
Final output: Attention(Q,K,V)=AV
the whole process lifts the context-free embedding into a contextualized representation
dk and dv are hyperparameters
often dk=dv=dmodel/h
h: number of attention heads
Causal Masking:
mask = [ [0, -∞, -∞], [0, 0, -∞], [0, 0, 0] ]
A=softmax(Sscaled+mask)
In decoder self-attention, we’re generating from left to right. During training, we have the full target sequence available, but we don’t want to attend to future positions.
Scalability:
Time complexity: O(n2dmodel)
Space complexity: O(n2)
for storing the attention matrix (Sscaled)
Quadratic scaling in context length is the reason why transformer struggles in long context