flashinfer.norm.fused_rmsnorm_silu

flashinfer.norm.fused_rmsnorm_silu(input: Tensor, weight: Tensor, eps: float = 1e-06, out: Tensor | None = None, block_scale: Tensor | None = None) Tensor | tuple

Fused RMSNorm + SiLU activation.

out[i] = SiLU(RMSNorm(input[i], weight, eps))

where SiLU(x) = x / (1 + exp(-x))

Optimized for SM100 (B200) for WAN VAE decoder problem sizes. Other shapes and architectures (SM80+) use conservative fallback heuristics.

Parameters:
  • input (torch.Tensor) – Input tensor, shape (num_tokens, hidden_size), dtype bfloat16.

  • weight (torch.Tensor) – Scale (gamma) tensor, shape (hidden_size,), dtype bfloat16.

  • eps (float) – Epsilon for numerical stability.

  • out (Optional[torch.Tensor]) –

    Output tensor. If None, allocated as bfloat16 matching input. The dtype of out selects the output format:

    • torch.bfloat16: shape (num_tokens, hidden_size).

    • torch.float8_e4m3fn: FP8 E4M3 output, shape (num_tokens, hidden_size). Requires SM89+ (Ada/Hopper).

    • torch.float4_e2m1fn_x2: NVFP4 block-scaled output, shape (num_tokens, hidden_size // 2). Requires SM100+ (Blackwell) and hidden_size divisible by 16.

  • block_scale (Optional[torch.Tensor]) – Pre-allocated output tensor for per-block scale factors (NVFP4 only). Shape (num_tokens, hidden_size // 16), dtype torch.float8_e4m3fn. If None, allocated automatically when out is NVFP4. Ignored for bf16/fp8 output.

Returns:

output – For bf16/fp8: normalized + SiLU activated tensor, shape (num_tokens, hidden_size).

For NVFP4: a tuple (y_fp4, block_scale) following the same convention as rmsnorm_fp4quant(). y_fp4 has shape (num_tokens, hidden_size // 2) with dtype float4_e2m1fn_x2, and block_scale has shape (num_tokens, hidden_size // 16) with dtype float8_e4m3fn (one E4M3 scale per 16-element block).

Return type:

torch.Tensor or Tuple[torch.Tensor, torch.Tensor]

Notes

Kernel tuning knobs are sweep-optimized on B200 (SM100) for WAN VAE decoder problem sizes: hidden_size in {64, 128, 160, 256, 320, 512, 640, 1024} and num_tokens in {1560, 6240, 24960, 99840, 399360}. Other problem sizes use conservative fallback heuristics that are functionally correct but may not achieve peak throughput. Performance on non-SM100 architectures uses the same fallback path.