Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suboptimal default initialization of q/k/v projections in nn.MultiHeadDotProductAttention #4027

Open
MasterSkepticista opened this issue Jun 25, 2024 · 2 comments
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@MasterSkepticista
Copy link

MasterSkepticista commented Jun 25, 2024

Initialization of q/k/v projections are not forward-backward normalized for linen.MultiHeadDotProductAttention. This implementation does not face optimization issues when a pre-LN variant of transformer is used; but faces convergence issues in the vanilla post-LN variant from "Attention is All You Need".

System information

  • Flax version: flax==0.8.4

Problem you have encountered:

Most papers over the past 4 years still use the vanilla post-LN transformer. One such is facebookresearch/detr. Inputs to the self-attention block on the first decoder is all zeros, as shown below:

# src: https://github.com/facebookresearch/detr/blob/29901c51d7fe8712168b8d0d64351170bc0f83e0/models/transformer.py#L55
...
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                            pos=pos_embed, query_pos=query_embed)
...

The default linen.MultiHeadDotProductAttention does not converge at all for the first 10 epochs. The reason is initialization and gradient norm behavior of q/k/v matrices, particularly when inputs to the MHDPA block are all zeros. I have a diff here which converges from the very first epoch.

Please note: I have isolated the problem to be in the kernel initializer alone, which this proposal is about. Custom implementation of MHDPA does not have any impact on this proposal.

This convergence issue does not exist in PyTorch, because q/k/v are:

  1. xavier_uniform initialized, and
  2. with appropriate fan_in and fan_out values.

What you expected to happen:

Expected linen.MultiHeadDotProductAttention to converge from the first epoch, as is the case in facebookresearch/detr

Fix:

  1. Switch to xavier_uniform initializer for projections (versus default_kernel_init which is lecun_normal). This is also standard best practice. t-fixup paper
  2. Use correct fan_in value for the initializer (for same embedding dimensions of q/k/v, fan_in should be 3 * embed_dim) ref1, ref2

Here is an approximate diff needed for change in Flax:

MasterSkepticista/detr@995f335

There could be multiple ways to go about calculating correct fan_in, or using a giant dense layer as PyTorch does it. This has performance implications.

Let me know if the fix in my code (barring the hardcoded values) is a reasonable approach?

Happy to do a PR.

@cgarciae
Copy link
Collaborator

Hey, thanks for looking into this! I think your analysis is correct and would favor using the xavier_uniform initializer with the scaled fan_in to emulate a bigger matmul. However, I worked on #3893 internally for a bit and learned that changing the initialization logic is very hard as it will break tons of tests that rely on the current defaults so merging this is a challenge given finite resources.

In practice, we encourage users to fork our layers and adapt them to their needs which is why this hasn't manifested as a strong issue. It is still worth thinking about this issue, I'll try to discuss this issue internally.

@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Jun 25, 2024
@MasterSkepticista
Copy link
Author

Hi @cgarciae. Thanks for responding.

Finding this bug took out my 3 weeks of dry spells while replicating PyTorch version. I think at a minimum, a mention of this initialization scheme in MHDPA docstrings would go a long way :)

On the potential test failures: per my understanding, this change would be localized to MultiHeadDotProductAttention, isn't it? What is the nature of those test failures? Are these training baselines? Just curious.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

2 participants