flashinfer.cascade.merge_state_in_place#
- flashinfer.cascade.merge_state_in_place(v: torch.Tensor, s: torch.Tensor, v_other: torch.Tensor, s_other: torch.Tensor, mask: torch.Tensor | None = None) None #
Merge the self-attention state
(v, s)
with another state(v_other, s_other)
in-place.- Parameters:
v (torch.Tensor) – The partial attention output to be updated in-place, shape:
(seq_len, num_heads, head_dim)
.s (torch.Tensor) – The partial logsumexpr value to be updated in-place, expected to be a float32 tensor, shape:
(seq_len, num_heads)
.v_other (torch.Tensor) – The other attention output to be merged, shape:
(seq_len, num_heads, head_dim)
.s_other (torch.Tensor) – The other logsumexp value to be merged, expected to be a float32 tensor, shape:
(seq_len, num_heads)
.mask (Optional[torch.Tensor]) – The boolean mask tensor for whether to merge the state for a corresponding sequence or not. Useful for CUDA graphs. If not specified (default), will merge states for all sequences. shape:
[seq_len]
Example
>>> import torch >>> import flashinfer >>> seq_len = 2048 >>> num_heads = 32 >>> head_dim = 128 >>> v = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> s = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> v_other = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> s_other = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> flashinfer.merge_state_in_place(v, s, v_other, s_other)