Note
Go to the end to download the full example code.
Gemma 3 on JAX with FlashInfer and the JAX TVM FFI Bridge¶
Overview¶
Enabling FlashInfer GPU Kernels on JAX with the JAX TVM FFI Bridge built three FlashInfer kernels from scratch and wired them into JAX as XLA custom calls. This tutorial connects those same primitives to a real language model: Gemma 3 1B Instruct, Google’s open-weight instruction-tuned LLM.
Every Gemma 3 transformer layer uses the following kernels:
Part 1 kernel |
Role in Gemma 3 |
|---|---|
|
Gated FFN activation (GeGLU variant of SiLU-GLU) |
|
Query and key positional embeddings - with two different theta values |
|
Attention over the growing KV-cache - two compiled variants |
Three things are new compared to Part 1:
GeGLU instead of SiLU-GLU - Gemma 3 uses
gelu_tanhfor its gated FFN; FlashInfer ships this as a one-word change fromsilu.QK-norm - per-head RMSNorm applied to Q and K before computing dot products, replacing the logit soft-capping that Gemma 2 used.
Dual RoPE theta - local-attention layers use theta = 10 000; global-attention layers use theta = 1 000 000. We select the right value per layer and pass it to
apply_rope.
Preliminaries¶
Hardware and software requirements¶
GPU |
NVIDIA, SM 7.5+ (Turing or later) |
Python packages |
|
HuggingFace |
Account with Gemma 3 licence accepted - request access |
Setting the environment¶
If you haven’t gone through Enabling FlashInfer GPU Kernels on JAX with the JAX TVM FFI Bridge, refer to it for the JAX and FlashInfer installation instructions.
Four additional packages are required:
Package |
Role |
|---|---|
|
Provides |
|
Efficient loading of model weights from the HuggingFace format |
|
Model download from the HuggingFace Hub |
|
Tokenizer and chat-template formatting |
Run the cell below only once in your environment.
Install the tutorial dependencies before running the notebook or script:
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install safetensors huggingface_hub transformers
Loading dependencies¶
Run the cell below to load the dependencies.
import json
import math
import os
import time
import subprocess
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress TF/XLA info & warnings
import importlib.util
IN_COLAB = importlib.util.find_spec("google.colab") is not None
if "CUDA_HOME" not in os.environ:
try:
nvcc = subprocess.check_output(["which", "nvcc"], text=True).strip()
os.environ["CUDA_HOME"] = str(os.path.dirname(os.path.dirname(nvcc)))
except subprocess.CalledProcessError:
os.environ["CUDA_HOME"] = "/usr/local/cuda"
if "--xla_gpu_cuda_data_dir=" not in os.environ.get("XLA_FLAGS", ""):
os.environ["XLA_FLAGS"] = (
f"{os.environ.get('XLA_FLAGS', '')} "
f"--xla_gpu_cuda_data_dir={os.environ['CUDA_HOME']}"
).strip()
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download, HfApi
from safetensors import safe_open
import jax_tvm_ffi
print(f"JAX: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"CUDA home: {os.environ['CUDA_HOME']}")
HuggingFace Authentication¶
Gemma 3 is a gated model. Before downloading the weights, you need to accept the licence on the HuggingFace model page - visit google/gemma-3-1b-it and click Request access.
The cell below reads your token from the HF_TOKEN environment variable, falls back to the Colab Secrets API if running on Colab, or prompts interactively.
– HuggingFace Authentication ————————————————– Accept the Gemma 3 license at: https://huggingface.co/google/gemma-3-1b
if IN_COLAB:
from google.colab import userdata
HF_TOKEN = userdata.get("HF_TOKEN")
else:
HF_TOKEN = os.environ.get("HF_TOKEN", "")
if not HF_TOKEN:
from getpass import getpass
HF_TOKEN = getpass(
"Hugging Face token not found in environment. Please enter it here: "
)
if not HF_TOKEN:
raise RuntimeError("Authentication failed: Hugging Face token is not set.")
# Ensure token is set in this process
os.environ["HF_TOKEN"] = HF_TOKEN
# Verify identity
api = HfApi()
user_info = api.whoami(token=HF_TOKEN)
username = user_info.get("name") or "Unknown user"
print(f"Authenticated with Hugging Face successfully as: {username}")
Downloading the model weights¶
The cell below downloads all model shards (~2 GB on first run) from the HuggingFace Hub, loads them as bfloat16 JAX arrays, and instantiates the tokenizer. Weights are cached locally; subsequent runs skip the download.
MODEL_ID = "google/gemma-3-1b-it"
HF_CACHE = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
# -- Tokenizer ------------------------------------------------------------------
print(f"Loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, cache_dir=HF_CACHE)
# -- Weights --------------------------------------------------------------------
print("Downloading model weights (~2 GB on first run)...")
model_dir = snapshot_download(MODEL_ID, token=HF_TOKEN, cache_dir=f"{HF_CACHE}/hub")
# Weights are split across shards - discover the full list from the index file
index_path = os.path.join(model_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
with open(index_path) as f:
shard_files = sorted(set(json.load(f)["weight_map"].values()))
else:
shard_files = ["model.safetensors"]
print(f"Loading {len(shard_files)} shard(s) as JAX bfloat16 arrays...")
weights = {}
for shard in shard_files:
with safe_open(os.path.join(model_dir, shard), framework="numpy") as f:
for key in f.keys():
# jnp.array handles any numpy dtype (float32, bfloat16, ...) -> bfloat16
weights[key] = jnp.array(f.get_tensor(key), dtype=jnp.bfloat16)
n_params = sum(int(w.size) for w in weights.values())
print(f"Loaded {len(weights)} tensors ({n_params / 1e9:.2f} B parameters)")
print(f" embed_tokens: {weights['model.embed_tokens.weight'].shape}")
print(
f" layer 0 q_proj: {weights['model.layers.0.self_attn.q_proj.weight'].shape}"
)
print(
f" layer 0 q_norm: {weights['model.layers.0.self_attn.q_norm.weight'].shape}"
)
print(f" layer 0 gate_proj: {weights['model.layers.0.mlp.gate_proj.weight'].shape}")
Gemma 3 Transformer Layer¶
Each Gemma 3 1B layer has a sandwich-norm structure: RMSNorm before and after each sub-layer.
-- Prefill (prompt, T tokens in parallel) ------------------------------------
x --+-- RMSNorm (input_layernorm) ------------------------------------------+
| Q, K, V <- linear projections |
| Q, K <- QK-norm (per-head RMSNorm, new in Gemma 3) |
| Q, K <- `apply_rope` (theta = local / global, per layer) |
| out <- `prefill_attention` (causal, local or global) |
| RMSNorm (post_attention_layernorm) |
+-- + -------------------------------------------------------------------+
-- Decode (one new token at a time) ------------------------------------------
x --+-- RMSNorm (input_layernorm) ------------------------------------------+
| Q, K, V <- linear projections |
| Q, K <- QK-norm (per-head RMSNorm, new in Gemma 3) |
| Q, K <- `apply_rope` (theta = local / global, per layer) |
| K, V -> KV-cache append |
| out <- `decode_attention` (local sliding-window or global) |
| RMSNorm (post_attention_layernorm) |
+-- + -------------------------------------------------------------------+
-- Shared FFN (same code for prefill and decode) -----------------------------
x --+-- RMSNorm (pre_feedforward_layernorm) ----------------------------------+
| gate, up <- separate linear projections |
| hidden <- `gelu_tanh_and_mul`( concat(gate, up) ) |
| out <- down_proj(hidden) |
| RMSNorm (post_feedforward_layernorm) |
+-- + --------------------------------------------------------------------+
Local vs global attention - the 5:1 pattern¶
Gemma 3 alternates attention span in a repeating 5:1 pattern:
Layer type |
Frequency |
Attends to |
RoPE theta |
|---|---|---|---|
Local |
5 of every 6 |
Last |
|
Global |
1 of every 6 |
Full KV-cache |
|
For Gemma 3 1B (26 layers): layers 5, 11, 17, 23 are global; the remaining 22 are local. Exact values for window size and theta are read from config.json in the next cell.
with open(os.path.join(model_dir, "config.json")) as _f:
_raw = json.load(_f)
# Gemma 3 wraps the language model config under "text_config" in its multimodal JSON
cfg = _raw.get("text_config", _raw)
HIDDEN = cfg["hidden_size"]
INTERMEDIATE = cfg["intermediate_size"]
N_LAYERS = cfg["num_hidden_layers"]
N_Q = cfg["num_attention_heads"]
N_KV = cfg["num_key_value_heads"]
HEAD_DIM = cfg.get("head_dim", HIDDEN // N_Q)
VOCAB = cfg["vocab_size"]
RMS_EPS = cfg.get("rms_norm_eps", 1e-6)
SLIDING_WINDOW = cfg.get("sliding_window", 1024)
SM_SCALE = 1.0 / math.sqrt(HEAD_DIM)
# Dual RoPE theta: local layers use a small base, global layers use a large base.
# Gemma 3 stores these as rope_local_base_freq (local) and rope_theta (global).
ROPE_THETA_LOCAL = int(cfg.get("rope_local_base_freq", 10_000))
ROPE_THETA_GLOBAL = int(cfg.get("rope_theta", 1_000_000))
def is_global(layer_idx: int) -> bool:
"""True for global (full-attention) layers (Gemma 3 1B: 5, 11, 17, 23)."""
return (layer_idx + 1) % 6 == 0
print("Architecture loaded from config.json:")
print(f" hidden={HIDDEN}, intermediate={INTERMEDIATE}, layers={N_LAYERS}")
print(f" N_Q={N_Q}, N_KV={N_KV}, head_dim={HEAD_DIM} (GQA ratio {N_Q // N_KV}x)")
print(f" vocab={VOCAB}, rms_eps={RMS_EPS}")
print(f" sliding_window={SLIDING_WINDOW}")
print(
f" rope_theta_local={ROPE_THETA_LOCAL:,}, rope_theta_global={ROPE_THETA_GLOBAL:,}"
)
print()
print(f"{'Layer':>5} {'Type':>8} {'RoPE theta':>12} {'Window':>8}")
print("-" * 42)
for i in range(N_LAYERS):
kind = "global" if is_global(i) else "local"
theta = ROPE_THETA_GLOBAL if is_global(i) else ROPE_THETA_LOCAL
window = "full" if is_global(i) else f"{SLIDING_WINDOW:,}"
print(f"{i:>5} {kind:>8} {theta:>12,} {window:>8}")
Concept 1: GeGLU - gelu_tanh replaces silu¶
Part 1 used FlashInfer’s silu_and_mul kernel. Gemma 3 swaps the gate activation:
SiLU-GLU (Llama, Gemma 2): out = silu(gate) * up
GeGLU (Gemma 3): out = gelu_tanh(gate) * up
where gelu_tanh is the tanh-approximated GELU, matching torch.nn.functional.gelu(x, approximate="tanh").
FlashInfer ships all three variants - silu, gelu, gelu_tanh - through the same gen_act_and_mul_module interface. Switching from Part 1 is a one-word change:
# Part 1
silu_module = gen_act_and_mul_module('silu').build_and_load()
# Part 2 - Gemma 3
gelu_module = gen_act_and_mul_module('gelu_tanh').build_and_load()
Everything else - the three-step bridge pattern, the wrapper, the ffi_call shape declaration - is identical.
Compile and register all kernels¶
All four kernel pairs - gelu_tanh_and_mul, apply_rope, and local/global variants of both decode and prefill attention - are compiled and registered in a single cell below. Concepts 2 (QK-norm) and 3 (dual RoPE theta) are explained in the sections that follow; they require no additional kernels beyond what Part 1 introduced.
import torch as _torch # used only for dtype spec in gen_single_*_module
from flashinfer.jit import (
gen_act_and_mul_module,
gen_single_decode_module,
gen_single_prefill_module,
)
from flashinfer.jit.rope import gen_rope_module
# -- 1. gelu_tanh_and_mul ------------------------------------------------------
print("Compiling gelu_tanh_and_mul...")
_gelu_mod = gen_act_and_mul_module("gelu_tanh").build_and_load()
def _gelu_wrapper(out, x, enable_pdl):
_gelu_mod.gelu_tanh_and_mul(out, x, enable_pdl)
jax_tvm_ffi.register_ffi_target(
"flashinfer.gelu_tanh_and_mul",
_gelu_wrapper,
arg_spec=["rets", "args", "attrs.enable_pdl"],
platform="gpu",
allow_cuda_graph=True,
pass_owned_tensor=True,
)
def gelu_and_mul(x: jax.Array) -> jax.Array:
"""Fused gelu_tanh(gate) * up. Input: [..., 2H] Output: [..., H]"""
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
return jax.ffi.ffi_call(
"flashinfer.gelu_tanh_and_mul",
jax.ShapeDtypeStruct(out_shape, x.dtype),
vmap_method="broadcast_all", # element-wise op: independent across any batch dim
)(x, enable_pdl=False)
# -- 2. apply_rope -------------------------------------------------------------
print("Compiling apply_rope...")
_rope_mod = gen_rope_module().build_and_load()
def _rope_wrapper(
q_rope,
k_rope,
q,
k,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
):
_rope_mod.apply_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
)
jax_tvm_ffi.register_ffi_target(
"flashinfer.apply_rope",
_rope_wrapper,
arg_spec=[
"rets",
"args",
"attrs.rotary_dim",
"attrs.interleave",
"attrs.rope_scale",
"attrs.rope_theta",
],
platform="gpu",
allow_cuda_graph=True,
pass_owned_tensor=True,
)
def apply_rope(q, k, indptr, offsets, rope_theta=1e4):
"""Apply RoPE to packed batches. Returns (q_rope, k_rope)."""
return jax.ffi.ffi_call(
"flashinfer.apply_rope",
(
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
),
vmap_method="broadcast_all", # each packed batch is independent; safe to vmap
)(
q,
k,
indptr,
offsets,
rotary_dim=q.shape[-1],
interleave=False,
rope_scale=1.0,
rope_theta=float(rope_theta),
)
# -- 3. decode_attention: local + global variants -------------------------------
_TMP_ELEMS = 32 * 1024 * 1024 // 2 # 32 MB scratch buffer in bfloat16 elements
print(f"Compiling decode attention (local, sliding-window={SLIDING_WINDOW})...")
_local_dec_mod = gen_single_decode_module(
_torch.bfloat16,
_torch.bfloat16,
_torch.bfloat16,
HEAD_DIM,
HEAD_DIM,
0,
True,
False,
).build_and_load()
print("Compiling decode attention (global, full attention)...")
_global_dec_mod = gen_single_decode_module(
_torch.bfloat16,
_torch.bfloat16,
_torch.bfloat16,
HEAD_DIM,
HEAD_DIM,
0,
False,
False,
).build_and_load()
def _make_decode_wrapper(run_fn):
def _w(
out,
tmp,
lse_or_empty,
q,
k,
v,
alibi_or_empty,
layout,
window_left,
logits_soft_cap,
sm_scale,
rope_rcp_scale,
rope_rcp_theta,
):
lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty
alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty
run_fn(
q,
k,
v,
tmp,
out,
lse,
layout,
window_left,
alibi,
logits_soft_cap,
sm_scale,
rope_rcp_scale,
rope_rcp_theta,
)
return _w
_DEC_ARG_SPEC = [
"rets",
"args",
"attrs.layout",
"attrs.window_left",
"attrs.logits_soft_cap",
"attrs.sm_scale",
"attrs.rope_rcp_scale",
"attrs.rope_rcp_theta",
]
_KW = dict(platform="gpu", allow_cuda_graph=True, pass_owned_tensor=True)
jax_tvm_ffi.register_ffi_target(
"flashinfer.decode_local",
_make_decode_wrapper(_local_dec_mod.run),
_DEC_ARG_SPEC,
**_KW,
)
jax_tvm_ffi.register_ffi_target(
"flashinfer.decode_global",
_make_decode_wrapper(_global_dec_mod.run),
_DEC_ARG_SPEC,
**_KW,
)
def decode_attention(q, k_cache, v_cache, global_layer=False):
"""Single-request GQA decode attention.
q: [N_Q, HEAD_DIM] bfloat16
k_cache: [seq_len, N_KV, HEAD_DIM] bfloat16
v_cache: [seq_len, N_KV, HEAD_DIM] bfloat16
Returns: [N_Q, HEAD_DIM]
"""
target = "flashinfer.decode_global" if global_layer else "flashinfer.decode_local"
window = -1 if global_layer else SLIDING_WINDOW
out, _, _ = jax.ffi.ffi_call(
target,
(
jax.ShapeDtypeStruct(q.shape, jnp.bfloat16),
jax.ShapeDtypeStruct((_TMP_ELEMS,), jnp.bfloat16),
jax.ShapeDtypeStruct((0,), jnp.float32),
),
# vmap_method intentionally omitted: the scratch buffer has no batch
# dimension, and GQA head-grouping does not decompose over an outer batch axis.
)(
q,
k_cache,
v_cache,
jnp.empty((0,), dtype=jnp.float32),
layout=0,
window_left=window,
logits_soft_cap=0.0,
sm_scale=SM_SCALE,
rope_rcp_scale=1.0,
rope_rcp_theta=1.0,
)
return out
# -- 4. prefill_attention: local + global variants ------------------------------
print(f"Compiling prefill attention (local, sliding-window={SLIDING_WINDOW})...")
_local_pre_mod = gen_single_prefill_module(
"fa2",
_torch.bfloat16,
_torch.bfloat16,
_torch.bfloat16,
HEAD_DIM,
HEAD_DIM,
0,
True,
False,
False,
).build_and_load()
print("Compiling prefill attention (global, full attention)...")
_global_pre_mod = gen_single_prefill_module(
"fa2",
_torch.bfloat16,
_torch.bfloat16,
_torch.bfloat16,
HEAD_DIM,
HEAD_DIM,
0,
False,
False,
False,
).build_and_load()
def _make_prefill_wrapper(run_fn):
def _w(
out,
tmp,
lse_or_empty,
q,
k,
v,
alibi_or_empty,
mask_mode_code,
layout,
window_left,
logits_soft_cap,
sm_scale,
rope_rcp_scale,
rope_rcp_theta,
):
lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty
alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty
run_fn(
q,
k,
v,
tmp,
out,
lse,
mask_mode_code,
layout,
window_left,
None,
alibi,
logits_soft_cap,
sm_scale,
rope_rcp_scale,
rope_rcp_theta,
)
return _w
_PRE_ARG_SPEC = [
"rets",
"args",
"attrs.mask_mode_code",
"attrs.layout",
"attrs.window_left",
"attrs.logits_soft_cap",
"attrs.sm_scale",
"attrs.rope_rcp_scale",
"attrs.rope_rcp_theta",
]
jax_tvm_ffi.register_ffi_target(
"flashinfer.prefill_local",
_make_prefill_wrapper(_local_pre_mod.run),
_PRE_ARG_SPEC,
**_KW,
)
jax_tvm_ffi.register_ffi_target(
"flashinfer.prefill_global",
_make_prefill_wrapper(_global_pre_mod.run),
_PRE_ARG_SPEC,
**_KW,
)
def prefill_attention(q, k, v, layer_i):
"""FlashInfer causal GQA attention for processing a multi-token prompt.
Uses the same kernel-bridge pattern as decode_attention: mask_mode_code=1
selects causal masking; window_left controls the sliding-window cut-off.
q, k, v: [T, n_heads, HEAD_DIM] bfloat16
Returns: [T, N_Q, HEAD_DIM] bfloat16
"""
glob = is_global(layer_i)
target = "flashinfer.prefill_global" if glob else "flashinfer.prefill_local"
window = -1 if glob else SLIDING_WINDOW
out, _, _ = jax.ffi.ffi_call(
target,
(
jax.ShapeDtypeStruct(q.shape, jnp.bfloat16), # output
jax.ShapeDtypeStruct((_TMP_ELEMS,), jnp.bfloat16), # tmp scratch
jax.ShapeDtypeStruct((0,), jnp.float32),
), # lse (discard)
# vmap_method intentionally omitted: scratch buffer + GQA head-grouping
# do not decompose over an outer batch axis.
)(
q,
k,
v,
jnp.empty((0,), dtype=jnp.float32), # alibi = empty sentinel
mask_mode_code=1,
layout=0,
window_left=window,
logits_soft_cap=0.0,
sm_scale=SM_SCALE,
rope_rcp_scale=1.0,
rope_rcp_theta=1.0,
)
return out
print("All kernels compiled and registered.")
Concept 2: QK-norm - per-head normalization¶
Gemma 2 bounded attention score magnitudes with logit soft-capping: scores = tanh(scores / 50) x 50. Gemma 3 replaces this with QK-norm: an RMSNorm applied independently to each query and key head after the linear projection and before the dot product.
# Gemma 2 (inside the attention kernel, via logits_soft_cap parameter)
scores = tanh(q @ k.T / sqrt(d) / 50) * 50
# Gemma 3 (in JAX, before calling decode_attention)
q = rms_norm_per_head(q, q_norm_weight) # [N_Q, head_dim]
k = rms_norm_per_head(k, k_norm_weight) # [N_KV, head_dim]
scores = q @ k.T / sqrt(d) # bounded by weight norms
The norm weights q_norm.weight and k_norm.weight have shape [head_dim] - the same weight is shared across all heads. In the model state dict they are model.layers.{i}.self_attn.q_norm.weight.
Concept 3: Dual RoPE theta - one theta per attention scope¶
Standard RoPE uses a single base frequency theta. Gemma 3 uses two:
Layer type |
theta |
Why |
|---|---|---|
Local (5/6 of layers) |
10 000 |
Standard positional bias for the 1 024-token window |
Global (1/6 of layers) |
1 000 000 |
Slower-decaying frequencies for long-range context |
In code this is a single if in the layer function - we select the right theta and pass it to apply_rope as a scalar attribute. The kernel is compiled once; the theta value is a runtime parameter.
– Pure-JAX building blocks —————————————————
@jax.jit
def rms_norm(x, weight, eps=RMS_EPS):
"""Gemma-style RMSNorm: normalise then scale by (1 + weight)."""
x32 = x.astype(jnp.float32)
y = x32 * jax.lax.rsqrt(jnp.mean(x32**2, axis=-1, keepdims=True) + eps)
return y.astype(x.dtype) * (1.0 + weight)
@jax.jit
def qk_norm(x, weight):
"""Per-head RMSNorm for Q or K vectors. x: [..., head_dim]."""
return rms_norm(x, weight)
def embed(token_ids):
"""Embedding lookup. Gemma multiplies by sqrt(hidden_size) to keep
hidden-state norms stable through the first RMSNorm.
token_ids: [T] -> [T, HIDDEN]
"""
return weights["model.embed_tokens.weight"][token_ids] * math.sqrt(HIDDEN)
def lm_head(h):
"""Project hidden state to vocabulary logits. h: [HIDDEN] -> [VOCAB] float32."""
# Gemma 3 ties the LM head to the embedding matrix
lm_w = weights.get("lm_head.weight", weights["model.embed_tokens.weight"])
return h.astype(jnp.float32) @ lm_w.astype(jnp.float32).T
def ffn(h, layer_i):
"""GeGLU feed-forward block. h: [..., HIDDEN] -> [..., HIDDEN].
Handles both single-token decode (h: [HIDDEN]) and
full-sequence prefill (h: [T, HIDDEN]) with the same code.
"""
pre = rms_norm(
h, weights[f"model.layers.{layer_i}.pre_feedforward_layernorm.weight"]
)
gate = (
pre @ weights[f"model.layers.{layer_i}.mlp.gate_proj.weight"].T
) # [..., INTER]
up = pre @ weights[f"model.layers.{layer_i}.mlp.up_proj.weight"].T # [..., INTER]
# Concatenate along the last axis: gelu_and_mul splits it back in two
gate_up = jnp.concatenate([gate, up], axis=-1) # [..., 2*INTER]
hidden = gelu_and_mul(gate_up) # [..., INTER] <- FlashInfer kernel
out = hidden @ weights[f"model.layers.{layer_i}.mlp.down_proj.weight"].T
out = rms_norm(
out, weights[f"model.layers.{layer_i}.post_feedforward_layernorm.weight"]
)
return out
# Quick sanity check on the FFN
_x_test = jax.random.normal(jax.random.key(0), (HIDDEN,), dtype=jnp.bfloat16)
_out = ffn(_x_test, 0)
print(
f"FFN layer 0: {_x_test.shape} -> {_out.shape} dtype={_out.dtype} ok={not jnp.any(jnp.isnan(_out))}"
)
Prefill layer and full forward pass¶
prefill_layer processes all prompt tokens in parallel through one transformer layer and builds the initial KV-cache. prefill chains it across all 26 layers, then applies rms_norm to the last token’s hidden state and returns the per-layer KV-caches that the decode loop will update.
– Prefill layer (processes all T prompt tokens in parallel) ——————
def prefill_layer(h, layer_i):
"""Run one transformer layer over the full prompt.
h: [T, HIDDEN] bfloat16
Returns: (h: [T, HIDDEN], kv_cache: (k: [T, N_KV, D], v: [T, N_KV, D]))
"""
T = h.shape[0]
glob = is_global(layer_i)
rope_theta = ROPE_THETA_GLOBAL if glob else ROPE_THETA_LOCAL
# -- Attention -------------------------------------------------------------
ln = rms_norm(h, weights[f"model.layers.{layer_i}.input_layernorm.weight"])
q = (ln @ weights[f"model.layers.{layer_i}.self_attn.q_proj.weight"].T).reshape(
T, N_Q, HEAD_DIM
)
k = (ln @ weights[f"model.layers.{layer_i}.self_attn.k_proj.weight"].T).reshape(
T, N_KV, HEAD_DIM
)
v = (ln @ weights[f"model.layers.{layer_i}.self_attn.v_proj.weight"].T).reshape(
T, N_KV, HEAD_DIM
)
# QK-norm (per head, same weight across all token positions)
q = qk_norm(q, weights[f"model.layers.{layer_i}.self_attn.q_norm.weight"])
k = qk_norm(k, weights[f"model.layers.{layer_i}.self_attn.k_norm.weight"])
# RoPE over all T tokens at once: one sequence starting at offset 0
indptr = jnp.array([0, T], dtype=jnp.int32)
offsets = jnp.array([0], dtype=jnp.int32)
q, k = apply_rope(q, k, indptr, offsets, rope_theta=rope_theta)
# FlashInfer causal attention
attn_out = prefill_attention(q, k, v, layer_i)
attn_out = attn_out.reshape(T, N_Q * HEAD_DIM)
attn_out = attn_out @ weights[f"model.layers.{layer_i}.self_attn.o_proj.weight"].T
attn_out = rms_norm(
attn_out, weights[f"model.layers.{layer_i}.post_attention_layernorm.weight"]
)
h = h + attn_out
# -- FFN (works naturally for [T, HIDDEN]) --------------------------------
h = h + ffn(h, layer_i)
# KV-cache: store the RoPE-applied K and raw V for all prompt positions
return h, (k, v)
# -- Full prefill pass ---------------------------------------------------------
def prefill(prompt_ids):
"""Process the full prompt. Returns (h_last: [HIDDEN], kv_caches)."""
h = embed(jnp.array(prompt_ids)) # [T, HIDDEN]
kv_caches = []
for i in range(N_LAYERS):
h, kv_cache = prefill_layer(h, i)
kv_caches.append(kv_cache)
# Final norm applied to the last token's hidden state
h_last = rms_norm(h[-1], weights["model.norm.weight"]) # [HIDDEN]
return h_last, kv_caches
Decode attention layer¶
decode_layer processes one newly generated token through a full transformer layer. It applies QK-norm, selects the layer’s RoPE theta, calls apply_rope and decode_attention, appends to the KV-cache, and returns the updated hidden state.
– Decode attention layer (one new token, growing KV-cache) ——————-
def decode_layer(h, layer_i, kv_cache, pos):
"""Process a single new token through one transformer layer.
h: [HIDDEN] bfloat16
kv_cache: (k: [pos, N_KV, D], v: [pos, N_KV, D])
pos: current token's position in the full sequence (Python int)
Returns: (h: [HIDDEN], updated_kv_cache)
"""
glob = is_global(layer_i)
rope_theta = ROPE_THETA_GLOBAL if glob else ROPE_THETA_LOCAL
# -- Attention -------------------------------------------------------------
ln = rms_norm(h, weights[f"model.layers.{layer_i}.input_layernorm.weight"])
q = (ln @ weights[f"model.layers.{layer_i}.self_attn.q_proj.weight"].T).reshape(
N_Q, HEAD_DIM
)
k = (ln @ weights[f"model.layers.{layer_i}.self_attn.k_proj.weight"].T).reshape(
N_KV, HEAD_DIM
)
v = (ln @ weights[f"model.layers.{layer_i}.self_attn.v_proj.weight"].T).reshape(
N_KV, HEAD_DIM
)
# QK-norm (Concept 2: Gemma 3 replaces soft-capping with per-head RMSNorm)
q = qk_norm(q, weights[f"model.layers.{layer_i}.self_attn.q_norm.weight"])
k = qk_norm(k, weights[f"model.layers.{layer_i}.self_attn.k_norm.weight"])
# apply_rope with the layer's theta (Concept 3: different theta for local vs global)
q_pack, k_pack = q[None], k[None] # [1, heads, D] packed batch of 1 token
indptr = jnp.array([0, 1], dtype=jnp.int32)
offsets = jnp.array([pos], dtype=jnp.int32)
q_r, k_r = apply_rope(q_pack, k_pack, indptr, offsets, rope_theta=rope_theta)
q_r = q_r.squeeze(0) # [N_Q, D]
k_r = k_r.squeeze(0) # [N_KV, D]
# Append RoPE'd K and raw V to the KV-cache
# NOTE: Using jnp.concatenate to grow KV cache is intentional.
# In standard JAX this is inefficient (O(N^2)) and you'd normally preallocate
# and use lax.dynamic_update_slice. However, FlashInfer's single-request
# decode kernel infers sequence length from k_cache/v_cache.shape.
# Therefore we must keep the cache length equal to the actual number of tokens.
# Switching to a fixed-size buffer would require a different FlashInfer API
# (e.g. paged KV cache) or an explicit length/mask.
k_cache, v_cache = kv_cache
k_cache = jnp.concatenate([k_cache, k_r[None]], axis=0) # [pos+1, N_KV, D]
v_cache = jnp.concatenate([v_cache, v[None]], axis=0) # [pos+1, N_KV, D]
# Decode attention over the full KV-cache (FlashInfer kernel)
attn_out = decode_attention(q_r, k_cache, v_cache, global_layer=glob) # [N_Q, D]
attn_out = attn_out.reshape(N_Q * HEAD_DIM)
attn_out = attn_out @ weights[f"model.layers.{layer_i}.self_attn.o_proj.weight"].T
attn_out = rms_norm(
attn_out, weights[f"model.layers.{layer_i}.post_attention_layernorm.weight"]
)
h = h + attn_out
# -- FFN -------------------------------------------------------------------
h = h + ffn(h, layer_i)
return h, (k_cache, v_cache)
Decode step¶
Why there is no ``@jax.jit`` here
The FlashInfer kernels (
decode_attention,apply_rope, …) are fully@jax.jit-compatible XLA custom calls. The obstacle is the KV-cache. Each decode step appends one new row:
`python k_cache = jnp.concatenate([k_cache, k_r[None]], axis=0) # shape grows every step `
@jax.jitrequires statically known output shapes. Becausek_cache.shape[0]increments at every step, XLA would have to recompiledecode_stepon each call - far more expensive than running eagerly.A production system fixes this by pre-allocating a maximum-length cache and writing into it with
jax.lax.dynamic_update_slice, which keeps shapes static and allows the entire decode loop to be compiled withjax.lax.scan. That is the paged KV-cache direction described in the Summary.
– One decode step ———————————————————–
def decode_step(token_id, kv_caches, pos):
"""Process one newly generated token and predict the next.
token_id: int the most recently produced token
kv_caches: list one (k, v) tuple per layer
pos: int this token's position in the full sequence
Returns: (logits: [VOCAB] float32, updated_kv_caches)
"""
h = embed(jnp.array([token_id])).squeeze(0) # [HIDDEN]
new_kv = []
for i in range(N_LAYERS):
h, kv = decode_layer(h, i, kv_caches[i], pos)
new_kv.append(kv)
h = rms_norm(h, weights["model.norm.weight"])
logits = lm_head(h)
return logits, new_kv
# -- Stop tokens ---------------------------------------------------------------
# Gemma instruct ends its turn with <end_of_turn>, not the generic <eos>.
# Collect all token IDs that should halt generation.
_STOP_IDS = {tokenizer.eos_token_id} if tokenizer.eos_token_id is not None else set()
for _tok in ["<end_of_turn>", "<eos>"]:
_id = tokenizer.convert_tokens_to_ids(_tok)
if _id is not None and _id != tokenizer.unk_token_id:
_STOP_IDS.add(_id)
# -- Text generation -----------------------------------------------------------
def generate(prompt, max_new_tokens=200, temperature=0.7, seed=0):
"""Autoregressive generation with the Gemma 3 instruct chat template."""
messages = [{"role": "user", "content": prompt}]
# Render chat template to plain text first.
rendered = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Then tokenize explicitly and extract only input_ids.
enc = tokenizer(rendered, add_special_tokens=False)
prompt_ids = enc["input_ids"]
# Flatten batch dimension if present.
if len(prompt_ids) > 0 and isinstance(prompt_ids[0], list):
prompt_ids = prompt_ids[0]
T = len(prompt_ids)
key = jax.random.key(seed)
print(f"Prompt ({T} tokens): {prompt!r}")
print(f"Rendered prompt preview: {rendered[:120]!r}")
print("Prefilling...", end=" ", flush=True)
t0 = time.perf_counter()
h_last, kv_caches = prefill(prompt_ids)
jax.block_until_ready(h_last)
print(f"{time.perf_counter() - t0:.1f}s")
def _sample(logits, key):
if temperature == 0.0:
return int(jnp.argmax(logits)), key
key, subkey = jax.random.split(key)
return int(jax.random.categorical(subkey, logits / temperature)), key
print("Response: ", end="", flush=True)
generated = []
for step in range(max_new_tokens):
if step == 0:
logits = lm_head(h_last)
else:
logits, kv_caches = decode_step(generated[-1], kv_caches, T + step - 1)
next_tok, key = _sample(logits, key)
generated.append(next_tok)
if next_tok in _STOP_IDS:
break
print(tokenizer.decode([next_tok]), end="", flush=True)
print()
return tokenizer.decode(generated, skip_special_tokens=True)
Running inference¶
The cell below runs the model on three sample questions using the Gemma 3 instruct chat template. XLA compiles the kernels on the first call; subsequent prompts reuse the cached compilation.
questions = [
"What is the capital of Germany",
"How does rotary positional embedding differ from learned positional embedding",
"What is grouped-query attention and why is it useful",
]
for q in questions:
generate(q, max_new_tokens=150, temperature=0.7, seed=0)
print()
Summary¶
We have implemented end-to-end autoregressive inference for Gemma 3 1B Instruct using four FlashInfer kernels as the computational backbone - covering both the prompt (prefill) and generation (decode) phases.
The complete inference recipe¶
# 1. Compile kernels (once)
gelu_module = gen_act_and_mul_module('gelu_tanh').build_and_load()
rope_module = gen_rope_module().build_and_load()
decode_local_module = gen_single_decode_module(..., use_sliding_window=True, ...).build_and_load()
decode_global_module = gen_single_decode_module(..., use_sliding_window=False, ...).build_and_load()
prefill_local = gen_single_prefill_module('fa2', ..., use_sliding_window=True).build_and_load()
prefill_global = gen_single_prefill_module('fa2', ..., use_sliding_window=False).build_and_load()
# 2. Prefill: FlashInfer causal attention over all prompt tokens -> KV-cache
h_last, kv_caches = prefill(prompt_ids)
# 3. Decode: FlashInfer decode_attention, one token at a time
for step in range(max_new_tokens):
if step == 0:
logits = lm_head(h_last)
else:
logits, kv_caches = decode_step(generated[-1], kv_caches, T + step - 1)
Going further¶
Paged KV-cache: Replace the growing list with a fixed-size paged cache and use FlashInfer’s
BatchDecodeWithPagedKVCacheWrapperfor batch inference with mixed sequence lengths.Sampling: Extend the sampler with top-p nucleus sampling or top-k filtering on the logits.
Continuous batching: Process multiple requests simultaneously, filling the decode kernel’s batch dimension.