Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Week 2 Day 2-3: Quantized Matmul

In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.

πŸ“š Readings

Why Quantization?

As we learned in the KV Cache chapter, the decode phase of LLM inference is memory-bandwidth bound. Let’s revisit the arithmetic intensity calculation for the Qwen2-0.5B model:

Per-token computation in decode phase:
- Input: 1 token Γ— 896 dimensions = 896 float16 values = 1.792 KB
- MLP weights: 896 Γ— 4864 Γ— 3 matrices Γ— 2 bytes = ~25 MB per layer
- Attention weights: 896 Γ— 896 Γ— 4 matrices Γ— 2 bytes = ~6 MB per layer
- Total weights per layer: ~31 MB
- Total for 24 layers: ~750 MB

FLOPs (2 per multiply-accumulate):
- MLP per layer: 2 Γ— 3 Γ— 896 Γ— 4864 β‰ˆ 26M
- Attention per layer: 2 Γ— 4 Γ— 896 Γ— 896 β‰ˆ 6.4M
- 24 layers: ~780 million per token

Memory access: ~750 MB
Arithmetic intensity: 780M FLOPs / 750 MB β‰ˆ 1.0 FLOPs/Byte

With M3 Max’s 400 GB/s memory bandwidth and ~10 TFLOPS compute:

Memory-bound throughput: 400 GB/s Γ— 1.0 FLOPs/Byte = 400 GFLOPS
Compute-bound throughput: 10 TFLOPS

We're using only ~4% of available compute!

The Solution: Quantization

By compressing weights from 16 bits (float16/bfloat16) to 4 bits (int4), we:

  • Reduce memory bandwidth by 4Γ—: 750 MB β†’ ~190 MB per token
  • Improve arithmetic intensity by 4Γ—: 1.0 β†’ ~4.0 FLOPs/Byte
  • Increase throughput by ~4Γ—: 400 GFLOPS β†’ ~1.6 TFLOPS

The tradeoff is minimal accuracy loss with proper quantization techniques.

Group-wise Quantization

Instead of quantizing all weights uniformly, we divide them into groups and quantize each group independently. This preserves more information about the weight distribution.

For a weight matrix of shape , we divide each row into groups of size (typically 64 or 128):

Original weight matrix W: K Γ— N (float16/bfloat16)

Group size G = 64
Number of groups per row = N / G

For each group of 64 consecutive values in a row:
  1. Find min and max values
  2. Compute scale and bias to map [min, max] β†’ [0, 15] (4-bit range)
  3. Quantize each value using: quantized = round((value - bias) / scale)

Affine Quantization

We use affine (asymmetric) quantization which maps a floating-point range to the full integer range:

For 4-bit quantization, the quantized values are in the range .

Given a group with minimum value and maximum value :

Example:

Group values: [-0.5, -0.3, 0.1, 0.4, 0.8]
min = -0.5, max = 0.8

scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 β‰ˆ 0.0867
bias = -0.5

Quantization:
  -0.5 β†’ round((-0.5 - (-0.5)) / 0.0867) = 0
  -0.3 β†’ round((-0.3 - (-0.5)) / 0.0867) = 2
   0.1 β†’ round((0.1 - (-0.5)) / 0.0867) = 7
   0.4 β†’ round((0.4 - (-0.5)) / 0.0867) = 10
   0.8 β†’ round((0.8 - (-0.5)) / 0.0867) = 15

Quantized: [0, 2, 7, 10, 15] (4 bits each)

Storage Format

For efficient storage and computation, quantized weights are packed:

Original: K Γ— N float16 (2 bytes each) = 2KN bytes
Quantized: K Γ— N int4 (0.5 bytes each) = 0.5KN bytes

Packing: 8 Γ— 4-bit values fit in one uint32 (32 bits)

Weight matrix shape: K Γ— N
Quantized storage shape: K Γ— (N / 8) uint32
Scales shape: K Γ— (N / 64) float16
Biases shape: K Γ— (N / 64) float16

Example packing for 8 consecutive 4-bit values [a, b, c, d, e, f, g, h]:

uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) |
               (d << 12) | (c << 8)  | (b << 4)  | a

Unpacking:
  a = (uint32_value >> 0)  & 0xF
  b = (uint32_value >> 4)  & 0xF
  c = (uint32_value >> 8)  & 0xF
  ...
  h = (uint32_value >> 28) & 0xF

Quantized Matrix Multiplication

Mathematical Formulation

For standard matrix multiplication where:

  • : shape , float16/bfloat16 (activations)
  • : shape , quantized to int4 (weights)
  • : shape , float16/bfloat16 (output)

Each element is computed as:

With quantization, is represented as:

where is the group index.

Substituting:

Rearranging:

This shows we can factor out the scale and bias per group, reducing the number of floating-point operations.

Computation Flow

Input:
  A: M Γ— N (float16, activations)
  B_quantized: K Γ— (N/8) (uint32, packed weights)
  scales: K Γ— (N/64) (float16)
  biases: K Γ— (N/64) (float16)

Output:
  C: M Γ— K (float16)

For each output element C[i, k]:
  sum = 0
  for each group g in 0..(N/64 - 1):
    scale = scales[k, g]
    bias = biases[k, g]
    
    # Process 64 values in the group (8 uint32 packs)
    for each pack p in 0..7:
      packed_value = B_quantized[k, g*8 + p]
      
      # Unpack 8 Γ— 4-bit values
      for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]:
        quantized = (packed_value >> bit_offset) & 0xF
        b_value = quantized * scale + bias
        a_value = A[i, g*64 + p*8 + bit_offset/4]
        sum += a_value * b_value
  
  C[i, k] = sum

Task 1: Implement QuantizedWeights

src/tiny_llm/quantize.py

First, familiarize yourself with the QuantizedWeights class, which stores quantized weight information:

FieldShapeDescription
weight uint32Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape , and after packing, it becomes .
scales float16Per-group scale factors for dequantization. Each group of consecutive values shares one scale. Recall:
biases float16Per-group bias (offset) for dequantization. Recall:
group_sizeintNumber of consecutive values that share the same scale/bias (typically 64)
bitsintQuantization bit width (typically 4, meaning values are in range )

The from_mlx_layer static method extracts these fields from MLX’s quantized linear layers when loading the model.

Next, implement the quantized_linear function, which is a wrapper around quantized_matmul that mimics the standard linear function interface. And we’ll implement quantized_matmul in the next task.

Task 2: Implement quantized_matmul (CPU version)

In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing axpby example in the codebase β€” read through axpby.h, axpby.cpp, and the corresponding binding in bindings.cpp first as your reference.

src/extensions/src/tiny_llm_ext.h
src/extensions/bindings.cpp
src/extensions/src/quantized_matmul.cpp
src/extensions/CMakeLists.txt

You need to touch three files, all within the tiny_llm_ext namespace:

  • tiny_llm_ext.h β€” Declare the quantized_matmul(...) function signature and define a QuantizedMatmul primitive class (inheriting mx::Primitive). Store group_size and bits as private members.
  • bindings.cpp β€” Add an m.def(...) call to expose the function to Python.
  • quantized_matmul.cpp β€” Implement the quantized_matmul(...) function (validate inputs, compute output shape, return a lazy mx::array) and the eval_cpu method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).

The eval_cpu implementation follows the same CPU encoder pattern as axpby: allocate output memory with out.set_data(mx::allocator::malloc(out.nbytes())), register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above β€” iterate over each output element (i, k), accumulate in float (fp32) to avoid precision loss, and cast the result back to float16 when writing to the output.

Don’t forget to add src/quantized_matmul.cpp to target_sources in CMakeLists.txt.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_2

Task 3: Implement quantized_matmul (GPU version)

src/extensions/src/quantized_matmul.metal
src/extensions/src/quantized_matmul.cpp

In this task, you will write the Metal kernel for quantized matmul and wire up the eval_gpu method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes.

Metal Kernel

You need to implement one kernel entry in quantized_matmul.metal:

  • Use a one-thread-per-output-element mapping: each thread computes out[i, k].
  • The kernel should be templated on the data type (to support both half and bfloat16_t).
  • Apply the same group-wise dequantization loop as the CPU version:
    • Iterate over groups (group_size = 64)
    • Unpack int4 values from packed uint32
    • Dequantize with q * scale + bias
    • Accumulate in float and cast to the output dtype at the end
  • Add boundary checks (i < M, k < K) before writing output.

GPU Dispatch

Complete the eval_gpu method in quantized_matmul.cpp to dispatch your Metal kernel. Follow the same pattern as axpby’s GPU dispatch:

  1. Get the Metal device and command encoder from the stream.
  2. Select the correct kernel name based on the activation dtype (float16 β†’ half, bfloat16 β†’ bfloat16_t).
  3. Set input/output buffers and dimension constants (M, N, K) on the encoder β€” make sure the buffer order matches your kernel signature.
  4. Calculate a 2D thread group configuration: use kernel->maxTotalThreadsPerThreadgroup() to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K).
  5. Dispatch with dispatchThreadgroups.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_3

Task 4: Model Integration

src/tiny_llm/qwen2_week2.py

Integrate your quantized matmul into the Week 2 Qwen2 model so that inference runs on quantized weights end-to-end.

Change the weight type from mx.array to QuantizedWeights for all linear layers in attention (wq/wk/wv/wo) and MLP (w_gate/w_up/w_down). Replace every linear(x, w) call with quantized_linear(x, w). In the model loading code, use QuantizedWeights.from_mlx_layer(...) to extract quantized weight information from each MLX linear layer, instead of calling mx.dequantize to get a full float16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain mx.array), while the Week 2 loader does not dequantize.

Note that MLX loads quantized models with scales and biases stored in bfloat16 by default, while the activation tensors are typically float16. Since we have not implemented bfloat16 support in our kernel, you will need to convert the scales and biases to float16 with mx.astype before calling the kernel. If you see nan or garbage output, a dtype mismatch is the most likely cause.

You can test your implementation by running:

pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b

You can also benchmark throughput and compare your implementation with the reference solution:

pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b
pdm bench --solution tiny_llm_ref --loader week2 --model qwen2-0.5b

Your feedback is greatly appreciated. Welcome to join our Discord Community.
Found an issue? Create an issue / pull request on github.com/skyzh/tiny-llm.
tiny-llm-book Β© 2025 by Alex Chi Z is licensed under CC BY-NC-SA 4.0.