mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-10-28 04:36:20 +08:00
119 lines
2.6 KiB
Python
119 lines
2.6 KiB
Python
"""
|
|
---
|
|
title: Cache for Intermediate Activations
|
|
summary: >
|
|
Cache for intermediate activations for faster inference.
|
|
---
|
|
|
|
# Cache for Intermediate Activations
|
|
|
|
During inference the model outputs token by token.
|
|
We use this simple cache to store key's and value's attention layers,
|
|
so that we don't have to recompute them for previous tokens.
|
|
"""
|
|
|
|
from typing import Any
|
|
|
|
|
|
class Cache:
|
|
"""
|
|
## Cache
|
|
|
|
This maintains a key-value cache and queues push values and pop them in the same order.
|
|
The queues are useful since we have multiple attention layers.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._cache = {}
|
|
|
|
def clear_all(self):
|
|
"""
|
|
### Clear cache
|
|
"""
|
|
self._cache = {}
|
|
|
|
def push(self, name: str, value: Any):
|
|
"""
|
|
### Push a value to a queue
|
|
|
|
:param name: is the name of the queue
|
|
:param value: is the value to be pushed
|
|
"""
|
|
|
|
# Create an empty queue if it's not present
|
|
if name not in self._cache:
|
|
self._cache[name] = []
|
|
|
|
# Push to the queue
|
|
self._cache[name].append(value)
|
|
|
|
def q_size(self, name):
|
|
"""
|
|
### Return the size of the queue
|
|
|
|
:param name: is the name of the queue
|
|
:return: size of the queue if exists else None
|
|
"""
|
|
|
|
if name not in self._cache:
|
|
return None
|
|
|
|
if type(self._cache[name]) != list:
|
|
return None
|
|
|
|
return len(self._cache[name])
|
|
|
|
def pop(self, name: str):
|
|
"""
|
|
### Pop from a queue
|
|
|
|
:param name: is the name of the queue
|
|
:return: the value
|
|
"""
|
|
return self._cache[name].pop(0)
|
|
|
|
def set(self, key: str, value: Any):
|
|
"""
|
|
### Cache a value
|
|
|
|
:param key: is the name of the value to be cached
|
|
:param value: is the value
|
|
"""
|
|
self._cache[key] = value
|
|
|
|
def get(self, key: str, default: Any = None):
|
|
"""
|
|
### Retrieve a value from cache
|
|
|
|
:param key: is the name used when caching
|
|
:param default: is the default value if the cache is empty
|
|
:return: the cached value
|
|
"""
|
|
return self._cache.get(key, default)
|
|
|
|
def clear(self, key: str):
|
|
"""
|
|
### Clear a cache value
|
|
|
|
:param key: is the name used when caching
|
|
"""
|
|
del self._cache[key]
|
|
|
|
|
|
# Singleton for cache
|
|
_INSTANCE = None
|
|
|
|
|
|
def get_cache() -> Cache:
|
|
"""
|
|
### Get the cache instance
|
|
|
|
:return: the cache instance
|
|
"""
|
|
global _INSTANCE
|
|
|
|
if _INSTANCE is None:
|
|
_INSTANCE = Cache()
|
|
|
|
return _INSTANCE
|