fi_trace — Operation Schema Extraction¶
fi_trace is FlashInfer’s operation schema extraction system. Every
@flashinfer_api-decorated function automatically grows a .fi_trace()
method that captures the shape, dtype, and axis structure of a call as a
portable JSON file — without running the GPU kernel.
These JSON files are the input format for flashinfer-bench, the companion benchmark toolkit. Collecting them while running your production workload gives you a precise benchmark suite that reflects your actual model and serving scenario.
Quick Start¶
Set two environment variables before importing FlashInfer:
export FLASHINFER_TRACE_DUMP=1
export FLASHINFER_TRACE_DUMP_DIR=./fi_trace_out # default: ./fi_trace_out
python my_inference_script.py
FlashInfer writes one .json file per unique (op, shape) combination.
Subsequent calls with the same shapes are deduplicated — no duplicate files.
fi_trace_out/
├── rmsnorm_h7168.json
├── gqa_paged_decode_h32_kv8_d128_ps16.json
├── moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.json
└── ...
Environment Variables¶
Variable |
Type |
Default |
Description |
|---|---|---|---|
|
int |
|
Set to |
|
str |
|
Directory where JSON files are written. |
Both variables are read lazily at call time, so they can be set after
import flashinfer (e.g. when using python -m).
JSON File Format¶
Each file describes one operation instance. Here is an annotated example for
rmsnorm with hidden_size=7168:
{
"name": "rmsnorm_h7168",
"description": "Root Mean Square Normalization. Epsilon is fixed at 1e-6.",
"op_type": "rmsnorm",
"tags": [
"fi_api:flashinfer.norm.rmsnorm",
"status:verified"
],
"axes": {
"batch_size": { "type": "var" },
"hidden_size": { "type": "const", "value": 7168 }
},
"inputs": {
"hidden_states": { "shape": ["batch_size", "hidden_size"], "dtype": "bfloat16" },
"weight": { "shape": ["hidden_size"], "dtype": "bfloat16" }
},
"outputs": {
"output": { "shape": ["batch_size", "hidden_size"], "dtype": "bfloat16" }
},
"reference": "..."
}
Key fields:
Field |
Meaning |
|---|---|
|
Auto-generated from |
|
Identifies the kernel class ( |
|
List of key:value tags. Always includes |
|
Symbolic dimensions. |
|
Each entry has |
|
Source of a pure-PyTorch reference implementation for correctness
checking (present on |
Calling .fi_trace() Directly¶
Every decorated function exposes a .fi_trace() method.
You can call it without running the kernel:
import torch
import flashinfer
q = torch.zeros(32, 32, 128, dtype=torch.bfloat16, device="cuda")
k = torch.zeros(64, 16, 8, 128, dtype=torch.bfloat16, device="cuda")
v = torch.zeros(64, 16, 8, 128, dtype=torch.bfloat16, device="cuda")
schema = flashinfer.norm.rmsnorm.fi_trace(
hidden_states=torch.zeros(32, 7168, dtype=torch.bfloat16),
weight=torch.ones(7168, dtype=torch.bfloat16),
)
print(schema["name"]) # rmsnorm_h7168
print(schema["axes"]) # {'batch_size': {'type': 'var'}, 'hidden_size': {'type': 'const', 'value': 7168}}
To write to a specific directory, pass save_dir:
schema = flashinfer.norm.rmsnorm.fi_trace(
hidden_states=...,
weight=...,
save_dir="./my_traces",
)
Covered Operations¶
The following FlashInfer operations have trace templates and will emit JSON
files when FLASHINFER_TRACE_DUMP=1:
Module |
Operation |
|
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MoE Routing Types¶
MoE operations dispatch to per-routing-type templates. The output filename encodes the routing method:
Value |
Name |
Filename pattern (FP8 example) |
|---|---|---|
0 |
Default (Softmax → TopK) |
|
1 |
Renormalize (TopK → Softmax) |
|
2 |
DeepSeekV3 (Sigmoid + group selection) |
|
3 |
Llama4 (Top1 → Sigmoid) |
|
4 |
RenormalizeNaive (Softmax → TopK → Renormalize) |
|
5 |
TopK (no normalisation) |
|
Example: Collecting Traces from a Real Workload¶
The script below runs a representative set of FlashInfer ops and collects all trace JSON files in one pass. It covers the shapes used in DeepSeek-V3-style models with expert-parallel MoE serving.
python tests/trace/example.py
The generated files can be passed directly to flashinfer-bench:
flashinfer-bench --trace-dir fi_trace_out/ --backends fa2 cudnn cutlass
Adding Trace Support to a New Kernel¶
When adding a new kernel (see CLAUDE.md and .claude/skills/add-cuda-kernel/SKILL.md
for the full tutorial), attach a TraceTemplate to the @flashinfer_api decorator:
from flashinfer.trace.template import Const, Tensor, TraceTemplate, Var
from flashinfer.api_logging import flashinfer_api
rmsnorm_trace = TraceTemplate(
op_type="rmsnorm",
name_prefix="rmsnorm",
description="Root Mean Square Normalization.",
axes={
"batch_size": Var(),
"hidden_size": Const(abbrev="h"),
},
inputs={
"hidden_states": Tensor(["batch_size", "hidden_size"]),
"weight": Tensor(["hidden_size"]),
},
outputs={
"output": Tensor(["batch_size", "hidden_size"], dtype_from="hidden_states"),
},
tags=["status:verified"],
)
@flashinfer_api(trace=rmsnorm_trace)
def rmsnorm(hidden_states, weight, eps=1e-6):
...
The template is registered automatically in _TRACE_REGISTRY at decoration
time and picked up by the consistency tests without any manual registration.
For operations whose template depends on a runtime parameter (e.g.
routing_method_type for MoE), write a dispatch callable and attach a
.templates attribute so the registry discovers all variants:
_TEMPLATES = {0: default_trace, 1: renorm_trace, ...}
def my_dispatch(**kwargs):
return _TEMPLATES.get(int(kwargs.get("routing_method_type", 0)))
my_dispatch.templates = list(_TEMPLATES.values())
@flashinfer_api(trace=my_dispatch)
def my_moe_op(...):
...
Consistency Tests¶
FlashInfer ships automated linter-style tests that validate every trace template without running GPU kernels:
pytest tests/trace/test_fi_trace_template_consistency.py -v
The tests check three properties for every registered template:
Signature consistency — every
param=reference in the template matches a real parameter of the decorated function.Axes coverage — every
Constaxis can be resolved from at least one tensor’s shape or from a scalar kwarg.End-to-end completeness — calling
.fi_trace()with auto-generated minimal tensors returns a dict where allConstaxes have values and no input/output hasdtype == "unknown".
When you add a template, these tests run automatically with no manual registration required.