flashinfer.testing.attention_flops_with_actual_seq_lens

flashinfer.testing.attention_flops_with_actual_seq_lens(actual_seq_lens_q, actual_seq_lens_kv, head_dim_qk, head_dim_vo, num_qo_heads, causal)

Calculate FLOPs for a given attention layer with actual sequence lengths where actual sequence lengths are provided as 1D tensors.

Parameters:
  • actual_seq_lens_q (torch.Tensor) – Array of actual sequence lengths of the query.

  • actual_seq_lens_kv (torch.Tensor) – Array of actual sequence lengths of the key and value.

  • head_dim_qk (int) – Head dimension of the query and key.

  • head_dim_vo (int) – Head dimension of the value.

  • num_qo_heads (int) – Number of query heads.

  • causal (bool) – Whether to use causal masking.

  • Note – Causal must be false for decode as this function assumes qo_seqlen == kv_seqlen.

Returns:

Total FLOPs for the layer.

Return type:

total_flops (int)