Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Week 2 Day 4-5: Flash Attention 2

In this chapter, we will implement Flash Attention 2 for the Week 2 Qwen2 serving pipeline. The goal is to replace the regular attention path with a tiled implementation to reduce memory bandwidth and increase throughput, especially for long contexts.

πŸ“š Readings

Why Flash Attention?

The key idea from the FlashAttention papers is that attention is often IO-bound, not FLOP-bound.

In the standard implementation, we compute:

  1. S = QK^T
  2. P = softmax(S + mask)
  3. O = PV

This path materializes large L x S tensors (S and often P) in global memory. For long contexts, repeatedly writing and reading these tensors dominates runtime.

For example, if L = S = 4096:

One L x S matrix: 4096 x 4096 = 16,777,216 elements
float32 storage: ~64 MB per matrix per head
Scores + probabilities: ~128 MB temporary memory per head

So even before counting Q/K/V and output tensors, memory traffic is already huge.

IO-Aware Exact Attention

FlashAttention avoids this bottleneck by tiling Q/K/V into on-chip memory (cache / shared memory), and combining each tile with online softmax updates. Instead of storing the full attention matrix, it keeps only per-row running statistics (m, l) and partial output (o).

This gives three practical benefits:

  • Exactness: same result as standard softmax attention (not an approximation).
  • Lower memory: activation memory scales linearly with sequence length instead of quadratically.
  • Higher throughput: fewer high-bandwidth-memory accesses, which is usually the real bottleneck.

Online Softmax Recap

For one query row, split keys/values into tiles j = 1..T:

At the end:

This is the core numerical trick used by both the CPU and GPU kernels in this chapter, and the rest of the implementation is mostly about mapping this update rule to CPU loops and Metal threadgroups.

Task 1: Implement flash_attention Wrapper

src/tiny_llm/attention.py

Implement flash_attention(query, key, value, scale=None, mask=None) so it matches the extension API in tiny_llm_ext.

Follow the same shape convention as Week 1 and Week 2 attention:

query: B..., H_q, L, E
key:   B..., H,   S, E
value: B..., H,   S, E
mask:  B..., H_q, L, S
out:   B..., H_q, L, E

The wrapper should compute factor using mx.rsqrt when scale is None, flatten batch and head dimensions before calling into C++, and reshape the output back to the original layout. Make sure query, key, and value are contiguous before calling the extension. For mask, always broadcast to B..., H_q, L, S, reshape to (N, L, S), and cast to float32 so that CPU and GPU kernels receive exactly the same dtype.

Task 2: Implement flash_attention (CPU version)

src/extensions/src/tiny_llm_ext.h
src/extensions/bindings.cpp
src/extensions/src/flash_attention.cpp
src/extensions/CMakeLists.txt

In this task, add the new MLX primitive and its CPU implementation. The structure is the same as the quantized matmul chapter: declare the primitive in tiny_llm_ext.h, expose it in bindings.cpp, and register flash_attention.cpp in CMakeLists.txt.

Before creating the lazy output array, validate all shape and dtype constraints in C++: inputs should be 3D float32 tensors, num_heads must be divisible by num_kv_heads, and head mapping between Q and KV batches must be consistent.

Then implement FlashAttention::eval_cpu(...) with tiled online softmax. Use Br = 32 and Bc = 32, and the rationale for this choice will be explained in the GPU section. Iterate over (n, i, j) tiles, map query heads to KV heads with q_kv_heads_ratio = num_heads / num_kv_heads, and accumulate in float32. Mask values should be applied in each tile before updating m_i and l_i.

When mask == "causal", treat it as a block-level optimization opportunity: if a tile is fully invalid, skip that tile entirely; if a tile is fully valid, skip mask read/add for that tile and continue with matmul + online softmax. Also note that L and S are not always equal in causal attention, so do not hardcode logic that assumes L == S.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 4 -- -k task_2

Task 3: Implement flash_attention (GPU version)

src/extensions/src/flash_attention.metal
src/extensions/src/flash_attention.cpp
src/extensions/CMakeLists.txt

Now implement the GPU path for the same algorithm.

GPU Parallelization Strategy

The key to an efficient GPU implementation is understanding how to map the tiled algorithm to Metal’s execution model.

Why Br = 32 and Bc = 32?

The tile sizes are not arbitraryβ€”they are constrained by Apple GPU hardware:

ConstraintSourceValue
SIMD widthApple GPU fixed32
Max threads per threadgroupHardware limit1024
Bc= SIMD width (for efficient simd_sum/simd_max)32
Br= 1024 / 3232
Threadgroup memory32KB limitFits q_local[32][128] + o_i[32][128]

With Br=32 and Bc=32, we get 32Γ—32 = 1024 threads per threadgroup, which exactly fills the hardware limit.

Grid and Threadgroup Layout

Grid (num_threadgroups):
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ TG(0, 0)              β”‚ TG(1, 0)              β”‚ TG(2, 0)              β”‚
β”‚ head=0, qtile=0       β”‚ head=1, qtile=0       β”‚ head=2, qtile=0       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ TG(0, 1)              β”‚ TG(1, 1)              β”‚ TG(2, 1)              β”‚
β”‚ head=0, qtile=1       β”‚ head=1, qtile=1       β”‚ head=2, qtile=1       β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ...                   β”‚ ...                   β”‚ ...                   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
     X: N (heads)         Y: Tr (query blocks)

Each threadgroup is responsible for one (head, Q-tile) output block.

Thread Mapping Within a Threadgroup

Each threadgroup handles one Q block (size BrΓ—E) for one head. Inside the threadgroup:

Threadgroup = 32 SIMD groups Γ— 32 threads/group = 1024 threads

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ SIMD group 0  β†’ Q[0, :]  (handles row 0)       β”‚ ← 32 threads
β”‚ SIMD group 1  β†’ Q[1, :]  (handles row 1)       β”‚ ← 32 threads
β”‚ SIMD group 2  β†’ Q[2, :]  (handles row 2)       β”‚ ← 32 threads
β”‚ ...                                             β”‚
β”‚ SIMD group 31 β†’ Q[31, :] (handles row 31)      β”‚ ← 32 threads
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Inside that single threadgroup, the kernel runs a serial loop over all K/V tiles j = 0..Tc-1.

Computing S = Q @ K^T

Each thread computes one element of the 32Γ—32 score matrix. Here’s how the matrix multiplication maps to threads:

Q block [Br=32, E=128]              K^T [E=128, Bc=32]
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”           β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€...─┬───┐
β”‚ Q[0,:]  (128 elements)β”‚           β”‚   β”‚   β”‚   β”‚     β”‚   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€           β”‚ K β”‚ K β”‚ K β”‚     β”‚ K β”‚
β”‚ Q[1,:]                β”‚           β”‚[0]β”‚[1]β”‚[2]β”‚ ... β”‚[31]β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€     @     β”‚ T β”‚ T β”‚ T β”‚     β”‚ T β”‚
β”‚ Q[2,:]                β”‚           β”‚   β”‚   β”‚   β”‚     β”‚   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€           β”‚128β”‚128β”‚128β”‚     β”‚128β”‚
β”‚ ...                   β”‚           β”‚   β”‚   β”‚   β”‚     β”‚   β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€           β”‚   β”‚   β”‚   β”‚     β”‚   β”‚
β”‚ Q[31,:]               β”‚           β”‚   β”‚   β”‚   β”‚     β”‚   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜           └───┴───┴───┴─...β”€β”΄β”€β”€β”€β”˜
        ↑                                 ↑
   simd_gid = a                      simd_lid = b
   (which row)                       (which column)

Result: S block [Br=32, Bc=32], each element computed by one thread:

                    simd_lid (b)
              0     1     2    ...   31
            β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€...─┬─────┐
          0 β”‚S0,0 β”‚S0,1 β”‚S0,2 β”‚     β”‚S0,31β”‚  ← SIMD group 0 (32 threads)
            β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€...─┼──────
simd_gid  1 β”‚S1,0 β”‚S1,1 β”‚S1,2 β”‚     β”‚S1,31β”‚  ← SIMD group 1
  (a)       β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€...─┼──────
          2 β”‚S2,0 β”‚S2,1 β”‚S2,2 β”‚     β”‚S2,31β”‚  ← SIMD group 2
            β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€...─┼──────
        ... β”‚ ... β”‚ ... β”‚ ... β”‚     β”‚ ... β”‚
            β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€...─┼──────
         31 β”‚S31,0β”‚S31,1β”‚S31,2β”‚     β”‚S31,31β”‚ ← SIMD group 31
            └─────┴─────┴─────┴─...β”€β”΄β”€β”€β”€β”€β”€β”˜

Thread (a=2, b=5) computes:
  S[2,5] = Q[2,0]*K[5,0] + Q[2,1]*K[5,1] + ... + Q[2,127]*K[5,127]
         = dot product of Q row 2 with K row 5 (128 multiply-adds)

After computing S[a,b], each thread holds one attention score. Row-wise reductions use SIMD intrinsicsβ€”all 32 threads in the same SIMD group cooperate:

SIMD group 2 (threads with simd_gid=2):
  Thread b=0 has S[2,0]
  Thread b=1 has S[2,1]
  ...
  Thread b=31 has S[2,31]

  simd_max(s_a_b) β†’ all 32 threads get max(S[2,0], S[2,1], ..., S[2,31])
  simd_sum(p_a_b) β†’ all 32 threads get sum(P[2,0], P[2,1], ..., P[2,31])
float rowmax = simd_max(s_a_b);  // max across 32 threads in same SIMD group
float rowsum = simd_sum(p_a_b);  // sum across 32 threads in same SIMD group

Computing O = P @ V inside a SIMD group

After softmax, we need to accumulate the output tile. A natural first thought is: β€œCan we assign threads to output elements the same way we did for S = Q @ K^T?” The answer is no, because the output dimensions don’t match:

Q @ K^T:                         P @ V:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Q       β”‚   β”‚ K^T     β”‚       β”‚ P       β”‚   β”‚ V               β”‚
β”‚[Br, E]  β”‚ @ β”‚[E, Bc]  β”‚       β”‚[Br, Bc] β”‚ @ β”‚[Bc, E]          β”‚
β”‚[32,128] β”‚   β”‚[128,32] β”‚       β”‚[32, 32] β”‚   β”‚[32, 128]        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
         ↓                               ↓
   S [Br, Bc]                      O [Br, E]
   [32, 32]                        [32, 128]
   = 1024 elements                 = 4096 elements
        ↓                               ↓
   1024 threads βœ“                  1024 threads βœ—
   (one per element)               (not enough!)

For S = Q @ K^T, we have 1024 output elements and 1024 threadsβ€”perfect one-to-one mapping. But for O = P @ V, we have 4096 output elements but only 1024 threads. The mismatch comes from the embedding dimension: E = 128 β‰  Bc = 32.

So we use a different strategy: instead of assigning threads to output columns, we loop over the 128 output columns and use SIMD reduction for each:

For each output element O[a, c]:
  
  O[a, c] = sum over b: P[a, b] * V[b, c]
            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                   32 terms (Bc = 32)
                         ↓
            simd_sum can handle this!

  Thread assignment:
    - simd_gid = a (which output row)
    - simd_lid = b (which term in the sum)
    
  Code:
    for c in 0..E-1:                      // loop 128 times
        val = P[a, b] * V[b, c]           // each lane computes one term
        result = simd_sum(val)            // reduce 32 terms β†’ 1 result
        if simd_lid == 0:
            o_i[a, c] += result           // only lane 0 writes

The key insight: even though we can’t parallelize over the E dimension (because E > SIMD width), we can parallelize the reduction over Bc = 32, which matches SIMD width exactly.

Memory Hierarchy

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Global Memory (HBM)                                     β”‚
β”‚ Q[N, L, E], K[N_kv, S, E], V[N_kv, S, E]               β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    ↓ load once per Q block
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Threadgroup Memory (SRAM, 32KB)                         β”‚
β”‚ q_local[Br][E]  ← Q block, reused for all Tc iterations β”‚
β”‚ o_i[Br][E]      ← accumulated output                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                    ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Registers (per thread)                                  β”‚
β”‚ m_i, l_i, s_a_b, p_a_b                                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

K and V blocks are streamed from global memory in the inner loop over Tc. The Q block is loaded once into threadgroup memory and reused across all K/V tiles.

Implementation

In flash_attention.metal, write flash_attention_f32_e128 with one threadgroup per (n, i) tile, where n is the flattened head batch and i is the query tile index. Use threadgroup memory for local Q and partial O, and use SIMD reductions (simd_max, simd_sum) for row-wise max/sum updates.

In eval_gpu(...), load the kernel from the extension, bind inputs/outputs and scalar constants (N, L, S, E, head counts, scale, tile sizes), and dispatch over (N, Tr, 1). Keep the same contiguous checks as CPU path. Also remember to add src/flash_attention.metal into mlx_build_metallib(...) in CMakeLists.txt.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 4 -- -k task_3

Task 4: Model Integration

src/tiny_llm/qwen2_week2.py

Finally, wire the kernel into model execution. Keep the existing grouped attention path as fallback, add the use_flash_attention switch in Qwen2MultiHeadAttention, and propagate enable_flash_attn from model initialization into each block. After KV cache update, build the correct causal mask for L x S, run attention in float32, and cast back to activation dtype.

You can run generation with Flash Attention enabled:

pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b --enable-flash-attn

You can also benchmark throughput with and without Flash Attention:

pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b
pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b --enable-flash-attn

Your feedback is greatly appreciated. Welcome to join our Discord Community.
Found an issue? Create an issue / pull request on github.com/skyzh/tiny-llm.
tiny-llm-book Β© 2025 by Alex Chi Z is licensed under CC BY-NC-SA 4.0.