Skip to content

vllm.attention.ops.vit_attn_wrappers

This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, to_list in xformers attn, or .item() in flash attention)

Using these ops and wrapping vision blocks with torch.compile can speed up throughput in vision models by ~5% relative on H100, and improve token latencies by ~7% (see qwen2_5_vl for example usage)

To use these ops, you must have a recent version of PyTorch installed (>= 2.4.0)

flash_attn_maxseqlen_wrapper

flash_attn_maxseqlen_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    if is_rocm_aiter:
        from aiter import flash_attn_varlen_func
    else:
        if use_upstream_fa:
            from flash_attn import flash_attn_varlen_func
        else:
            from vllm.attention.utils.fa_utils import flash_attn_varlen_func
    q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
    output = flash_attn_varlen_func(
        q,
        k,
        v,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max_seqlen.item(),
        max_seqlen_k=max_seqlen.item(),
        dropout_p=0.0,
        causal=False,
    )
    context_layer = einops.rearrange(
        output, "(b s) h d -> s b (h d)", b=batch_size
    ).contiguous()
    return context_layer

flash_attn_maxseqlen_wrapper_fake

flash_attn_maxseqlen_wrapper_fake(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def flash_attn_maxseqlen_wrapper_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)

torch_sdpa_wrapper

torch_sdpa_wrapper(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def torch_sdpa_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    outputs = []
    for i in range(1, len(cu_seqlens)):
        start_idx = cu_seqlens[i - 1]
        end_idx = cu_seqlens[i]
        q_i = q[:, start_idx:end_idx]
        k_i = k[:, start_idx:end_idx]
        v_i = v[:, start_idx:end_idx]
        q_i, k_i, v_i = (
            einops.rearrange(x, "b s h d -> b h s d") for x in [q_i, k_i, v_i]
        )
        output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0)
        output_i = einops.rearrange(output_i, "b h s d -> b s h d ")
        outputs.append(output_i)
    context_layer = torch.cat(outputs, dim=1)
    context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
    return context_layer

torch_sdpa_wrapper_fake

torch_sdpa_wrapper_fake(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def torch_sdpa_wrapper_fake(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)

vit_flash_attn_wrapper

vit_flash_attn_wrapper(
    q: Tensor,
    k: Tensor,
    v: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_flash_attn_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor,
    batch_size: int,
    is_rocm_aiter: bool,
    use_upstream_fa: bool,
) -> torch.Tensor:
    return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
        q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa
    )

vit_torch_sdpa_wrapper

vit_torch_sdpa_wrapper(
    q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_torch_sdpa_wrapper(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
) -> torch.Tensor:
    return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)

vit_xformers_attn_wrapper

vit_xformers_attn_wrapper(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def vit_xformers_attn_wrapper(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens)

xformers_attn_seqlens_wrapper

xformers_attn_seqlens_wrapper(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def xformers_attn_seqlens_wrapper(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    from xformers import ops as xops
    from xformers.ops.fmha.attn_bias import BlockDiagonalMask

    attn_bias = BlockDiagonalMask.from_seqlens(
        q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device
    )
    context_layer = xops.memory_efficient_attention_forward(
        q, k, v, attn_bias=attn_bias, p=0, scale=None
    )
    context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
    return context_layer

xformers_attn_seqlens_wrapper_fake

xformers_attn_seqlens_wrapper_fake(
    q: Tensor, k: Tensor, v: Tensor, seqlens: Tensor
) -> Tensor
Source code in vllm/attention/ops/vit_attn_wrappers.py
def xformers_attn_seqlens_wrapper_fake(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor
) -> torch.Tensor:
    b, s, h, d = q.shape
    return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device)