Skip to content

vllm.model_executor.layers.mamba.mamba_mixer

MambaMixer

Bases: MambaBase, CustomOp

Compute ∆, A, B, C, and D the state space parameters and compute the contextualized_states. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, and is why Mamba is called selective state spaces)

Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
@CustomOp.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp):
    """
    Compute ∆, A, B, C, and D the state space parameters and compute
    the `contextualized_states`. A, D are input independent
    (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
    for why A isn't selective) ∆, B, C are input-dependent
    (this is a key difference between Mamba and the linear time
    invariant S4, and is why Mamba is called
    **selective** state spaces)
    """

    def __init__(
        self,
        hidden_size: int,
        ssm_state_size: int,
        conv_kernel_size: int,
        intermediate_size: int,
        time_step_rank: int,
        use_conv_bias: bool,
        use_bias: bool,
        use_rms_norm: bool,
        rms_norm_has_weight: bool = True,
        rms_norm_eps: float = 1e-5,
        activation="silu",
        is_lora_enabled: bool = False,
        model_config: ModelConfig | None = None,
        cache_config: CacheConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.time_step_rank = time_step_rank
        self.ssm_state_size = ssm_state_size
        self.use_rms_norm = use_rms_norm
        self.activation = activation
        self.is_lora_enabled = is_lora_enabled
        self.conv_kernel_size = conv_kernel_size
        self.intermediate_size = intermediate_size

        self.conv1d = ColumnParallelLinear(
            input_size=conv_kernel_size,
            output_size=intermediate_size,
            bias=use_conv_bias,
        )
        # unsqueeze to fit conv1d weights shape into the linear weights shape.
        # Can't do this in `weight_loader` since it already exists in
        # `ColumnParallelLinear` and `set_weight_attrs`
        # doesn't allow to override it
        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

        self.in_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2, bias=use_bias
        )

        # selective projection used to make dt, B and C input dependent
        self.x_proj = RowParallelLinear(
            intermediate_size,
            time_step_rank + ssm_state_size * 2,
            bias=False,
        )
        # time step projection (discretization) -
        # In the forward we need to apply dt_proj without the bias,
        # as the bias is added in the selective scan kernel.
        self.dt_proj = ColumnParallelLinear(
            time_step_rank, intermediate_size, bias=True, skip_bias_add=True
        )

        def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
            tp_rank = get_tensor_model_parallel_rank()
            tp_size = get_tensor_model_parallel_world_size()
            param.data.copy_(
                loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
                    tp_rank
                ]
            )

        def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
            weight_loader(param, -torch.exp(loaded_weight.float()))

        tp_size = get_tensor_model_parallel_world_size()
        self.A = nn.Parameter(
            torch.empty(
                intermediate_size // tp_size,
                ssm_state_size,
                dtype=torch.float32,
            )
        )
        self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))

        set_weight_attrs(self.D, {"weight_loader": weight_loader})
        set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

        self.out_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=use_bias,
            input_is_parallel=True,
        )

        self.dt_layernorm = (
            RMSNorm(
                time_step_rank,
                eps=rms_norm_eps,
                has_weight=rms_norm_has_weight,
            )
            if use_rms_norm
            else None
        )

        self.b_layernorm = (
            RMSNorm(
                ssm_state_size,
                eps=rms_norm_eps,
                has_weight=rms_norm_has_weight,
            )
            if use_rms_norm
            else None
        )

        self.c_layernorm = (
            RMSNorm(
                ssm_state_size,
                eps=rms_norm_eps,
                has_weight=rms_norm_has_weight,
            )
            if use_rms_norm
            else None
        )

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        # The inner tuple is (conv_state, ssm_state)
        self.kv_cache = (torch.tensor([]), torch.tensor([]))

        self.model_config = model_config
        self.cache_config = cache_config
        self.prefix = prefix

    def _ssm_transform(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        if self.is_lora_enabled:
            #  Lora kernel requires contiguous tensor.
            ssm_params = self.x_proj(x.contiguous())[0]
        else:
            ssm_params = self.x_proj(x)[0]
        time_step, B, C = torch.split(
            ssm_params,
            [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
            dim=-1,
        )
        if self.use_rms_norm:
            assert self.dt_layernorm is not None
            assert self.b_layernorm is not None
            assert self.c_layernorm is not None
            time_step = self.dt_layernorm(time_step.contiguous())
            B = self.b_layernorm(B.contiguous())
            C = self.c_layernorm(C.contiguous())
        discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
        return discrete_time_step, B, C

    def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
        torch.ops.vllm.mamba_mixer(
            hidden_states,
            output,
            self.prefix,
        )

    def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
        pass

    def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
        """
        Run the Mamba-1 SSM pipeline.

        Steps
        -----
        1. Apply the gated-MLP linear projection to the raw input.
        2. Pass the projected sequence through the convolutional mixing layer.
        3. Feed the result into the State-Space Model (SSM) blocks.
        4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
           to produce contextual representations.
        5. Project the contextualised sequence back
           to the output embedding dimension.

        Batch handling
        --------------
        Prefill and decode tokens are processed by dedicated CUDA
        kernels for both the convolutional (conv1d) and SSM stages.
        In the case of a mixed batch (containing both prefill and
        decode tokens), both sets of kernels are executed independently
        and their outputs are concatenated before the final output projection.
        """

        forward_context: ForwardContext = get_forward_context()
        attn_metadata = forward_context.attn_metadata

        assert self.cache_config is not None
        mamba_block_size = self.cache_config.mamba_block_size
        prefix_caching_enabled = self.cache_config.enable_prefix_caching

        if attn_metadata is not None:
            assert isinstance(attn_metadata, dict)
            attn_metadata = attn_metadata[self.prefix]
            assert isinstance(attn_metadata, Mamba1AttentionMetadata)
            query_start_loc_p = attn_metadata.query_start_loc_p
            state_indices_tensor = attn_metadata.state_indices_tensor
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]
            conv_state = self_kv_cache[0].transpose(-1, -2)
            ssm_state = self_kv_cache[1]
            has_initial_states_p = attn_metadata.has_initial_states_p
            num_padded_decodes = attn_metadata.num_padded_decodes

        # 1. Gated MLP's linear projection
        projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
        hidden_states_BC, gate = projected_states.chunk(2, dim=-2)

        conv_weights = self.conv1d.weight.view(
            self.conv1d.weight.size(0), self.conv1d.weight.size(2)
        )

        if attn_metadata is None:
            # V1 profile run
            hidden_states_BC = hidden_states_BC.contiguous()
            return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]

        num_prefill_tokens = attn_metadata.num_prefill_tokens  # token count
        num_decode_tokens = attn_metadata.num_decode_tokens
        num_prefills = attn_metadata.num_prefills  # request count
        num_decodes = attn_metadata.num_decode_tokens  # token count (=request)
        has_prefill = num_prefill_tokens > 0
        has_decode = num_decode_tokens > 0
        num_actual_tokens = num_prefill_tokens + num_decode_tokens

        prefill_decode_split = split_batch_to_prefill_and_decode(
            hidden_states_BC,
            gate,
            state_indices_tensor,
            num_prefill_tokens,
            num_prefills,
            num_padded_decodes,
        )
        hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
        hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
        gate_p = prefill_decode_split.gate_p
        gate_d = prefill_decode_split.gate_d
        state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
        state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d

        if prefix_caching_enabled:
            block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
                torch.split(
                    attn_metadata.block_idx_last_computed_token,
                    [num_decodes, num_prefills],
                    dim=0,
                )
            )
            block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
                torch.split(
                    attn_metadata.block_idx_last_scheduled_token,
                    [num_decodes, num_prefills],
                    dim=0,
                )
            )

            block_idx_first_scheduled_token_p = (
                attn_metadata.block_idx_first_scheduled_token_p
            )
            num_computed_tokens_p = attn_metadata.num_computed_tokens_p
        else:
            block_idx_last_computed_token_d = None
            block_idx_last_computed_token_p = None
            block_idx_last_scheduled_token_d = None
            block_idx_last_scheduled_token_p = None
            block_idx_first_scheduled_token_p = None
            num_computed_tokens_p = None

        ssm_outputs = []

        if has_prefill:
            # 2. Convolution sequence transformation
            conv_out_p = causal_conv1d_fn(
                hidden_states_BC_p,
                conv_weights,
                self.conv1d.bias,
                activation=self.activation,
                conv_states=conv_state,
                has_initial_state=has_initial_states_p,
                cache_indices=state_indices_tensor_p,
                query_start_loc=query_start_loc_p,
                block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
                block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
                initial_state_idx=block_idx_last_computed_token_p,
                num_computed_tokens=num_computed_tokens_p,
                block_size_to_align=mamba_block_size,
            )
            # 3. State Space Model sequence transformations.
            discrete_time_step_p, B_p, C_p = self._ssm_transform(
                conv_out_p.transpose(-2, -1)
            )
            time_proj_bias = self._time_proj_bias()

            # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
            scan_out_p = selective_scan_fn(
                conv_out_p,
                ssm_state,
                discrete_time_step_p,
                self.A,
                B_p.transpose(-2, -1),
                C_p.transpose(-2, -1),
                self.D.float(),
                gate_p,
                time_proj_bias,
                delta_softplus=True,
                cache_indices=state_indices_tensor_p,
                has_initial_state=has_initial_states_p,
                query_start_loc=query_start_loc_p,
                block_size=mamba_block_size,
                block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
                block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
                initial_state_idx=block_idx_last_computed_token_p,
            )
            ssm_outputs.append(scan_out_p)

        if has_decode:
            if prefix_caching_enabled:
                state_indices_tensor_d_input = state_indices_tensor_d.gather(
                    1, block_idx_last_computed_token_d.unsqueeze(1)
                ).squeeze(1)
                state_indices_tensor_d_output = state_indices_tensor_d.gather(
                    1, block_idx_last_scheduled_token_d.unsqueeze(1)
                ).squeeze(1)
            else:
                state_indices_tensor_d_input = state_indices_tensor_d
                state_indices_tensor_d_output = state_indices_tensor_d
            # 2. Convolution sequence transformation
            conv_out_d = causal_conv1d_update(
                hidden_states_BC_d.transpose(0, 1),
                conv_state,
                conv_weights,
                self.conv1d.bias,
                self.activation,
                conv_state_indices=state_indices_tensor_d,
                block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
                initial_state_idx=block_idx_last_computed_token_d,
            ).transpose(0, 1)

            # 3. State Space Model sequence transformation.
            discrete_time_step_d, B_d, C_d = self._ssm_transform(
                conv_out_d.transpose(-2, -1)
            )
            time_proj_bias = self._time_proj_bias()

            # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
            scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
            selective_state_update(
                ssm_state,
                conv_out_d.transpose(0, 1),
                discrete_time_step_d.transpose(0, 1),
                self.A,
                B_d,
                C_d,
                self.D,
                gate_d.transpose(0, 1),
                time_proj_bias,
                dt_softplus=True,
                state_batch_indices=state_indices_tensor_d_input,
                dst_state_batch_indices=state_indices_tensor_d_output,
                out=scan_outputs_d,
            )
            scan_outputs_d = scan_outputs_d.transpose(0, 1)

            ssm_outputs.insert(0, scan_outputs_d)

        scan_outputs_combined = (
            ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
        )

        # 5. Final output projection
        if self.is_lora_enabled:  # Lora kernel requires contiguous tensor.
            scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
            out = self.out_proj(scan_outputs_combined)[0]
        else:
            out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]

        output[:num_actual_tokens] = out

    def get_state_dtype(self) -> tuple[torch.dtype]:
        assert self.model_config is not None
        assert self.cache_config is not None
        return MambaStateDtypeCalculator.mamba1_state_dtype(
            self.model_config.dtype,
            self.cache_config.mamba_cache_dtype,
            self.cache_config.mamba_ssm_cache_dtype,
        )

    def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
        return MambaStateShapeCalculator.mamba1_state_shape(
            tp_world_size=get_tensor_model_parallel_world_size(),
            intermediate_size=self.intermediate_size,
            state_size=self.ssm_state_size,
            conv_kernel=self.conv_kernel_size,
        )

    @property
    def mamba_type(self) -> str:
        return "mamba1"

    def get_attn_backend(self) -> type["AttentionBackend"]:
        from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend

        return Mamba1AttentionBackend

    def _time_proj_bias(self) -> torch.Tensor | None:
        if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
            return self.dt_proj.bias.float()
        return None

A instance-attribute

A = Parameter(
    empty(
        intermediate_size // tp_size,
        ssm_state_size,
        dtype=float32,
    )
)

D instance-attribute

D = Parameter(ones(intermediate_size // tp_size))

activation instance-attribute

activation = activation

b_layernorm instance-attribute

b_layernorm = (
    RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

c_layernorm instance-attribute

c_layernorm = (
    RMSNorm(
        ssm_state_size,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

cache_config instance-attribute

cache_config = cache_config

conv1d instance-attribute

conv1d = ColumnParallelLinear(
    input_size=conv_kernel_size,
    output_size=intermediate_size,
    bias=use_conv_bias,
)

conv_kernel_size instance-attribute

conv_kernel_size = conv_kernel_size

dt_layernorm instance-attribute

dt_layernorm = (
    RMSNorm(
        time_step_rank,
        eps=rms_norm_eps,
        has_weight=rms_norm_has_weight,
    )
    if use_rms_norm
    else None
)

dt_proj instance-attribute

dt_proj = ColumnParallelLinear(
    time_step_rank,
    intermediate_size,
    bias=True,
    skip_bias_add=True,
)

in_proj instance-attribute

in_proj = MergedColumnParallelLinear(
    hidden_size, [intermediate_size] * 2, bias=use_bias
)

intermediate_size instance-attribute

intermediate_size = intermediate_size

is_lora_enabled instance-attribute

is_lora_enabled = is_lora_enabled

kv_cache instance-attribute

kv_cache = (tensor([]), tensor([]))

mamba_type property

mamba_type: str

model_config instance-attribute

model_config = model_config

out_proj instance-attribute

out_proj = RowParallelLinear(
    intermediate_size,
    hidden_size,
    bias=use_bias,
    input_is_parallel=True,
)

prefix instance-attribute

prefix = prefix

ssm_state_size instance-attribute

ssm_state_size = ssm_state_size

time_step_rank instance-attribute

time_step_rank = time_step_rank

use_rms_norm instance-attribute

use_rms_norm = use_rms_norm

x_proj instance-attribute

x_proj = RowParallelLinear(
    intermediate_size,
    time_step_rank + ssm_state_size * 2,
    bias=False,
)

__init__

__init__(
    hidden_size: int,
    ssm_state_size: int,
    conv_kernel_size: int,
    intermediate_size: int,
    time_step_rank: int,
    use_conv_bias: bool,
    use_bias: bool,
    use_rms_norm: bool,
    rms_norm_has_weight: bool = True,
    rms_norm_eps: float = 1e-05,
    activation="silu",
    is_lora_enabled: bool = False,
    model_config: ModelConfig | None = None,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def __init__(
    self,
    hidden_size: int,
    ssm_state_size: int,
    conv_kernel_size: int,
    intermediate_size: int,
    time_step_rank: int,
    use_conv_bias: bool,
    use_bias: bool,
    use_rms_norm: bool,
    rms_norm_has_weight: bool = True,
    rms_norm_eps: float = 1e-5,
    activation="silu",
    is_lora_enabled: bool = False,
    model_config: ModelConfig | None = None,
    cache_config: CacheConfig | None = None,
    prefix: str = "",
):
    super().__init__()
    self.time_step_rank = time_step_rank
    self.ssm_state_size = ssm_state_size
    self.use_rms_norm = use_rms_norm
    self.activation = activation
    self.is_lora_enabled = is_lora_enabled
    self.conv_kernel_size = conv_kernel_size
    self.intermediate_size = intermediate_size

    self.conv1d = ColumnParallelLinear(
        input_size=conv_kernel_size,
        output_size=intermediate_size,
        bias=use_conv_bias,
    )
    # unsqueeze to fit conv1d weights shape into the linear weights shape.
    # Can't do this in `weight_loader` since it already exists in
    # `ColumnParallelLinear` and `set_weight_attrs`
    # doesn't allow to override it
    self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

    self.in_proj = MergedColumnParallelLinear(
        hidden_size, [intermediate_size] * 2, bias=use_bias
    )

    # selective projection used to make dt, B and C input dependent
    self.x_proj = RowParallelLinear(
        intermediate_size,
        time_step_rank + ssm_state_size * 2,
        bias=False,
    )
    # time step projection (discretization) -
    # In the forward we need to apply dt_proj without the bias,
    # as the bias is added in the selective scan kernel.
    self.dt_proj = ColumnParallelLinear(
        time_step_rank, intermediate_size, bias=True, skip_bias_add=True
    )

    def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()
        param.data.copy_(
            loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
                tp_rank
            ]
        )

    def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
        weight_loader(param, -torch.exp(loaded_weight.float()))

    tp_size = get_tensor_model_parallel_world_size()
    self.A = nn.Parameter(
        torch.empty(
            intermediate_size // tp_size,
            ssm_state_size,
            dtype=torch.float32,
        )
    )
    self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))

    set_weight_attrs(self.D, {"weight_loader": weight_loader})
    set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

    self.out_proj = RowParallelLinear(
        intermediate_size,
        hidden_size,
        bias=use_bias,
        input_is_parallel=True,
    )

    self.dt_layernorm = (
        RMSNorm(
            time_step_rank,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        )
        if use_rms_norm
        else None
    )

    self.b_layernorm = (
        RMSNorm(
            ssm_state_size,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        )
        if use_rms_norm
        else None
    )

    self.c_layernorm = (
        RMSNorm(
            ssm_state_size,
            eps=rms_norm_eps,
            has_weight=rms_norm_has_weight,
        )
        if use_rms_norm
        else None
    )

    compilation_config = get_current_vllm_config().compilation_config
    if prefix in compilation_config.static_forward_context:
        raise ValueError(f"Duplicate layer name: {prefix}")
    compilation_config.static_forward_context[prefix] = self
    # The inner tuple is (conv_state, ssm_state)
    self.kv_cache = (torch.tensor([]), torch.tensor([]))

    self.model_config = model_config
    self.cache_config = cache_config
    self.prefix = prefix

_ssm_transform

_ssm_transform(x: Tensor) -> tuple[Tensor, Tensor, Tensor]
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def _ssm_transform(
    self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    if self.is_lora_enabled:
        #  Lora kernel requires contiguous tensor.
        ssm_params = self.x_proj(x.contiguous())[0]
    else:
        ssm_params = self.x_proj(x)[0]
    time_step, B, C = torch.split(
        ssm_params,
        [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
        dim=-1,
    )
    if self.use_rms_norm:
        assert self.dt_layernorm is not None
        assert self.b_layernorm is not None
        assert self.c_layernorm is not None
        time_step = self.dt_layernorm(time_step.contiguous())
        B = self.b_layernorm(B.contiguous())
        C = self.c_layernorm(C.contiguous())
    discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
    return discrete_time_step, B, C

_time_proj_bias

_time_proj_bias() -> Tensor | None
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def _time_proj_bias(self) -> torch.Tensor | None:
    if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
        return self.dt_proj.bias.float()
    return None

forward

forward(hidden_states: Tensor, output: Tensor)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
    torch.ops.vllm.mamba_mixer(
        hidden_states,
        output,
        self.prefix,
    )

forward_cuda

forward_cuda(hidden_states: Tensor, output: Tensor)

Run the Mamba-1 SSM pipeline.

Steps
  1. Apply the gated-MLP linear projection to the raw input.
  2. Pass the projected sequence through the convolutional mixing layer.
  3. Feed the result into the State-Space Model (SSM) blocks.
  4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) to produce contextual representations.
  5. Project the contextualised sequence back to the output embedding dimension.
Batch handling

Prefill and decode tokens are processed by dedicated CUDA kernels for both the convolutional (conv1d) and SSM stages. In the case of a mixed batch (containing both prefill and decode tokens), both sets of kernels are executed independently and their outputs are concatenated before the final output projection.

Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
    """
    Run the Mamba-1 SSM pipeline.

    Steps
    -----
    1. Apply the gated-MLP linear projection to the raw input.
    2. Pass the projected sequence through the convolutional mixing layer.
    3. Feed the result into the State-Space Model (SSM) blocks.
    4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
       to produce contextual representations.
    5. Project the contextualised sequence back
       to the output embedding dimension.

    Batch handling
    --------------
    Prefill and decode tokens are processed by dedicated CUDA
    kernels for both the convolutional (conv1d) and SSM stages.
    In the case of a mixed batch (containing both prefill and
    decode tokens), both sets of kernels are executed independently
    and their outputs are concatenated before the final output projection.
    """

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    assert self.cache_config is not None
    mamba_block_size = self.cache_config.mamba_block_size
    prefix_caching_enabled = self.cache_config.enable_prefix_caching

    if attn_metadata is not None:
        assert isinstance(attn_metadata, dict)
        attn_metadata = attn_metadata[self.prefix]
        assert isinstance(attn_metadata, Mamba1AttentionMetadata)
        query_start_loc_p = attn_metadata.query_start_loc_p
        state_indices_tensor = attn_metadata.state_indices_tensor
        self_kv_cache = self.kv_cache[forward_context.virtual_engine]
        conv_state = self_kv_cache[0].transpose(-1, -2)
        ssm_state = self_kv_cache[1]
        has_initial_states_p = attn_metadata.has_initial_states_p
        num_padded_decodes = attn_metadata.num_padded_decodes

    # 1. Gated MLP's linear projection
    projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
    hidden_states_BC, gate = projected_states.chunk(2, dim=-2)

    conv_weights = self.conv1d.weight.view(
        self.conv1d.weight.size(0), self.conv1d.weight.size(2)
    )

    if attn_metadata is None:
        # V1 profile run
        hidden_states_BC = hidden_states_BC.contiguous()
        return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]

    num_prefill_tokens = attn_metadata.num_prefill_tokens  # token count
    num_decode_tokens = attn_metadata.num_decode_tokens
    num_prefills = attn_metadata.num_prefills  # request count
    num_decodes = attn_metadata.num_decode_tokens  # token count (=request)
    has_prefill = num_prefill_tokens > 0
    has_decode = num_decode_tokens > 0
    num_actual_tokens = num_prefill_tokens + num_decode_tokens

    prefill_decode_split = split_batch_to_prefill_and_decode(
        hidden_states_BC,
        gate,
        state_indices_tensor,
        num_prefill_tokens,
        num_prefills,
        num_padded_decodes,
    )
    hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
    hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
    gate_p = prefill_decode_split.gate_p
    gate_d = prefill_decode_split.gate_d
    state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
    state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d

    if prefix_caching_enabled:
        block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
            torch.split(
                attn_metadata.block_idx_last_computed_token,
                [num_decodes, num_prefills],
                dim=0,
            )
        )
        block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
            torch.split(
                attn_metadata.block_idx_last_scheduled_token,
                [num_decodes, num_prefills],
                dim=0,
            )
        )

        block_idx_first_scheduled_token_p = (
            attn_metadata.block_idx_first_scheduled_token_p
        )
        num_computed_tokens_p = attn_metadata.num_computed_tokens_p
    else:
        block_idx_last_computed_token_d = None
        block_idx_last_computed_token_p = None
        block_idx_last_scheduled_token_d = None
        block_idx_last_scheduled_token_p = None
        block_idx_first_scheduled_token_p = None
        num_computed_tokens_p = None

    ssm_outputs = []

    if has_prefill:
        # 2. Convolution sequence transformation
        conv_out_p = causal_conv1d_fn(
            hidden_states_BC_p,
            conv_weights,
            self.conv1d.bias,
            activation=self.activation,
            conv_states=conv_state,
            has_initial_state=has_initial_states_p,
            cache_indices=state_indices_tensor_p,
            query_start_loc=query_start_loc_p,
            block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
            block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
            initial_state_idx=block_idx_last_computed_token_p,
            num_computed_tokens=num_computed_tokens_p,
            block_size_to_align=mamba_block_size,
        )
        # 3. State Space Model sequence transformations.
        discrete_time_step_p, B_p, C_p = self._ssm_transform(
            conv_out_p.transpose(-2, -1)
        )
        time_proj_bias = self._time_proj_bias()

        # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
        scan_out_p = selective_scan_fn(
            conv_out_p,
            ssm_state,
            discrete_time_step_p,
            self.A,
            B_p.transpose(-2, -1),
            C_p.transpose(-2, -1),
            self.D.float(),
            gate_p,
            time_proj_bias,
            delta_softplus=True,
            cache_indices=state_indices_tensor_p,
            has_initial_state=has_initial_states_p,
            query_start_loc=query_start_loc_p,
            block_size=mamba_block_size,
            block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
            block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
            initial_state_idx=block_idx_last_computed_token_p,
        )
        ssm_outputs.append(scan_out_p)

    if has_decode:
        if prefix_caching_enabled:
            state_indices_tensor_d_input = state_indices_tensor_d.gather(
                1, block_idx_last_computed_token_d.unsqueeze(1)
            ).squeeze(1)
            state_indices_tensor_d_output = state_indices_tensor_d.gather(
                1, block_idx_last_scheduled_token_d.unsqueeze(1)
            ).squeeze(1)
        else:
            state_indices_tensor_d_input = state_indices_tensor_d
            state_indices_tensor_d_output = state_indices_tensor_d
        # 2. Convolution sequence transformation
        conv_out_d = causal_conv1d_update(
            hidden_states_BC_d.transpose(0, 1),
            conv_state,
            conv_weights,
            self.conv1d.bias,
            self.activation,
            conv_state_indices=state_indices_tensor_d,
            block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
            initial_state_idx=block_idx_last_computed_token_d,
        ).transpose(0, 1)

        # 3. State Space Model sequence transformation.
        discrete_time_step_d, B_d, C_d = self._ssm_transform(
            conv_out_d.transpose(-2, -1)
        )
        time_proj_bias = self._time_proj_bias()

        # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
        scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
        selective_state_update(
            ssm_state,
            conv_out_d.transpose(0, 1),
            discrete_time_step_d.transpose(0, 1),
            self.A,
            B_d,
            C_d,
            self.D,
            gate_d.transpose(0, 1),
            time_proj_bias,
            dt_softplus=True,
            state_batch_indices=state_indices_tensor_d_input,
            dst_state_batch_indices=state_indices_tensor_d_output,
            out=scan_outputs_d,
        )
        scan_outputs_d = scan_outputs_d.transpose(0, 1)

        ssm_outputs.insert(0, scan_outputs_d)

    scan_outputs_combined = (
        ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
    )

    # 5. Final output projection
    if self.is_lora_enabled:  # Lora kernel requires contiguous tensor.
        scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
        out = self.out_proj(scan_outputs_combined)[0]
    else:
        out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]

    output[:num_actual_tokens] = out

forward_native

forward_native(hidden_states: Tensor, output: Tensor)
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
    pass

get_attn_backend

get_attn_backend() -> type[AttentionBackend]
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def get_attn_backend(self) -> type["AttentionBackend"]:
    from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend

    return Mamba1AttentionBackend

get_state_dtype

get_state_dtype() -> tuple[dtype]
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def get_state_dtype(self) -> tuple[torch.dtype]:
    assert self.model_config is not None
    assert self.cache_config is not None
    return MambaStateDtypeCalculator.mamba1_state_dtype(
        self.model_config.dtype,
        self.cache_config.mamba_cache_dtype,
        self.cache_config.mamba_ssm_cache_dtype,
    )

get_state_shape

get_state_shape() -> tuple[
    tuple[int, ...], tuple[int, ...]
]
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
    return MambaStateShapeCalculator.mamba1_state_shape(
        tp_world_size=get_tensor_model_parallel_world_size(),
        intermediate_size=self.intermediate_size,
        state_size=self.ssm_state_size,
        conv_kernel=self.conv_kernel_size,
    )

PrefillDecodeSplit

Bases: NamedTuple

Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
class PrefillDecodeSplit(NamedTuple):
    hidden_states_BC_p: torch.Tensor
    hidden_states_BC_d: torch.Tensor
    gate_p: torch.Tensor
    gate_d: torch.Tensor
    state_indices_tensor_p: torch.Tensor
    state_indices_tensor_d: torch.Tensor

gate_d instance-attribute

gate_d: Tensor

gate_p instance-attribute

gate_p: Tensor

hidden_states_BC_d instance-attribute

hidden_states_BC_d: Tensor

hidden_states_BC_p instance-attribute

hidden_states_BC_p: Tensor

state_indices_tensor_d instance-attribute

state_indices_tensor_d: Tensor

state_indices_tensor_p instance-attribute

state_indices_tensor_p: Tensor

mamba_mixer

mamba_mixer(
    hidden_states: Tensor, output: Tensor, layer_name: str
) -> None
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def mamba_mixer(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]
    self.forward_cuda(hidden_states=hidden_states, output=output)

mamba_mixer_fake

mamba_mixer_fake(
    hidden_states: Tensor, output: Tensor, layer_name: str
) -> None
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def mamba_mixer_fake(
    hidden_states: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    return

split_batch_to_prefill_and_decode

split_batch_to_prefill_and_decode(
    hidden_states_BC: Tensor,
    gate: Tensor,
    state_indices_tensor: Tensor,
    num_prefill_tokens: int,
    num_prefills: int,
    num_padded_decodes: int,
) -> PrefillDecodeSplit
Source code in vllm/model_executor/layers/mamba/mamba_mixer.py
def split_batch_to_prefill_and_decode(
    hidden_states_BC: torch.Tensor,
    gate: torch.Tensor,
    state_indices_tensor: torch.Tensor,
    num_prefill_tokens: int,
    num_prefills: int,
    num_padded_decodes: int,
) -> PrefillDecodeSplit:
    num_actual_tokens = num_prefill_tokens + num_padded_decodes

    # In v1, decode tokens come first, then prefill tokens.
    hidden_states_BC_d, hidden_states_BC_p = torch.split(
        hidden_states_BC[..., :num_actual_tokens],
        [num_padded_decodes, num_prefill_tokens],
        dim=-1,
    )
    gate_d, gate_p = torch.split(
        gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
    )

    # num_padded_decodes accounts for CUDA graph padding when applicable
    state_indices_tensor_d, state_indices_tensor_p = torch.split(
        state_indices_tensor[: num_padded_decodes + num_prefills],
        [num_padded_decodes, num_prefills],
        dim=0,
    )

    return PrefillDecodeSplit(
        hidden_states_BC_p=hidden_states_BC_p,
        hidden_states_BC_d=hidden_states_BC_d,
        gate_p=gate_p,
        gate_d=gate_d,
        state_indices_tensor_p=state_indices_tensor_p,
        state_indices_tensor_d=state_indices_tensor_d,
    )