Skip to content

vllm.v1.spec_decode.suffix_decoding

SuffixDecodingProposer

Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975). This class imports and uses the official implementation from Arctic Inference (https://github.com/snowflakedb/ArcticInference).

Source code in vllm/v1/spec_decode/suffix_decoding.py
class SuffixDecodingProposer:
    """
    Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
    This class imports and uses the official implementation from Arctic Inference
    (https://github.com/snowflakedb/ArcticInference).
    """

    def __init__(self, vllm_config: VllmConfig):
        config = vllm_config.speculative_config
        self.num_speculative_tokens = config.num_speculative_tokens
        self.max_tree_depth = config.suffix_decoding_max_tree_depth
        self.max_spec_factor = config.suffix_decoding_max_spec_factor
        self.min_token_prob = config.suffix_decoding_min_token_prob
        self.max_model_len = vllm_config.model_config.max_model_len

        # Lazy import to avoid error when Suffix Decoding is not used.
        from arctic_inference.suffix_decoding import SuffixDecodingCache

        # Initialize and empty cache. This object will take care of caching request
        # outputs, evicting old requests, and manages the per-prompt suffix trees.
        self.suffix_cache = SuffixDecodingCache(
            max_tree_depth=config.suffix_decoding_max_tree_depth,
            max_cached_requests=config.suffix_decoding_max_cached_requests,
        )

    def propose(
        self,
        input_batch: InputBatch,
        sampled_token_ids: list[list[int]],
    ) -> list[list[int]]:
        """
        Propose speculative tokens for each request in the input batch. Suffix Decoding
        will speculate a dynamic number of tokens for each request every decoding step,
        so each entry in the returned list may have different lengths.
        """
        draft_token_ids: list[list[int]] = []
        for i, sampled_ids in enumerate(sampled_token_ids):
            if not sampled_ids:
                # Skip speculative decoding for partial prefills.
                draft_token_ids.append([])
                continue

            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
            req_id = input_batch.req_ids[i]
            if req_id in input_batch.spec_decode_unsupported_reqs:
                draft_token_ids.append([])
                continue

            num_tokens = input_batch.num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

            index = input_batch.req_id_to_index[req_id]
            if req_id not in self.suffix_cache.active_requests:
                if req_id in self.suffix_cache.cached_requests:
                    # Reset the suffix cache for this request.
                    self.suffix_cache.evict_cached_response(req_id)
                num_prompt_tokens = input_batch.num_prompt_tokens[index]
                prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
                # Start a new request, this will build the suffix tree for that prompt.
                self.suffix_cache.start_request(req_id, prompt_token_ids)

            # Append the newly sampled ids to the suffix cache for this request.
            self.suffix_cache.add_active_response(req_id, sampled_ids)

            # Suffix decoding only uses the most recent tokens up to max_tree_depth, so
            # we extract the pattern from the end of the input.
            start = max(0, num_tokens - self.max_tree_depth)
            pattern = input_batch.token_ids_cpu[i, start:num_tokens]
            draft = self.suffix_cache.speculate(
                req_id,
                pattern,
                max_spec_tokens=min(
                    self.num_speculative_tokens, self.max_model_len - num_tokens - 1
                ),
                max_spec_factor=self.max_spec_factor,
                min_token_prob=self.min_token_prob,
            )

            draft_token_ids.append(draft.token_ids)

        # Stop requests that were not seen in the input batch.
        for req_id in (
            self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
        ):
            self.suffix_cache.stop_request(req_id)

        return draft_token_ids

    def load_model(self, *args, **kwargs):
        # No model to load.
        pass

max_model_len instance-attribute

max_model_len = max_model_len

max_spec_factor instance-attribute

max_spec_factor = suffix_decoding_max_spec_factor

max_tree_depth instance-attribute

max_tree_depth = suffix_decoding_max_tree_depth

min_token_prob instance-attribute

min_token_prob = suffix_decoding_min_token_prob

num_speculative_tokens instance-attribute

num_speculative_tokens = num_speculative_tokens

suffix_cache instance-attribute

suffix_cache = SuffixDecodingCache(
    max_tree_depth=suffix_decoding_max_tree_depth,
    max_cached_requests=suffix_decoding_max_cached_requests,
)

__init__

__init__(vllm_config: VllmConfig)
Source code in vllm/v1/spec_decode/suffix_decoding.py
def __init__(self, vllm_config: VllmConfig):
    config = vllm_config.speculative_config
    self.num_speculative_tokens = config.num_speculative_tokens
    self.max_tree_depth = config.suffix_decoding_max_tree_depth
    self.max_spec_factor = config.suffix_decoding_max_spec_factor
    self.min_token_prob = config.suffix_decoding_min_token_prob
    self.max_model_len = vllm_config.model_config.max_model_len

    # Lazy import to avoid error when Suffix Decoding is not used.
    from arctic_inference.suffix_decoding import SuffixDecodingCache

    # Initialize and empty cache. This object will take care of caching request
    # outputs, evicting old requests, and manages the per-prompt suffix trees.
    self.suffix_cache = SuffixDecodingCache(
        max_tree_depth=config.suffix_decoding_max_tree_depth,
        max_cached_requests=config.suffix_decoding_max_cached_requests,
    )

load_model

load_model(*args, **kwargs)
Source code in vllm/v1/spec_decode/suffix_decoding.py
def load_model(self, *args, **kwargs):
    # No model to load.
    pass

propose

propose(
    input_batch: InputBatch,
    sampled_token_ids: list[list[int]],
) -> list[list[int]]

Propose speculative tokens for each request in the input batch. Suffix Decoding will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths.

Source code in vllm/v1/spec_decode/suffix_decoding.py
def propose(
    self,
    input_batch: InputBatch,
    sampled_token_ids: list[list[int]],
) -> list[list[int]]:
    """
    Propose speculative tokens for each request in the input batch. Suffix Decoding
    will speculate a dynamic number of tokens for each request every decoding step,
    so each entry in the returned list may have different lengths.
    """
    draft_token_ids: list[list[int]] = []
    for i, sampled_ids in enumerate(sampled_token_ids):
        if not sampled_ids:
            # Skip speculative decoding for partial prefills.
            draft_token_ids.append([])
            continue

        # Skip requests that require sampling parameters that are not
        # supported with speculative decoding.
        req_id = input_batch.req_ids[i]
        if req_id in input_batch.spec_decode_unsupported_reqs:
            draft_token_ids.append([])
            continue

        num_tokens = input_batch.num_tokens_no_spec[i]
        if num_tokens >= self.max_model_len:
            # Skip requests that have already reached the max model length.
            draft_token_ids.append([])
            continue

        index = input_batch.req_id_to_index[req_id]
        if req_id not in self.suffix_cache.active_requests:
            if req_id in self.suffix_cache.cached_requests:
                # Reset the suffix cache for this request.
                self.suffix_cache.evict_cached_response(req_id)
            num_prompt_tokens = input_batch.num_prompt_tokens[index]
            prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
            # Start a new request, this will build the suffix tree for that prompt.
            self.suffix_cache.start_request(req_id, prompt_token_ids)

        # Append the newly sampled ids to the suffix cache for this request.
        self.suffix_cache.add_active_response(req_id, sampled_ids)

        # Suffix decoding only uses the most recent tokens up to max_tree_depth, so
        # we extract the pattern from the end of the input.
        start = max(0, num_tokens - self.max_tree_depth)
        pattern = input_batch.token_ids_cpu[i, start:num_tokens]
        draft = self.suffix_cache.speculate(
            req_id,
            pattern,
            max_spec_tokens=min(
                self.num_speculative_tokens, self.max_model_len - num_tokens - 1
            ),
            max_spec_factor=self.max_spec_factor,
            min_token_prob=self.min_token_prob,
        )

        draft_token_ids.append(draft.token_ids)

    # Stop requests that were not seen in the input batch.
    for req_id in (
        self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
    ):
        self.suffix_cache.stop_request(req_id)

    return draft_token_ids