Skip to content

vllm.v1.attention.backends.mamba_attn

M module-attribute

M = TypeVar('M')

BaseMambaAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[M], ABC

Source code in vllm/v1/attention/backends/mamba_attn.py
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
    reorder_batch_threshold: int = 1
    cudagraph_support: ClassVar[AttentionCGSupport] = (
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
    )

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)

        assert isinstance(kv_cache_spec, MambaSpec)
        self.compilation_config = vllm_config.compilation_config
        self.decode_cudagraph_max_bs = min(
            self.vllm_config.scheduler_config.max_num_seqs,
            self.compilation_config.max_cudagraph_capture_size,
        )

        if self.vllm_config.cache_config.enable_prefix_caching:
            self.state_indices_tensor = torch.empty(
                (
                    self.decode_cudagraph_max_bs,
                    cdiv(
                        self.vllm_config.model_config.max_model_len,
                        self.kv_cache_spec.block_size,
                    ),
                ),
                dtype=torch.int32,
                device=device,
            )
            self.block_idx_last_scheduled_token = torch.empty(
                (self.decode_cudagraph_max_bs,),
                dtype=torch.int32,
                device=device,
            )
            self.block_idx_last_computed_token = torch.empty(
                (self.decode_cudagraph_max_bs,),
                dtype=torch.int32,
                device=device,
            )
        else:
            self.state_indices_tensor = torch.empty(
                (self.decode_cudagraph_max_bs,),
                dtype=torch.int32,
                device=device,
            )

    def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> M:
        """
        This method builds the metadata for full cudagraph capture.
        Currently, only decode is supported for full cudagraphs with Mamba.
        """
        m = common_attn_metadata

        assert m.num_reqs == m.num_actual_tokens, (
            "Mamba only supports decode-only full CUDAGraph capture. "
            "Make sure all cudagraph capture sizes <= max_num_seq."
        )

        m.max_query_len = 1  # decode-only

        return self.build(0, m)

    def _compute_prefix_caching_block_indices(
        self,
        common_attn_metadata: CommonAttentionMetadata,
        mamba_block_size: int,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
            self.device
        )
        # Block index of the last computed token
        block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
        # which is <= block index for the first scheduled token
        block_idx_first_scheduled_token = (
            cdiv(num_computed_tokens + 1, mamba_block_size) - 1
        )
        # which is <= block index of the last scheduled token
        block_idx_last_scheduled_token = (
            cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
        )
        # -1 in case it's non-computed and causes later issues with indexing
        block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)

        return (
            block_idx_last_computed_token,
            block_idx_first_scheduled_token,
            block_idx_last_scheduled_token,
        )

block_idx_last_computed_token instance-attribute

block_idx_last_computed_token = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

block_idx_last_scheduled_token instance-attribute

block_idx_last_scheduled_token = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_support class-attribute

decode_cudagraph_max_bs instance-attribute

decode_cudagraph_max_bs = min(
    max_num_seqs, max_cudagraph_capture_size
)

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = 1

state_indices_tensor instance-attribute

state_indices_tensor = empty(
    (
        decode_cudagraph_max_bs,
        cdiv(max_model_len, block_size),
    ),
    dtype=int32,
    device=device,
)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mamba_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)

    assert isinstance(kv_cache_spec, MambaSpec)
    self.compilation_config = vllm_config.compilation_config
    self.decode_cudagraph_max_bs = min(
        self.vllm_config.scheduler_config.max_num_seqs,
        self.compilation_config.max_cudagraph_capture_size,
    )

    if self.vllm_config.cache_config.enable_prefix_caching:
        self.state_indices_tensor = torch.empty(
            (
                self.decode_cudagraph_max_bs,
                cdiv(
                    self.vllm_config.model_config.max_model_len,
                    self.kv_cache_spec.block_size,
                ),
            ),
            dtype=torch.int32,
            device=device,
        )
        self.block_idx_last_scheduled_token = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )
        self.block_idx_last_computed_token = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )
    else:
        self.state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )

_compute_prefix_caching_block_indices

_compute_prefix_caching_block_indices(
    common_attn_metadata: CommonAttentionMetadata,
    mamba_block_size: int,
) -> tuple[Tensor, Tensor, Tensor]
Source code in vllm/v1/attention/backends/mamba_attn.py
def _compute_prefix_caching_block_indices(
    self,
    common_attn_metadata: CommonAttentionMetadata,
    mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
        self.device
    )
    # Block index of the last computed token
    block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
    # which is <= block index for the first scheduled token
    block_idx_first_scheduled_token = (
        cdiv(num_computed_tokens + 1, mamba_block_size) - 1
    )
    # which is <= block index of the last scheduled token
    block_idx_last_scheduled_token = (
        cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
    )
    # -1 in case it's non-computed and causes later issues with indexing
    block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)

    return (
        block_idx_last_computed_token,
        block_idx_first_scheduled_token,
        block_idx_last_scheduled_token,
    )

build_for_cudagraph_capture

build_for_cudagraph_capture(
    common_attn_metadata: CommonAttentionMetadata,
) -> M

This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba.

Source code in vllm/v1/attention/backends/mamba_attn.py
def build_for_cudagraph_capture(
    self, common_attn_metadata: CommonAttentionMetadata
) -> M:
    """
    This method builds the metadata for full cudagraph capture.
    Currently, only decode is supported for full cudagraphs with Mamba.
    """
    m = common_attn_metadata

    assert m.num_reqs == m.num_actual_tokens, (
        "Mamba only supports decode-only full CUDAGraph capture. "
        "Make sure all cudagraph capture sizes <= max_num_seq."
    )

    m.max_query_len = 1  # decode-only

    return self.build(0, m)