Week 2 Day 2-3: Quantized Matmul
In this chapter, we will implement the quantized matrix multiplication. Quantization compresses model weights from 16-bit floating point to 4-bit integers, which is critical for efficient LLM serving on devices with limited memory bandwidth.
π Readings
- Model Compression and Quantization
- MLX Extensions Development Guide
- Quantized Matmul on CPU (Video)
- Quantized Matmul on GPU (Video)
Why Quantization?
As we learned in the KV Cache chapter, the decode phase of LLM inference is memory-bandwidth bound. Letβs revisit the arithmetic intensity calculation for the Qwen2-0.5B model:
Per-token computation in decode phase:
- Input: 1 token Γ 896 dimensions = 896 float16 values = 1.792 KB
- MLP weights: 896 Γ 4864 Γ 3 matrices Γ 2 bytes = ~25 MB per layer
- Attention weights: 896 Γ 896 Γ 4 matrices Γ 2 bytes = ~6 MB per layer
- Total weights per layer: ~31 MB
- Total for 24 layers: ~750 MB
FLOPs (2 per multiply-accumulate):
- MLP per layer: 2 Γ 3 Γ 896 Γ 4864 β 26M
- Attention per layer: 2 Γ 4 Γ 896 Γ 896 β 6.4M
- 24 layers: ~780 million per token
Memory access: ~750 MB
Arithmetic intensity: 780M FLOPs / 750 MB β 1.0 FLOPs/Byte
With M3 Maxβs 400 GB/s memory bandwidth and ~10 TFLOPS compute:
Memory-bound throughput: 400 GB/s Γ 1.0 FLOPs/Byte = 400 GFLOPS
Compute-bound throughput: 10 TFLOPS
We're using only ~4% of available compute!
The Solution: Quantization
By compressing weights from 16 bits (float16/bfloat16) to 4 bits (int4), we:
- Reduce memory bandwidth by 4Γ: 750 MB β ~190 MB per token
- Improve arithmetic intensity by 4Γ: 1.0 β ~4.0 FLOPs/Byte
- Increase throughput by ~4Γ: 400 GFLOPS β ~1.6 TFLOPS
The tradeoff is minimal accuracy loss with proper quantization techniques.
Group-wise Quantization
Instead of quantizing all weights uniformly, we divide them into groups and quantize each group independently. This preserves more information about the weight distribution.
For a weight matrix of shape , we divide each row into groups of size (typically 64 or 128):
Original weight matrix W: K Γ N (float16/bfloat16)
Group size G = 64
Number of groups per row = N / G
For each group of 64 consecutive values in a row:
1. Find min and max values
2. Compute scale and bias to map [min, max] β [0, 15] (4-bit range)
3. Quantize each value using: quantized = round((value - bias) / scale)
Affine Quantization
We use affine (asymmetric) quantization which maps a floating-point range to the full integer range:
For 4-bit quantization, the quantized values are in the range .
Given a group with minimum value and maximum value :
Example:
Group values: [-0.5, -0.3, 0.1, 0.4, 0.8]
min = -0.5, max = 0.8
scale = (0.8 - (-0.5)) / 15 = 1.3 / 15 β 0.0867
bias = -0.5
Quantization:
-0.5 β round((-0.5 - (-0.5)) / 0.0867) = 0
-0.3 β round((-0.3 - (-0.5)) / 0.0867) = 2
0.1 β round((0.1 - (-0.5)) / 0.0867) = 7
0.4 β round((0.4 - (-0.5)) / 0.0867) = 10
0.8 β round((0.8 - (-0.5)) / 0.0867) = 15
Quantized: [0, 2, 7, 10, 15] (4 bits each)
Storage Format
For efficient storage and computation, quantized weights are packed:
Original: K Γ N float16 (2 bytes each) = 2KN bytes
Quantized: K Γ N int4 (0.5 bytes each) = 0.5KN bytes
Packing: 8 Γ 4-bit values fit in one uint32 (32 bits)
Weight matrix shape: K Γ N
Quantized storage shape: K Γ (N / 8) uint32
Scales shape: K Γ (N / 64) float16
Biases shape: K Γ (N / 64) float16
Example packing for 8 consecutive 4-bit values [a, b, c, d, e, f, g, h]:
uint32_value = (h << 28) | (g << 24) | (f << 20) | (e << 16) |
(d << 12) | (c << 8) | (b << 4) | a
Unpacking:
a = (uint32_value >> 0) & 0xF
b = (uint32_value >> 4) & 0xF
c = (uint32_value >> 8) & 0xF
...
h = (uint32_value >> 28) & 0xF
Quantized Matrix Multiplication
Mathematical Formulation
For standard matrix multiplication where:
- : shape , float16/bfloat16 (activations)
- : shape , quantized to int4 (weights)
- : shape , float16/bfloat16 (output)
Each element is computed as:
With quantization, is represented as:
where is the group index.
Substituting:
Rearranging:
This shows we can factor out the scale and bias per group, reducing the number of floating-point operations.
Computation Flow
Input:
A: M Γ N (float16, activations)
B_quantized: K Γ (N/8) (uint32, packed weights)
scales: K Γ (N/64) (float16)
biases: K Γ (N/64) (float16)
Output:
C: M Γ K (float16)
For each output element C[i, k]:
sum = 0
for each group g in 0..(N/64 - 1):
scale = scales[k, g]
bias = biases[k, g]
# Process 64 values in the group (8 uint32 packs)
for each pack p in 0..7:
packed_value = B_quantized[k, g*8 + p]
# Unpack 8 Γ 4-bit values
for bit_offset in [0, 4, 8, 12, 16, 20, 24, 28]:
quantized = (packed_value >> bit_offset) & 0xF
b_value = quantized * scale + bias
a_value = A[i, g*64 + p*8 + bit_offset/4]
sum += a_value * b_value
C[i, k] = sum
Task 1: Implement QuantizedWeights
src/tiny_llm/quantize.py
First, familiarize yourself with the QuantizedWeights class, which stores quantized weight information:
| Field | Shape | Description |
|---|---|---|
weight | uint32 | Packed quantized weights. Each uint32 stores 8 consecutive 4-bit values. The original weight matrix has shape , and after packing, it becomes . |
scales | float16 | Per-group scale factors for dequantization. Each group of consecutive values shares one scale. Recall: |
biases | float16 | Per-group bias (offset) for dequantization. Recall: |
group_size | int | Number of consecutive values that share the same scale/bias (typically 64) |
bits | int | Quantization bit width (typically 4, meaning values are in range ) |
The from_mlx_layer static method extracts these fields from MLXβs quantized linear layers when loading the model.
Next, implement the quantized_linear function, which is a wrapper around quantized_matmul that mimics the standard linear function interface. And weβll implement quantized_matmul in the next task.
Task 2: Implement quantized_matmul (CPU version)
In this task, we will implement the quantized matmul as an MLX C++ extension. The pattern is identical to the existing axpby example in the codebase β read through axpby.h, axpby.cpp, and the corresponding binding in bindings.cpp first as your reference.
src/extensions/src/tiny_llm_ext.h
src/extensions/bindings.cpp
src/extensions/src/quantized_matmul.cpp
src/extensions/CMakeLists.txt
You need to touch three files, all within the tiny_llm_ext namespace:
tiny_llm_ext.hβ Declare thequantized_matmul(...)function signature and define aQuantizedMatmulprimitive class (inheritingmx::Primitive). Storegroup_sizeandbitsas private members.bindings.cppβ Add anm.def(...)call to expose the function to Python.quantized_matmul.cppβ Implement thequantized_matmul(...)function (validate inputs, compute output shape, return a lazymx::array) and theeval_cpumethod (allocate output, register arrays with the CPU encoder, dispatch the compute kernel).
The eval_cpu implementation follows the same CPU encoder pattern as axpby: allocate output memory with out.set_data(mx::allocator::malloc(out.nbytes())), register input/output arrays with the encoder, then dispatch a lambda that performs the actual computation. Inside the lambda, implement the nested loop from the Computation Flow section above β iterate over each output element (i, k), accumulate in float (fp32) to avoid precision loss, and cast the result back to float16 when writing to the output.
Donβt forget to add src/quantized_matmul.cpp to target_sources in CMakeLists.txt.
You can test your implementation by running:
pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_2
Task 3: Implement quantized_matmul (GPU version)
src/extensions/src/quantized_matmul.metal
src/extensions/src/quantized_matmul.cpp
In this task, you will write the Metal kernel for quantized matmul and wire up the eval_gpu method to dispatch it. Keep the math exactly the same as Task 2 (CPU); only the execution model changes.
Metal Kernel
You need to implement one kernel entry in quantized_matmul.metal:
- Use a one-thread-per-output-element mapping: each thread computes
out[i, k]. - The kernel should be templated on the data type (to support both
halfandbfloat16_t). - Apply the same group-wise dequantization loop as the CPU version:
- Iterate over groups (
group_size = 64) - Unpack int4 values from packed
uint32 - Dequantize with
q * scale + bias - Accumulate in
floatand cast to the output dtype at the end
- Iterate over groups (
- Add boundary checks (
i < M,k < K) before writing output.
GPU Dispatch
Complete the eval_gpu method in quantized_matmul.cpp to dispatch your Metal kernel. Follow the same pattern as axpbyβs GPU dispatch:
- Get the Metal device and command encoder from the stream.
- Select the correct kernel name based on the activation dtype (
float16βhalf,bfloat16βbfloat16_t). - Set input/output buffers and dimension constants (
M,N,K) on the encoder β make sure the buffer order matches your kernel signature. - Calculate a 2D thread group configuration: use
kernel->maxTotalThreadsPerThreadgroup()to determine the total threads, then split between the M and K dimensions (e.g., 32 threads for M, the rest for K). - Dispatch with
dispatchThreadgroups.
You can test your implementation by running:
pdm run build-ext
pdm run test --week 2 --day 2 -- -k task_3
Task 4: Model Integration
src/tiny_llm/qwen2_week2.py
Integrate your quantized matmul into the Week 2 Qwen2 model so that inference runs on quantized weights end-to-end.
Change the weight type from mx.array to QuantizedWeights for all linear layers in attention (wq/wk/wv/wo) and MLP (w_gate/w_up/w_down). Replace every linear(x, w) call with quantized_linear(x, w). In the model loading code, use QuantizedWeights.from_mlx_layer(...) to extract quantized weight information from each MLX linear layer, instead of calling mx.dequantize to get a full float16 matrix. Make sure the Week 1 loader still dequantizes (since Week 1 layers expect plain mx.array), while the Week 2 loader does not dequantize.
Note that MLX loads quantized models with scales and biases stored in bfloat16 by default, while the activation tensors are typically float16. Since we have not implemented bfloat16 support in our kernel, you will need to convert the scales and biases to float16 with mx.astype before calling the kernel. If you see nan or garbage output, a dtype mismatch is the most likely cause.
You can test your implementation by running:
pdm run main --solution tiny_llm --loader week2 --model qwen2-0.5b
You can also benchmark throughput and compare your implementation with the reference solution:
pdm bench --solution tiny_llm --loader week2 --model qwen2-0.5b
pdm bench --solution tiny_llm_ref --loader week2 --model qwen2-0.5b
Your feedback is greatly appreciated. Welcome to join our Discord Community.
Found an issue? Create an issue / pull request on github.com/skyzh/tiny-llm.
tiny-llm-book Β© 2025 by Alex Chi Z is licensed under CC BY-NC-SA 4.0.