RuntimeError: FlashAttention forward only supports head dimension at most 256

#30
by badhon1512 - opened

Getting this error when using Gemma 4 32B with flash_attention_2:

RuntimeError: FlashAttention forward only supports head dimension at most 256

Seems like the model has head_dim > 256. Switching to attn_implementation="sdpa" works.

Is FlashAttention not supported for this model?

Sounds like it isnt to me

I went through the torture of trying to get FA to work, so that you don’t have to:

FA4 (SM100/Blackwell-native): max head_dim 128. Dead on arrival for Gemma 4.

FA3 (Hopper-native): max head_dim 256. Covers sliding layers, not global (512). May underperform on B200 since it targets SM90.

FA2: max head_dim 256. Same story, sliding only.

SDPA (cuDNN backend): handles both 256 and 512. Works everywhere.

Sign up or log in to comment