flashinfer.cascade.merge_state¶
- flashinfer.cascade.merge_state(v_a: Tensor, s_a: Tensor, v_b: Tensor, s_b: Tensor) Tuple[Tensor, Tensor]¶
Merge the attention output
Vand the logsumexp valueSfrom the two KV-segments. Check our tutorial on the mathematical details.- Parameters:
v_a (torch.Tensor) – The attention output from the KV segment
A, shape:[seq_len, num_heads, head_dim].s_a (torch.Tensor) – The logsumexp value from the KV segment
A. expected to be a float32 tensor, shape:[seq_len, num_heads].v_b (torch.Tensor) – The attention output from the KV segment
B, shape:[seq_len, num_heads, head_dim].s_b (torch.Tensor) – The logsumexp value from the KV segment
B, expected to be a float32 tensor, shape:[seq_len, num_heads]
- Returns:
V (torch.Tensor) – The merged attention output (equivalent to attention with merged KV-segment
[A: B]), shape:[seq_len, num_heads, head_dim].S (torch.Tensor) – The logsumexp value from the merged KV-segment
[A: B], shape:[seq_len, num_heads].
Example
>>> import torch >>> import flashinfer >>> seq_len = 2048 >>> num_heads = 32 >>> head_dim = 128 >>> va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") >>> sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") >>> v_merged, s_merged = flashinfer.merge_state(va, sa, vb, sb) >>> v_merged.shape torch.Size([2048, 32, 128]) >>> s_merged.shape torch.Size([2048, 32])