TurboQuant Slice 4 — GPU Dequantization Scoping

Slices 1 through 1f got us from “TurboQuant doesn’t compile” to “ships 21× better quality at 4× compression as an opt-in flag.” The missing piece for --turboquant to be the default is performance. Right now our dequant lives on the CPU, and that’s fine for correctness validation but unshippable for production.
This page scopes Slice 4.
The current bottleneck
Section titled “The current bottleneck”Every call to CompressedKVCache::update_and_fetch, inside the decode loop:
step t: quantize new K,V → push to Vec<CompressedHeadKV>step t: rebuild_f32 → dequantize ALL T stored tokens, build flat Vec<f32>, ship to Metalstep t: Array::from_slice → fp32→bf16 cast on GPUThree problems, in priority order:
- O(T²) CPU work across a generation. Each step rebuilds every previously-stored token. At T=1000 that’s 500,000 per-vector dequants on the CPU.
- CPU↔GPU bounce every step. The dequant output is an f32 Vec in host memory, then
Array::from_sliceships it back to the GPU. Full round-trip per decode step. - No kernel fusion. Even with GPU dequant, the attention path is: dequant → cast → scaled-dot-product-attention. Two separate Metal dispatches. Ideal: one fused kernel that does dequant + Q·Kᵀ + softmax + ·V in one shot.
On an M4 Max at realistic context lengths (1k–4k tokens), the CPU dequant cost is the dominant factor. That’s why --turboquant ships with a bold warning about “dev-grade CPU dequant, expect throughput regression” in the server startup log.
The three tiers
Section titled “The three tiers”The practical production unlock. Replace every CPU-side dequant step with MLX operations, which run as Metal kernels automatically.
Storage refactor: CompressedKVCache stores mlx_rs::Array instead of Vec<u8> / Vec<half::f16>. Quantize produces Arrays directly; concatenate across tokens with mx.concatenate; dequantize with mx.astype + mx.multiply + mx.add (all Metal-backed).
Key path (rotation + per-vector affine):
on write: normalize_and_rotate :: Array → Array # matmul against H_d, O(d²) on GPU find scale/zero :: Array → (Array,Array) # mx.min, mx.max quantize :: Array → Array # (x - zero) / scale, round, clamp
on read: dequantize + unrotate :: Array → Array # reverse of the aboveNo per-token CPU loops, no Vec allocation, no host round-trip.
Value path: mlx_quantize and mlx_dequantize already exist as C FFI → Metal kernels. They do exactly our per-group affine scheme. Drop-in replacement for quantize_value_group / dequantize_value_group.
Estimated effort: 200–400 LOC net (replacement of cache.rs internals + removal of CPU vec storage + plumbing of Array↔Array in the KeyValueCache::update_and_fetch contract).
Risk: Medium-low. The algebra is unchanged; we’re porting CPU float loops to MLX Array ops that do exactly the same math. Main risks are (a) Array layout quirks we already burned ourselves on once (the flatten stride bug from Slice 1), and (b) dtype conversions introducing numerical drift vs the validated CPU path. Mitigation: re-run the 3-arm A/B and bench/ANALYSIS.md sanity checks before merging.
Wins:
- O(T) per step → O(1) amortized per step with proper concatenation.
- Zero CPU↔GPU round-trips in the decode loop.
- Expected throughput at 2k context: 3–6× current
--turboquantpath; within spitting distance of the fp16 reference.
Blocking concerns: None. MLX has all the primitives. This is a straightforward engineering port.
The mid-tier polish. Once 4a lands, wrap the dequant + SDPA stages in mx.compile so MLX fuses them into a single kernel dispatch.
MLX’s compile takes a function that calls MLX ops and returns an Array, then traces it into a graph and generates a single Metal kernel. Real fusion, no shader-authoring required. We’d wrap:
mx.compile(|cache_state, q| { let k = dequant_keys(cache_state); let v = dequant_values(cache_state); fast::scaled_dot_product_attention(q, k, v, scale, mask)})The compiled function becomes one kernel call: dequant + SDPA, no intermediate materialization of the full K/V tensor.
Estimated effort: 50–150 LOC on top of 4a. Mostly dealing with mx.compile’s shape requirements and making sure cache state flows in as compile-stable inputs.
Risk: Medium. mx.compile has footguns around shape variation (decode increments T every step), and the compiled graph may need re-tracing when T hits new sizes. Mitigation: profile first — if 4a is already “fast enough” (<2× fp16 overhead), 4b is polish, not necessary.
Wins:
- Single Metal dispatch per decode step for the attention block.
- Potentially 20–40% faster than 4a at small T.
The maximum-effort play. Write a Metal compute shader that takes quantized K/V buffers + rotation metadata + scales + zeros, and computes the full scaled-dot-product-attention inline without ever materializing the dequantized K or V tensors.
This is the version that kills both the O(T²) compute AND the memory bandwidth cost of transferring dequantized K/V through GPU memory. The kernel reads quantized bytes, dequantizes lane-local, multiplies by Q, softmaxes, and writes the attention output. Dequantized tensors never leave registers.
Estimated effort: 800–2000 LOC including:
- Metal shader source in
.metalfiles inside mlx-sys (or a new crate) - Build-system integration to compile the .metal into a .metallib
- Host-side Rust plumbing to dispatch the kernel with the right buffer bindings
- Correctness tests cross-verified against the MLX-ops 4a path
- Benchmarks
Risk: High. Metal shader authoring is an entirely new skill layer; tooling for debugging shaders on Apple Silicon is sparse; subtle correctness bugs can hide for weeks. Mitigation: don’t start 4c until 4a is shipped and we have a concrete profiling-derived reason the MLX-ops path isn’t enough.
Wins:
- Single Metal dispatch.
- Zero bandwidth spent on dequantized tensors.
- Plausibly 1.1–1.3× the fp16 reference speed (because the quantized K/V is smaller in DRAM than fp16 K/V, so memory-bound attention actually gets faster with good compression).
Blocking concerns:
- Apple hasn’t made Metal shader development a welcoming experience on M-series. Debugging tools are Xcode-bound and UI-heavy.
- MLX’s internal C++ contributes some of the Metal kernels it ships; adopting their approach (shader source embedded in Objective-C++, linked into the framework) is a build-system rabbit hole.
Recommended ordering
Section titled “Recommended ordering”- Ship 4a first. It gives us real-world production throughput and removes the “dev-grade” warning from the server log. Measurable quickly on a wikitext decode loop.
- Measure before 4b or 4c. Profile at 2k and 4k contexts. If 4a is within 2× of fp16, 4b and 4c become nice-to-haves, not necessities. If 4a is still 5× slower than fp16, 4b is the right next move.
- 4c only if 4b hits a wall.
mx.compileis Apple’s own answer to kernel fusion; building a bespoke Metal kernel to beat their compiler is a long bet with high carrying cost.
What 4a buys us (and when to stop)
Section titled “What 4a buys us (and when to stop)”The production gate on --turboquant becoming auto-on isn’t “fastest possible” — it’s “within 2× of fp16 throughput at acceptable context lengths.” If 4a clears that bar, the flag graduates to default-on for memory-constrained contexts (shared-hardware Mac Mini) and stays opt-in for throughput-dominated contexts (M4 Max with 128 GB).
Contracts we DON’T want to break
Section titled “Contracts we DON’T want to break”- Quality is already paid for. Slices 1a–1f landed at Δppl ≤ 0.01 on Qwen3.5, Δppl ≤ 0.17 on Qwen2. Any perf optimization that shifts those numbers by more than measurement noise is a bug, not a win.
- The
KeyValueCachetrait is a load-bearing boundary. Slice 4 changes the internals ofCompressedKVCache; it must not changeupdate_and_fetch’s signature or theKeyValueCache + Defaultbound that Qwen3.5 and Qwen2 model impls depend on. - Bypass modes stay.
BypassMode::{IdentityPassthrough, BridgeOnly}saved us from a stride bug on day one and saved us from a mislabeled A/B on day two. Whatever storage refactor 4a imposes, the bypass short-circuits must keep working.
Rough timeline
Section titled “Rough timeline”- 4a implementation: 1 focused day. Much of the surgery is in
cache.rs::update_and_fetch— rewriting the storage model from Vec-based to Array-based and teaching the two halves of the pipeline (write-side quantize, read-side dequantize) to speak in mlx_rs::Array the whole way through. - 4a validation: 2–3 hours. Re-run
turboquant_pplacross all validated configs, diff against current results. Any Δppl change > 0.01 means a correctness bug somewhere in the port. - 4a profiling + doc: half a day. Before/after latency + memory numbers, update
ANALYSIS.md, add a “GPU dequant” section toturboquant-kv-compression.mdx.
Total: ~2 days of focused work for the change that unlocks --turboquant default.
Where the work lives
Section titled “Where the work lives”services/sanctum-mlx/src/turboquant/cache.rs— the main targetservices/sanctum-mlx/src/turboquant/quantizer.rs— migratequantize_key_affine/quantize_value_groupto Array-based signaturesvendor/mlx-rs/mlx-rs/src/ops/— may need to expose mlx_quantize / mlx_dequantize in the Rust wrapper (C FFI exists; Rust safe wrapper may not)- New:
services/sanctum-mlx/bench/slice4_perf.jsonl— throughput measurements per config
No sanctum-rs commits needed for mlx-rs FFI changes until we verify whether mx.quantize is already wrapped. Quick check: grep -r 'pub fn quantize' vendor/mlx-rs/mlx-rs/src/ops/ — came up empty on the first look, so we’d need to add a thin safe wrapper around mlx_quantize/mlx_dequantize first.
References
Section titled “References”- MLX fast.scaled_dot_product_attention — the optimized kernel we’re feeding
- MLX quantize — the GPU primitive Slice 4a leans on
- mx.compile — Apple’s kernel fusion path for 4b
- Slice 1a–1f writeup:
services/sanctum-mlx/bench/ANALYSIS.md