flashinfer.norm.fused_rmsnorm_silu¶
- flashinfer.norm.fused_rmsnorm_silu(input: Tensor, weight: Tensor, eps: float = 1e-06, out: Tensor | None = None, block_scale: Tensor | None = None) Tensor | tuple¶
Fused RMSNorm + SiLU activation.
out[i] = SiLU(RMSNorm(input[i], weight, eps))where
SiLU(x) = x / (1 + exp(-x))Optimized for SM100 (B200) for WAN VAE decoder problem sizes. Other shapes and architectures (SM80+) use conservative fallback heuristics.
- Parameters:
input (torch.Tensor) – Input tensor, shape
(num_tokens, hidden_size), dtypebfloat16.weight (torch.Tensor) – Scale (gamma) tensor, shape
(hidden_size,), dtypebfloat16.eps (float) – Epsilon for numerical stability.
out (Optional[torch.Tensor]) –
Output tensor. If
None, allocated asbfloat16matching input. The dtype ofoutselects the output format:torch.bfloat16: shape(num_tokens, hidden_size).torch.float8_e4m3fn: FP8 E4M3 output, shape(num_tokens, hidden_size). Requires SM89+ (Ada/Hopper).torch.float4_e2m1fn_x2: NVFP4 block-scaled output, shape(num_tokens, hidden_size // 2). Requires SM100+ (Blackwell) andhidden_sizedivisible by 16.
block_scale (Optional[torch.Tensor]) – Pre-allocated output tensor for per-block scale factors (NVFP4 only). Shape
(num_tokens, hidden_size // 16), dtypetorch.float8_e4m3fn. IfNone, allocated automatically whenoutis NVFP4. Ignored for bf16/fp8 output.
- Returns:
output – For bf16/fp8: normalized + SiLU activated tensor, shape
(num_tokens, hidden_size).For NVFP4: a tuple
(y_fp4, block_scale)following the same convention asrmsnorm_fp4quant().y_fp4has shape(num_tokens, hidden_size // 2)with dtypefloat4_e2m1fn_x2, andblock_scalehas shape(num_tokens, hidden_size // 16)with dtypefloat8_e4m3fn(one E4M3 scale per 16-element block).- Return type:
torch.Tensor or Tuple[torch.Tensor, torch.Tensor]
Notes
Kernel tuning knobs are sweep-optimized on B200 (SM100) for WAN VAE decoder problem sizes:
hidden_sizein {64, 128, 160, 256, 320, 512, 640, 1024} andnum_tokensin {1560, 6240, 24960, 99840, 399360}. Other problem sizes use conservative fallback heuristics that are functionally correct but may not achieve peak throughput. Performance on non-SM100 architectures uses the same fallback path.