flashinfer.quantization.segment_packbits

flashinfer.quantization.segment_packbits(x: torch.Tensor, indptr: torch.Tensor, bitorder: str = 'big') Tuple[torch.Tensor, torch.Tensor]

Pack a batch of binary-valued segments into bits in a uint8 array.

For each segment, the semantics of this function is the same as numpy.packbits.

Parameters:
  • x (torch.Tensor) – The 1D binary-valued array to pack, shape (indptr[-1],).

  • indptr (torch.Tensor) – The index pointer of each segment in x, shape (batch_size + 1,). The i-th segment in x is x[indptr[i]:indptr[i+1]].

  • bitorder (str) – The bit-order (“bit”/”little”) of the output. Default is “big”.

Returns:

  • y (torch.Tensor) – An uint8 packed array, shape: (new_indptr[-1],). The y[new_indptr[i]:new_indptr[i+1]] contains the packed bits x[indptr[i]:indptr[i+1]].

  • new_indptr (torch.Tensor) – The new index pointer of each packed segment in y, shape (batch_size + 1,). It’s guaranteed that new_indptr[i+1] - new_indptr[i] == (indptr[i+1] - indptr[i] + 7) // 8.

Examples

>>> import torch
>>> from flashinfer import segment_packbits
>>> x = torch.tensor([1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1], dtype=torch.bool, device="cuda")
>>> x_packed, new_indptr = segment_packbits(x, torch.tensor([0, 4, 7, 11], device="cuda"), bitorder="big")
>>> list(map(bin, x_packed.tolist()))
['0b10110000', '0b100000', '0b11010000']
>>> new_indptr
tensor([0, 1, 2, 3], device='cuda:0')

Note

torch.compile is not supported for this function because it’s data dependent.

See also

packbits