flashinfer.cascade.merge_states

flashinfer.cascade.merge_states(v: torch.Tensor, s: torch.Tensor) Tuple[torch.Tensor, torch.Tensor]

Merge multiple attention states (v, s).

Parameters:
  • v (torch.Tensor) – The attention output from the KV segments, shape: [seq_len, num_states, num_heads, head_dim].

  • s (torch.Tensor) – The logsumexp value from the KV segments, shape: [seq_len, num_states, num_heads], expected to be a float32 tensor.

Returns:

  • V (torch.Tensor) – The merged attention output, shape: [seq_len, num_heads, head_dim].

  • S (torch.Tensor) – The logsumexp value from the merged KV-segments, shape: [seq_len, num_heads].

Example

>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> num_states = 100
>>> v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0")
>>> s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_merged, s_merged = flashinfer.merge_states(v, s)
>>> v_merged.shape
torch.Size([2048, 32, 128])
>>> s_merged.shape
torch.Size([2048, 32])