Fast Attention For Short Sequences
Nowadays, there is a big spotlight on big transformer decoder models with big context windows and billions (even trillions) of parameters.
But machine learning/deep learning is not limited to just LLMs; there are other applications of transformers too.
And sometimes we need to use attention on short sequences: 32-128.
For some tasks, like classification or representation learning, we also want to have a Transformer Encoder with one CLS token.
If we only need an embedding of a cls token from the transformer, then it's a good optimization to skip computing self-attention for non-cls tokens at the last layer.
First, lets benchmark different pytorch backends (MemEff,FlashAttention2,CUDNN)
on a full self-attention setup:
All times below are in ms, e2e in torch.
Setup: H100, torch 2.11.0+cu128, bf16, non-causal, w/o dropout, batch=16000, heads=8, headdim=64
| seq_len | FlashAttention2 | MemEff | cuDNN | |||
|---|---|---|---|---|---|---|
| fwd | bwd | fwd | bwd | fwd | bwd | |
| 32 | 2.83 | 14.11 | 1.42 | 6.02 | 2.15 | 5.20 |
| 64 | 2.93 | 16.39 | 1.60 | 9.56 | 2.16 | 7.98 |
| 128 | 3.34 | 20.67 | 5.02 | 23.04 | 2.73 | 13.62 |
| kv_len | FlashAttention2 | MemEff | cuDNN | |||
|---|---|---|---|---|---|---|
| fwd | bwd | fwd | bwd | fwd | bwd | |
| 32 | 2.76 | 11.76 | 1.41 | 4.42 | 0.70 | 3.13 |
| 64 | 2.80 | 12.36 | 1.48 | 5.96 | 1.00 | 3.54 |
| 128 | 2.81 | 12.85 | 2.46 | 11.82 | 1.52 | 4.35 |
Interesting, that the more advanced algorithm (Flash Attention 2), works slower. But why? What is the difference between MemEfficient and Flash Attention 2? First, lets examine what kernels are actually being run on seq=32
MemEfficient
fmha_cutlassF_bf16_aligned_64x64_rf_sm80(PyTorchMemEffAttention::AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params)
fmha_cutlassB_bf16_aligned_64x64_k64_sm80(PyTorchMemEffAttention::AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, false>::Params)
FlashAttention 2
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<64, 128, 128, 4, cutlass::bfloat16_t> >, false, false, false, false, false, true, false, false>(pytorch_flash::Flash_fwd_params)
void pytorch_flash::flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Flash_bwd_kernel_traits<64, 128, 128, 8, 4, 4, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<64, 128, 128, 8, cutlass::bfloat16_t> >, false, false, false, false, false, true, false>(pytorch_flash::Flash_bwd_params)
Well, that explains why 32vs64 seqlen had a smaller increase in fwd/bwd time than 128. MemEff uses a local QK^T of 64 query rows by 64 key rows, while FlashAttention2 uses 128 query rows by 128 key rows. It is underutilized on seq=32.
If seq=32 was already using the tile efficiently, we would expect kernel time to grow much more from 32 -> 64. It does not. What actually grows sharply is the memory-side throughput shown by Memory %.
Now we will profile only the main attention kernel with ncu, not the full end-to-end time (sdpa also launches kernels besides the main attention kernel).
FlashAttention2 main attention kernel
| seq | fwd | bwd | ||||
|---|---|---|---|---|---|---|
| time | SM % | Memory % | time | SM % | Memory % | |
| 32 | 3.42 | 53.15 | 39.51 | 10.11 | 33.91 | 50.42 |
| 64 | 3.51 | 52.25 | 38.36 | 10.67 | 32.62 | 47.03 |
| 128 | 3.61 | 40.74 | 70.64 | 10.96 | 25.41 | 57.77 |
MemEff main attention kernel
| seq | fwd | bwd | ||||
|---|---|---|---|---|---|---|
| time | SM % | Memory % | time | SM % | Memory % | |
| 32 | 1.67 | 69.98 | 49.18 | 3.84 | 50.39 | 58.47 |
| 64 | 1.81 | 66.28 | 69.17 | 4.26 | 45.48 | 61.36 |
| 128 | 6.05 | 67.41 | 49.68 | 13.88 | 45.49 | 59.16 |
So seq=32 is underutilized for both kernels, and much more so for FlashAttention2. The useful work grows fast from 32 -> 64, but the attention kernel barely gets slower.
At seq=64, MemEff forward is almost a perfect 64x64 one-tile problem, so memory throughput gets much higher. At seq=128, the same kernel has to iterate over multiple K/V tiles and merge online-softmax state, which adds more on-chip work and lowers the reported Memory %.
Looking at these results, is it possible to write a much simpler triton kernel that outperforms MemEfficient/Flash Attention 2 on short sequences?
For seq=32 or seq=64, the whole attention problem is small enough that one kernel instance can handle the entire sequence tile directly. In that regime, we do not need extra complexity, but this kernel will only be suitable for short sequences.
Here is the forward of our new triton kernel:
pid_b = tl.program_id(0).to(tl.int64)
pid_h = tl.program_id(1).to(tl.int64)
offs_s = tl.arange(0, block_s)
row_mask = offs_s < seq_len
qkv_base = qkv_ptr + pid_b * stride_qkv_b + pid_h * stride_qkv_h
q_blk = tl.make_block_ptr(
qkv_base + 0 * stride_qkv_c,
shape=[seq_len, head_dim],
strides=[stride_qkv_s, stride_qkv_d],
offsets=[0, 0],
block_shape=[block_s, block_d],
order=(1, 0),
)
k_blk = tl.make_block_ptr(
qkv_base + 1 * stride_qkv_c,
shape=[seq_len, head_dim],
strides=[stride_qkv_s, stride_qkv_d],
offsets=[0, 0],
block_shape=[block_s, block_d],
order=(1, 0),
)
v_blk = tl.make_block_ptr(
qkv_base + 2 * stride_qkv_c,
shape=[seq_len, head_dim],
strides=[stride_qkv_s, stride_qkv_d],
offsets=[0, 0],
block_shape=[block_s, block_d],
order=(1, 0),
)
q = tl.load(q_blk, boundary_check=(0, 1), padding_option="zero")
k = tl.load(k_blk, boundary_check=(0, 1), padding_option="zero")
v = tl.load(v_blk, boundary_check=(0, 1), padding_option="zero")
scores = tl.dot(q, tl.trans(k), out_dtype=tl.float32) * scale
scores = tl.where(row_mask[None, :], scores, float("-inf"))
scores_max = tl.max(scores, axis=1)
probs = tl.exp(scores - scores_max[:, None])
probs = tl.where(row_mask[:, None], probs, 0.0)
probs_sum = tl.sum(probs, axis=1)
probs_sum_safe = tl.where(row_mask, probs_sum, 1.0)
lse = tl.where(row_mask, tl.log(probs_sum_safe) + scores_max, 0.0)
probs = probs / probs_sum_safe[:, None]
out = tl.dot(probs.to(q.dtype), v, out_dtype=tl.float32)
o_blk = tl.make_block_ptr(
o_ptr + pid_b * stride_o_b + pid_h * stride_o_h,
shape=[seq_len, head_dim],
strides=[stride_o_s, stride_o_d],
offsets=[0, 0],
block_shape=[block_s, block_d],
order=(1, 0),
)
tl.store(o_blk, out.to(q.dtype), boundary_check=(0, 1))
tl.store(lse_ptr + pid_b * stride_lse_b + pid_h * stride_lse_h + offs_s, lse, mask=row_mask)
Backward has the same idea - recalculate scores using lse, calculate whole sequence in one go.
There are also separate fwd/bwd kernels for a case where len(q)=1 e.g. just a cls token, they work basically the same way.
Lets benchmark our new kernels
For Triton vs best, values are best non-Triton time / Triton time; >1x favors Triton and <1x favors the best non-Triton backend.
| seq_len | head_dim | Triton | FlashAttention2 | MemEff | cuDNN | Triton vs best | ||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| fwd | bwd | fwd | bwd | fwd | bwd | fwd | bwd | fwd | bwd | total | ||
| 32 | 32 | 0.56 | 1.01 | 2.33 | 8.48 | 1.35 | 4.28 | 2.17 | 4.11 | 2.41x | 4.07x | 3.59x |
| 64 | 0.77 | 1.31 | 2.83 | 14.11 | 1.42 | 6.02 | 2.15 | 5.20 | 1.84x | 3.97x | 3.53x | |
| 128 | 1.43 | 2.49 | 2.81 | 15.47 | 3.38 | 22.40 | 2.75 | 8.51 | 1.92x | 3.42x | 2.87x | |
| 256 (bs=8k) | 1.42 | 2.62 | 1.55 | 13.25 | 1.44 | 55.25 | 2.58 | 9.67 | 1.01x | 3.69x | 3.03x | |
| 64 | 32 | 1.07 | 1.96 | 2.54 | 9.86 | 1.45 | 5.84 | 2.15 | 5.47 | 1.36x | 2.79x | 2.41x |
| 64 | 1.49 | 2.63 | 2.93 | 16.39 | 1.60 | 9.56 | 2.16 | 7.98 | 1.07x | 3.03x | 2.46x | |
| 128 | 2.79 | 5.50 | 3.28 | 20.31 | 3.68 | 31.18 | 2.79 | 13.67 | 1.00x | 2.49x | 1.99x | |
| 256 (bs=8k) | 2.97 | 7.75 | 2.77 | 18.54 | 2.96 | 64.36 | 2.91 | 14.51 | 0.93x | 1.87x | 1.62x | |
| 128 | 32 | 2.89 | 6.16 | 2.50 | 12.24 | 4.63 | 15.28 | 1.97 | 8.63 | 0.68x | 1.40x | 1.17x |
| 64 | 3.94 | 10.26 | 3.34 | 20.67 | 5.02 | 23.04 | 2.73 | 13.62 | 0.69x | 1.33x | 1.15x | |
| 128 | 6.85 | 24.28 | 6.09 | 37.20 | 8.22 | 47.53 | 5.41 | 24.80 | 0.79x | 1.02x | 0.97x | |
| 256 (bs=8k) | 9.40 | 331.75 | 5.61 | 39.50 | 6.58 | 95.48 | 5.40 | 26.26 | 0.57x | 0.08x | 0.09x | |
q=1
| kv_len | head_dim | Triton | FlashAttention2 | MemEff | cuDNN | Triton vs best | ||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| fwd | bwd | fwd | bwd | fwd | bwd | fwd | bwd | fwd | bwd | total | ||
| 32 | 32 | 0.34 | 0.58 | 2.28 | 7.43 | 1.36 | 3.52 | 0.69 | 2.92 | 2.03x | 5.03x | 3.92x |
| 64 | 0.45 | 0.77 | 2.76 | 11.76 | 1.41 | 4.42 | 0.70 | 3.13 | 1.56x | 4.06x | 3.14x | |
| 128 | 0.80 | 1.47 | 2.68 | 11.69 | 3.34 | 18.94 | 1.36 | 4.64 | 1.70x | 3.16x | 2.64x | |
| 256 (bs=8k) | 0.79 | 1.78 | 1.28 | 9.05 | 1.26 | 49.56 | 1.57 | 5.60 | 1.59x | 3.15x | 2.79x | |
| 64 | 32 | 0.57 | 1.16 | 2.30 | 7.58 | 1.41 | 4.32 | 0.99 | 2.92 | 1.74x | 2.52x | 2.26x |
| 64 | 0.80 | 1.54 | 2.80 | 12.36 | 1.48 | 5.96 | 1.00 | 3.54 | 1.25x | 2.30x | 1.94x | |
| 128 | 1.46 | 2.88 | 2.70 | 12.22 | 3.59 | 22.37 | 1.77 | 5.47 | 1.21x | 1.90x | 1.67x | |
| 256 (bs=8k) | 1.45 | 3.82 | 1.43 | 9.87 | 1.44 | 55.30 | 2.46 | 5.29 | 0.99x | 1.38x | 1.47x | |
| 128 | 32 | 1.06 | 2.45 | 2.35 | 8.16 | 2.31 | 8.13 | 1.41 | 3.54 | 1.33x | 1.44x | 1.41x |
| 64 | 1.47 | 3.06 | 2.81 | 12.85 | 2.46 | 11.82 | 1.52 | 4.35 | 1.03x | 1.42x | 1.30x | |
| 128 | 2.80 | 6.21 | 4.31 | 13.61 | 3.97 | 30.41 | 2.80 | 6.69 | 1.00x | 1.08x | 1.05x | |
| 256 (bs=8k) | 3.09 | 5.89 | 2.73 | 14.27 | 2.74 | 73.62 | 4.41 | 6.90 | 0.88x | 1.17x | 1.26x | |
Well, seems that our kernels do pretty well on small seq length and headdim.
At seq_len=128, head_dim=256, the current backward kernel gets close to the SM resource limits. Shared memory usage becomes very high, and many intermediate values have to go through local memory, which is private to each thread but backed by global memory. That makes the kernel extremely slow. The table below compares seq_len=64 and 128 at head_dim=256 and shows how shared memory usage, warp activity, and local-memory traffic change (measured using ncu).
| seq_len | bwd time (ms) | shared mem / block | warps active | local loads | local stores |
|---|---|---|---|---|---|
| 64 | 8.71 | 137.0 KB | 12.42% | 208.9M | 124.9M |
| 128 | 344.39 | 225.0 KB | 6.25% | 14.0B | 13.1B |
To fix it, we have to design a better kernel, but it will not be as simple as the current one.
In this blog post we showed that a simple triton attention kernel can be faster than default torch attention backends, if the seq_len and headdim are sufficiently small
https://github.com/qwertyforce/fast_attn_short_seq/tree/main