Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Week 2 Day 2-3: Quantized Matmul

In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.

πŸ“š Readings

Why Quantization?

As we learned in the KV Cache chapter, the decode phase of LLM inference is memory-bandwidth bound. Let’s revisit the arithmetic intensity calculation for the Qwen3-0.6B model:

Per-token linear layers in decode phase:
- Input: 1 token Γ— 1024 dimensions = 1024 bfloat16 values = 2 KB
- MLP weights: 1024 Γ— 3072 Γ— 3 matrices Γ— 2 bytes = ~19 MB per layer
- Attention weights:
  - q_proj / o_proj: 1024 Γ— 2048 Γ— 2 matrices Γ— 2 bytes = ~8 MB per layer
  - k_proj / v_proj: 1024 Γ— 1024 Γ— 2 matrices Γ— 2 bytes = ~4 MB per layer
- Total weights per layer: ~31 MB
- Total for 28 layers: ~880 MB

FLOPs (2 per multiply-accumulate):
- MLP per layer: 2 Γ— 3 Γ— 1024 Γ— 3072 β‰ˆ 19M
- Attention projections per layer: 2 Γ— (1024 Γ— 2048 Γ— 2 + 1024 Γ— 1024 Γ— 2) β‰ˆ 13M
- 28 layers: ~880 million per token

Memory access: ~880 MB
Arithmetic intensity: 880M FLOPs / 880 MB β‰ˆ 1.0 FLOPs/Byte

With M3 Max’s 400 GB/s memory bandwidth and ~10 TFLOPS compute:

Memory-bound throughput: 400 GB/s Γ— 1.0 FLOPs/Byte = 400 GFLOPS
Compute-bound throughput: 10 TFLOPS

We're using only ~4% of available compute!

The Solution: Quantization

By compressing weights from 16-bit floating point (float16 or bfloat16) to 4-bit integers (int4), we:

  • Reduce memory bandwidth by 4Γ—: 880 MB β†’ ~220 MB per token
  • Improve arithmetic intensity by 4Γ—: 1.0 β†’ ~4.0 FLOPs/Byte
  • Increase throughput by ~4Γ—: 400 GFLOPS β†’ ~1.6 TFLOPS

The tradeoff is minimal accuracy loss with proper quantization techniques.

Group-wise Quantization

Instead of quantizing all weights uniformly, we divide them into groups and quantize each group independently. This preserves more information about the weight distribution.

For a weight matrix of shape , we divide each row into groups of size . In this course we use Qwen3 MLX 4-bit weights, whose group size is fixed at 128:

Original weight matrix W: K Γ— N (float16 or bfloat16)

Group size: G = 128
Number of groups per row = N / G

For each group of G consecutive values in a row:
  1. Find min and max values
  2. Compute scale and bias to map [min, max] β†’ [0, 15] (4-bit range)
  3. Quantize each value using: quantized = round((value - bias) / scale)

All quantized matmul tests use group_size = 128, matching the Qwen3 MLX 4-bit weights used by the rest of the course. The tests cover both float16 and bfloat16 because different MLX checkpoints store their scales, biases, and activations in different 16-bit data types.

Affine Quantization

We use affine (asymmetric) quantization which maps a floating-point range to the full integer range:

For 4-bit quantization, the quantized values are in the range .

Given a group with minimum value and maximum value :

Example:

Group values: [-0.5, -0.3, 0.1, 0.4, 0.8]
min = -0.5, max = 0.8

scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 β‰ˆ 0.0867
bias = -0.5

Quantization:
  -0.5 β†’ round((-0.5 - (-0.5)) / 0.0867) = 0
  -0.3 β†’ round((-0.3 - (-0.5)) / 0.0867) = 2
   0.1 β†’ round((0.1 - (-0.5)) / 0.0867) = 7
   0.4 β†’ round((0.4 - (-0.5)) / 0.0867) = 10
   0.8 β†’ round((0.8 - (-0.5)) / 0.0867) = 15

Quantized: [0, 2, 7, 10, 15] (4 bits each)

Storage Format

For efficient storage and computation, quantized weights are packed:

Original: K Γ— N float16/bfloat16 (2 bytes each) = 2KN bytes
Quantized: K Γ— N int4 (0.5 bytes each) = 0.5KN bytes

Packing: 8 Γ— 4-bit values fit in one uint32 (32 bits)

Weight matrix shape: K Γ— N
Quantized storage shape: K Γ— (N / 8) uint32
Scales shape: K Γ— (N / G) float16/bfloat16
Biases shape: K Γ— (N / G) float16/bfloat16

Example packing for 8 consecutive 4-bit values [a, b, c, d, e, f, g, h]:

uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) |
               (d << 12) | (c << 8)  | (b << 4)  | a

Unpacking:
  a = (uint32_value >> 0)  & 0xF
  b = (uint32_value >> 4)  & 0xF
  c = (uint32_value >> 8)  & 0xF
  ...
  h = (uint32_value >> 28) & 0xF

Quantized Matrix Multiplication

Mathematical Formulation

For standard matrix multiplication where:

  • : shape , float16 or bfloat16 (activations)
  • : shape , quantized to int4 (weights)
  • : shape , same 16-bit dtype as (output)

Each element is computed as:

With quantization, is represented as:

where is the group index.

Substituting:

Rearranging:

This shows we can factor out the scale and bias per group, reducing the number of floating-point operations.

Computation Flow

Input:
  A: M Γ— N (float16 or bfloat16, activations)
  B_quantized: K Γ— (N/8) (uint32, packed weights)
  scales: K Γ— (N/G) (float16/bfloat16)
  biases: K Γ— (N/G) (float16/bfloat16)

Output:
  C: M Γ— K (float16/bfloat16)

For each output element C[i, k]:
  sum = 0  # float accumulator
  for each group g in 0..(N/G - 1):
    scale = scales[k, g]
    bias = biases[k, g]
    
    # Process G values in the group (G/8 uint32 packs)
    for each pack p in 0..(G/8 - 1):
      packed_value = B_quantized[k, g*(G/8) + p]
      
      # Unpack 8 Γ— 4-bit values
      for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]:
        quantized = (packed_value >> bit_offset) & 0xF
        b_value = quantized * scale + bias
        a_value = A[i, g*G + p*8 + bit_offset/4]
        sum = sum + a_value * b_value
  
  C[i, k] = float16/bfloat16(sum)

Task 1: Implement QuantizedWeights

src/tiny_llm/quantize.py

First, familiarize yourself with the QuantizedWeights class, which stores quantized weight information:

FieldShapeDescription
weight uint32Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape , and after packing, it becomes .
scales float16/bfloat16Per-group scale factors for dequantization. Each group of consecutive values shares one scale. Recall:
biases float16/bfloat16Per-group bias (offset) for dequantization. Recall:
group_sizeintNumber of consecutive values that share the same scale/bias. For the Qwen3 MLX 4-bit weights used here, this is 128.
bitsintQuantization bit width (typically 4, meaning values are in range )

The from_mlx_layer static method extracts these fields from MLX’s quantized linear layers when loading the model.

Next, implement the quantized_linear function, which is a wrapper around quantized_matmul that mimics the standard linear function interface. And we’ll implement quantized_matmul in the next task.

Task 2: Implement quantized_matmul (CPU version)

In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing axpby example in the codebase β€” read through axpby.h, axpby.cpp, and the corresponding binding in bindings.cpp first as your reference.

src/extensions/src/tiny_llm_ext.h
src/extensions/bindings.cpp
src/extensions/src/quantized_matmul.cpp
src/extensions/CMakeLists.txt

You need to touch three files, all within the tiny_llm_ext namespace:

  • tiny_llm_ext.h β€” Declare the quantized_matmul(...) function signature and define a QuantizedMatmul primitive class (inheriting mx::Primitive). Store group_size and bits as private members.
  • bindings.cpp β€” Add an m.def(...) call to expose the function to Python.
  • quantized_matmul.cpp β€” Implement the quantized_matmul(...) function (validate inputs, compute output shape, return a lazy mx::array) and the eval_cpu method (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).

The eval_cpu implementation follows the same CPU encoder pattern as axpby: allocate output memory with out.set_data(mx::allocator::malloc(out.nbytes())), register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above β€” iterate over each output element (i, k), dequantize each packed value, accumulate the products in float, and write the result back as either float16 or bfloat16, matching the input dtype.

Follow the axpby dtype-dispatch pattern here: write the CPU implementation as a template, then dispatch with mx::float16_t or mx::bfloat16_t based on the output dtype.

Don’t forget to add src/quantized_matmul.cpp to target_sources in CMakeLists.txt.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_2

Task 3: Implement quantized_matmul (GPU version)

src/extensions/src/quantized_matmul.metal
src/extensions/src/quantized_matmul.cpp

In this task, you will write the Metal kernel for quantized matmul and wire up the eval_gpu method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes.

Metal Kernel

You need to implement one kernel entry in quantized_matmul.metal:

  • Use a one-thread-per-output-element mapping: each thread computes out[i, k].
  • The kernel should support both half and bfloat16_t inputs and outputs.
  • Apply the same group-wise dequantization loop as the CPU version:
    • Iterate over groups of 128 values
    • Unpack int4 values from packed uint32
    • Dequantize with q * scale + bias
    • Accumulate the products in float and cast the final output back to the kernel dtype
  • Add boundary checks (i < M, k < K) before writing output.

The custom kernel only needs to handle bits = 4 and group_size = 128. Use that group size to compute groups_per_row and the packed weight offsets. Instantiate the same templated Metal kernel twice, once for half and once for bfloat16_t, and select the matching kernel name in eval_gpu.

GPU Dispatch

Complete the eval_gpu method in quantized_matmul.cpp to dispatch your Metal kernel. Follow the same pattern as axpby’s GPU dispatch:

  1. Get the Metal device and command encoder from the stream.
  2. Load the quantized matmul kernel matching the output dtype from the Metal library.
  3. Set input/output buffers and dimension constants (M, N, K) on the encoder β€” make sure the buffer order matches your kernel signature.
  4. Calculate a 2D thread group configuration: use kernel->maxTotalThreadsPerThreadgroup() to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K).
  5. Dispatch with dispatchThreadgroups.

You can test your implementation by running:

pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_3

Task 4: Model Integration

src/tiny_llm/qwen3_week2.py

Integrate your quantized matmul into the Week 2 Qwen3 model so that inference runs on quantized weights end-to-end.

Change the weight type from mx.array to QuantizedWeights for all linear layers in attention (wq/wk/wv/wo) and MLP (w_gate/w_up/w_down). Replace every linear(x, w) call with quantized_linear(x, w). In the model loading code, use QuantizedWeights.from_mlx_layer(...) to extract quantized weight information from each MLX linear layer, instead of calling mx.dequantize to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain mx.array), while the Week 2 loader does not dequantize.

Qwen3 MLX quantized layers may use float16 or bfloat16 for the tensors involved in dequantization. Your kernel should accept scales, biases, and activations in either dtype, require them to match, and return the same dtype. If you see nan or garbage output, a dtype mismatch is the most likely cause.

Also keep the quantized layer’s parameters. The model code should pass through w.group_size and w.bits; the extension should validate that they match the Qwen3 course assumptions: group_size = 128 and bits = 4.

You can test your implementation by running:

pdm run main --solution tiny_llm --loader week2 --model qwen3-0.6b

You can also benchmark throughput and compare your implementation with the reference solution:

pdm bench --solution tiny_llm --loader week2 --model qwen3-0.6b
pdm bench --solution tiny_llm_ref --loader week2 --model qwen3-0.6b

Your feedback is greatly appreciated. Welcome to join our Discord Community.
Found an issue? Create an issue / pull request on github.com/skyzh/tiny-llm.
tiny-llm-book Β© 2025 by Alex Chi Z is licensed under CC BY-NC-SA 4.0.