flashinfer.cascade.merge_state#

flashinfer.cascade.merge_state(v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor)#

Merge the attention output V and the logsumexp value S from the two KV-segments. Check our tutorial on the mathematical details.

Parameters:
  • v_a (torch.Tensor) – The attention output from the KV segment A, shape: [seq_len, num_heads, head_dim].

  • s_a (torch.Tensor) – The logsumexp value from the KV segment A. expected to be a float32 tensor, shape: [seq_len, num_heads].

  • v_b (torch.Tensor) – The attention output from the KV segment B, shape: [seq_len, num_heads, head_dim].

  • s_b (torch.Tensor) – The logsumexp value from the KV segment B, expected to be a float32 tensor, shape: [seq_len, num_heads]

Returns:

  • V (torch.Tensor) – The merged attention output (equivalent to attention with merged KV-segment [A: B]), shape: [batch_size, num_heads, head_dim].

  • S (torch.Tensor) – The logsumexp value from the merged KV-segment [A: B], shape: [batch_size, num_heads].

Example

>>> import torch
>>> import flashinfer
>>> seq_len = 2048
>>> num_heads = 32
>>> head_dim = 128
>>> va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
>>> sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
>>> v_merged, s_merged = flashinfer.merge_state(va, sa, vb, sb)
>>> v_merged.shape
torch.Size([2048, 32, 128])
>>> s_merged.shape
torch.Size([2048, 32])