Week 2 Day 2-3: Quantized Matmul
In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.
π Readings
- Model Compression and Quantization
- MLX Extensions Development Guide
- Quantized Matmul on CPU (Video)
- Quantized Matmul on GPU (Video)
Why Quantization?
As we learned in the KV Cache chapter, the decode phase of LLM inference is memory-bandwidth bound. Letβs revisit the arithmetic intensity calculation for the Qwen3-0.6B model:
Per-token linear layers in decode phase:
- Input: 1 token Γ 1024 dimensions = 1024 bfloat16 values = 2 KB
- MLP weights: 1024 Γ 3072 Γ 3 matrices Γ 2 bytes = ~19 MB per layer
- Attention weights:
- q_proj / o_proj: 1024 Γ 2048 Γ 2 matrices Γ 2 bytes = ~8 MB per layer
- k_proj / v_proj: 1024 Γ 1024 Γ 2 matrices Γ 2 bytes = ~4 MB per layer
- Total weights per layer: ~31 MB
- Total for 28 layers: ~880 MB
FLOPs (2 per multiply-accumulate):
- MLP per layer: 2 Γ 3 Γ 1024 Γ 3072 β 19M
- Attention projections per layer: 2 Γ (1024 Γ 2048 Γ 2 + 1024 Γ 1024 Γ 2) β 13M
- 28 layers: ~880 million per token
Memory access: ~880 MB
Arithmetic intensity: 880M FLOPs / 880 MB β 1.0 FLOPs/Byte
With M3 Maxβs 400 GB/s memory bandwidth and ~10 TFLOPS compute:
Memory-bound throughput: 400 GB/s Γ 1.0 FLOPs/Byte = 400 GFLOPS
Compute-bound throughput: 10 TFLOPS
We're using only ~4% of available compute!
The Solution: Quantization
By compressing weights from 16-bit floating point (float16 or bfloat16) to
4-bit integers (int4), we:
- Reduce memory bandwidth by 4Γ: 880 MB β ~220 MB per token
- Improve arithmetic intensity by 4Γ: 1.0 β ~4.0 FLOPs/Byte
- Increase throughput by ~4Γ: 400 GFLOPS β ~1.6 TFLOPS
The tradeoff is minimal accuracy loss with proper quantization techniques.
Group-wise Quantization
Instead of quantizing all weights uniformly, we divide them into groups and quantize each group independently. This preserves more information about the weight distribution.
For a weight matrix of shape , we divide each row into groups of size . In this course we use Qwen3 MLX 4-bit weights, whose group size is fixed at 128:
Original weight matrix W: K Γ N (float16 or bfloat16)
Group size: G = 128
Number of groups per row = N / G
For each group of G consecutive values in a row:
1. Find min and max values
2. Compute scale and bias to map [min, max] β [0, 15] (4-bit range)
3. Quantize each value using: quantized = round((value - bias) / scale)
All quantized matmul tests use group_size = 128, matching the Qwen3 MLX
4-bit weights used by the rest of the course. The tests cover both float16
and bfloat16 because different MLX checkpoints store their scales, biases,
and activations in different 16-bit data types.
Affine Quantization
We use affine (asymmetric) quantization which maps a floating-point range to the full integer range:
For 4-bit quantization, the quantized values are in the range .
Given a group with minimum value and maximum value :
Example:
Group values: [-0.5, -0.3, 0.1, 0.4, 0.8]
min = -0.5, max = 0.8
scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 β 0.0867
bias = -0.5
Quantization:
-0.5 β round((-0.5 - (-0.5)) / 0.0867) = 0
-0.3 β round((-0.3 - (-0.5)) / 0.0867) = 2
0.1 β round((0.1 - (-0.5)) / 0.0867) = 7
0.4 β round((0.4 - (-0.5)) / 0.0867) = 10
0.8 β round((0.8 - (-0.5)) / 0.0867) = 15
Quantized: [0, 2, 7, 10, 15] (4 bits each)
Storage Format
For efficient storage and computation, quantized weights are packed:
Original: K Γ N float16/bfloat16 (2 bytes each) = 2KN bytes
Quantized: K Γ N int4 (0.5 bytes each) = 0.5KN bytes
Packing: 8 Γ 4-bit values fit in one uint32 (32 bits)
Weight matrix shape: K Γ N
Quantized storage shape: K Γ (N / 8) uint32
Scales shape: K Γ (N / G) float16/bfloat16
Biases shape: K Γ (N / G) float16/bfloat16
Example packing for 8 consecutive 4-bit values [a, b, c, d, e, f, g, h]:
uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) |
(d << 12) | (c << 8) | (b << 4) | a
Unpacking:
a = (uint32_value >> 0) & 0xF
b = (uint32_value >> 4) & 0xF
c = (uint32_value >> 8) & 0xF
...
h = (uint32_value >> 28) & 0xF
Quantized Matrix Multiplication
Mathematical Formulation
For standard matrix multiplication where:
- : shape , float16 or bfloat16 (activations)
- : shape , quantized to int4 (weights)
- : shape , same 16-bit dtype as (output)
Each element is computed as:
With quantization, is represented as:
where is the group index.
Substituting:
Rearranging:
This shows we can factor out the scale and bias per group, reducing the number of floating-point operations.
Computation Flow
Input:
A: M Γ N (float16 or bfloat16, activations)
B_quantized: K Γ (N/8) (uint32, packed weights)
scales: K Γ (N/G) (float16/bfloat16)
biases: K Γ (N/G) (float16/bfloat16)
Output:
C: M Γ K (float16/bfloat16)
For each output element C[i, k]:
sum = 0 # float accumulator
for each group g in 0..(N/G - 1):
scale = scales[k, g]
bias = biases[k, g]
# Process G values in the group (G/8 uint32 packs)
for each pack p in 0..(G/8 - 1):
packed_value = B_quantized[k, g*(G/8) + p]
# Unpack 8 Γ 4-bit values
for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]:
quantized = (packed_value >> bit_offset) & 0xF
b_value = quantized * scale + bias
a_value = A[i, g*G + p*8 + bit_offset/4]
sum = sum + a_value * b_value
C[i, k] = float16/bfloat16(sum)
Task 1: Implement QuantizedWeights
src/tiny_llm/quantize.py
First, familiarize yourself with the QuantizedWeights class, which stores quantized weight information:
| Field | Shape | Description |
|---|---|---|
weight | uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape , and after packing, it becomes . |
scales | float16/bfloat16 | Per-group scale factors for dequantization. Each group of consecutive values shares one scale. Recall: |
biases | float16/bfloat16 | Per-group bias (offset) for dequantization. Recall: |
group_size | int | Number of consecutive values that share the same scale/bias. For the Qwen3 MLX 4-bit weights used here, this is 128. |
bits | int | Quantization bit width (typically 4, meaning values are in range ) |
The from_mlx_layer static method extracts these fields from MLXβs quantized linear layers when loading the model.
Next, implement the quantized_linear function, which is a wrapper around quantized_matmul that mimics the standard linear function interface. And weβll implement quantized_matmul in the next task.
Task 2: Implement quantized_matmul (CPU version)
In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing axpby example in the codebase β read through axpby.h, axpby.cpp, and the corresponding binding in bindings.cpp first as your reference.
src/extensions/src/tiny_llm_ext.h
src/extensions/bindings.cpp
src/extensions/src/quantized_matmul.cpp
src/extensions/CMakeLists.txt
You need to touch three files, all within the tiny_llm_ext namespace:
tiny_llm_ext.hβ Declare thequantized_matmul(...)function signature and define aQuantizedMatmulprimitive class (inheritingmx::Primitive). Storegroup_sizeandbitsas private members.bindings.cppβ Add anm.def(...)call to expose the function to Python.quantized_matmul.cppβ Implement thequantized_matmul(...)function (validate inputs, compute output shape, return a lazymx::array) and theeval_cpumethod (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).
The eval_cpu implementation follows the same CPU encoder pattern as axpby: allocate output memory with out.set_data(mx::allocator::malloc(out.nbytes())), register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above β iterate over each output element (i, k), dequantize each packed value, accumulate the products in float, and write the result back as either float16 or bfloat16, matching the input dtype.
Follow the axpby dtype-dispatch pattern here: write the CPU implementation as
a template, then dispatch with mx::float16_t or mx::bfloat16_t based on the
output dtype.
Donβt forget to add src/quantized_matmul.cpp to target_sources in CMakeLists.txt.
You can test your implementation by running:
pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_2
Task 3: Implement quantized_matmul (GPU version)
src/extensions/src/quantized_matmul.metal
src/extensions/src/quantized_matmul.cpp
In this task, you will write the Metal kernel for quantized matmul and wire up the eval_gpu method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes.
Metal Kernel
You need to implement one kernel entry in quantized_matmul.metal:
- Use a one-thread-per-output-element mapping: each thread computes
out[i, k]. - The kernel should support both
halfandbfloat16_tinputs and outputs. - Apply the same group-wise dequantization loop as the CPU version:
- Iterate over groups of 128 values
- Unpack int4 values from packed
uint32 - Dequantize with
q * scale + bias - Accumulate the products in
floatand cast the final output back to the kernel dtype
- Add boundary checks (
i < M,k < K) before writing output.
The custom kernel only needs to handle bits = 4 and group_size = 128. Use that group size to compute groups_per_row and the packed weight offsets.
Instantiate the same templated Metal kernel twice, once for half and once for
bfloat16_t, and select the matching kernel name in eval_gpu.
GPU Dispatch
Complete the eval_gpu method in quantized_matmul.cpp to dispatch your Metal kernel. Follow the same pattern as axpbyβs GPU dispatch:
- Get the Metal device and command encoder from the stream.
- Load the quantized matmul kernel matching the output dtype from the Metal library.
- Set input/output buffers and dimension constants (
M,N,K) on the encoder β make sure the buffer order matches your kernel signature. - Calculate a 2D thread group configuration: use
kernel->maxTotalThreadsPerThreadgroup()to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K). - Dispatch with
dispatchThreadgroups.
You can test your implementation by running:
pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_3
Task 4: Model Integration
src/tiny_llm/qwen3_week2.py
Integrate your quantized matmul into the Week 2 Qwen3 model so that inference runs on quantized weights end-to-end.
Change the weight type from mx.array to QuantizedWeights for all linear layers in attention (wq/wk/wv/wo) and MLP (w_gate/w_up/w_down). Replace every linear(x, w) call with quantized_linear(x, w). In the model loading code, use QuantizedWeights.from_mlx_layer(...) to extract quantized weight information from each MLX linear layer, instead of calling mx.dequantize to get a full 16-bit matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain mx.array), while the Week 2 loader does not dequantize.
Qwen3 MLX quantized layers may use float16 or bfloat16 for the tensors involved in dequantization. Your kernel should accept scales, biases, and activations in either dtype, require them to match, and return the same dtype. If you see nan or garbage output, a dtype mismatch is the most likely cause.
Also keep the quantized layerβs parameters. The model code should pass through w.group_size and w.bits; the extension should validate that they match the Qwen3 course assumptions: group_size = 128 and bits = 4.
You can test your implementation by running:
pdm run main --solution tiny_llm --loader week2 --model qwen3-0.6b
You can also benchmark throughput and compare your implementation with the reference solution:
pdm bench --solution tiny_llm --loader week2 --model qwen3-0.6b
pdm bench --solution tiny_llm_ref --loader week2 --model qwen3-0.6b
Your feedback is greatly appreciated. Welcome to join our Discord Community.
Found an issue? Create an issue / pull request on github.com/skyzh/tiny-llm.
tiny-llm-book Β© 2025 by Alex Chi Z is licensed under CC BY-NC-SA 4.0.