flashinfer.testing.attention_flops¶
- flashinfer.testing.attention_flops(batch_size, qo_seqlen, kv_seqlen, head_dim_qk, head_dim_vo, num_qo_heads, causal)¶
Calculate FLOPs for a given attention layer. Assumes all sequence lengths are the same within the batch
- Parameters:
batch_size (int) – Batch size.
qo_seqlen (int) – Sequence length of the query. Assumed same within the batch.
kv_seqlen (int) – Sequence length of the key and value. Assumed same within the batch.
head_dim_qk (int) – Head dimension of the query and key.
head_dim_vo (int) – Head dimension of the value.
num_qo_heads (int) – Number of query heads.
causal (bool) – Whether to use causal masking. FLOPs is halved for causal masking.
- Returns:
Total FLOPs for the layer.
- Return type:
total_flops (int)