Suboptimal default initialization of q/k/v projections in nn.MultiHeadDotProductAttention
#4027
Labels
Priority: P2 - no schedule
Best effort response and resolution. We have no plan to work on this at the moment.
Initialization of q/k/v projections are not forward-backward normalized for
linen.MultiHeadDotProductAttention
. This implementation does not face optimization issues when apre-LN
variant of transformer is used; but faces convergence issues in the vanillapost-LN
variant from "Attention is All You Need".System information
flax==0.8.4
Problem you have encountered:
Most papers over the past 4 years still use the vanilla
post-LN
transformer. One such is facebookresearch/detr. Inputs to the self-attention block on the first decoder is allzeros
, as shown below:The default
linen.MultiHeadDotProductAttention
does not converge at all for the first 10 epochs. The reason is initialization and gradient norm behavior of q/k/v matrices, particularly when inputs to the MHDPA block are all zeros. I have a diff here which converges from the very first epoch.Please note: I have isolated the problem to be in the kernel initializer alone, which this proposal is about. Custom implementation of MHDPA does not have any impact on this proposal.
This convergence issue does not exist in PyTorch, because q/k/v are:
xavier_uniform
initialized, andfan_in
andfan_out
values.What you expected to happen:
Expected
linen.MultiHeadDotProductAttention
to converge from the first epoch, as is the case in facebookresearch/detrFix:
xavier_uniform
initializer for projections (versusdefault_kernel_init
which islecun_normal
). This is also standard best practice. t-fixup paperfan_in
value for the initializer (for same embedding dimensions of q/k/v,fan_in
should be3 * embed_dim
) ref1, ref2Here is an approximate diff needed for change in Flax:
MasterSkepticista/detr@995f335
There could be multiple ways to go about calculating correct
fan_in
, or using a giantdense
layer as PyTorch does it. This has performance implications.Let me know if the fix in my code (barring the hardcoded values) is a reasonable approach?
Happy to do a PR.
The text was updated successfully, but these errors were encountered: