113 lines
4.7 KiB
Python
113 lines
4.7 KiB
Python
import copy
|
|
import importlib.metadata
|
|
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from packaging import version
|
|
|
|
# from .configuration_utils import PretrainedConfig
|
|
# from .utils import (
|
|
# is_hqq_available,
|
|
# is_optimum_quanto_available,
|
|
# is_torchdynamo_compiling,
|
|
# logging,
|
|
# )
|
|
# from .utils.deprecation import deprecate_kwarg
|
|
|
|
|
|
# if is_hqq_available():
|
|
# from hqq.core.quantize import Quantizer as HQQQuantizer
|
|
|
|
# logger = logging.get_logger(__name__)
|
|
|
|
|
|
class Cache(torch.nn.Module):
|
|
"""
|
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def update(
|
|
self,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
|
|
Parameters:
|
|
key_states (`torch.Tensor`):
|
|
The new key states to cache.
|
|
value_states (`torch.Tensor`):
|
|
The new value states to cache.
|
|
layer_idx (`int`):
|
|
The index of the layer to cache the states for.
|
|
cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
|
|
cache to be created.
|
|
|
|
Return:
|
|
A tuple containing the updated key and value states.
|
|
"""
|
|
raise NotImplementedError("Make sure to implement `update` in a subclass.")
|
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
# TODO: deprecate this function in favor of `cache_position`
|
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
|
|
|
|
# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
|
|
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
|
|
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
|
|
# we change naming to be more explicit
|
|
def get_max_length(self) -> Optional[int]:
|
|
# logger.warning_once(
|
|
# "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
|
|
# "Calling `get_max_cache()` will raise error from v4.48"
|
|
# )
|
|
return self.get_max_cache_shape()
|
|
|
|
def get_max_cache_shape(self) -> Optional[int]:
|
|
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
|
|
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
|
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
|
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
|
# Cache without size limit -> all cache is usable
|
|
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
|
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
|
max_length = self.get_max_cache_shape()
|
|
previous_seq_length = self.get_seq_length(layer_idx)
|
|
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
|
return max_length - new_seq_length
|
|
return previous_seq_length
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
|
for layer_idx in range(len(self.key_cache)):
|
|
if self.key_cache[layer_idx] != []:
|
|
device = self.key_cache[layer_idx].device
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
if self.value_cache[layer_idx] != []:
|
|
device = self.value_cache[layer_idx].device
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
|
|
@property
|
|
def seen_tokens(self):
|
|
# logger.warning_once(
|
|
# "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
|
|
# "model input instead."
|
|
# )
|
|
if hasattr(self, "_seen_tokens"):
|
|
return self._seen_tokens
|
|
else:
|
|
return None
|
|
|