Note
Go to the end to download the full example code.
Enabling FlashInfer GPU Kernels on JAX with the JAX TVM FFI Bridge¶
Overview¶
JAX’s XLA compiler is excellent for training and general tensor computation, but LLM inference has a distinct performance profile: every decode step attends over an ever-growing KV-cache, placing the bottleneck squarely on memory bandwidth rather than raw compute. FlashInfer is a library of hand-tuned CUDA kernels built precisely for this regime.
FlashInfer ships every CUDA kernel as a shared library - an .so-file compiled with a cross-language binary interface defined by Apache TVM’s Foreign Function Interface. Any language with a TVM FFI binding can load these .so-files and call the functions inside. This tutorial shows how to do it from JAX via jax-tvm-ffi, a bridge library that adapts TVM FFI functions to XLA custom calls.
What you’ll build¶
Three FlashInfer kernels:
Kernel |
What it computes |
New concept |
|---|---|---|
|
Gated FFN activation: |
The minimal bridge: load -> register -> call |
|
Rotary positional embeddings on packed batches |
Multiple outputs; argument reordering |
|
Attention over a KV-cache (single request, GQA) |
Type-specialized JIT; scratch buffers; optional-argument sentinels |
At the end, all three run together inside a single @jax.jit region - exactly as they would in a real LLM decode loop.
Preliminaries¶
Hardware and software requirements¶
GPU |
NVIDIA, SM 7.5+ (Turing or later) |
Python packages |
|
CUDA |
12.6+ |
Setting the environment¶
The easiest way to get a working JAX environment is the NVIDIA NGC JAX container. To install it manually:
Recommended (CUDA 13):
pip install 'jax[cuda13]'
For CUDA 12.x:
pip install 'jax[cuda12]'
Three packages are required beyond a standard JAX environment:
Package |
Role |
|---|---|
|
FlashInfer CUDA kernels + JIT compilation system |
|
Bridge: adapts TVM FFI functions to XLA custom calls |
flashinfer-python ships pre-built wheels for each CUDA/Python combination. The --extra-index-url below selects the CUDA 13.0 wheel; replace cu130 with a corresponding mapping for your CUDA Toolkit release:
CUDA 13 -> cu130
CUDA 12.x -> cu12x (e.g., cu126)
Run the cell below only once in your environment.
Install the tutorial dependencies before running the notebook or script:
pip install flashinfer-python -U jax-tvm-ffi \
--no-build-isolation \
--extra-index-url https://flashinfer.ai/whl/cu130/ \
Loading dependencies¶
Run the cell below to load the dependencies.
import os
import time
import math
import jinja2
import numpy as np
import subprocess
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # suppress TF/XLA info & warnings
if "CUDA_HOME" not in os.environ:
try:
nvcc = subprocess.check_output(["which", "nvcc"], text=True).strip()
os.environ["CUDA_HOME"] = str(os.path.dirname(os.path.dirname(nvcc)))
except subprocess.CalledProcessError:
os.environ["CUDA_HOME"] = "/usr/local/cuda"
if "--xla_gpu_cuda_data_dir=" not in os.environ.get("XLA_FLAGS", ""):
os.environ["XLA_FLAGS"] = (
f"{os.environ.get('XLA_FLAGS', '')} "
f"--xla_gpu_cuda_data_dir={os.environ['CUDA_HOME']}"
).strip()
import jax
import jax.numpy as jnp
import jax_tvm_ffi # Bridge adapter: TVM FFI -> XLA custom call
from flashinfer.jit import gen_act_and_mul_module, gen_jit_spec, env as jit_env
from flashinfer.jit.rope import gen_rope_module
from flashinfer.jit.attention.utils import generate_additional_params
from flashinfer.jit.utils import write_if_different
print(f"JAX: {jax.__version__}")
print(f"Devices: {jax.devices()}")
print(f"CUDA home: {os.environ.get('CUDA_HOME')}")
print(f"JIT cache: {jit_env.FLASHINFER_GEN_SRC_DIR.parent}")
The JAX TVM FFI bridge¶
Every FlashInfer kernel lives in a compiled .so-file. Getting that kernel into a @jax.jit computation graph takes three steps - the same three steps for every kernel in this tutorial:
Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module
Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)
Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)
Step 1 - Compile and load the TVM FFI module¶
FlashInfer compiles kernels on demand using its JIT system. flashinfer.jit contains framework-agnostic helpers - one per kernel family - that generate CUDA source code, compile it with nvcc + ninja, and produce a shared library (.so-file). The result is cached in ~/.cache/flashinfer/ so compilation only happens once per configuration. You’ll see the specific helper for each kernel when we get to the examples.
Each helper returns a JitSpec - a recipe describing what to compile. Calling .build_and_load() on a JitSpec runs the compilation pipeline and returns a tvm_ffi.Module:
module = some_jit_spec.build_and_load()
module.my_kernel_function # tvm_ffi.Function, callable from Python
module.my_kernel_function(a, b, c) # call it directly
Under the hood, .build_and_load() writes the generated .so to the cache directory and calls tvm_ffi.load_module() to open it. If the .so already exists (cache hit), the compilation step is skipped and tvm_ffi.load_module() loads the cached binary directly.
In other words: tvm_ffi.load_module(path) is the low-level primitive that opens any .so-file, while .build_and_load() is the high-level entry point that handles source generation, compilation, caching, and loading in one call. In this tutorial we always use .build_and_load() because FlashInfer’s JIT system manages the full pipeline for us.
Step 2 - Register as a JAX FFI target¶
jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec) teaches XLA about the kernel. It wraps a Python callable into an XLA custom call node. Once registered, XLA can schedule the kernel inside @jax.jit computations.
The wrapper is a Python function whose job is to call the TVM function. It may need to reorder arguments: JAX delivers output tensors first and input tensors second, but TVM functions are compiled with their own parameter order that you must match.
def _wrapper(out, x, scalar): # positional order set by arg_spec
module.my_kernel_function(x, out, scalar) # reorder here to match TVM signature
jax_tvm_ffi.register_ffi_target(
'registered_name', _wrapper,
arg_spec=['rets', 'args', 'attrs.scalar'],
platform='gpu', allow_cuda_graph=True, pass_owned_tensor=True,
)
Note
Concept: arg_spec routing
arg_spec is a list that tells the bridge how to route JAX’s three categories of call-time data into the wrapper’s positional arguments:
|
What the wrapper receives |
|---|---|
|
All output tensors, pre-allocated by XLA |
|
All input tensors, in call order |
|
A scalar keyword argument named |
For example, arg_spec=['rets', 'args', 'attrs.top_k'] means the wrapper is called as wrapper(*outputs, *inputs, top_k).
With pass_owned_tensor=True (required for GPU kernels), the tensors inside the wrapper are tvm_ffi.Tensor objects - they have a .shape attribute and are passed directly to TVM functions.
Step 3 - Call as a regular JAX expression¶
result = jax.ffi.ffi_call(
'registered_name',
jax.ShapeDtypeStruct(output_shape, output_dtype), # tells XLA what to pre-allocate
vmap_method='broadcast_all', # how jax.vmap should batch this call
)(*input_tensors, scalar=value)
The name 'registered_name' must match what was passed to register_ffi_target in Step 2, and scalar=value must match the 'attrs.scalar' entry in arg_spec.
Output shapes must be statically known - XLA needs them at trace time to pre-allocate the output buffers that get passed as 'rets' in the wrapper. The call site is a regular JAX expression and composes naturally with @jax.jit.
Note
Concept: vmap_method
vmap_method tells JAX how to handle jax.vmap over this call:
Value |
Behaviour |
|---|---|
|
Treat all inputs as sharing the same batch axis; run the kernel once per batch element. Use this when the operation is independent across a batch dimension, such as an element-wise activation or a per-token positional embedding. |
omitted |
The call raises |
Please refer to this documentation to learn more about vmap.
The full execution path¶
User JAX code
| jax.ffi.ffi_call(name, output_shapes)(*inputs, **attrs)
v
XLA Compiler
- emits a custom call node; records output shapes and scalar attrs at trace time
|
XLA Runtime
- looks up the registered target; pre-allocates output buffers
|
JAX-TVM-FFI Bridge
- unpacks call frame -> positional args according to arg_spec
|
Python wrapper (user-defined)
- reorders arguments to match the TVM function signature
|
tvm_ffi.Function -> CUDA kernel
Example 1: Gated SiLU (a.k.a. SwiGLU-style)¶
In this example, you should learn how minimal the JAX-TVM bridge can be when the function signature and arg_spec are aligned: we have one kernel and one thin wrapper, and JAX handles the rest.
Modern LLM feed-forward layers use a gated activation instead of a plain nonlinearity. A common form is:
FFN(x) = SiLU(W1 * x) * (W2 * x)
where * is elementwise multiply.
In practice, implementations often compute both linear projections in one matmul by using a weight with twice the hidden width, producing a tensor of shape [..., 2H]. Here, hidden size H is just the width of the model’s internal feature vectors. And in gated FFNs, we temporarily double it (2H) to compute two parallel projections before combining them back to size H:
a = input[..., :H] (the "gate" half)
b = input[..., H:] (the "value / up-projection" half)
The TVM function signature is:
silu_and_mul(out, input, enable_pdl)
out: a pre-allocated output buffer that the kernel writes into.input: the fused[..., 2H]activation tensor (contains both halves concatenated).enable_pdl: a boolean that toggles Programmatic Dependent Launch (PDL) - an SM90+ feature that can help chain GPU work with less host involvement. For this tutorial, we keep itFalse.
The same compiled binary supports float16 and bfloat16, selecting the right path via runtime dispatch based on the input dtype.
This is the simplest bridge case: one input tensor -> one output tensor -> one scalar attribute. And because the TVM function’s parameter order (out, input, enable_pdl) matches the order JAX delivers them (rets, then args, then attrs), the wrapper needs no reordering.
Compile and load¶
gen_act_and_mul_module('silu') returns a JitSpec - a recipe describing what to compile.
.build_and_load() runs nvcc + ninja, writes the .so to the FlashInfer cache directory, and returns a tvm_ffi.Module.
print("Compiling silu_and_mul (first run may take ~30 s)...")
silu_module = gen_act_and_mul_module("silu").build_and_load()
print(f" Module type: {type(silu_module).__name__}")
print(f" Function: {silu_module.silu_and_mul}")
Register as a JAX FFI target and validate¶
In this section, we connect the compiled TVM kernel to JAX by registering it as an FFI target and exposing a clean Python wrapper. The key idea is that JAX groups parameters as outputs (rets) -> inputs (args) -> scalar attributes (attrs), and because the TVM function already expects (out, input, enable_pdl) in that exact order, the wrapper simply forwards the call with no argument reshuffling. The silu_and_mul function then uses jax.ffi.ffi_call to allocate the output buffer, pass the fused [..., 2H] tensor to the kernel, and return the computed [..., H] result, while remaining fully batch-compatible via vmap_method="broadcast_all".
The validation step should confirm correctness against a pure JAX reference.
– Register as a JAX FFI target ————————————–
# TVM function: silu_and_mul(out, input, enable_pdl)
# arg_spec: ['rets', 'args', 'attrs.enable_pdl']
# JAX delivers (out, input, enable_pdl) in the same order as the TVM function
# expects, so no reordering is needed in this wrapper.
def _silu_and_mul_wrapper(out, x, enable_pdl):
silu_module.silu_and_mul(out, x, enable_pdl)
jax_tvm_ffi.register_ffi_target(
"flashinfer.silu_and_mul",
_silu_and_mul_wrapper,
arg_spec=["rets", "args", "attrs.enable_pdl"],
platform="gpu",
allow_cuda_graph=True,
pass_owned_tensor=True,
)
# -- JAX-facing function -----------------------------------------------
def silu_and_mul(x: jax.Array) -> jax.Array:
"""Fused silu(gate) * up. Input: [..., 2H] Output: [..., H]"""
out_shape = x.shape[:-1] + (x.shape[-1] // 2,)
return jax.ffi.ffi_call(
"flashinfer.silu_and_mul",
jax.ShapeDtypeStruct(out_shape, x.dtype),
vmap_method="broadcast_all", # element-wise op: independent across any batch dim
)(x, enable_pdl=False)
# -- Validate ------------------------------------------------------------------
TOKENS, HIDDEN = 32, 256
gate_up = jax.random.normal(jax.random.key(0), (TOKENS, 2 * HIDDEN), dtype=jnp.float16)
out = silu_and_mul(gate_up)
# Compute reference to test our function
gate_ref = gate_up[..., :HIDDEN].astype(jnp.float32)
up_ref = gate_up[..., HIDDEN:].astype(jnp.float32)
ref = (jax.nn.silu(gate_ref) * up_ref).astype(jnp.float16)
np.testing.assert_allclose(
np.array(out.astype(jnp.float32)),
np.array(ref.astype(jnp.float32)),
rtol=1e-2,
atol=1e-2,
)
print("silu_and_mul: PASSED")
print(f" {gate_up.shape} -> {out.shape}")
print(f" max error: {float(jnp.max(jnp.abs(out.astype(jnp.float32) - ref))):.5f}")
Example 2: Rotary Positional Embeddings¶
In this example, you should learn how to handle two new complications that were absent in Example 1: a TVM function that produces two outputs and whose argument order does not match JAX’s convention. Both are resolved entirely inside the Python wrapper - registration and the call site otherwise follow the same pattern.
RoPE encodes each token’s position by rotating its query and key vectors in pairs. For position p and dimension index i:
theta_i = rope_theta^(-2i / head_dim)
q_rot[2i] = q[2i] * cos(p*theta_i) - q[2i+1] * sin(p*theta_i)
q_rot[2i+1] = q[2i] * sin(p*theta_i) + q[2i+1] * cos(p*theta_i)
The same rotation is applied to both q and k, so the kernel produces two rotated outputs.
Ragged (packed) batches¶
Rather than a padded [batch, seq_len, heads, dim] tensor, apply_rope takes a flat [total_tokens, heads, dim] tensor with two auxiliary arrays:
indptr- a CSR-style pointer array whereindptr[i]:indptr[i+1]gives the token range of sequenceioffsets- the KV-cache position of the first token of each sequence
This lets the kernel handle variable-length sequences without padding overhead.
The TVM function signature¶
apply_rope(q, k, q_rope, k_rope, indptr, offsets, rotary_dim, interleave, rope_scale, rope_theta)
q,k: input query and key tensors, shape[total_tokens, num_heads, head_dim]q_rope,k_rope: pre-allocated output buffers - they appear in the middle of the argument list, after the inputs but before the index arraysindptr,offsets: ragged-batch descriptorsrotary_dim,interleave,rope_scale,rope_theta: scalar parameters
Because the outputs are interleaved with the inputs, the JAX convention (all outputs first) does not match the TVM signature. The wrapper must swap them back.
Note
Concept: argument reordering
TVM functions are compiled with whatever parameter order the kernel author chose. JAX’s FFI convention always delivers (outputs, inputs, scalars). The wrapper bridges the two:
# JAX delivers (via arg_spec): q_rope, k_rope, q, k, indptr, offsets, *scalars
# TVM function expects: q, k, q_rope, k_rope, indptr, offsets, *scalars
def _wrapper(q_rope, k_rope, q, k, indptr, offsets, *scalars):
tvm_fn(q, k, q_rope, k_rope, indptr, offsets, *scalars) # reorder here
Compile and load¶
gen_rope_module() returns a JitSpec for the RoPE kernel. Unlike gen_act_and_mul_module, this kernel dispatches over dtypes at runtime from a single binary - no dtype specialisation is needed at compile time.
print("Compiling apply_rope...")
rope_module = gen_rope_module().build_and_load()
print(f" Function: {rope_module.apply_rope}")
Register as a JAX FFI target and validate¶
The apply_rope wrapper introduces the argument-reordering pattern described above. Because the TVM function places q_rope and k_rope after the input tensors - rather than at the beginning - the wrapper receives them first (via 'rets') and must swap them back before calling the TVM function. The kernel also produces two output tensors, so jax.ffi.ffi_call receives a tuple of two ShapeDtypeStruct descriptors and returns a tuple of two JAX arrays. Because RoPE rotates each token independently, the operation decomposes cleanly over any outer batch dimension: vmap_method='broadcast_all' enables jax.vmap to map over it without any changes to the kernel.
The validation uses two sequences with different starting positions - sequence 0 begins at position 0 and sequence 1 at position 100 - simulating a prefill where the second sequence already has 100 tokens in the KV-cache. Each token is rotated by angles matching its absolute position, so the test exercises both the CSR indexing and the offset arithmetic.
– Register as a JAX FFI target ———————————————-
- TVM function: apply_rope(q, k, q_rope, k_rope, indptr, offsets,
rotary_dim, interleave, rope_scale, rope_theta)
JAX delivers: q_rope, k_rope first (rets), then q, k, indptr, offsets (args). TVM expects: q, k, q_rope, k_rope, indptr, offsets, … - the wrapper swaps them.
def _rope_wrapper(
q_rope,
k_rope,
q,
k,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
):
rope_module.apply_rope(
q,
k,
q_rope,
k_rope,
indptr,
offsets,
rotary_dim,
interleave,
rope_scale,
rope_theta,
)
jax_tvm_ffi.register_ffi_target(
"flashinfer.apply_rope",
_rope_wrapper,
arg_spec=[
"rets",
"args",
"attrs.rotary_dim",
"attrs.interleave",
"attrs.rope_scale",
"attrs.rope_theta",
],
platform="gpu",
allow_cuda_graph=True,
pass_owned_tensor=True,
)
# -- JAX-facing function -------------------------------------------------------
def apply_rope(q, k, indptr, offsets, *, rope_theta=1e4):
"""Apply rotary positional embeddings to packed query and key tensors.
q, k: [total_tokens, num_heads, head_dim]
indptr: [num_seqs + 1] CSR-style token range per sequence
offsets: [num_seqs] absolute position of the first token of each sequence
Returns: (q_rotated, k_rotated), same shapes as inputs
"""
head_dim = q.shape[-1]
return jax.ffi.ffi_call(
"flashinfer.apply_rope",
(
jax.ShapeDtypeStruct(q.shape, q.dtype),
jax.ShapeDtypeStruct(k.shape, k.dtype),
),
vmap_method="broadcast_all", # each packed batch is independent; safe to vmap
)(
q,
k,
indptr,
offsets,
rotary_dim=head_dim,
interleave=False,
rope_scale=1.0,
rope_theta=float(rope_theta),
)
# -- Validate ------------------------------------------------------------------
def _reference_rope(x, positions, theta=1e4):
"""Non-interleaved RoPE reference. x: [tokens, heads, dim]"""
x32 = x.astype(jnp.float32)
d = x32.shape[-1] // 2
freqs = 1.0 / (theta ** (2.0 * jnp.arange(d, dtype=jnp.float32) / x32.shape[-1]))
angles = positions[:, None].astype(jnp.float32) * freqs[None, :] # [T, d]
cos_a = jnp.cos(angles)[:, None, :] # [T, 1, d]
sin_a = jnp.sin(angles)[:, None, :]
x1, x2 = x32[..., :d], x32[..., d:]
return jnp.concatenate(
[x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1
).astype(x.dtype)
# Two sequences of 8 tokens: first starts at position 0, second at position 100
NUM_HEADS, HEAD_DIM, SEQ_LEN, NUM_SEQ = 8, 64, 8, 2
ROPE_THETA = 1e4
q_in = jax.random.normal(
jax.random.key(1), (NUM_SEQ * SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=jnp.bfloat16
)
k_in = jax.random.normal(
jax.random.key(2), (NUM_SEQ * SEQ_LEN, NUM_HEADS, HEAD_DIM), dtype=jnp.bfloat16
)
indptr = jnp.array([0, SEQ_LEN, 2 * SEQ_LEN], dtype=jnp.int32)
offsets = jnp.array([0, 100], dtype=jnp.int32)
q_rot, k_rot = apply_rope(q_in, k_in, indptr, offsets, rope_theta=ROPE_THETA)
positions = jnp.concatenate(
[jnp.arange(SEQ_LEN, dtype=jnp.int32) + off for off in [0, 100]]
)
q_ref = _reference_rope(q_in, positions, theta=ROPE_THETA)
k_ref = _reference_rope(k_in, positions, theta=ROPE_THETA)
for name, got, ref in [("q", q_rot, q_ref), ("k", k_rot, k_ref)]:
np.testing.assert_allclose(
np.array(got.astype(jnp.float32)),
np.array(ref.astype(jnp.float32)),
rtol=1e-2,
atol=1e-2,
)
max_err = float(jnp.max(jnp.abs(got.astype(jnp.float32) - ref.astype(jnp.float32))))
print(f"apply_rope {name}: PASSED max_err={max_err:.5f}")
print(
f" Input: {q_in.shape} ({NUM_SEQ} seqs x {SEQ_LEN} tokens, offsets {offsets.tolist()})"
)
Example 3: Single-request decode attention¶
In this example, you should learn about three patterns that appear together for the first time: a kernel compiled separately per dtype and head dimension, a scratch buffer the caller must allocate as an output, and optional arguments signalled as absent with empty tensors.
In autoregressive generation, each new token attends over all previously generated tokens stored in the KV-cache:
Attention(q, K, V) = softmax(q KT / sqrtd) V
where q is a single query (or a small group in Grouped Query Attention), and K, V are the full cached sequences. FlashInfer’s single-request kernel is tuned for exactly this shape: one query token, many keys, memory-bound.
Type-specialized compilation¶
Unlike silu_and_mul and apply_rope - which dispatch over dtypes at runtime from a single binary - the decode attention kernel is compiled separately for each ``(dtype, head_dim)`` combination. This lets the compiler choose tile sizes and memory access patterns for the exact configuration, with no runtime branching.
FlashInfer’s JIT system renders Jinja2 templates with concrete type names to produce configuration-specific CUDA code. A separate .so is compiled for each combination and identified by a URI that doubles as the on-disk cache key. gen_decode_jit_spec, defined in the cell below, builds that URI and invokes the JIT system.
The TVM function signature¶
run(q, k, v,
tmp, out, # outputs: scratch buffer first, result second
maybe_lse, # optional output: log-sum-exp values
kv_layout_code, window_left,
maybe_alibi_slopes, # optional input: ALiBi position biases
logits_soft_cap, sm_scale, rope_rcp_scale, rope_rcp_theta)
q,k,v: query and KV-cache tensorstmp: a scratch buffer for split-K partial sums - must be provided by the caller as a pre-allocated outputout: the attention resultmaybe_lse: optional log-sum-exp output;Noneto skipkv_layout_code:0for NHD layout,1for HNDwindow_left: sliding-window cutoff;-1for full attentionmaybe_alibi_slopes: optional ALiBi position biases;Noneto skiplogits_soft_cap,sm_scale,rope_rcp_scale,rope_rcp_theta: scalar parameters
Two patterns are new here:
Note
Concept: scratch buffer
The decode kernel uses an internal buffer for split-K partial results. We declare it as an output tensor so XLA pre-allocates it, then discard it after the call:
out, _, _ = jax.ffi.ffi_call(target, (result_struct, tmp_struct, lse_struct))(...)
# ^^^ ^^^
# 32 MB scratch LSE sentinel
Note
Concept: empty-tensor sentinel
Optional arguments are signalled as absent with a tensor whose first dimension is zero. Inside the wrapper, tensor.shape[0] == 0 maps to None:
# As an output (not computed):
jax.ShapeDtypeStruct((0,), jnp.float32)
# As an input (not provided):
jnp.empty((0,), dtype=jnp.float32)
Compile and load¶
gen_decode_jit_spec assembles the build recipe: it renders two Jinja2 templates with the concrete dtype and head-dimension values, copies the .cu source files into the generated directory, and returns a JitSpec that .build_and_load() compiles and caches.
– Type mappings for Jinja template rendering ——————————–
DTYPE_CPP = {"float16": "half", "bfloat16": "nv_bfloat16", "float32": "float"}
DTYPE_SAFE = {"float16": "f16", "bfloat16": "bf16", "float32": "f32"}
POS_ENC = {
0: "PosEncodingMode::kNone",
1: "PosEncodingMode::kRoPELlama",
2: "PosEncodingMode::kALiBi",
}
def gen_decode_jit_spec(dtype: str = "float16", head_dim: int = 64):
"""Return a JitSpec for type-specialized single-request decode attention."""
s = DTYPE_SAFE[dtype]
uri = (
f"single_decode_with_kv_cache_dtype_q_{s}_dtype_kv_{s}_dtype_o_{s}_"
f"head_dim_qk_{head_dim}_head_dim_vo_{head_dim}_"
f"posenc_0_use_swa_False_use_logits_cap_False"
)
gen_dir = jit_env.FLASHINFER_GEN_SRC_DIR / uri
os.makedirs(gen_dir, exist_ok=True)
# generate_additional_params produces the C++ boilerplate strings
# for the optional alibi-slopes tensor and four scalar parameters.
params_decl, func_params, params_setter = generate_additional_params(
additional_tensor_names=["maybe_alibi_slopes"],
additional_tensor_dtypes=["float"],
additional_scalar_names=[
"logits_soft_cap",
"sm_scale",
"rope_rcp_scale",
"rope_rcp_theta",
],
additional_scalar_dtypes=["double", "double", "double", "double"],
)
kwargs = dict(
additional_func_params=func_params,
additional_params_decl=params_decl,
additional_params_setter=params_setter,
variant_decl="#include<flashinfer/attention/variants.cuh>",
variant_name="DefaultAttention<false, false, false, false>",
dtype_q=DTYPE_CPP[dtype],
dtype_kv=DTYPE_CPP[dtype],
dtype_o=DTYPE_CPP[dtype],
head_dim_qk=head_dim,
head_dim_vo=head_dim,
pos_encoding_mode=POS_ENC[0],
use_sliding_window="false",
use_logits_soft_cap="false",
)
csrc = jit_env.FLASHINFER_CSRC_DIR
# Render Jinja2 templates with the type-specific values
for tmpl, out in [
("single_decode_customize_config.jinja", "single_decode_config.inc"),
("single_decode_kernel_inst.jinja", "single_decode_kernel.cu"),
]:
rendered = jinja2.Template((csrc / tmpl).read_text()).render(**kwargs)
write_if_different(gen_dir / out, rendered)
# Copy the .cu source files that #include the rendered headers
sources = [gen_dir / "single_decode_kernel.cu"]
for fname in ["single_decode.cu", "single_decode_jit_binding.cu"]:
dest = gen_dir / fname
write_if_different(dest, (csrc / fname).read_text())
sources.append(dest)
return gen_jit_spec(uri, sources)
# -- Compile and load ----------------------------------------------------------
DTYPE, HEAD_DIM = "float16", 64
print(f"Compiling decode attention ({DTYPE}, head_dim={HEAD_DIM})...")
decode_module = gen_decode_jit_spec(DTYPE, HEAD_DIM).build_and_load()
print(f" run function: {decode_module.run}")
Register as a JAX FFI target¶
The decode_attention wrapper handles all three patterns from this example. The optional lse and alibi_slopes arguments use the empty-tensor sentinel: when tensor.shape[0] == 0, the wrapper passes None to the TVM function, telling the kernel to skip that output or ignore that input. The scratch buffer tmp is declared as a pre-allocated output - XLA reserves 32 MB for it - and the caller discards it after the call. Finally, vmap_method is intentionally omitted: the scratch buffer is a flat array with no batch dimension, and GQA head-grouping does not decompose cleanly over an added outer batch axis. Callers that need batching should loop explicitly or use a batch-aware kernel variant.
– Register ——————————————————————
- TVM function signature:
- run(q, k, v, tmp, out, maybe_lse, kv_layout_code, window_left,
maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_rcp_scale, rope_rcp_theta)
Sentinel rule: tensor.shape[0] == 0 => pass None (argument is absent)
_run = decode_module.run
def _decode_wrapper(
out,
tmp,
lse_or_empty, # <- rets
q,
k,
v,
alibi_or_empty, # <- args
kv_layout_code,
window_left, # <- attrs
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
):
lse = None if lse_or_empty.shape[0] == 0 else lse_or_empty
alibi = None if alibi_or_empty.shape[0] == 0 else alibi_or_empty
# Reorder to match TVM function signature:
# kv_layout_code and window_left come before maybe_alibi_slopes
_run(
q,
k,
v,
tmp,
out,
lse,
kv_layout_code,
window_left,
alibi,
logits_soft_cap,
sm_scale,
rope_scale,
rope_theta,
)
# Embed dtype + head_dim in the name: each compiled binary is a distinct target
DECODE_TARGET = f"flashinfer.single_decode_{DTYPE}_h{HEAD_DIM}"
jax_tvm_ffi.register_ffi_target(
DECODE_TARGET,
_decode_wrapper,
arg_spec=[
"rets",
"args",
"attrs.kv_layout_code",
"attrs.window_left",
"attrs.logits_soft_cap",
"attrs.sm_scale",
"attrs.rope_scale",
"attrs.rope_theta",
],
platform="gpu",
allow_cuda_graph=True,
pass_owned_tensor=True,
)
# -- JAX-facing function -------------------------------------------------------
def decode_attention(q, k, v):
"""Single-request GQA decode attention.
q: [num_qo_heads, head_dim] float16 (query for one new token)
k: [kv_len, num_kv_heads, head_dim] float16 (full KV-cache, NHD layout)
v: [kv_len, num_kv_heads, head_dim] float16
Returns: [num_qo_heads, head_dim]
"""
sm_scale = 1.0 / math.sqrt(q.shape[-1])
tmp_elems = 32 * 1024 * 1024 // np.dtype(q.dtype).itemsize
out, _, _ = jax.ffi.ffi_call(
DECODE_TARGET,
(
jax.ShapeDtypeStruct(q.shape, q.dtype), # out
jax.ShapeDtypeStruct((tmp_elems,), q.dtype), # tmp (scratch, discarded)
jax.ShapeDtypeStruct((0,), jnp.float32), # lse (sentinel: not computed)
),
)(
q,
k,
v,
jnp.empty((0,), dtype=jnp.float32), # alibi slopes: sentinel (not used)
kv_layout_code=0, # NHD=0, HND=1
window_left=-1, # full attention window
logits_soft_cap=0.0,
sm_scale=sm_scale,
rope_scale=1.0,
rope_theta=1e4,
)
return out
print(f"Registered '{DECODE_TARGET}'.")
The validation uses a GQA configuration: 16 query heads and 4 KV heads (4x grouping) attending over a 512-token KV-cache. A pure-JAX reference computes the grouped softmax attention in float32 for comparison.
– Validate ——————————————————————
def _reference_gqa_decode(q, k, v):
"""Reference GQA decode. q: [H_q, D] k, v: [S, H_kv, D] (NHD)"""
H_q, H_kv = q.shape[0], k.shape[1]
scale = q.shape[-1] ** -0.5
q32 = q.astype(jnp.float32).reshape(H_kv, H_q // H_kv, -1) # [H_kv, group, D]
scores = jnp.einsum("hgd,shd->hgs", q32, k.astype(jnp.float32)) * scale
weights = jax.nn.softmax(scores, axis=-1)
out = jnp.einsum("hgs,shd->hgd", weights, v.astype(jnp.float32))
return out.reshape(H_q, -1)
# GQA: 16 query heads, 4 KV heads (4x grouping), 512-token KV-cache
NUM_QO, NUM_KV, KV_LEN = 16, 4, 512
q = jax.random.normal(jax.random.key(10), (NUM_QO, HEAD_DIM), dtype=jnp.float16)
k = jax.random.normal(jax.random.key(11), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16)
v = jax.random.normal(jax.random.key(12), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16)
out_raw = decode_attention(q, k, v)
out_ref = _reference_gqa_decode(q, k, v)
np.testing.assert_allclose(
np.array(out_raw.astype(jnp.float32)),
np.array(out_ref),
rtol=1e-2,
atol=1e-2,
)
print("decode_attention: PASSED")
print(
f" GQA: {NUM_QO} query / {NUM_KV} KV heads ({NUM_QO // NUM_KV}x groups), kv_len={KV_LEN}"
)
print(f" Output: {out_raw.shape}")
print(
f" Max error: {float(jnp.max(jnp.abs(out_raw.astype(jnp.float32) - out_ref))):.4f}"
)
Composing kernels in @jax.jit¶
In this final section, you should learn that registered TVM FFI targets are plain XLA custom call nodes and compose naturally inside a single @jax.jit-decorated function alongside any other JAX operations - no special handling is needed.
A realistic LLM decode step brings all three kernels together:
silu_and_mul- gated FFN activationapply_rope- rotary embeddings applied to the new query and key tokensdecode_attention- cross-attention over the full KV-cache with the RoPE’d query
XLA compiles the entire function once. All three kernels become custom call nodes in the same HLO computation graph.
Note
Where to put @jax.jit
jax.ffi.ffi_call works with or without @jax.jit. Two patterns are common:
Per-kernel JIT - decorate each helper individually:
@jax.jit
def silu_and_mul(x): ...
@jax.jit
def decode_attention(q, k, v): ...
Each kernel compiles independently and runs fast when called in isolation. Good for standalone use, exploratory work, and benchmarking individual kernels.
Outer JIT only - helpers are plain functions; only the composition function is decorated:
def silu_and_mul(x): ... # plain Python
def decode_attention(q, k, v): ...
@jax.jit
def decode_step(...):
ffn_out = silu_and_mul(...)
attn_out = decode_attention(...)
return ffn_out, attn_out
XLA traces the entire decode step as a single computation graph, seeing all custom call nodes at once. This is the right choice for production LLM inference, where all kernels run together every step anyway.
JAX handles nested @jax.jit correctly - inner jits are inlined during tracing - so there is no penalty to mixing patterns later. This tutorial uses outer JIT only to keep the focus on the composition.
– Inputs for the decode step ———————————————–
# FFN input: 4 tokens, each with a 2H fused gate+up projection
gate_up = jax.random.normal(jax.random.key(20), (4, 2 * HEAD_DIM), dtype=jnp.float16)
# GQA dimensions and KV-cache (same config as the validate cell)
NUM_QO, NUM_KV, KV_LEN = 16, 4, 512
q = jax.random.normal(jax.random.key(10), (NUM_QO, HEAD_DIM), dtype=jnp.float16)
k = jax.random.normal(jax.random.key(11), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16)
v = jax.random.normal(jax.random.key(12), (KV_LEN, NUM_KV, HEAD_DIM), dtype=jnp.float16)
# New-token Q and K as packed batches for apply_rope: [tokens, heads, dim]
q_new = q.reshape(1, NUM_QO, HEAD_DIM) # [1, 16, 64] <- new query token
k_new = k[:1] # [1, 4, 64] <- new key token
indptr = jnp.array([0, 1], dtype=jnp.int32) # one sequence of length 1
offsets = jnp.array([KV_LEN], dtype=jnp.int32) # new token sits at position KV_LEN
# -- @jax.jit composition ------------------------------------------------------
@jax.jit
def decode_step(gate_up, q_new, k_new, k_cache, v_cache, indptr, offsets):
"""One LLM decode step compiled into a single XLA computation."""
# 1. Gated FFN activation
ffn_out = silu_and_mul(gate_up)
# 2. Rotary embeddings for the new query and key tokens
q_r, k_r = apply_rope(q_new, k_new, indptr, offsets)
# 3. Decode attention over the full KV-cache with the RoPE'd query
attn_out = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k_cache, v_cache)
return ffn_out, attn_out
ffn_out, attn_out = decode_step(gate_up, q_new, k_new, k, v, indptr, offsets)
# Validate against calling each kernel individually (outside @jax.jit)
ffn_ref = silu_and_mul(gate_up)
q_r, k_r = apply_rope(q_new, k_new, indptr, offsets)
attn_ref = decode_attention(q_r.reshape(NUM_QO, HEAD_DIM), k, v)
np.testing.assert_allclose(
np.array(ffn_out.astype(jnp.float32)),
np.array(ffn_ref.astype(jnp.float32)),
rtol=1e-2,
atol=1e-2,
)
np.testing.assert_allclose(
np.array(attn_out.astype(jnp.float32)),
np.array(attn_ref.astype(jnp.float32)),
rtol=1e-2,
atol=1e-2,
)
print("@jax.jit composition: PASSED")
print(f" gate_up {gate_up.shape} -> ffn_out {ffn_out.shape}")
print(f" q_new {q_new.shape} -> attn_out {attn_out.shape}")
# -- Latency benchmark ----------------------------------------------------------
_ = decode_attention(q, k, v).block_until_ready() # warm-up (triggers XLA compilation)
N = 100
t0 = time.perf_counter()
for _ in range(N):
decode_attention(q, k, v).block_until_ready()
us = (time.perf_counter() - t0) / N * 1e6
print(
f"\ndecode_attention kv_len={KV_LEN}, {NUM_QO}/{NUM_KV} GQA heads -> {us:.1f} us"
)
Summary¶
You have applied the JAX-TVM FFI bridge to three real LLM inference kernels, each revealing a new layer of the pattern.
Step 1 BUILD & LOAD jit_spec.build_and_load() -> tvm_ffi.Module
Step 2 REGISTER jax_tvm_ffi.register_ffi_target(name, wrapper, arg_spec)
Step 3 CALL jax.ffi.ffi_call(name, output_shapes)(*inputs, **scalar_attrs)
FlashInfer provides the same bridge for batch decode with paged KV-cache, variable-length prefill attention, fused mixture-of-experts, quantized GEMM, and more. Every kernel uses the same three-step recipe summarized above. The main variables are:
How many output tensors to declare, and whether any are scratch buffers to discard after the call
Whether any inputs or outputs are optional (use the empty-tensor sentinel)
Whether the kernel needs type-specialized compilation (Jinja template rendering) or dispatches at runtime over dtypes
Beyond the examples in this tutorial¶
This tutorial demonstrated silu_and_mul, apply_rope, and single_decode as representative examples, but FlashInfer’s strength lies in its broader library of high-performance kernels - including Multi-head Latent Attention (MLA), sparse attention, TensorRT-LLM generative batch attention, and fused Mixture-of-Experts (MoE). The same three-step jax-tvm-ffi pattern shown here applies directly to all of them: compile the kernel, register the wrapper, and call it from JAX. No changes to the bridge are needed - only the TVM function signature and arg_spec vary per kernel.
For more details on jax-tvm-ffi itself - including how to wrap your own C++ or Python TVM functions - see the jax-tvm-ffi documentation and examples.