flashinfer.cascade.merge_states#
- flashinfer.cascade.merge_states(v: torch.Tensor, s: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] #
Merge multiple attention states (v, s).
- Parameters:
v (torch.Tensor) – The attention output from the KV segments, shape:
[seq_len, num_states, num_heads, head_dim]
.s (torch.Tensor) – The logsumexp value from the KV segments, shape:
[seq_len, num_states, num_heads]
, expected to be a float32 tensor.
- Returns:
V (torch.Tensor) – The merged attention output, shape:
[seq_len, num_heads, head_dim]
.S (torch.Tensor) – The logsumexp value from the merged KV-segments, shape:
[seq_len, num_heads]
.
Example
>>> import torch >>> import flashinfer >>> seq_len = 2048 >>> num_heads = 32 >>> head_dim = 128 >>> num_states = 100 >>> v = torch.randn(seq_len, num_states, num_heads, head_dim).half().to("cuda:0") >>> s = torch.randn(seq_len, num_states, num_heads, dtype=torch.float32).to("cuda:0") >>> v_merged, s_merged = flashinfer.merge_states(v, s) >>> v_merged.shape torch.Size([2048, 32, 128]) >>> s_merged.shape torch.Size([2048, 32])