Trace Apply¶
Trace Apply substitutes selected FlashInfer API calls with custom solutions at runtime, with no changes to the calling code (SGLang, vLLM, or other engines). It is the consumer side of the FlashInfer Trace: you provide a mapping from definition name to the solution to run for it, and Trace Apply dispatches that solution whenever the corresponding API is called with the matching shape. For any single definition there is exactly one solution — the one you registered.
Enabling¶
Register solutions explicitly with enable_apply(). The argument is a
mapping {definition_name: solution}, where a solution is either a Python
callable or a first-class Solution:
import torch
import flashinfer
import flashinfer.trace_apply as fi_trace_apply
def my_rmsnorm(hidden_states, weight, eps=1e-6):
x = hidden_states.float()
y = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
return (y * weight.float()).to(hidden_states.dtype)
fi_trace_apply.enable_apply({"rmsnorm_h4096": my_rmsnorm}) # returns the # of wrapped APIs
...
flashinfer.rmsnorm(x, w) # routed to my_rmsnorm when hidden_size == 4096
fi_trace_apply.stats() # per-API dispatch counts (hit / fallback / error)
fi_trace_apply.disable_apply() # restore the original FlashInfer APIs
enable_apply is idempotent — calling it again replaces the previous mapping.
With no argument (enable_apply()) it uses the environment configuration
below; if nothing is configured it is a no-op.
There is also an import-time hook controlled by environment variables:
Variable |
Meaning |
|---|---|
|
Set to |
|
Directory the deployment-configured solutions are loaded from. |
If the configuration is missing or invalid, Trace Apply stays disabled and FlashInfer continues to work normally, with a warning describing the problem.
FLASHINFER_TRACE_APPLY_PATH must point at a curated solutions folder — its
solutions/ subtree is scanned recursively for one solution per definition (a
duplicate definition is an error). It is not the raw extraction bundle, which
also contains baseline solutions and several backends per definition.
Routing¶
A call is routed by definition name. On each call the wrapper extracts the
call’s const axes and recomputes the definition name from the live
TraceTemplate — the same name_prefix + const-axis convention the trace
collector uses (e.g. rmsnorm at hidden_size=4096 → "rmsnorm_h4096").
If that name is in the registered mapping, the call dispatches to its solution;
otherwise it falls back to the original FlashInfer kernel.
Const axes are the compile-time shape a definition is specialized for (
hidden_size,head_dim, …) and form the name. Variable axes (batch size, sequence length, …) are not part of the name, so a single solution serves all of their values.The decision is cached per name, so steady-state dispatch is a dictionary lookup. During CUDA-graph capture only already-resolved shapes are applied, so warm up eagerly before capturing.
Solutions¶
A solution value is either:
a Python callable — invoked directly with the definition’s inputs by keyword (value-returning); or
a first-class
Solution— loaded by language family: the Python family (python,triton,cutedsl,cutile,tilelang) is imported and called by keyword; the C++/CUDA family (cpp,cuda,cutlass) is built viaflashinfer.jitand called positionally.
Trace Apply reconciles a solution’s outputs with the calling API’s convention:
value-returning outputs, caller-provided out= / lse= buffers, in-place
writes (e.g. fused_add_rmsnorm writing back into its input/residual buffers),
and data-dependent arity (e.g. return_lse). If a caller passes out= (e.g.
rmsnorm(x, w, out=buf)), the substituted solution’s result is written into
buf and buf is returned, exactly like the original kernel.
The out= / lse= bindings are auto-derived from the live API signature
(a uniform FlashInfer convention), so they are not recorded in the trace. Only
the non-derivable in-place bindings (which input a result is written back into)
are declared in the trace via the output param.
Error policy¶
Trace Apply is strict. A matched solution that fails to build or run raises, so a broken solution is reported immediately rather than masked. A call with no matching registered name falls back to the original FlashInfer API.