flashinfer.quantization.shuffle_matrix_a

flashinfer.quantization.shuffle_matrix_a(input_tensor: Tensor, epilogue_tile_m: int) Tensor

PyTorch equivalent of TRT-LLM-gen shuffleMatrixA.

Parameters:
  • input_tensor (torch.Tensor) – Row-major matrix to shuffle.

  • epilogue_tile_m (int) – Epilogue tile size along the M dimension; determines the shuffle permutation.

Returns:

Row-shuffled copy of input_tensor.

Return type:

torch.Tensor