flashinfer.cascade.merge_state_in_place

flashinfer.cascade.merge_state_in_place(v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor, mask: torch.Tensor | None = None) None

Merge the self-attention state (v, s) with another state (v_other, s_other) in-place.

Parameters:
  • v (torch.Tensor) – The partial attention output to be updated in-place, shape: (seq_len, num_heads, head_dim).

  • s (torch.Tensor) – The partial logsumexpr value to be updated in-place, expected to be a float32 tensor, shape: (seq_len, num_heads).

  • v_other (torch.Tensor) – The other attention output to be merged, shape: (seq_len, num_heads, head_dim).

  • s_other (torch.Tensor) – The other logsumexp value to be merged, expected to be a float32 tensor, shape: (seq_len, num_heads).

  • mask (Optional[torch.Tensor]) – The boolean mask tensor for whether to merge the state for a corresponding sequence or not. Useful for CUDA graphs. If not specified (default), will merge states for all sequences. shape: [seq_len]

Example

>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> v = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> s = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> flashinfer.merge_state_in_place(v, s, v_other, s_other)