RuntimeError: FlashAttention forward only supports head dimension at most 256
#30
by badhon1512 - opened
Getting this error when using Gemma 4 32B with flash_attention_2:
RuntimeError: FlashAttention forward only supports head dimension at most 256
Seems like the model has head_dim > 256. Switching to attn_implementation="sdpa" works.
Is FlashAttention not supported for this model?
Sounds like it isnt to me
I went through the torture of trying to get FA to work, so that you don’t have to:
FA4 (SM100/Blackwell-native): max head_dim 128. Dead on arrival for Gemma 4.
FA3 (Hopper-native): max head_dim 256. Covers sliding layers, not global (512). May underperform on B200 since it targets SM90.
FA2: max head_dim 256. Same story, sliding only.
SDPA (cuDNN backend): handles both 256 and 512. Works everywhere.