flashinfer.cascade.merge_state¶
- flashinfer.cascade.merge_state(v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] ¶
Merge the attention output
V
and the logsumexp valueS
from the two KV-segments. Check our tutorial on the mathematical details.- Parameters:
v_a (torch.Tensor) – The attention output from the KV segment
A
, shape:[seq_len, num_heads, head_dim]
.s_a (torch.Tensor) – The logsumexp value from the KV segment
A
. expected to be a float32 tensor, shape:[seq_len, num_heads]
.v_b (torch.Tensor) – The attention output from the KV segment
B
, shape:[seq_len, num_heads, head_dim]
.s_b (torch.Tensor) – The logsumexp value from the KV segment
B
, expected to be a float32 tensor, shape:[seq_len, num_heads]
- Returns:
V (torch.Tensor) – The merged attention output (equivalent to attention with merged KV-segment
[A: B]
), shape:[seq_len, num_heads, head_dim]
.S (torch.Tensor) – The logsumexp value from the merged KV-segment
[A: B]
, shape:[seq_len, num_heads]
.
Example
>>> import torch >>> import flashinfer >>> seq_len = 2048 >>> num_heads = 32 >>> head_dim = 128 >>> va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> v_merged, s_merged = flashinfer.merge_state(va, sa, vb, sb) >>> v_merged.shape torch.Size([2048, 32, 128]) >>> s_merged.shape torch.Size([2048, 32])