Skip to content

vllm.model_executor.layers.kda

logger module-attribute

logger = init_logger(__name__)

KimiDeltaAttention

Bases: Module, MambaBase

Source code in vllm/model_executor/layers/kda.py
class KimiDeltaAttention(nn.Module, MambaBase):
    @property
    def mamba_type(self) -> str:
        return "linear_attention"

    def get_attn_backend(self) -> type["AttentionBackend"]:
        from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend

        return GDNAttentionBackend

    def get_state_dtype(
        self,
    ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
        if self.model_config is None or self.cache_config is None:
            raise ValueError("model_config and cache_config must be set")
        return MambaStateDtypeCalculator.kda_state_dtype(
            self.model_config.dtype, self.cache_config.mamba_cache_dtype
        )

    def get_state_shape(
        self,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.kda_state_shape(
            self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
        )

    def __init__(
        self,
        layer_idx: int,
        hidden_size: int,
        quant_config: QuantizationConfig | None = None,
        cache_config: CacheConfig | None = None,
        model_config: ModelConfig | None = None,
        rms_norm_eps: float = 1e-5,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.hidden_size = hidden_size
        self.model_config = model_config
        self.cache_config = cache_config
        if model_config is None:
            raise ValueError("model_config must be provided")
        kda_config = model_config.linear_attn_config
        self.head_dim = kda_config["head_dim"]
        self.num_heads = kda_config["num_heads"]
        self.layer_idx = layer_idx
        self.prefix = prefix
        assert self.num_heads % self.tp_size == 0
        self.local_num_heads = divide(self.num_heads, self.tp_size)

        projection_size = self.head_dim * self.num_heads
        self.conv_size = kda_config["short_conv_kernel_size"]

        self.q_proj = ColumnParallelLinear(
            self.hidden_size,
            projection_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.q_proj",
        )
        self.k_proj = ColumnParallelLinear(
            self.hidden_size,
            projection_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.k_proj",
        )
        self.v_proj = ColumnParallelLinear(
            self.hidden_size,
            projection_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.v_proj",
        )

        self.f_a_proj = ReplicatedLinear(
            self.hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.f_a_proj",
        )

        self.f_b_proj = ColumnParallelLinear(
            self.head_dim,
            projection_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.f_b_proj",
        )
        self.dt_bias = nn.Parameter(
            torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
        )

        set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})

        self.b_proj = ColumnParallelLinear(
            self.hidden_size,
            self.num_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.b_proj",
        )

        self.q_conv1d = ColumnParallelLinear(
            input_size=self.conv_size,
            output_size=projection_size,
            bias=False,
            params_dtype=torch.float32,
            prefix=f"{prefix}.q_conv1d",
        )
        self.k_conv1d = ColumnParallelLinear(
            input_size=self.conv_size,
            output_size=projection_size,
            bias=False,
            params_dtype=torch.float32,
            prefix=f"{prefix}.k_conv1d",
        )
        self.v_conv1d = ColumnParallelLinear(
            input_size=self.conv_size,
            output_size=projection_size,
            bias=False,
            params_dtype=torch.float32,
            prefix=f"{prefix}.v_conv1d",
        )
        # unsqueeze to fit conv1d weights shape into the linear weights shape.
        # Can't do this in `weight_loader` since it already exists in
        # `ColumnParallelLinear` and `set_weight_attrs`
        # doesn't allow to override it
        self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
        self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
        self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)

        self.A_log = nn.Parameter(
            torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
        )
        set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})

        self.g_a_proj = ReplicatedLinear(
            self.hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.g_a_proj",
        )
        self.g_b_proj = ColumnParallelLinear(
            self.head_dim,
            projection_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.g_b_proj",
        )
        self.o_norm = FusedRMSNormGated(
            self.head_dim, eps=rms_norm_eps, activation="sigmoid"
        )
        self.o_proj = RowParallelLinear(
            projection_size,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        output: torch.Tensor,
    ) -> None:
        num_tokens = hidden_states.size(0)
        q = self.q_proj(hidden_states)[0]
        k = self.k_proj(hidden_states)[0]
        v = self.v_proj(hidden_states)[0]

        beta = self.b_proj(hidden_states)[0].float().sigmoid()
        g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
        g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
        beta = beta.unsqueeze(0)
        g1 = g1.unsqueeze(0)

        g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
        g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)

        core_attn_out = torch.zeros(
            (1, num_tokens, self.local_num_heads, self.head_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device,
        )
        torch.ops.vllm.kda_attention(
            q,
            k,
            v,
            g1,
            g2,
            beta,
            core_attn_out,
            self.prefix,
        )
        core_attn_out = self.o_norm(core_attn_out, g2)
        core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
        output[:] = self.o_proj(core_attn_out)[0]

    def _forward(
        self,
        q_proj_states: torch.Tensor,
        k_proj_states: torch.Tensor,
        v_proj_states: torch.Tensor,
        g1: torch.Tensor,
        g2: torch.Tensor,
        beta: torch.Tensor,
        core_attn_out: torch.Tensor,
    ) -> None:
        forward_context = get_forward_context()
        attn_metadata: AttentionMetadata = forward_context.attn_metadata

        if attn_metadata is None:
            #     # V1 profile run
            return

        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, GDNAttentionMetadata)
        has_initial_state = attn_metadata.has_initial_state
        non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
        non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
        constant_caches = self.kv_cache[forward_context.virtual_engine]

        (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
        # deal with strides
        conv_state_q = conv_state_q.transpose(-1, -2)
        conv_state_k = conv_state_k.transpose(-1, -2)
        conv_state_v = conv_state_v.transpose(-1, -2)

        q_conv_weights = self.q_conv1d.weight.view(
            self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
        )
        k_conv_weights = self.k_conv1d.weight.view(
            self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
        )
        v_conv_weights = self.v_conv1d.weight.view(
            self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
        )
        if attn_metadata.num_prefills > 0:
            q_proj_states = q_proj_states.transpose(0, 1)
            k_proj_states = k_proj_states.transpose(0, 1)
            v_proj_states = v_proj_states.transpose(0, 1)
            q = causal_conv1d_fn(
                q_proj_states,
                q_conv_weights,
                self.q_conv1d.bias,
                activation="silu",
                conv_states=conv_state_q,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
                metadata=attn_metadata,
            ).transpose(0, 1)
            k = causal_conv1d_fn(
                k_proj_states,
                k_conv_weights,
                self.k_conv1d.bias,
                activation="silu",
                conv_states=conv_state_k,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
                metadata=attn_metadata,
            ).transpose(0, 1)
            v = causal_conv1d_fn(
                v_proj_states,
                v_conv_weights,
                self.v_conv1d.bias,
                activation="silu",
                conv_states=conv_state_v,
                has_initial_state=has_initial_state,
                cache_indices=non_spec_state_indices_tensor,
                query_start_loc=non_spec_query_start_loc,
                metadata=attn_metadata,
            ).transpose(0, 1)
        else:
            decode_conv_indices = non_spec_state_indices_tensor[
                : attn_metadata.num_decodes
            ]
            q = causal_conv1d_update(
                q_proj_states,
                conv_state_q,
                q_conv_weights,
                self.q_conv1d.bias,
                activation="silu",
                conv_state_indices=decode_conv_indices,
                validate_data=True,
            )
            k = causal_conv1d_update(
                k_proj_states,
                conv_state_k,
                k_conv_weights,
                self.k_conv1d.bias,
                activation="silu",
                conv_state_indices=decode_conv_indices,
                validate_data=True,
            )
            v = causal_conv1d_update(
                v_proj_states,
                conv_state_v,
                v_conv_weights,
                self.v_conv1d.bias,
                activation="silu",
                conv_state_indices=decode_conv_indices,
                validate_data=True,
            )

        q, k, v = map(
            lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
        )

        if attn_metadata.num_prefills > 0:
            zero_idx = non_spec_state_indices_tensor[~has_initial_state]
            recurrent_state[zero_idx] = 0
            initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
            (
                core_attn_out_non_spec,
                last_recurrent_state,
            ) = chunk_kda(
                q=q,
                k=k,
                v=v,
                g=g1,
                beta=beta,
                initial_state=initial_state,
                output_final_state=True,
                use_qk_l2norm_in_kernel=True,
                cu_seqlens=non_spec_query_start_loc,
            )
            # Init cache
            recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
        else:
            (
                core_attn_out_non_spec,
                last_recurrent_state,
            ) = fused_recurrent_kda(
                q=q,
                k=k,
                v=v,
                g=g1,
                beta=beta,
                initial_state=recurrent_state,
                use_qk_l2norm_in_kernel=True,
                cu_seqlens=non_spec_query_start_loc,
                ssm_state_indices=non_spec_state_indices_tensor,
            )
        assert core_attn_out_non_spec.shape == core_attn_out.shape
        core_attn_out[:] = core_attn_out_non_spec

A_log instance-attribute

A_log = Parameter(
    empty(1, 1, local_num_heads, 1, dtype=float32)
)

b_proj instance-attribute

b_proj = ColumnParallelLinear(
    hidden_size,
    num_heads,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.b_proj",
)

cache_config instance-attribute

cache_config = cache_config

conv_size instance-attribute

conv_size = kda_config['short_conv_kernel_size']

dt_bias instance-attribute

dt_bias = Parameter(
    empty(divide(projection_size, tp_size), dtype=float32)
)

f_a_proj instance-attribute

f_a_proj = ReplicatedLinear(
    hidden_size,
    head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.f_a_proj",
)

f_b_proj instance-attribute

f_b_proj = ColumnParallelLinear(
    head_dim,
    projection_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.f_b_proj",
)

g_a_proj instance-attribute

g_a_proj = ReplicatedLinear(
    hidden_size,
    head_dim,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.g_a_proj",
)

g_b_proj instance-attribute

g_b_proj = ColumnParallelLinear(
    head_dim,
    projection_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.g_b_proj",
)

head_dim instance-attribute

head_dim = kda_config['head_dim']

hidden_size instance-attribute

hidden_size = hidden_size

k_conv1d instance-attribute

k_conv1d = ColumnParallelLinear(
    input_size=conv_size,
    output_size=projection_size,
    bias=False,
    params_dtype=float32,
    prefix=f"{prefix}.k_conv1d",
)

k_proj instance-attribute

k_proj = ColumnParallelLinear(
    hidden_size,
    projection_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.k_proj",
)

layer_idx instance-attribute

layer_idx = layer_idx

local_num_heads instance-attribute

local_num_heads = divide(num_heads, tp_size)

mamba_type property

mamba_type: str

model_config instance-attribute

model_config = model_config

num_heads instance-attribute

num_heads = kda_config['num_heads']

o_norm instance-attribute

o_norm = FusedRMSNormGated(
    head_dim, eps=rms_norm_eps, activation="sigmoid"
)

o_proj instance-attribute

o_proj = RowParallelLinear(
    projection_size,
    hidden_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

prefix instance-attribute

prefix = prefix

q_conv1d instance-attribute

q_conv1d = ColumnParallelLinear(
    input_size=conv_size,
    output_size=projection_size,
    bias=False,
    params_dtype=float32,
    prefix=f"{prefix}.q_conv1d",
)

q_proj instance-attribute

q_proj = ColumnParallelLinear(
    hidden_size,
    projection_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.q_proj",
)

tp_rank instance-attribute

tp_size instance-attribute

v_conv1d instance-attribute

v_conv1d = ColumnParallelLinear(
    input_size=conv_size,
    output_size=projection_size,
    bias=False,
    params_dtype=float32,
    prefix=f"{prefix}.v_conv1d",
)

v_proj instance-attribute

v_proj = ColumnParallelLinear(
    hidden_size,
    projection_size,
    bias=False,
    quant_config=quant_config,
    prefix=f"{prefix}.v_proj",
)

__init__

__init__(
    layer_idx: int,
    hidden_size: int,
    quant_config: QuantizationConfig | None = None,
    cache_config: CacheConfig | None = None,
    model_config: ModelConfig | None = None,
    rms_norm_eps: float = 1e-05,
    prefix: str = "",
    **kwargs,
) -> None
Source code in vllm/model_executor/layers/kda.py
def __init__(
    self,
    layer_idx: int,
    hidden_size: int,
    quant_config: QuantizationConfig | None = None,
    cache_config: CacheConfig | None = None,
    model_config: ModelConfig | None = None,
    rms_norm_eps: float = 1e-5,
    prefix: str = "",
    **kwargs,
) -> None:
    super().__init__()
    self.tp_size = get_tensor_model_parallel_world_size()
    self.tp_rank = get_tensor_model_parallel_rank()
    self.hidden_size = hidden_size
    self.model_config = model_config
    self.cache_config = cache_config
    if model_config is None:
        raise ValueError("model_config must be provided")
    kda_config = model_config.linear_attn_config
    self.head_dim = kda_config["head_dim"]
    self.num_heads = kda_config["num_heads"]
    self.layer_idx = layer_idx
    self.prefix = prefix
    assert self.num_heads % self.tp_size == 0
    self.local_num_heads = divide(self.num_heads, self.tp_size)

    projection_size = self.head_dim * self.num_heads
    self.conv_size = kda_config["short_conv_kernel_size"]

    self.q_proj = ColumnParallelLinear(
        self.hidden_size,
        projection_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.q_proj",
    )
    self.k_proj = ColumnParallelLinear(
        self.hidden_size,
        projection_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.k_proj",
    )
    self.v_proj = ColumnParallelLinear(
        self.hidden_size,
        projection_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.v_proj",
    )

    self.f_a_proj = ReplicatedLinear(
        self.hidden_size,
        self.head_dim,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.f_a_proj",
    )

    self.f_b_proj = ColumnParallelLinear(
        self.head_dim,
        projection_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.f_b_proj",
    )
    self.dt_bias = nn.Parameter(
        torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
    )

    set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})

    self.b_proj = ColumnParallelLinear(
        self.hidden_size,
        self.num_heads,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.b_proj",
    )

    self.q_conv1d = ColumnParallelLinear(
        input_size=self.conv_size,
        output_size=projection_size,
        bias=False,
        params_dtype=torch.float32,
        prefix=f"{prefix}.q_conv1d",
    )
    self.k_conv1d = ColumnParallelLinear(
        input_size=self.conv_size,
        output_size=projection_size,
        bias=False,
        params_dtype=torch.float32,
        prefix=f"{prefix}.k_conv1d",
    )
    self.v_conv1d = ColumnParallelLinear(
        input_size=self.conv_size,
        output_size=projection_size,
        bias=False,
        params_dtype=torch.float32,
        prefix=f"{prefix}.v_conv1d",
    )
    # unsqueeze to fit conv1d weights shape into the linear weights shape.
    # Can't do this in `weight_loader` since it already exists in
    # `ColumnParallelLinear` and `set_weight_attrs`
    # doesn't allow to override it
    self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
    self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
    self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)

    self.A_log = nn.Parameter(
        torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
    )
    set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})

    self.g_a_proj = ReplicatedLinear(
        self.hidden_size,
        self.head_dim,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.g_a_proj",
    )
    self.g_b_proj = ColumnParallelLinear(
        self.head_dim,
        projection_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.g_b_proj",
    )
    self.o_norm = FusedRMSNormGated(
        self.head_dim, eps=rms_norm_eps, activation="sigmoid"
    )
    self.o_proj = RowParallelLinear(
        projection_size,
        self.hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

    compilation_config = get_current_vllm_config().compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self

_forward

_forward(
    q_proj_states: Tensor,
    k_proj_states: Tensor,
    v_proj_states: Tensor,
    g1: Tensor,
    g2: Tensor,
    beta: Tensor,
    core_attn_out: Tensor,
) -> None
Source code in vllm/model_executor/layers/kda.py
def _forward(
    self,
    q_proj_states: torch.Tensor,
    k_proj_states: torch.Tensor,
    v_proj_states: torch.Tensor,
    g1: torch.Tensor,
    g2: torch.Tensor,
    beta: torch.Tensor,
    core_attn_out: torch.Tensor,
) -> None:
    forward_context = get_forward_context()
    attn_metadata: AttentionMetadata = forward_context.attn_metadata

    if attn_metadata is None:
        #     # V1 profile run
        return

    assert isinstance(attn_metadata, dict)
    attn_metadata = attn_metadata[self.prefix]
    assert isinstance(attn_metadata, GDNAttentionMetadata)
    has_initial_state = attn_metadata.has_initial_state
    non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
    non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor  # noqa: E501
    constant_caches = self.kv_cache[forward_context.virtual_engine]

    (conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
    # deal with strides
    conv_state_q = conv_state_q.transpose(-1, -2)
    conv_state_k = conv_state_k.transpose(-1, -2)
    conv_state_v = conv_state_v.transpose(-1, -2)

    q_conv_weights = self.q_conv1d.weight.view(
        self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
    )
    k_conv_weights = self.k_conv1d.weight.view(
        self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
    )
    v_conv_weights = self.v_conv1d.weight.view(
        self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
    )
    if attn_metadata.num_prefills > 0:
        q_proj_states = q_proj_states.transpose(0, 1)
        k_proj_states = k_proj_states.transpose(0, 1)
        v_proj_states = v_proj_states.transpose(0, 1)
        q = causal_conv1d_fn(
            q_proj_states,
            q_conv_weights,
            self.q_conv1d.bias,
            activation="silu",
            conv_states=conv_state_q,
            has_initial_state=has_initial_state,
            cache_indices=non_spec_state_indices_tensor,
            query_start_loc=non_spec_query_start_loc,
            metadata=attn_metadata,
        ).transpose(0, 1)
        k = causal_conv1d_fn(
            k_proj_states,
            k_conv_weights,
            self.k_conv1d.bias,
            activation="silu",
            conv_states=conv_state_k,
            has_initial_state=has_initial_state,
            cache_indices=non_spec_state_indices_tensor,
            query_start_loc=non_spec_query_start_loc,
            metadata=attn_metadata,
        ).transpose(0, 1)
        v = causal_conv1d_fn(
            v_proj_states,
            v_conv_weights,
            self.v_conv1d.bias,
            activation="silu",
            conv_states=conv_state_v,
            has_initial_state=has_initial_state,
            cache_indices=non_spec_state_indices_tensor,
            query_start_loc=non_spec_query_start_loc,
            metadata=attn_metadata,
        ).transpose(0, 1)
    else:
        decode_conv_indices = non_spec_state_indices_tensor[
            : attn_metadata.num_decodes
        ]
        q = causal_conv1d_update(
            q_proj_states,
            conv_state_q,
            q_conv_weights,
            self.q_conv1d.bias,
            activation="silu",
            conv_state_indices=decode_conv_indices,
            validate_data=True,
        )
        k = causal_conv1d_update(
            k_proj_states,
            conv_state_k,
            k_conv_weights,
            self.k_conv1d.bias,
            activation="silu",
            conv_state_indices=decode_conv_indices,
            validate_data=True,
        )
        v = causal_conv1d_update(
            v_proj_states,
            conv_state_v,
            v_conv_weights,
            self.v_conv1d.bias,
            activation="silu",
            conv_state_indices=decode_conv_indices,
            validate_data=True,
        )

    q, k, v = map(
        lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
    )

    if attn_metadata.num_prefills > 0:
        zero_idx = non_spec_state_indices_tensor[~has_initial_state]
        recurrent_state[zero_idx] = 0
        initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
        (
            core_attn_out_non_spec,
            last_recurrent_state,
        ) = chunk_kda(
            q=q,
            k=k,
            v=v,
            g=g1,
            beta=beta,
            initial_state=initial_state,
            output_final_state=True,
            use_qk_l2norm_in_kernel=True,
            cu_seqlens=non_spec_query_start_loc,
        )
        # Init cache
        recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
    else:
        (
            core_attn_out_non_spec,
            last_recurrent_state,
        ) = fused_recurrent_kda(
            q=q,
            k=k,
            v=v,
            g=g1,
            beta=beta,
            initial_state=recurrent_state,
            use_qk_l2norm_in_kernel=True,
            cu_seqlens=non_spec_query_start_loc,
            ssm_state_indices=non_spec_state_indices_tensor,
        )
    assert core_attn_out_non_spec.shape == core_attn_out.shape
    core_attn_out[:] = core_attn_out_non_spec

forward

forward(
    hidden_states: Tensor, positions: Tensor, output: Tensor
) -> None
Source code in vllm/model_executor/layers/kda.py
def forward(
    self,
    hidden_states: torch.Tensor,
    positions: torch.Tensor,
    output: torch.Tensor,
) -> None:
    num_tokens = hidden_states.size(0)
    q = self.q_proj(hidden_states)[0]
    k = self.k_proj(hidden_states)[0]
    v = self.v_proj(hidden_states)[0]

    beta = self.b_proj(hidden_states)[0].float().sigmoid()
    g1 = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
    g1 = fused_kda_gate(g1, self.A_log, self.head_dim, g_bias=self.dt_bias)
    beta = beta.unsqueeze(0)
    g1 = g1.unsqueeze(0)

    g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
    g2 = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)

    core_attn_out = torch.zeros(
        (1, num_tokens, self.local_num_heads, self.head_dim),
        dtype=hidden_states.dtype,
        device=hidden_states.device,
    )
    torch.ops.vllm.kda_attention(
        q,
        k,
        v,
        g1,
        g2,
        beta,
        core_attn_out,
        self.prefix,
    )
    core_attn_out = self.o_norm(core_attn_out, g2)
    core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
    output[:] = self.o_proj(core_attn_out)[0]

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/kda.py
def get_attn_backend(self) -> type["AttentionBackend"]:
    from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend

    return GDNAttentionBackend

get_state_dtype

get_state_dtype() -> tuple[dtype, dtype, dtype, dtype]
Source code in vllm/model_executor/layers/kda.py
def get_state_dtype(
    self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
    if self.model_config is None or self.cache_config is None:
        raise ValueError("model_config and cache_config must be set")
    return MambaStateDtypeCalculator.kda_state_dtype(
        self.model_config.dtype, self.cache_config.mamba_cache_dtype
    )

get_state_shape

get_state_shape() -> tuple[
    tuple[int, ...],
    tuple[int, ...],
    tuple[int, ...],
    tuple[int, ...],
]
Source code in vllm/model_executor/layers/kda.py
def get_state_shape(
    self,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
    return MambaStateShapeCalculator.kda_state_shape(
        self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
    )

kda_attention

kda_attention(
    q_proj_states: Tensor,
    k_proj_states: Tensor,
    v_proj_states: Tensor,
    g1: Tensor,
    g2: Tensor,
    beta: Tensor,
    core_attn_out: Tensor,
    layer_name: str,
) -> None
Source code in vllm/model_executor/layers/kda.py
def kda_attention(
    q_proj_states: torch.Tensor,
    k_proj_states: torch.Tensor,
    v_proj_states: torch.Tensor,
    g1: torch.Tensor,
    g2: torch.Tensor,
    beta: torch.Tensor,
    core_attn_out: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    self._forward(
        q_proj_states=q_proj_states,
        k_proj_states=k_proj_states,
        v_proj_states=v_proj_states,
        g1=g1,
        g2=g2,
        beta=beta,
        core_attn_out=core_attn_out,
    )

kda_attention_fake

kda_attention_fake(
    q_proj_states: Tensor,
    k_proj_states: Tensor,
    v_proj_states: Tensor,
    g1: Tensor,
    g2: Tensor,
    beta: Tensor,
    core_attn_out: Tensor,
    layer_name: str,
) -> None
Source code in vllm/model_executor/layers/kda.py
def kda_attention_fake(
    q_proj_states: torch.Tensor,
    k_proj_states: torch.Tensor,
    v_proj_states: torch.Tensor,
    g1: torch.Tensor,
    g2: torch.Tensor,
    beta: torch.Tensor,
    core_attn_out: torch.Tensor,
    layer_name: str,
) -> None:
    return