Fast Attention For Short Sequences

Nowadays, there is a big spotlight on big transformer decoder models with big context windows and billions (even trillions) of parameters.

But machine learning/deep learning is not limited to just LLMs; there are other applications of transformers too.

And sometimes we need to use attention on short sequences: 32-128.
For some tasks, like classification or representation learning, we also want to have a Transformer Encoder with one CLS token.

If we only need an embedding of a cls token from the transformer, then it's a good optimization to skip computing self-attention for non-cls tokens at the last layer.
First, lets benchmark different pytorch backends (MemEff,FlashAttention2,CUDNN) on a full self-attention setup:
All times below are in ms, e2e in torch.
Setup: H100, torch 2.11.0+cu128, bf16, non-causal, w/o dropout, batch=16000, heads=8, headdim=64

seq_len FlashAttention2 MemEff cuDNN
fwdbwd fwdbwd fwdbwd
32 2.8314.11 1.426.02 2.155.20
64 2.9316.39 1.609.56 2.167.98
128 3.3420.67 5.0223.04 2.7313.62
and case where q is 1 element (cls token)
kv_len FlashAttention2 MemEff cuDNN
fwdbwd fwdbwd fwdbwd
32 2.7611.76 1.414.42 0.703.13
64 2.8012.36 1.485.96 1.003.54
128 2.8112.85 2.4611.82 1.524.35

Interesting, that the more advanced algorithm (Flash Attention 2), works slower. But why? What is the difference between MemEfficient and Flash Attention 2? First, lets examine what kernels are actually being run on seq=32

MemEfficient

fmha_cutlassF_bf16_aligned_64x64_rf_sm80(PyTorchMemEffAttention::AttentionKernel<cutlass::bfloat16_t, cutlass::arch::Sm80, true, 64, 64, 64, true, true>::Params)
fmha_cutlassB_bf16_aligned_64x64_k64_sm80(PyTorchMemEffAttention::AttentionBackwardKernel<cutlass::arch::Sm80, cutlass::bfloat16_t, true, false, true, 64, 64, 64, false>::Params)

FlashAttention 2

void pytorch_flash::flash_fwd_kernel<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<64, 128, 128, 4, cutlass::bfloat16_t> >, false, false, false, false, false, true, false, false>(pytorch_flash::Flash_fwd_params)  
void pytorch_flash::flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Flash_bwd_kernel_traits<64, 128, 128, 8, 4, 4, 4, false, false, cutlass::bfloat16_t, Flash_kernel_traits<64, 128, 128, 8, cutlass::bfloat16_t> >, false, false, false, false, false, true, false>(pytorch_flash::Flash_bwd_params)

Well, that explains why 32vs64 seqlen had a smaller increase in fwd/bwd time than 128. MemEff uses a local QK^T of 64 query rows by 64 key rows, while FlashAttention2 uses 128 query rows by 128 key rows. It is underutilized on seq=32.

If seq=32 was already using the tile efficiently, we would expect kernel time to grow much more from 32 -> 64. It does not. What actually grows sharply is the memory-side throughput shown by Memory %. Now we will profile only the main attention kernel with ncu, not the full end-to-end time (sdpa also launches kernels besides the main attention kernel).

FlashAttention2 main attention kernel

seq fwd bwd
timeSM %Memory % timeSM %Memory %
32 3.42 53.15 39.51 10.11 33.91 50.42
64 3.51 52.25 38.36 10.67 32.62 47.03
128 3.61 40.74 70.64 10.96 25.41 57.77

MemEff main attention kernel

seq fwd bwd
timeSM %Memory % timeSM %Memory %
32 1.67 69.98 49.18 3.84 50.39 58.47
64 1.81 66.28 69.17 4.26 45.48 61.36
128 6.05 67.41 49.68 13.88 45.49 59.16

So seq=32 is underutilized for both kernels, and much more so for FlashAttention2. The useful work grows fast from 32 -> 64, but the attention kernel barely gets slower.

At seq=64, MemEff forward is almost a perfect 64x64 one-tile problem, so memory throughput gets much higher. At seq=128, the same kernel has to iterate over multiple K/V tiles and merge online-softmax state, which adds more on-chip work and lowers the reported Memory %. Looking at these results, is it possible to write a much simpler triton kernel that outperforms MemEfficient/Flash Attention 2 on short sequences?

For seq=32 or seq=64, the whole attention problem is small enough that one kernel instance can handle the entire sequence tile directly. In that regime, we do not need extra complexity, but this kernel will only be suitable for short sequences.

Here is the forward of our new triton kernel:

    pid_b = tl.program_id(0).to(tl.int64)
    pid_h = tl.program_id(1).to(tl.int64)
    offs_s = tl.arange(0, block_s)
    row_mask = offs_s < seq_len
    qkv_base = qkv_ptr + pid_b * stride_qkv_b + pid_h * stride_qkv_h

    q_blk = tl.make_block_ptr(
        qkv_base + 0 * stride_qkv_c,
        shape=[seq_len, head_dim],
        strides=[stride_qkv_s, stride_qkv_d],
        offsets=[0, 0],
        block_shape=[block_s, block_d],
        order=(1, 0),
    )
    k_blk = tl.make_block_ptr(
        qkv_base + 1 * stride_qkv_c,
        shape=[seq_len, head_dim],
        strides=[stride_qkv_s, stride_qkv_d],
        offsets=[0, 0],
        block_shape=[block_s, block_d],
        order=(1, 0),
    )
    v_blk = tl.make_block_ptr(
        qkv_base + 2 * stride_qkv_c,
        shape=[seq_len, head_dim],
        strides=[stride_qkv_s, stride_qkv_d],
        offsets=[0, 0],
        block_shape=[block_s, block_d],
        order=(1, 0),
    )

    q = tl.load(q_blk, boundary_check=(0, 1), padding_option="zero")
    k = tl.load(k_blk, boundary_check=(0, 1), padding_option="zero")
    v = tl.load(v_blk, boundary_check=(0, 1), padding_option="zero")

    scores = tl.dot(q, tl.trans(k), out_dtype=tl.float32) * scale
    scores = tl.where(row_mask[None, :], scores, float("-inf"))
    scores_max = tl.max(scores, axis=1)

    probs = tl.exp(scores - scores_max[:, None])
    probs = tl.where(row_mask[:, None], probs, 0.0)
    probs_sum = tl.sum(probs, axis=1)
    probs_sum_safe = tl.where(row_mask, probs_sum, 1.0)
    lse = tl.where(row_mask, tl.log(probs_sum_safe) + scores_max, 0.0)
    probs = probs / probs_sum_safe[:, None]

    out = tl.dot(probs.to(q.dtype), v, out_dtype=tl.float32)
    o_blk = tl.make_block_ptr(
        o_ptr + pid_b * stride_o_b + pid_h * stride_o_h,
        shape=[seq_len, head_dim],
        strides=[stride_o_s, stride_o_d],
        offsets=[0, 0],
        block_shape=[block_s, block_d],
        order=(1, 0),
    )
    tl.store(o_blk, out.to(q.dtype), boundary_check=(0, 1))
    tl.store(lse_ptr + pid_b * stride_lse_b + pid_h * stride_lse_h + offs_s, lse, mask=row_mask)

Backward has the same idea - recalculate scores using lse, calculate whole sequence in one go.
There are also separate fwd/bwd kernels for a case where len(q)=1 e.g. just a cls token, they work basically the same way.

Lets benchmark our new kernels For Triton vs best, values are best non-Triton time / Triton time; >1x favors Triton and <1x favors the best non-Triton backend.

seq_len head_dim Triton FlashAttention2 MemEff cuDNN Triton vs best
fwdbwd fwdbwd fwdbwd fwdbwd fwdbwdtotal
32 32 0.561.01 2.338.48 1.354.28 2.174.11 2.41x4.07x 3.59x
64 0.771.31 2.8314.11 1.426.02 2.155.20 1.84x3.97x 3.53x
128 1.432.49 2.8115.47 3.3822.40 2.758.51 1.92x3.42x 2.87x
256 (bs=8k) 1.422.62 1.5513.25 1.4455.25 2.589.67 1.01x3.69x 3.03x
64 32 1.071.96 2.549.86 1.455.84 2.155.47 1.36x2.79x 2.41x
64 1.492.63 2.9316.39 1.609.56 2.167.98 1.07x3.03x 2.46x
128 2.795.50 3.2820.31 3.6831.18 2.7913.67 1.00x2.49x 1.99x
256 (bs=8k) 2.977.75 2.7718.54 2.9664.36 2.9114.51 0.93x1.87x 1.62x
128 32 2.896.16 2.5012.24 4.6315.28 1.978.63 0.68x1.40x 1.17x
64 3.9410.26 3.3420.67 5.0223.04 2.7313.62 0.69x1.33x 1.15x
128 6.8524.28 6.0937.20 8.2247.53 5.4124.80 0.79x1.02x 0.97x
256 (bs=8k) 9.40331.75 5.6139.50 6.5895.48 5.4026.26 0.57x0.08x 0.09x

q=1

kv_len head_dim Triton FlashAttention2 MemEff cuDNN Triton vs best
fwdbwd fwdbwd fwdbwd fwdbwd fwdbwdtotal
32 32 0.340.58 2.287.43 1.363.52 0.692.92 2.03x5.03x 3.92x
64 0.450.77 2.7611.76 1.414.42 0.703.13 1.56x4.06x 3.14x
128 0.801.47 2.6811.69 3.3418.94 1.364.64 1.70x3.16x 2.64x
256 (bs=8k) 0.791.78 1.289.05 1.2649.56 1.575.60 1.59x3.15x 2.79x
64 32 0.571.16 2.307.58 1.414.32 0.992.92 1.74x2.52x 2.26x
64 0.801.54 2.8012.36 1.485.96 1.003.54 1.25x2.30x 1.94x
128 1.462.88 2.7012.22 3.5922.37 1.775.47 1.21x1.90x 1.67x
256 (bs=8k) 1.453.82 1.439.87 1.4455.30 2.465.29 0.99x1.38x 1.47x
128 32 1.062.45 2.358.16 2.318.13 1.413.54 1.33x1.44x 1.41x
64 1.473.06 2.8112.85 2.4611.82 1.524.35 1.03x1.42x 1.30x
128 2.806.21 4.3113.61 3.9730.41 2.806.69 1.00x1.08x 1.05x
256 (bs=8k) 3.095.89 2.7314.27 2.7473.62 4.416.90 0.88x1.17x 1.26x

Well, seems that our kernels do pretty well on small seq length and headdim.

At seq_len=128, head_dim=256, the current backward kernel gets close to the SM resource limits. Shared memory usage becomes very high, and many intermediate values have to go through local memory, which is private to each thread but backed by global memory. That makes the kernel extremely slow. The table below compares seq_len=64 and 128 at head_dim=256 and shows how shared memory usage, warp activity, and local-memory traffic change (measured using ncu).

seq_len bwd time (ms) shared mem / block warps active local loads local stores
64 8.71 137.0 KB 12.42% 208.9M 124.9M
128 344.39 225.0 KB 6.25% 14.0B 13.1B

To fix it, we have to design a better kernel, but it will not be as simple as the current one.
In this blog post we showed that a simple triton attention kernel can be faster than default torch attention backends, if the seq_len and headdim are sufficiently small
https://github.com/qwertyforce/fast_attn_short_seq/tree/main